├── .gitignore ├── README.md ├── configs └── example.yaml ├── main.py ├── model ├── __init__.py ├── model.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── modules.py │ ├── resnet.py │ ├── unet.py │ ├── unet_blocks.py │ └── utils.py ├── pipeline.py └── utils.py ├── requirements.txt └── vae ├── __init__.py ├── modules.py └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .idea 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reuse and Diffuse: Iterative Denoising for Text-to-Video Generation 2 | 3 | [Website](https://anonymous0x233.github.io/ReuseAndDiffuse) • [Paper](https://arxiv.org/abs/2309.03549) • [Code](https://github.com/anonymous0x233/ReuseAndDiffuse) 4 | 5 | ## Model preparation 6 | 7 | 1. **VidRD LDM model**: [GoogleDrive](https://drive.google.com/file/d/1rdT9cnMjjoggFBsu3LKFFJBl3b_gXa-N/view?usp=drive_link) 8 | 2. **VidRD Fine-tuned VAE**: [GoogleDrive](https://drive.google.com/file/d/1HfhpI4zy4kBmRSy0G600UDnDgJh6bAQp/view?usp=drive_link) 9 | 3. **StableDiffusion 2.1**: [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) 10 | 11 | Below is an example structure of these model files. 12 | 13 | ``` 14 | assets/ 15 | ├── ModelT2V.pth 16 | ├── vae_finetuned/ 17 | │ ├── diffusion_pytorch_model.bin 18 | │ └── config.json 19 | └── stable-diffusion-2-1-base/ 20 | ├── scheduler/... 21 | ├── text_encoder/... 22 | ├── tokenizer/... 23 | ├── unet/... 24 | ├── vae/... 25 | ├── ... 26 | └── README.md 27 | ``` 28 | 29 | ## Environment setup 30 | 31 | Python version needs to be >=3.10. 32 | 33 | ```bash 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Model inference 38 | 39 | Configurations for model inferences are put in `configs/examples.yaml` including text prompts for video generation. 40 | 41 | ```bash 42 | python main.py --config-name="example" \ 43 | ++model.ckpt_path="assets/ModelT2V.pth" \ 44 | ++model.temporal_vae_path="assets/vae_finetuned/" \ 45 | ++model.pretrained_model_path="assets/stable-diffusion-2-1-base/" 46 | ``` 47 | 48 | ## BibTex 49 | 50 | ``` 51 | @article{reuse2023, 52 | title = {Reuse and Diffuse: Iterative Denoising for Text-to-Video Generation}, 53 | journal = {arXiv preprint arXiv:2309.03549}, 54 | year = {2023} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /configs/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | job: 5 | chdir: true 6 | name: vidrd_${now:%Y%m%d}_${now:%H%M%S} 7 | 8 | model: 9 | _target_: model.SDVideoModel 10 | pretrained_model_path: "assets/stable-diffusion-2-1-base/" 11 | ckpt_path: "assets/ModelT2V.pth" 12 | temporal_vae_path: "assets/vae_finetuned/" 13 | guidance_scale: 10.0 14 | num_inference_steps: 50 15 | resolution: 256 16 | add_temp_embed: true 17 | add_temp_conv: true 18 | 19 | evaluator: 20 | _target_: model.SDVideoModelEvaluator 21 | seed: 42 22 | num_generated_clips: 1 23 | extend_overlap_frames: 4 24 | extend_noise_recycle: 4 25 | extend_denoise_guidance: 0.4 26 | batch_size: 4 27 | prompts: 28 | - "wood on fire" 29 | - "a boat is sailing in a lake" 30 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | from hydra.utils import instantiate 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | @hydra.main(config_path="configs", config_name="example", version_base=None) 10 | def main(config): 11 | model = instantiate(config.model) 12 | model.setup(stage="test") 13 | evaluator = instantiate(config.evaluator) 14 | evaluator(model) 15 | logger.info("finished.") 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SDVideoModel, SDVideoModelEvaluator 2 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | from datetime import datetime 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import pandas as pd 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.utils.checkpoint 12 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler 13 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor 14 | 15 | from vae import TemporalAutoencoderKL 16 | from .modules.unet import UNet3DConditionModel 17 | from .pipeline import SDVideoPipeline 18 | from .utils import save_videos_grid, compute_clip_score 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class SDVideoModel(pl.LightningModule): 24 | def __init__(self, pretrained_model_path, **kwargs): 25 | super().__init__() 26 | self.save_hyperparameters(ignore=["pretrained_model_path"], logger=False) 27 | # main training module 28 | self.unet: Union[str, UNet3DConditionModel] = Path( 29 | pretrained_model_path, "unet" 30 | ).as_posix() 31 | # components for training 32 | self.noise_scheduler_dir = Path(pretrained_model_path, "scheduler").as_posix() 33 | self.vae = Path(pretrained_model_path, "vae").as_posix() 34 | self.text_encoder = Path(pretrained_model_path, "text_encoder").as_posix() 35 | self.tokenizer: Union[str, CLIPTokenizer] = Path( 36 | pretrained_model_path, "tokenizer" 37 | ).as_posix() 38 | # clip model for metric 39 | self.clip = Path(pretrained_model_path, "clip").as_posix() 40 | self.clip_processor = Path(pretrained_model_path, "clip").as_posix() 41 | # define pipeline for inference 42 | self.val_pipeline = None 43 | # video frame resolution 44 | self.resolution = kwargs.get("resolution", 512) 45 | # use temporal_vae 46 | self.temporal_vae_path = kwargs.get("temporal_vae_path", None) 47 | 48 | def setup(self, stage: str) -> None: 49 | # build modules 50 | self.noise_scheduler = DDPMScheduler.from_pretrained(self.noise_scheduler_dir) 51 | self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer) 52 | 53 | if self.temporal_vae_path: 54 | self.vae = TemporalAutoencoderKL.from_pretrained(self.temporal_vae_path) 55 | else: 56 | self.vae = AutoencoderKL.from_pretrained(self.vae) 57 | self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder) 58 | self.unet = UNet3DConditionModel.from_pretrained_2d( 59 | self.unet, 60 | sample_size=self.resolution 61 | // (2 ** (len(self.vae.config.block_out_channels) - 1)), 62 | add_temp_transformer=self.hparams.get("add_temp_transformer", False), 63 | add_temp_attn_only_on_upblocks=self.hparams.get( 64 | "add_temp_attn_only_on_upblocks", False 65 | ), 66 | prepend_first_frame=self.hparams.get("prepend_first_frame", False), 67 | add_temp_embed=self.hparams.get("add_temp_embed", False), 68 | add_temp_conv=self.hparams.get("add_temp_conv", False), 69 | ) 70 | 71 | # load previously trained components for resumed training 72 | ckpt_path = self.hparams.get("ckpt_path", None) 73 | if ckpt_path is not None: 74 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 75 | mod_list = ( 76 | ["unet", "text_encoder"] 77 | if self.temporal_vae_path 78 | else ["unet", "text_encoder", "vae"] 79 | ) 80 | for mod in mod_list: 81 | if any(filter(lambda x: x.startswith(mod), state_dict.keys())): 82 | mod_instance = getattr(self, mod) 83 | mod_instance.load_state_dict( 84 | { 85 | k[len(mod) + 1 :]: v 86 | for k, v in state_dict.items() 87 | if k.startswith(mod) 88 | } 89 | ) 90 | 91 | # null text for classifier-free guidance 92 | self.null_text_token_ids = self.tokenizer( # noqa 93 | "", 94 | max_length=self.tokenizer.model_max_length, 95 | padding="max_length", 96 | truncation=True, 97 | return_tensors="pt", 98 | ).input_ids[0] 99 | 100 | # load clip modules for evaluation 101 | self.clip = CLIPModel.from_pretrained(self.clip) 102 | self.clip_processor = CLIPProcessor.from_pretrained(self.clip_processor) 103 | # prepare modules 104 | for component in [self.vae, self.text_encoder, self.clip]: 105 | if not isinstance(component, CLIPTextModel) or self.hparams.get( 106 | "freeze_text_encoder", False 107 | ): 108 | component.requires_grad_(False).eval() 109 | if stage != "test" and self.trainer.precision.startswith("16"): 110 | component.to(dtype=torch.float16) 111 | # use gradient checkpointing 112 | if self.hparams.get("enable_gradient_checkpointing", True): 113 | if not self.hparams.get("freeze_text_encoder", False): 114 | self.text_encoder.gradient_checkpointing_enable() 115 | self.unet.enable_gradient_checkpointing() 116 | 117 | # construct pipeline for inference 118 | self.val_pipeline = SDVideoPipeline( 119 | vae=self.vae, 120 | text_encoder=self.text_encoder, 121 | tokenizer=self.tokenizer, 122 | unet=self.unet, 123 | scheduler=DDIMScheduler.from_pretrained(self.noise_scheduler_dir), 124 | ) 125 | 126 | 127 | class SDVideoModelEvaluator: 128 | def __init__(self, **kwargs): 129 | torch.multiprocessing.set_start_method("spawn", force=True) 130 | torch.multiprocessing.set_sharing_strategy("file_system") 131 | 132 | self.seed = kwargs.pop("seed", 42) 133 | self.prompts = kwargs.pop("prompts", None) 134 | if self.prompts is None: 135 | raise ValueError(f"No prompts provided.") 136 | elif isinstance(self.prompts, str) and not Path(self.prompts).exists(): 137 | raise FileNotFoundError(f"Prompt file not found: {self.prompts}") 138 | elif isinstance(self.prompts, str): 139 | if self.prompts.endswith(".txt"): 140 | with open(self.prompts, "r", encoding="utf-8") as f: 141 | self.prompts = [x.strip() for x in f.readlines() if x.strip()] 142 | elif self.prompts.endswith(".json"): 143 | with open(self.prompts, "r", encoding="utf-8") as f: 144 | self.prompts = sorted( 145 | [ 146 | random.choice(x) if isinstance(x, list) else x 147 | for x in json.load(f).values() 148 | ] 149 | ) 150 | self.add_file_logger(logger, kwargs.pop("log_file", None)) 151 | self.output_file = kwargs.pop("output_file", "results.csv") 152 | self.batch_size = kwargs.pop("batch_size", 4) 153 | self.val_params = kwargs 154 | 155 | @staticmethod 156 | def add_file_logger(logger, log_file=None, log_level=logging.INFO): 157 | if log_file is not None: 158 | log_handler = logging.FileHandler(log_file, "w") 159 | log_handler.setFormatter( 160 | logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") 161 | ) 162 | log_handler.setLevel(log_level) 163 | logger.addHandler(log_handler) 164 | 165 | @staticmethod 166 | def infer(rank, model, model_params, q_input, q_output, seed=42): 167 | device = torch.device(f"cuda:{rank}") 168 | model.to(device) 169 | generator = torch.Generator(device=device) 170 | generator.manual_seed(seed + rank) 171 | output_video_dir = Path("output_videos") 172 | output_video_dir.mkdir(parents=True, exist_ok=True) 173 | while True: 174 | inputs = q_input.get() 175 | if inputs is None: # check for sentinel value 176 | print(f"[{datetime.now()}] Process #{rank} ended.") 177 | break 178 | start_idx, prompts = inputs 179 | videos = model.val_pipeline( 180 | prompts, 181 | generator=generator, 182 | negative_prompt=["watermark"] * len(prompts), 183 | **model_params, 184 | ).videos 185 | for idx, prompt in enumerate(prompts): 186 | gif_file = output_video_dir.joinpath(f"{start_idx + idx}_{prompt}.gif") 187 | save_videos_grid(videos[idx : idx + 1, ...], gif_file) 188 | print( 189 | f'[{datetime.now()}] Sample is saved #{start_idx + idx}: "{prompt}"' 190 | ) 191 | clip_scores = compute_clip_score( 192 | model=model.clip, 193 | model_processor=model.clip_processor, 194 | images=videos, 195 | texts=prompts, 196 | rescale=False, 197 | ) 198 | q_output.put((prompts, clip_scores.cpu().tolist())) 199 | return None 200 | 201 | def __call__(self, model): 202 | model.eval() 203 | 204 | if not torch.cuda.is_available(): 205 | raise NotImplementedError(f"No GPU found.") 206 | 207 | self.val_params.setdefault( 208 | "num_inference_steps", model.hparams.get("num_inference_steps", 50) 209 | ) 210 | self.val_params.setdefault( 211 | "guidance_scale", model.hparams.get("guidance_scale", 7.5) 212 | ) 213 | self.val_params.setdefault("noise_alpha", model.hparams.get("noise_alpha", 0.0)) 214 | logger.info(f"val_params: {self.val_params}") 215 | 216 | q_input = torch.multiprocessing.Queue() 217 | q_output = torch.multiprocessing.Queue() 218 | processes = [] 219 | for rank in range(torch.cuda.device_count()): 220 | p = torch.multiprocessing.Process( 221 | target=self.infer, 222 | args=(rank, model, self.val_params, q_input, q_output, self.seed), 223 | ) 224 | p.start() 225 | processes.append(p) 226 | # send model inputs to queue 227 | result_num = 0 228 | for start_idx in range(0, len(self.prompts), self.batch_size): 229 | result_num += 1 230 | q_input.put( 231 | (start_idx, self.prompts[start_idx : start_idx + self.batch_size]) 232 | ) 233 | for _ in processes: 234 | q_input.put(None) # sentinel value to signal subprocesses to exit 235 | # The result queue has to be processed before joining the processes. 236 | results = [q_output.get() for _ in range(result_num)] 237 | # joining the processes 238 | for p in processes: 239 | p.join() # wait for all subprocesses to finish 240 | all_prompts, all_clip_scores = [], [] 241 | for prompts, clip_scores in results: 242 | all_prompts.extend(prompts) 243 | all_clip_scores.extend(clip_scores) 244 | output_df = pd.DataFrame({"prompt": all_prompts, "clip_score": all_clip_scores}) 245 | output_df.to_csv(self.output_file, index=False) 246 | logger.info(f"--- Metrics ---") 247 | logger.info(f"Mean CLIP_SCORE: {sum(all_clip_scores) / len(all_clip_scores)}") 248 | logger.info(f"Test results saved in: {self.output_file}") 249 | -------------------------------------------------------------------------------- /model/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous0x233/ReuseAndDiffuse/3976e9431ca9445bac5ec16f738ffec33e6f188f/model/modules/__init__.py -------------------------------------------------------------------------------- /model/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | from dataclasses import dataclass 3 | from typing import Callable, Optional 4 | 5 | import torch 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models import ModelMixin 8 | from diffusers.models.attention import FeedForward, AdaLayerNorm 9 | from diffusers.models.cross_attention import CrossAttention 10 | from diffusers.utils import BaseOutput 11 | from diffusers.utils.import_utils import is_xformers_available 12 | from einops import rearrange, repeat 13 | from torch import nn 14 | 15 | from .modules import get_sin_pos_embedding 16 | from .utils import zero_module 17 | 18 | if is_xformers_available(): 19 | import xformers 20 | import xformers.ops 21 | else: 22 | xformers = None 23 | 24 | 25 | class BasicTransformerBlock(nn.Module): 26 | r""" 27 | A basic Transformer block. 28 | Parameters: 29 | dim (`int`): The number of channels in the input and output. 30 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 31 | attention_head_dim (`int`): The number of channels in each head. 32 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 33 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 34 | only_cross_attention (`bool`, *optional*): 35 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 36 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 37 | num_embeds_ada_norm (: 38 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 39 | attention_bias (: 40 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dim: int, 46 | num_attention_heads: int, 47 | attention_head_dim: int, 48 | dropout=0.0, 49 | cross_attention_dim: Optional[int] = None, 50 | activation_fn: str = "geglu", 51 | num_embeds_ada_norm: Optional[int] = None, 52 | attention_bias: bool = False, 53 | only_cross_attention: bool = False, 54 | upcast_attention: bool = False, 55 | norm_elementwise_affine: bool = True, 56 | final_dropout: bool = False, 57 | add_temp_attn: bool = False, 58 | prepend_first_frame: bool = False, 59 | add_temp_embed: bool = False, 60 | ): 61 | super().__init__() 62 | self.only_cross_attention = only_cross_attention 63 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 64 | 65 | # temporal embedding 66 | self.add_temp_embed = add_temp_embed 67 | 68 | if add_temp_attn: 69 | if prepend_first_frame: 70 | # SC-Attn 71 | self.attn1 = SparseCausalAttention( 72 | query_dim=dim, 73 | cross_attention_dim=cross_attention_dim 74 | if only_cross_attention 75 | else None, 76 | heads=num_attention_heads, 77 | dim_head=attention_head_dim, 78 | dropout=dropout, 79 | bias=attention_bias, 80 | upcast_attention=upcast_attention, 81 | ) 82 | else: 83 | # Normal CrossAttn 84 | self.attn1 = CrossAttention( 85 | query_dim=dim, 86 | cross_attention_dim=cross_attention_dim 87 | if only_cross_attention 88 | else None, 89 | heads=num_attention_heads, 90 | dim_head=attention_head_dim, 91 | dropout=dropout, 92 | bias=attention_bias, 93 | upcast_attention=upcast_attention, 94 | ) 95 | 96 | # Temp-Attn 97 | self.temp_norm = ( 98 | AdaLayerNorm(dim, num_embeds_ada_norm) 99 | if self.use_ada_layer_norm 100 | else nn.LayerNorm(dim) 101 | ) 102 | self.temp_attn = CrossAttention( 103 | query_dim=dim, 104 | heads=num_attention_heads, 105 | dim_head=attention_head_dim, 106 | dropout=dropout, 107 | bias=attention_bias, 108 | upcast_attention=upcast_attention, 109 | ) 110 | zero_module(self.temp_attn.to_out) 111 | else: 112 | # Normal Attention 113 | self.attn1 = CrossAttention( 114 | query_dim=dim, 115 | cross_attention_dim=cross_attention_dim 116 | if only_cross_attention 117 | else None, 118 | heads=num_attention_heads, 119 | dim_head=attention_head_dim, 120 | dropout=dropout, 121 | bias=attention_bias, 122 | upcast_attention=upcast_attention, 123 | ) 124 | self.temp_attn = None 125 | 126 | self.norm1 = ( 127 | AdaLayerNorm(dim, num_embeds_ada_norm) 128 | if self.use_ada_layer_norm 129 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 130 | ) 131 | 132 | # Cross-Attn 133 | if cross_attention_dim is not None: 134 | self.attn2 = CrossAttention( 135 | query_dim=dim, 136 | cross_attention_dim=cross_attention_dim, 137 | heads=num_attention_heads, 138 | dim_head=attention_head_dim, 139 | dropout=dropout, 140 | bias=attention_bias, 141 | upcast_attention=upcast_attention, 142 | ) 143 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 144 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 145 | # the second cross attention block. 146 | self.norm2 = ( 147 | AdaLayerNorm(dim, num_embeds_ada_norm) 148 | if self.use_ada_layer_norm 149 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 150 | ) 151 | else: 152 | self.attn2 = None 153 | self.norm2 = None 154 | 155 | # Feed-forward 156 | self.ff = FeedForward( 157 | dim, 158 | dropout=dropout, 159 | activation_fn=activation_fn, 160 | final_dropout=final_dropout, 161 | ) 162 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 163 | 164 | def set_use_memory_efficient_attention_xformers( 165 | self, 166 | use_memory_efficient_attention_xformers: bool, 167 | attention_op: Optional[Callable] = None, 168 | ): 169 | if not is_xformers_available(): 170 | print("Here is how to install it") 171 | raise ModuleNotFoundError( 172 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 173 | " xformers", 174 | name="xformers", 175 | ) 176 | elif not torch.cuda.is_available(): 177 | raise ValueError( 178 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 179 | " available for GPU " 180 | ) 181 | else: 182 | try: 183 | # Make sure we can run the memory efficient attention 184 | xformers.ops.memory_efficient_attention( 185 | torch.randn((1, 2, 40), device="cuda"), 186 | torch.randn((1, 2, 40), device="cuda"), 187 | torch.randn((1, 2, 40), device="cuda"), 188 | ) 189 | except Exception as e: 190 | raise e 191 | self.attn1.set_use_memory_efficient_attention_xformers( 192 | use_memory_efficient_attention_xformers, attention_op=attention_op 193 | ) 194 | if self.attn2 is not None: 195 | self.attn2.set_use_memory_efficient_attention_xformers( 196 | use_memory_efficient_attention_xformers, attention_op=attention_op 197 | ) 198 | if self.temp_attn is not None: 199 | self.temp_attn.set_use_memory_efficient_attention_xformers( 200 | use_memory_efficient_attention_xformers, attention_op=attention_op 201 | ) 202 | 203 | def forward( 204 | self, 205 | hidden_states, 206 | encoder_hidden_states=None, 207 | timestep=None, 208 | attention_mask=None, 209 | video_length=None, 210 | ): 211 | # SparseCausal-Attention 212 | norm_hidden_states = ( 213 | self.norm1(hidden_states, timestep) 214 | if self.use_ada_layer_norm 215 | else self.norm1(hidden_states) 216 | ) 217 | 218 | attn1_args = dict( 219 | hidden_states=norm_hidden_states, attention_mask=attention_mask 220 | ) 221 | if self.temp_attn is not None and isinstance(self.attn1, SparseCausalAttention): 222 | attn1_args.update({"video_length": video_length}) 223 | # Self-/Sparse-Attention 224 | if self.only_cross_attention: 225 | hidden_states = ( 226 | self.attn1(**attn1_args, encoder_hidden_states=encoder_hidden_states) 227 | + hidden_states 228 | ) 229 | else: 230 | hidden_states = self.attn1(**attn1_args) + hidden_states 231 | 232 | if self.attn2 is not None: 233 | # Cross-Attention 234 | norm_hidden_states = ( 235 | self.norm2(hidden_states, timestep) 236 | if self.use_ada_layer_norm 237 | else self.norm2(hidden_states) 238 | ) 239 | hidden_states = ( 240 | self.attn2( 241 | norm_hidden_states, 242 | encoder_hidden_states=encoder_hidden_states, 243 | attention_mask=attention_mask, 244 | ) 245 | + hidden_states 246 | ) 247 | 248 | if self.temp_attn is not None: 249 | identity = hidden_states 250 | d = hidden_states.shape[1] 251 | # add temporal embedding 252 | if self.add_temp_embed: 253 | temp_emb = get_sin_pos_embedding( 254 | hidden_states.shape[-1], video_length 255 | ).to(hidden_states) 256 | hidden_states = rearrange( 257 | hidden_states, "(b f) d c -> b d f c", f=video_length 258 | ) 259 | hidden_states += temp_emb 260 | hidden_states = rearrange(hidden_states, "b d f c -> (b f) d c") 261 | # normalization 262 | hidden_states = rearrange( 263 | hidden_states, "(b f) d c -> (b d) f c", f=video_length 264 | ) 265 | norm_hidden_states = ( 266 | self.temp_norm(hidden_states, timestep) 267 | if self.use_ada_layer_norm 268 | else self.temp_norm(hidden_states) 269 | ) 270 | # apply temporal attention 271 | hidden_states = self.temp_attn(norm_hidden_states) + hidden_states 272 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 273 | # ignore effects of temporal layers on image inputs 274 | if video_length <= 1: 275 | hidden_states = identity + 0.0 * hidden_states 276 | 277 | # Feed-forward 278 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 279 | 280 | return hidden_states 281 | 282 | 283 | @dataclass 284 | class Transformer3DModelOutput(BaseOutput): 285 | sample: torch.FloatTensor 286 | 287 | 288 | class Transformer3DModel(ModelMixin, ConfigMixin): 289 | @register_to_config 290 | def __init__( 291 | self, 292 | num_attention_heads: int = 16, 293 | attention_head_dim: int = 88, 294 | in_channels: Optional[int] = None, 295 | num_layers: int = 1, 296 | dropout: float = 0.0, 297 | norm_num_groups: int = 32, 298 | cross_attention_dim: Optional[int] = None, 299 | attention_bias: bool = False, 300 | activation_fn: str = "geglu", 301 | num_embeds_ada_norm: Optional[int] = None, 302 | use_linear_projection: bool = False, 303 | only_cross_attention: bool = False, 304 | upcast_attention: bool = False, 305 | add_temp_attn: bool = False, 306 | prepend_first_frame: bool = False, 307 | add_temp_embed: bool = False, 308 | ): 309 | super().__init__() 310 | self.use_linear_projection = use_linear_projection 311 | self.num_attention_heads = num_attention_heads 312 | self.attention_head_dim = attention_head_dim 313 | inner_dim = num_attention_heads * attention_head_dim 314 | 315 | # Define input layers 316 | self.in_channels = in_channels 317 | 318 | self.norm = torch.nn.GroupNorm( 319 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 320 | ) 321 | if use_linear_projection: 322 | self.proj_in = nn.Linear(in_channels, inner_dim) 323 | else: 324 | self.proj_in = nn.Conv2d( 325 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 326 | ) 327 | 328 | # Define transformers blocks 329 | self.transformer_blocks = nn.ModuleList( 330 | [ 331 | BasicTransformerBlock( 332 | inner_dim, 333 | num_attention_heads, 334 | attention_head_dim, 335 | dropout=dropout, 336 | cross_attention_dim=cross_attention_dim, 337 | activation_fn=activation_fn, 338 | num_embeds_ada_norm=num_embeds_ada_norm, 339 | attention_bias=attention_bias, 340 | only_cross_attention=only_cross_attention, 341 | upcast_attention=upcast_attention, 342 | add_temp_attn=add_temp_attn, 343 | prepend_first_frame=prepend_first_frame, 344 | add_temp_embed=add_temp_embed, 345 | ) 346 | for _ in range(num_layers) 347 | ] 348 | ) 349 | 350 | # 4. Define output layers 351 | if use_linear_projection: 352 | self.proj_out = nn.Linear(in_channels, inner_dim) 353 | else: 354 | self.proj_out = nn.Conv2d( 355 | inner_dim, in_channels, kernel_size=1, stride=1, padding=0 356 | ) 357 | 358 | def forward( 359 | self, 360 | hidden_states, 361 | encoder_hidden_states=None, 362 | timestep=None, 363 | return_dict=False, 364 | ): 365 | # Input 366 | assert ( 367 | hidden_states.dim() == 5 368 | ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 369 | video_length = hidden_states.shape[2] 370 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 371 | encoder_hidden_states = repeat( 372 | encoder_hidden_states, "b n c -> (b f) n c", f=video_length 373 | ) 374 | 375 | batch, channel, height, weight = hidden_states.shape 376 | residual = hidden_states 377 | 378 | hidden_states = self.norm(hidden_states) 379 | if not self.use_linear_projection: 380 | hidden_states = self.proj_in(hidden_states) 381 | inner_dim = hidden_states.shape[1] 382 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 383 | batch, height * weight, inner_dim 384 | ) 385 | else: 386 | inner_dim = hidden_states.shape[1] 387 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( 388 | batch, height * weight, inner_dim 389 | ) 390 | hidden_states = self.proj_in(hidden_states) 391 | 392 | # Blocks 393 | for block in self.transformer_blocks: 394 | hidden_states = block( 395 | hidden_states, 396 | encoder_hidden_states=encoder_hidden_states, 397 | timestep=timestep, 398 | video_length=video_length, 399 | ) 400 | 401 | # Output 402 | if not self.use_linear_projection: 403 | hidden_states = ( 404 | hidden_states.reshape(batch, height, weight, inner_dim) 405 | .permute(0, 3, 1, 2) 406 | .contiguous() 407 | ) 408 | hidden_states = self.proj_out(hidden_states) 409 | else: 410 | hidden_states = self.proj_out(hidden_states) 411 | hidden_states = ( 412 | hidden_states.reshape(batch, height, weight, inner_dim) 413 | .permute(0, 3, 1, 2) 414 | .contiguous() 415 | ) 416 | 417 | output = hidden_states + residual 418 | 419 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 420 | if not return_dict: 421 | return (output,) 422 | 423 | return Transformer3DModelOutput(sample=output) 424 | 425 | 426 | @dataclass 427 | class TransformerTemporalModelOutput(BaseOutput): 428 | """ 429 | Args: 430 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) 431 | Hidden states conditioned on `encoder_hidden_states` input. 432 | """ 433 | 434 | sample: torch.FloatTensor 435 | 436 | 437 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 438 | """ 439 | Transformer model for video-like data. 440 | 441 | Parameters: 442 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 443 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 444 | in_channels (`int`, *optional*): 445 | Pass if the input is continuous. The number of channels in the input and output. 446 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 447 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 448 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 449 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 450 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 451 | `ImagePositionalEmbeddings`. 452 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 453 | attention_bias (`bool`, *optional*): 454 | Configure if the TransformerBlocks' attention should contain a bias parameter. 455 | """ 456 | 457 | @register_to_config 458 | def __init__( 459 | self, 460 | num_attention_heads: int = 16, 461 | attention_head_dim: int = 88, 462 | in_channels: Optional[int] = None, 463 | num_layers: int = 1, 464 | dropout: float = 0.0, 465 | norm_num_groups: int = 32, 466 | cross_attention_dim: Optional[int] = None, 467 | attention_bias: bool = False, 468 | activation_fn: str = "geglu", 469 | norm_elementwise_affine: bool = True, 470 | add_temp_embed: bool = False, 471 | ): 472 | super().__init__() 473 | self.num_attention_heads = num_attention_heads 474 | self.attention_head_dim = attention_head_dim 475 | inner_dim = num_attention_heads * attention_head_dim 476 | 477 | self.in_channels = in_channels 478 | 479 | self.norm = torch.nn.GroupNorm( 480 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True 481 | ) 482 | self.proj_in = nn.Linear(in_channels, inner_dim) 483 | 484 | # 3. Define transformers blocks 485 | self.transformer_blocks = nn.ModuleList( 486 | [ 487 | BasicTransformerBlock( 488 | inner_dim, 489 | num_attention_heads, 490 | attention_head_dim, 491 | dropout=dropout, 492 | cross_attention_dim=cross_attention_dim, 493 | activation_fn=activation_fn, 494 | attention_bias=attention_bias, 495 | norm_elementwise_affine=norm_elementwise_affine, 496 | add_temp_embed=add_temp_embed, 497 | ) 498 | for _ in range(num_layers) 499 | ] 500 | ) 501 | 502 | self.proj_out = nn.Linear(inner_dim, in_channels) 503 | self.proj_out = zero_module(self.proj_out) 504 | 505 | def forward( 506 | self, 507 | hidden_states, 508 | encoder_hidden_states=None, 509 | timestep=None, 510 | return_dict: bool = True, 511 | ): 512 | """ 513 | Args: 514 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 515 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 516 | hidden_states 517 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 518 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 519 | self-attention. 520 | timestep ( `torch.long`, *optional*): 521 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 522 | return_dict (`bool`, *optional*, defaults to `True`): 523 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 524 | 525 | Returns: 526 | [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: 527 | [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. 528 | When returning a tuple, the first element is the sample tensor. 529 | """ 530 | # 1. Input 531 | batch_size, channel, num_frames, height, width = hidden_states.shape 532 | 533 | residual = hidden_states 534 | 535 | hidden_states = self.norm(hidden_states) 536 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( 537 | batch_size * height * width, num_frames, channel 538 | ) 539 | hidden_states = self.proj_in(hidden_states) 540 | 541 | # 2. Blocks 542 | for block in self.transformer_blocks: 543 | hidden_states = block( 544 | hidden_states, 545 | encoder_hidden_states=encoder_hidden_states, 546 | timestep=timestep, 547 | video_length=num_frames, 548 | ) 549 | 550 | # 3. Output 551 | hidden_states = self.proj_out(hidden_states) 552 | hidden_states = ( 553 | hidden_states[None, None, :] 554 | .reshape(batch_size, height, width, channel, num_frames) 555 | .permute(0, 3, 4, 1, 2) 556 | .contiguous() 557 | ) 558 | output = hidden_states + residual 559 | 560 | if not return_dict: 561 | return (output,) 562 | 563 | return TransformerTemporalModelOutput(sample=output) 564 | 565 | 566 | class SparseCausalAttention(CrossAttention): 567 | def forward( 568 | self, 569 | hidden_states, 570 | encoder_hidden_states=None, 571 | attention_mask=None, 572 | **cross_attention_kwargs, 573 | ): 574 | batch_size, sequence_length, _ = hidden_states.shape 575 | video_length = cross_attention_kwargs.get("video_length", 8) 576 | attention_mask = self.prepare_attention_mask( 577 | attention_mask, sequence_length, batch_size 578 | ) 579 | query = self.to_q(hidden_states) 580 | dim = query.shape[-1] 581 | 582 | if self.added_kv_proj_dim is not None: 583 | raise NotImplementedError 584 | 585 | if encoder_hidden_states is None: 586 | encoder_hidden_states = hidden_states 587 | elif self.cross_attention_norm: 588 | encoder_hidden_states = self.norm_cross(encoder_hidden_states) 589 | 590 | key = self.to_k(encoder_hidden_states) 591 | value = self.to_v(encoder_hidden_states) 592 | 593 | former_frame_index = torch.arange(video_length) - 1 594 | former_frame_index[0] = 0 595 | 596 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 597 | if video_length > 1: 598 | key = torch.cat( 599 | [key[:, [0] * video_length], key[:, former_frame_index]], dim=2 600 | ) 601 | key = rearrange(key, "b f d c -> (b f) d c") 602 | 603 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 604 | if video_length > 1: 605 | value = torch.cat( 606 | [value[:, [0] * video_length], value[:, former_frame_index]], dim=2 607 | ) 608 | value = rearrange(value, "b f d c -> (b f) d c") 609 | 610 | query = self.head_to_batch_dim(query) 611 | key = self.head_to_batch_dim(key) 612 | value = self.head_to_batch_dim(value) 613 | 614 | # attention, what we cannot get enough of 615 | if hasattr(self.processor, "attention_op"): 616 | hidden_states = xformers.ops.memory_efficient_attention( 617 | query, 618 | key, 619 | value, 620 | attn_bias=attention_mask, 621 | op=self.processor.attention_op, 622 | ) 623 | hidden_states = hidden_states.to(query.dtype) 624 | elif hasattr(self.processor, "slice_size"): 625 | batch_size_attention = query.shape[0] 626 | hidden_states = torch.zeros( 627 | (batch_size_attention, sequence_length, dim // self.heads), 628 | device=query.device, 629 | dtype=query.dtype, 630 | ) 631 | for i in range(hidden_states.shape[0] // self.processor.slice_size): 632 | start_idx = i * self.slice_size 633 | end_idx = (i + 1) * self.slice_size 634 | query_slice = query[start_idx:end_idx] 635 | key_slice = key[start_idx:end_idx] 636 | attn_mask_slice = ( 637 | attention_mask[start_idx:end_idx] 638 | if attention_mask is not None 639 | else None 640 | ) 641 | attn_slice = self.get_attention_scores( 642 | query_slice, key_slice, attn_mask_slice 643 | ) 644 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 645 | hidden_states[start_idx:end_idx] = attn_slice 646 | else: 647 | attention_probs = self.get_attention_scores(query, key, attention_mask) 648 | hidden_states = torch.bmm(attention_probs, value) 649 | hidden_states = self.batch_to_head_dim(hidden_states) 650 | 651 | # linear proj 652 | hidden_states = self.to_out[0](hidden_states) 653 | 654 | # dropout 655 | hidden_states = self.to_out[1](hidden_states) 656 | return hidden_states 657 | -------------------------------------------------------------------------------- /model/modules/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def get_sin_pos_embedding(embed_dim, seq_len): 7 | """ 8 | :param embed_dim: dimension of the model 9 | :param seq_len: length of positions 10 | :return: [length, embed_dim] position matrix 11 | """ 12 | if embed_dim % 2 != 0: 13 | raise ValueError( 14 | "Cannot use sin/cos positional encoding with " 15 | "odd dim (got dim={:d})".format(embed_dim) 16 | ) 17 | pe = torch.zeros(seq_len, embed_dim) 18 | position = torch.arange(0, seq_len).unsqueeze(1) 19 | div_term = torch.exp( 20 | torch.arange(0, embed_dim, 2, dtype=torch.float) 21 | * -(math.log(10000.0) / embed_dim) 22 | ) 23 | pe[:, 0::2] = torch.sin(position.float() * div_term) 24 | pe[:, 1::2] = torch.cos(position.float() * div_term) 25 | 26 | return pe 27 | -------------------------------------------------------------------------------- /model/modules/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | return x 18 | 19 | 20 | class Upsample3D(nn.Module): 21 | def __init__( 22 | self, 23 | channels, 24 | use_conv=False, 25 | use_conv_transpose=False, 26 | out_channels=None, 27 | name="conv", 28 | ): 29 | super().__init__() 30 | self.channels = channels 31 | self.out_channels = out_channels or channels 32 | self.use_conv = use_conv 33 | self.use_conv_transpose = use_conv_transpose 34 | self.name = name 35 | 36 | conv = None 37 | if use_conv_transpose: 38 | raise NotImplementedError 39 | elif use_conv: 40 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 41 | 42 | if name == "conv": 43 | self.conv = conv 44 | else: 45 | self.Conv2d_0 = conv 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate( 66 | hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest" 67 | ) 68 | else: 69 | hidden_states = F.interpolate( 70 | hidden_states, size=output_size, mode="nearest" 71 | ) 72 | 73 | # If the input is bfloat16, we cast back to bfloat16 74 | if dtype == torch.bfloat16: 75 | hidden_states = hidden_states.to(dtype) 76 | 77 | if self.use_conv: 78 | if self.name == "conv": 79 | hidden_states = self.conv(hidden_states) 80 | else: 81 | hidden_states = self.Conv2d_0(hidden_states) 82 | 83 | return hidden_states 84 | 85 | 86 | class Downsample3D(nn.Module): 87 | def __init__( 88 | self, channels, use_conv=False, out_channels=None, padding=1, name="conv" 89 | ): 90 | super().__init__() 91 | self.channels = channels 92 | self.out_channels = out_channels or channels 93 | self.use_conv = use_conv 94 | self.padding = padding 95 | stride = 2 96 | self.name = name 97 | 98 | if use_conv: 99 | conv = InflatedConv3d( 100 | self.channels, self.out_channels, 3, stride=stride, padding=padding 101 | ) 102 | else: 103 | raise NotImplementedError 104 | 105 | if name == "conv": 106 | self.Conv2d_0 = conv 107 | self.conv = conv 108 | elif name == "Conv2d_0": 109 | self.conv = conv 110 | else: 111 | self.conv = conv 112 | 113 | def forward(self, hidden_states): 114 | assert hidden_states.shape[1] == self.channels 115 | if self.use_conv and self.padding == 0: 116 | raise NotImplementedError 117 | 118 | assert hidden_states.shape[1] == self.channels 119 | hidden_states = self.conv(hidden_states) 120 | 121 | return hidden_states 122 | 123 | 124 | class ResnetBlock3D(nn.Module): 125 | def __init__( 126 | self, 127 | *, 128 | in_channels, 129 | out_channels=None, 130 | conv_shortcut=False, 131 | dropout=0.0, 132 | temb_channels=512, 133 | groups=32, 134 | groups_out=None, 135 | pre_norm=True, 136 | eps=1e-6, 137 | non_linearity="swish", 138 | time_embedding_norm="default", 139 | output_scale_factor=1.0, 140 | use_in_shortcut=None, 141 | ): 142 | super().__init__() 143 | self.pre_norm = pre_norm 144 | self.pre_norm = True 145 | self.in_channels = in_channels 146 | out_channels = in_channels if out_channels is None else out_channels 147 | self.out_channels = out_channels 148 | self.use_conv_shortcut = conv_shortcut 149 | self.time_embedding_norm = time_embedding_norm 150 | self.output_scale_factor = output_scale_factor 151 | 152 | if groups_out is None: 153 | groups_out = groups 154 | 155 | self.norm1 = torch.nn.GroupNorm( 156 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True 157 | ) 158 | 159 | self.conv1 = InflatedConv3d( 160 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 161 | ) 162 | 163 | if temb_channels is not None: 164 | if self.time_embedding_norm == "default": 165 | time_emb_proj_out_channels = out_channels 166 | elif self.time_embedding_norm == "scale_shift": 167 | time_emb_proj_out_channels = out_channels * 2 168 | else: 169 | raise ValueError( 170 | f"unknown time_embedding_norm : {self.time_embedding_norm} " 171 | ) 172 | 173 | self.time_emb_proj = torch.nn.Linear( 174 | temb_channels, time_emb_proj_out_channels 175 | ) 176 | else: 177 | self.time_emb_proj = None 178 | 179 | self.norm2 = torch.nn.GroupNorm( 180 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True 181 | ) 182 | self.dropout = torch.nn.Dropout(dropout) 183 | self.conv2 = InflatedConv3d( 184 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 185 | ) 186 | 187 | if non_linearity == "swish": 188 | self.nonlinearity = lambda x: F.silu(x) 189 | elif non_linearity == "mish": 190 | self.nonlinearity = Mish() 191 | elif non_linearity == "silu": 192 | self.nonlinearity = nn.SiLU() 193 | 194 | self.use_in_shortcut = ( 195 | self.in_channels != self.out_channels 196 | if use_in_shortcut is None 197 | else use_in_shortcut 198 | ) 199 | 200 | self.conv_shortcut = None 201 | if self.use_in_shortcut: 202 | self.conv_shortcut = InflatedConv3d( 203 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 204 | ) 205 | 206 | def forward(self, input_tensor, temb): 207 | hidden_states = input_tensor 208 | 209 | hidden_states = self.norm1(hidden_states) 210 | hidden_states = self.nonlinearity(hidden_states) 211 | 212 | hidden_states = self.conv1(hidden_states) 213 | 214 | if temb is not None: 215 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 216 | 217 | if temb is not None and self.time_embedding_norm == "default": 218 | hidden_states = hidden_states + temb 219 | 220 | hidden_states = self.norm2(hidden_states) 221 | 222 | if temb is not None and self.time_embedding_norm == "scale_shift": 223 | scale, shift = torch.chunk(temb, 2, dim=1) 224 | hidden_states = hidden_states * (1 + scale) + shift 225 | 226 | hidden_states = self.nonlinearity(hidden_states) 227 | 228 | hidden_states = self.dropout(hidden_states) 229 | hidden_states = self.conv2(hidden_states) 230 | 231 | if self.conv_shortcut is not None: 232 | input_tensor = self.conv_shortcut(input_tensor) 233 | 234 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 235 | 236 | return output_tensor 237 | 238 | 239 | class Mish(torch.nn.Module): 240 | def forward(self, hidden_states): 241 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 242 | -------------------------------------------------------------------------------- /model/modules/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | import json 4 | import os 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models import ModelMixin 13 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 14 | from diffusers.utils import BaseOutput, logging 15 | 16 | from .attention import TransformerTemporalModel 17 | from .resnet import InflatedConv3d 18 | from .unet_blocks import ( 19 | CrossAttnDownBlock3D, 20 | CrossAttnUpBlock3D, 21 | DownBlock3D, 22 | UNetMidBlock3DCrossAttn, 23 | UpBlock3D, 24 | get_down_block, 25 | get_up_block, 26 | ) 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | class UNet3DConditionOutput(BaseOutput): 33 | sample: torch.FloatTensor 34 | 35 | 36 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 37 | _supports_gradient_checkpointing = True 38 | 39 | @register_to_config 40 | def __init__( 41 | self, 42 | sample_size: Optional[int] = None, 43 | in_channels: int = 4, 44 | out_channels: int = 4, 45 | center_input_sample: bool = False, 46 | flip_sin_to_cos: bool = True, 47 | freq_shift: int = 0, 48 | down_block_types: Tuple[str] = ( 49 | "CrossAttnDownBlock3D", 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "DownBlock3D", 53 | ), 54 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 55 | up_block_types: Tuple[str] = ( 56 | "UpBlock3D", 57 | "CrossAttnUpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D", 60 | ), 61 | only_cross_attention: Union[bool, Tuple[bool]] = False, 62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 63 | layers_per_block: int = 2, 64 | downsample_padding: int = 1, 65 | mid_block_scale_factor: float = 1, 66 | act_fn: str = "silu", 67 | norm_num_groups: int = 32, 68 | norm_eps: float = 1e-5, 69 | cross_attention_dim: int = 1280, 70 | attention_head_dim: Union[int, Tuple[int]] = 8, 71 | dual_cross_attention: bool = False, 72 | use_linear_projection: bool = False, 73 | class_embed_type: Optional[str] = None, 74 | num_class_embeds: Optional[int] = None, 75 | upcast_attention: bool = False, 76 | resnet_time_scale_shift: str = "default", 77 | add_temp_transformer: bool = False, 78 | add_temp_attn_only_on_upblocks: bool = False, 79 | prepend_first_frame: bool = False, 80 | add_temp_embed: bool = False, 81 | add_temp_conv: bool = False, 82 | ): 83 | super().__init__() 84 | 85 | self.sample_size = sample_size 86 | time_embed_dim = block_out_channels[0] * 4 87 | 88 | # input 89 | self.conv_in = InflatedConv3d( 90 | in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) 91 | ) 92 | 93 | self.temp_transformer = ( 94 | TransformerTemporalModel( 95 | num_attention_heads=8, 96 | attention_head_dim=64, 97 | in_channels=block_out_channels[0], 98 | num_layers=1, 99 | add_temp_embed=add_temp_embed, 100 | ) 101 | if add_temp_transformer 102 | else None 103 | ) 104 | 105 | # time 106 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 107 | timestep_input_dim = block_out_channels[0] 108 | 109 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 110 | 111 | # class embedding 112 | if class_embed_type is None and num_class_embeds is not None: 113 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 114 | elif class_embed_type == "timestep": 115 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 116 | elif class_embed_type == "identity": 117 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 118 | else: 119 | self.class_embedding = None 120 | 121 | self.down_blocks = nn.ModuleList([]) 122 | self.mid_block = None 123 | self.up_blocks = nn.ModuleList([]) 124 | 125 | if isinstance(only_cross_attention, bool): 126 | only_cross_attention = [only_cross_attention] * len(down_block_types) 127 | 128 | if isinstance(attention_head_dim, int): 129 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 130 | 131 | # down 132 | output_channel = block_out_channels[0] 133 | for i, down_block_type in enumerate(down_block_types): 134 | input_channel = output_channel 135 | output_channel = block_out_channels[i] 136 | is_final_block = i == len(block_out_channels) - 1 137 | 138 | down_block = get_down_block( 139 | down_block_type, 140 | num_layers=layers_per_block, 141 | in_channels=input_channel, 142 | out_channels=output_channel, 143 | temb_channels=time_embed_dim, 144 | add_downsample=not is_final_block, 145 | resnet_eps=norm_eps, 146 | resnet_act_fn=act_fn, 147 | resnet_groups=norm_num_groups, 148 | cross_attention_dim=cross_attention_dim, 149 | attn_num_head_channels=attention_head_dim[i], 150 | downsample_padding=downsample_padding, 151 | dual_cross_attention=dual_cross_attention, 152 | use_linear_projection=use_linear_projection, 153 | only_cross_attention=only_cross_attention[i], 154 | upcast_attention=upcast_attention, 155 | resnet_time_scale_shift=resnet_time_scale_shift, 156 | add_temp_attn=not add_temp_attn_only_on_upblocks, 157 | prepend_first_frame=prepend_first_frame, 158 | add_temp_embed=add_temp_embed, 159 | add_temp_conv=add_temp_conv, 160 | ) 161 | self.down_blocks.append(down_block) 162 | 163 | # mid 164 | if mid_block_type == "UNetMidBlock3DCrossAttn": 165 | self.mid_block = UNetMidBlock3DCrossAttn( 166 | in_channels=block_out_channels[-1], 167 | temb_channels=time_embed_dim, 168 | resnet_eps=norm_eps, 169 | resnet_act_fn=act_fn, 170 | output_scale_factor=mid_block_scale_factor, 171 | resnet_time_scale_shift=resnet_time_scale_shift, 172 | cross_attention_dim=cross_attention_dim, 173 | attn_num_head_channels=attention_head_dim[-1], 174 | resnet_groups=norm_num_groups, 175 | dual_cross_attention=dual_cross_attention, 176 | use_linear_projection=use_linear_projection, 177 | upcast_attention=upcast_attention, 178 | add_temp_attn=not add_temp_attn_only_on_upblocks, 179 | prepend_first_frame=prepend_first_frame, 180 | add_temp_embed=add_temp_embed, 181 | add_temp_conv=add_temp_conv, 182 | ) 183 | else: 184 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 185 | 186 | # count how many layers upsample the videos 187 | self.num_upsamplers = 0 188 | 189 | # up 190 | reversed_block_out_channels = list(reversed(block_out_channels)) 191 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 192 | only_cross_attention = list(reversed(only_cross_attention)) 193 | output_channel = reversed_block_out_channels[0] 194 | for i, up_block_type in enumerate(up_block_types): 195 | is_final_block = i == len(block_out_channels) - 1 196 | 197 | prev_output_channel = output_channel 198 | output_channel = reversed_block_out_channels[i] 199 | input_channel = reversed_block_out_channels[ 200 | min(i + 1, len(block_out_channels) - 1) 201 | ] 202 | 203 | # add upsample block for all BUT final layer 204 | if not is_final_block: 205 | add_upsample = True 206 | self.num_upsamplers += 1 207 | else: 208 | add_upsample = False 209 | 210 | up_block = get_up_block( 211 | up_block_type, 212 | num_layers=layers_per_block + 1, 213 | in_channels=input_channel, 214 | out_channels=output_channel, 215 | prev_output_channel=prev_output_channel, 216 | temb_channels=time_embed_dim, 217 | add_upsample=add_upsample, 218 | resnet_eps=norm_eps, 219 | resnet_act_fn=act_fn, 220 | resnet_groups=norm_num_groups, 221 | cross_attention_dim=cross_attention_dim, 222 | attn_num_head_channels=reversed_attention_head_dim[i], 223 | dual_cross_attention=dual_cross_attention, 224 | use_linear_projection=use_linear_projection, 225 | only_cross_attention=only_cross_attention[i], 226 | upcast_attention=upcast_attention, 227 | resnet_time_scale_shift=resnet_time_scale_shift, 228 | add_temp_attn=True, 229 | prepend_first_frame=prepend_first_frame, 230 | add_temp_embed=add_temp_embed, 231 | add_temp_conv=add_temp_conv, 232 | ) 233 | self.up_blocks.append(up_block) 234 | prev_output_channel = output_channel 235 | 236 | # out 237 | self.conv_norm_out = nn.GroupNorm( 238 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 239 | ) 240 | self.conv_act = nn.SiLU() 241 | self.conv_out = InflatedConv3d( 242 | block_out_channels[0], out_channels, kernel_size=3, padding=1 243 | ) 244 | 245 | def set_attention_slice(self, slice_size): 246 | r""" 247 | Enable sliced attention computation. 248 | 249 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 250 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 251 | 252 | Args: 253 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 254 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 255 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 256 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 257 | must be a multiple of `slice_size`. 258 | """ 259 | sliceable_head_dims = [] 260 | 261 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 262 | if hasattr(module, "set_attention_slice"): 263 | sliceable_head_dims.append(module.sliceable_head_dim) 264 | 265 | for child in module.children(): 266 | fn_recursive_retrieve_slicable_dims(child) 267 | 268 | # retrieve number of attention layers 269 | for module in self.children(): 270 | fn_recursive_retrieve_slicable_dims(module) 271 | 272 | num_slicable_layers = len(sliceable_head_dims) 273 | 274 | if slice_size == "auto": 275 | # half the attention head size is usually a good trade-off between 276 | # speed and memory 277 | slice_size = [dim // 2 for dim in sliceable_head_dims] 278 | elif slice_size == "max": 279 | # make smallest slice possible 280 | slice_size = num_slicable_layers * [1] 281 | 282 | slice_size = ( 283 | num_slicable_layers * [slice_size] 284 | if not isinstance(slice_size, list) 285 | else slice_size 286 | ) 287 | 288 | if len(slice_size) != len(sliceable_head_dims): 289 | raise ValueError( 290 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 291 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 292 | ) 293 | 294 | for i in range(len(slice_size)): 295 | size = slice_size[i] 296 | dim = sliceable_head_dims[i] 297 | if size is not None and size > dim: 298 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 299 | 300 | # Recursively walk through all the children. 301 | # Any children which exposes the set_attention_slice method 302 | # gets the message 303 | def fn_recursive_set_attention_slice( 304 | module: torch.nn.Module, slice_size: List[int] 305 | ): 306 | if hasattr(module, "set_attention_slice"): 307 | module.set_attention_slice(slice_size.pop()) 308 | 309 | for child in module.children(): 310 | fn_recursive_set_attention_slice(child, slice_size) 311 | 312 | reversed_slice_size = list(reversed(slice_size)) 313 | for module in self.children(): 314 | fn_recursive_set_attention_slice(module, reversed_slice_size) 315 | 316 | def _set_gradient_checkpointing(self, module, value=False): 317 | if isinstance( 318 | module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D) 319 | ): 320 | module.gradient_checkpointing = value 321 | 322 | def forward( 323 | self, 324 | sample: torch.FloatTensor, 325 | timestep: Union[torch.Tensor, float, int], 326 | encoder_hidden_states: torch.Tensor, 327 | class_labels: Optional[torch.Tensor] = None, 328 | attention_mask: Optional[torch.Tensor] = None, 329 | return_dict: bool = True, 330 | ) -> Union[UNet3DConditionOutput, Tuple]: 331 | r""" 332 | Args: 333 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 334 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 335 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 336 | return_dict (`bool`, *optional*, defaults to `True`): 337 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 338 | 339 | Returns: 340 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 341 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 342 | returning a tuple, the first element is the sample tensor. 343 | """ 344 | # By default samples have to be AT least a multiple of the overall upsampling factor. 345 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 346 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 347 | # on the fly if necessary. 348 | default_overall_up_factor = 2**self.num_upsamplers 349 | 350 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 351 | forward_upsample_size = False 352 | upsample_size = None 353 | 354 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 355 | logger.info("Forward upsample size to force interpolation output size.") 356 | forward_upsample_size = True 357 | 358 | # prepare attention_mask 359 | if attention_mask is not None: 360 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 361 | attention_mask = attention_mask.unsqueeze(1) 362 | 363 | # center input if necessary 364 | if self.config.center_input_sample: 365 | sample = 2 * sample - 1.0 366 | 367 | # time 368 | timesteps = timestep 369 | if not torch.is_tensor(timesteps): 370 | # This would be a good case for the `match` statement (Python 3.10+) 371 | is_mps = sample.device.type == "mps" 372 | if isinstance(timestep, float): 373 | dtype = torch.float32 if is_mps else torch.float64 374 | else: 375 | dtype = torch.int32 if is_mps else torch.int64 376 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 377 | elif len(timesteps.shape) == 0: 378 | timesteps = timesteps[None].to(sample.device) 379 | 380 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 381 | timesteps = timesteps.expand(sample.shape[0]) 382 | 383 | t_emb = self.time_proj(timesteps) 384 | 385 | # timesteps does not contain any weights and will always return f32 tensors 386 | # but time_embedding might actually be running in fp16. so we need to cast here. 387 | # there might be better ways to encapsulate this. 388 | t_emb = t_emb.to(dtype=self.dtype) 389 | emb = self.time_embedding(t_emb) 390 | 391 | if self.class_embedding is not None: 392 | if class_labels is None: 393 | raise ValueError( 394 | "class_labels should be provided when num_class_embeds > 0" 395 | ) 396 | 397 | if self.config.class_embed_type == "timestep": 398 | class_labels = self.time_proj(class_labels) 399 | 400 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 401 | emb = emb + class_emb 402 | 403 | # fp16: cast to model dtype 404 | sample = sample.to(self.dtype) 405 | encoder_hidden_states = encoder_hidden_states.to(self.dtype) 406 | 407 | # pre-process 408 | sample = self.conv_in(sample) 409 | if self.temp_transformer is not None: 410 | sample_new = self.temp_transformer(sample).sample 411 | sample = sample_new if sample.shape[2] > 1 else sample + 0.0 * sample_new 412 | 413 | # down 414 | down_block_res_samples = (sample,) 415 | for downsample_block in self.down_blocks: 416 | if ( 417 | hasattr(downsample_block, "has_cross_attention") 418 | and downsample_block.has_cross_attention 419 | ): 420 | sample, res_samples = downsample_block( 421 | hidden_states=sample, 422 | temb=emb, 423 | encoder_hidden_states=encoder_hidden_states, 424 | attention_mask=attention_mask, 425 | ) 426 | else: 427 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 428 | down_block_res_samples += res_samples 429 | 430 | # mid 431 | sample = self.mid_block( 432 | sample, 433 | emb, 434 | encoder_hidden_states=encoder_hidden_states, 435 | attention_mask=attention_mask, 436 | ) 437 | 438 | # up 439 | for i, upsample_block in enumerate(self.up_blocks): 440 | is_final_block = i == len(self.up_blocks) - 1 441 | 442 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 443 | down_block_res_samples = down_block_res_samples[ 444 | : -len(upsample_block.resnets) 445 | ] 446 | 447 | # if we have not reached the final block and need to forward the 448 | # upsample size, we do it here 449 | if not is_final_block and forward_upsample_size: 450 | upsample_size = down_block_res_samples[-1].shape[2:] 451 | 452 | if ( 453 | hasattr(upsample_block, "has_cross_attention") 454 | and upsample_block.has_cross_attention 455 | ): 456 | sample = upsample_block( 457 | hidden_states=sample, 458 | temb=emb, 459 | res_hidden_states_tuple=res_samples, 460 | encoder_hidden_states=encoder_hidden_states, 461 | upsample_size=upsample_size, 462 | attention_mask=attention_mask, 463 | ) 464 | else: 465 | sample = upsample_block( 466 | hidden_states=sample, 467 | temb=emb, 468 | res_hidden_states_tuple=res_samples, 469 | upsample_size=upsample_size, 470 | ) 471 | # post-process 472 | sample = self.conv_norm_out(sample) 473 | sample = self.conv_act(sample) 474 | sample = self.conv_out(sample) 475 | 476 | if not return_dict: 477 | return (sample,) 478 | 479 | return UNet3DConditionOutput(sample=sample) 480 | 481 | @classmethod 482 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): 483 | if subfolder is not None: 484 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 485 | 486 | config_file = os.path.join(pretrained_model_path, "config.json") 487 | if not os.path.isfile(config_file): 488 | raise RuntimeError(f"{config_file} does not exist") 489 | with open(config_file, "r") as f: 490 | config = json.load(f) 491 | config["_class_name"] = cls.__name__ 492 | config["down_block_types"] = [ 493 | "CrossAttnDownBlock3D", 494 | "CrossAttnDownBlock3D", 495 | "CrossAttnDownBlock3D", 496 | "DownBlock3D", 497 | ] 498 | config["up_block_types"] = [ 499 | "UpBlock3D", 500 | "CrossAttnUpBlock3D", 501 | "CrossAttnUpBlock3D", 502 | "CrossAttnUpBlock3D", 503 | ] 504 | if "sample_size" in kwargs and "sample_size" in config: 505 | config["sample_size"] = kwargs.get("sample_size") 506 | 507 | from diffusers.utils import WEIGHTS_NAME 508 | 509 | model = cls.from_config( 510 | config, 511 | add_temp_transformer=kwargs.get("add_temp_transformer", False), 512 | add_temp_attn_only_on_upblocks=kwargs.get( 513 | "add_temp_attn_only_on_upblocks", False 514 | ), 515 | prepend_first_frame=kwargs.get("prepend_first_frame", False), 516 | add_temp_embed=kwargs.get("add_temp_embed", False), 517 | add_temp_conv=kwargs.get("add_temp_conv", False), 518 | ) 519 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 520 | if not os.path.isfile(model_file): 521 | raise RuntimeError(f"{model_file} does not exist") 522 | state_dict = torch.load(model_file, map_location="cpu") 523 | for k, v in model.state_dict().items(): 524 | if "temp_" in k and k not in state_dict: 525 | state_dict.update({k: v}) 526 | model.load_state_dict(state_dict) 527 | 528 | return model 529 | -------------------------------------------------------------------------------- /model/modules/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | from .utils import checkpoint, zero_module 9 | 10 | 11 | def get_down_block( 12 | down_block_type, 13 | num_layers, 14 | in_channels, 15 | out_channels, 16 | temb_channels, 17 | add_downsample, 18 | resnet_eps, 19 | resnet_act_fn, 20 | attn_num_head_channels, 21 | resnet_groups=None, 22 | cross_attention_dim=None, 23 | downsample_padding=None, 24 | dual_cross_attention=False, 25 | use_linear_projection=False, 26 | only_cross_attention=False, 27 | upcast_attention=False, 28 | resnet_time_scale_shift="default", 29 | add_temp_attn=False, 30 | prepend_first_frame=False, 31 | add_temp_embed=False, 32 | add_temp_conv=False, 33 | ): 34 | down_block_type = ( 35 | down_block_type[7:] 36 | if down_block_type.startswith("UNetRes") 37 | else down_block_type 38 | ) 39 | if down_block_type == "DownBlock3D": 40 | return DownBlock3D( 41 | num_layers=num_layers, 42 | in_channels=in_channels, 43 | out_channels=out_channels, 44 | temb_channels=temb_channels, 45 | add_downsample=add_downsample, 46 | resnet_eps=resnet_eps, 47 | resnet_act_fn=resnet_act_fn, 48 | resnet_groups=resnet_groups, 49 | downsample_padding=downsample_padding, 50 | resnet_time_scale_shift=resnet_time_scale_shift, 51 | add_temp_conv=add_temp_conv, 52 | ) 53 | elif down_block_type == "CrossAttnDownBlock3D": 54 | if cross_attention_dim is None: 55 | raise ValueError( 56 | "cross_attention_dim must be specified for CrossAttnDownBlock3D" 57 | ) 58 | return CrossAttnDownBlock3D( 59 | num_layers=num_layers, 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | temb_channels=temb_channels, 63 | add_downsample=add_downsample, 64 | resnet_eps=resnet_eps, 65 | resnet_act_fn=resnet_act_fn, 66 | resnet_groups=resnet_groups, 67 | downsample_padding=downsample_padding, 68 | cross_attention_dim=cross_attention_dim, 69 | attn_num_head_channels=attn_num_head_channels, 70 | dual_cross_attention=dual_cross_attention, 71 | use_linear_projection=use_linear_projection, 72 | only_cross_attention=only_cross_attention, 73 | upcast_attention=upcast_attention, 74 | resnet_time_scale_shift=resnet_time_scale_shift, 75 | add_temp_attn=add_temp_attn, 76 | prepend_first_frame=prepend_first_frame, 77 | add_temp_embed=add_temp_embed, 78 | add_temp_conv=add_temp_conv, 79 | ) 80 | raise ValueError(f"{down_block_type} does not exist.") 81 | 82 | 83 | def get_up_block( 84 | up_block_type, 85 | num_layers, 86 | in_channels, 87 | out_channels, 88 | prev_output_channel, 89 | temb_channels, 90 | add_upsample, 91 | resnet_eps, 92 | resnet_act_fn, 93 | attn_num_head_channels, 94 | resnet_groups=None, 95 | cross_attention_dim=None, 96 | dual_cross_attention=False, 97 | use_linear_projection=False, 98 | only_cross_attention=False, 99 | upcast_attention=False, 100 | resnet_time_scale_shift="default", 101 | add_temp_attn=False, 102 | prepend_first_frame=False, 103 | add_temp_embed=False, 104 | add_temp_conv=False, 105 | ): 106 | up_block_type = ( 107 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 108 | ) 109 | if up_block_type == "UpBlock3D": 110 | return UpBlock3D( 111 | num_layers=num_layers, 112 | in_channels=in_channels, 113 | out_channels=out_channels, 114 | prev_output_channel=prev_output_channel, 115 | temb_channels=temb_channels, 116 | add_upsample=add_upsample, 117 | resnet_eps=resnet_eps, 118 | resnet_act_fn=resnet_act_fn, 119 | resnet_groups=resnet_groups, 120 | resnet_time_scale_shift=resnet_time_scale_shift, 121 | add_temp_conv=add_temp_conv, 122 | ) 123 | elif up_block_type == "CrossAttnUpBlock3D": 124 | if cross_attention_dim is None: 125 | raise ValueError( 126 | "cross_attention_dim must be specified for CrossAttnUpBlock3D" 127 | ) 128 | return CrossAttnUpBlock3D( 129 | num_layers=num_layers, 130 | in_channels=in_channels, 131 | out_channels=out_channels, 132 | prev_output_channel=prev_output_channel, 133 | temb_channels=temb_channels, 134 | add_upsample=add_upsample, 135 | resnet_eps=resnet_eps, 136 | resnet_act_fn=resnet_act_fn, 137 | resnet_groups=resnet_groups, 138 | cross_attention_dim=cross_attention_dim, 139 | attn_num_head_channels=attn_num_head_channels, 140 | dual_cross_attention=dual_cross_attention, 141 | use_linear_projection=use_linear_projection, 142 | only_cross_attention=only_cross_attention, 143 | upcast_attention=upcast_attention, 144 | resnet_time_scale_shift=resnet_time_scale_shift, 145 | add_temp_attn=add_temp_attn, 146 | prepend_first_frame=prepend_first_frame, 147 | add_temp_embed=add_temp_embed, 148 | add_temp_conv=add_temp_conv, 149 | ) 150 | raise ValueError(f"{up_block_type} does not exist.") 151 | 152 | 153 | class TemporalConvLayer(nn.Module): 154 | """ 155 | Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: 156 | https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 157 | """ 158 | 159 | def __init__(self, in_dim, out_dim=None, num_layers=4, dropout=0.0): 160 | super().__init__() 161 | out_dim = out_dim or in_dim 162 | 163 | # conv layers 164 | convs = [] 165 | prev_dim, next_dim = in_dim, out_dim 166 | for i in range(num_layers): 167 | if i == num_layers - 1: 168 | next_dim = out_dim 169 | convs.extend( 170 | [ 171 | nn.GroupNorm(32, prev_dim), 172 | nn.SiLU(), 173 | nn.Dropout(dropout), 174 | nn.Conv3d(prev_dim, next_dim, (3, 1, 1), padding=(1, 0, 0)), 175 | ] 176 | ) 177 | prev_dim, next_dim = next_dim, prev_dim 178 | self.convs = nn.ModuleList(convs) 179 | 180 | def forward(self, hidden_states): 181 | video_length = hidden_states.shape[2] 182 | 183 | identity = hidden_states 184 | for conv in self.convs: 185 | hidden_states = conv(hidden_states) 186 | 187 | # ignore effects of temporal layers on image inputs 188 | hidden_states = ( 189 | identity + hidden_states 190 | if video_length > 1 191 | else identity + 0.0 * hidden_states 192 | ) 193 | 194 | return hidden_states 195 | 196 | 197 | class UNetMidBlock3DCrossAttn(nn.Module): 198 | def __init__( 199 | self, 200 | in_channels: int, 201 | temb_channels: int, 202 | dropout: float = 0.0, 203 | num_layers: int = 1, 204 | resnet_eps: float = 1e-6, 205 | resnet_time_scale_shift: str = "default", 206 | resnet_act_fn: str = "swish", 207 | resnet_groups: int = 32, 208 | resnet_pre_norm: bool = True, 209 | attn_num_head_channels=1, 210 | output_scale_factor=1.0, 211 | cross_attention_dim=1280, 212 | dual_cross_attention=False, 213 | use_linear_projection=False, 214 | upcast_attention=False, 215 | add_temp_attn=False, 216 | prepend_first_frame=False, 217 | add_temp_embed=False, 218 | add_temp_conv=False, 219 | ): 220 | super().__init__() 221 | 222 | self.has_cross_attention = True 223 | self.attn_num_head_channels = attn_num_head_channels 224 | resnet_groups = ( 225 | resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 226 | ) 227 | 228 | # there is always at least one resnet 229 | resnets = [ 230 | ResnetBlock3D( 231 | in_channels=in_channels, 232 | out_channels=in_channels, 233 | temb_channels=temb_channels, 234 | eps=resnet_eps, 235 | groups=resnet_groups, 236 | dropout=dropout, 237 | time_embedding_norm=resnet_time_scale_shift, 238 | non_linearity=resnet_act_fn, 239 | output_scale_factor=output_scale_factor, 240 | pre_norm=resnet_pre_norm, 241 | ) 242 | ] 243 | attentions = [] 244 | if add_temp_conv: 245 | self.temp_convs = None 246 | temp_convs = [TemporalConvLayer(in_channels, in_channels, dropout=0.1)] 247 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 248 | 249 | for _ in range(num_layers): 250 | if dual_cross_attention: 251 | raise NotImplementedError 252 | attentions.append( 253 | Transformer3DModel( 254 | attn_num_head_channels, 255 | in_channels // attn_num_head_channels, 256 | in_channels=in_channels, 257 | num_layers=1, 258 | cross_attention_dim=cross_attention_dim, 259 | norm_num_groups=resnet_groups, 260 | use_linear_projection=use_linear_projection, 261 | upcast_attention=upcast_attention, 262 | add_temp_attn=add_temp_attn, 263 | prepend_first_frame=prepend_first_frame, 264 | add_temp_embed=add_temp_embed, 265 | ) 266 | ) 267 | resnets.append( 268 | ResnetBlock3D( 269 | in_channels=in_channels, 270 | out_channels=in_channels, 271 | temb_channels=temb_channels, 272 | eps=resnet_eps, 273 | groups=resnet_groups, 274 | dropout=dropout, 275 | time_embedding_norm=resnet_time_scale_shift, 276 | non_linearity=resnet_act_fn, 277 | output_scale_factor=output_scale_factor, 278 | pre_norm=resnet_pre_norm, 279 | ) 280 | ) 281 | if hasattr(self, "temp_convs"): 282 | temp_convs.append( 283 | TemporalConvLayer(in_channels, in_channels, dropout=0.1) 284 | ) 285 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 286 | 287 | self.attentions = nn.ModuleList(attentions) 288 | self.resnets = nn.ModuleList(resnets) 289 | if hasattr(self, "temp_convs"): 290 | self.temp_convs = nn.ModuleList(temp_convs) 291 | 292 | def forward( 293 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None 294 | ): 295 | hidden_states = self.resnets[0](hidden_states, temb) 296 | if hasattr(self, "temp_convs"): 297 | hidden_states = self.temp_convs[0](hidden_states) 298 | for layer_idx in range(len(self.attentions)): 299 | attn = self.attentions[layer_idx] 300 | resnet = self.resnets[layer_idx + 1] 301 | hidden_states = attn( 302 | hidden_states, encoder_hidden_states=encoder_hidden_states 303 | )[0] 304 | hidden_states = resnet(hidden_states, temb) 305 | if hasattr(self, "temp_convs"): 306 | temp_conv = self.temp_convs[layer_idx + 1] 307 | hidden_states = temp_conv(hidden_states) 308 | 309 | return hidden_states 310 | 311 | 312 | class CrossAttnDownBlock3D(nn.Module): 313 | def __init__( 314 | self, 315 | in_channels: int, 316 | out_channels: int, 317 | temb_channels: int, 318 | dropout: float = 0.0, 319 | num_layers: int = 1, 320 | resnet_eps: float = 1e-6, 321 | resnet_time_scale_shift: str = "default", 322 | resnet_act_fn: str = "swish", 323 | resnet_groups: int = 32, 324 | resnet_pre_norm: bool = True, 325 | attn_num_head_channels=1, 326 | cross_attention_dim=1280, 327 | output_scale_factor=1.0, 328 | downsample_padding=1, 329 | add_downsample=True, 330 | dual_cross_attention=False, 331 | use_linear_projection=False, 332 | only_cross_attention=False, 333 | upcast_attention=False, 334 | add_temp_attn=False, 335 | prepend_first_frame=False, 336 | add_temp_embed=False, 337 | add_temp_conv=False, 338 | ): 339 | super().__init__() 340 | resnets = [] 341 | attentions = [] 342 | if add_temp_conv: 343 | self.temp_convs = None 344 | temp_convs = [] 345 | 346 | self.has_cross_attention = True 347 | self.attn_num_head_channels = attn_num_head_channels 348 | 349 | for i in range(num_layers): 350 | in_channels = in_channels if i == 0 else out_channels 351 | resnets.append( 352 | ResnetBlock3D( 353 | in_channels=in_channels, 354 | out_channels=out_channels, 355 | temb_channels=temb_channels, 356 | eps=resnet_eps, 357 | groups=resnet_groups, 358 | dropout=dropout, 359 | time_embedding_norm=resnet_time_scale_shift, 360 | non_linearity=resnet_act_fn, 361 | output_scale_factor=output_scale_factor, 362 | pre_norm=resnet_pre_norm, 363 | ) 364 | ) 365 | if hasattr(self, "temp_convs"): 366 | temp_convs.append( 367 | TemporalConvLayer(out_channels, out_channels, dropout=0.1) 368 | ) 369 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 370 | if dual_cross_attention: 371 | raise NotImplementedError 372 | attentions.append( 373 | Transformer3DModel( 374 | attn_num_head_channels, 375 | out_channels // attn_num_head_channels, 376 | in_channels=out_channels, 377 | num_layers=1, 378 | cross_attention_dim=cross_attention_dim, 379 | norm_num_groups=resnet_groups, 380 | use_linear_projection=use_linear_projection, 381 | only_cross_attention=only_cross_attention, 382 | upcast_attention=upcast_attention, 383 | add_temp_attn=add_temp_attn, 384 | prepend_first_frame=prepend_first_frame, 385 | add_temp_embed=add_temp_embed, 386 | ) 387 | ) 388 | self.attentions = nn.ModuleList(attentions) 389 | self.resnets = nn.ModuleList(resnets) 390 | if hasattr(self, "temp_convs"): 391 | self.temp_convs = nn.ModuleList(temp_convs) 392 | 393 | if add_downsample: 394 | self.downsamplers = nn.ModuleList( 395 | [ 396 | Downsample3D( 397 | out_channels, 398 | use_conv=True, 399 | out_channels=out_channels, 400 | padding=downsample_padding, 401 | name="op", 402 | ) 403 | ] 404 | ) 405 | else: 406 | self.downsamplers = None 407 | 408 | self.gradient_checkpointing = False 409 | 410 | def forward( 411 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None 412 | ): 413 | output_states = () 414 | 415 | for layer_idx in range(len(self.resnets)): 416 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx] 417 | is_checkpointing = self.training and self.gradient_checkpointing 418 | hidden_states = checkpoint( 419 | func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing 420 | ) 421 | if hasattr(self, "temp_convs"): 422 | temp_conv = self.temp_convs[layer_idx] 423 | hidden_states = checkpoint( 424 | func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing 425 | ) 426 | hidden_states = checkpoint( 427 | func=attn, 428 | inputs=(hidden_states, encoder_hidden_states), 429 | flag=is_checkpointing, 430 | )[0] 431 | 432 | output_states += (hidden_states,) 433 | 434 | if self.downsamplers is not None: 435 | for downsampler in self.downsamplers: 436 | hidden_states = downsampler(hidden_states) 437 | 438 | output_states += (hidden_states,) 439 | 440 | return hidden_states, output_states 441 | 442 | 443 | class DownBlock3D(nn.Module): 444 | def __init__( 445 | self, 446 | in_channels: int, 447 | out_channels: int, 448 | temb_channels: int, 449 | dropout: float = 0.0, 450 | num_layers: int = 1, 451 | resnet_eps: float = 1e-6, 452 | resnet_time_scale_shift: str = "default", 453 | resnet_act_fn: str = "swish", 454 | resnet_groups: int = 32, 455 | resnet_pre_norm: bool = True, 456 | output_scale_factor=1.0, 457 | add_downsample=True, 458 | downsample_padding=1, 459 | add_temp_conv=False, 460 | ): 461 | super().__init__() 462 | resnets = [] 463 | if add_temp_conv: 464 | self.temp_convs = None 465 | temp_convs = [] 466 | for i in range(num_layers): 467 | in_channels = in_channels if i == 0 else out_channels 468 | resnets.append( 469 | ResnetBlock3D( 470 | in_channels=in_channels, 471 | out_channels=out_channels, 472 | temb_channels=temb_channels, 473 | eps=resnet_eps, 474 | groups=resnet_groups, 475 | dropout=dropout, 476 | time_embedding_norm=resnet_time_scale_shift, 477 | non_linearity=resnet_act_fn, 478 | output_scale_factor=output_scale_factor, 479 | pre_norm=resnet_pre_norm, 480 | ) 481 | ) 482 | if hasattr(self, "temp_convs"): 483 | temp_convs.append( 484 | TemporalConvLayer(out_channels, out_channels, dropout=0.1) 485 | ) 486 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 487 | 488 | self.resnets = nn.ModuleList(resnets) 489 | if hasattr(self, "temp_convs"): 490 | self.temp_convs = nn.ModuleList(temp_convs) 491 | 492 | if add_downsample: 493 | self.downsamplers = nn.ModuleList( 494 | [ 495 | Downsample3D( 496 | out_channels, 497 | use_conv=True, 498 | out_channels=out_channels, 499 | padding=downsample_padding, 500 | name="op", 501 | ) 502 | ] 503 | ) 504 | else: 505 | self.downsamplers = None 506 | 507 | self.gradient_checkpointing = False 508 | 509 | def forward(self, hidden_states, temb=None): 510 | output_states = () 511 | 512 | for layer_idx in range(len(self.resnets)): 513 | resnet = self.resnets[layer_idx] 514 | is_checkpointing = self.training and self.gradient_checkpointing 515 | hidden_states = checkpoint( 516 | func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing 517 | ) 518 | if hasattr(self, "temp_convs"): 519 | temp_conv = self.temp_convs[layer_idx] 520 | hidden_states = checkpoint( 521 | func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing 522 | ) 523 | 524 | output_states += (hidden_states,) 525 | 526 | if self.downsamplers is not None: 527 | for downsampler in self.downsamplers: 528 | hidden_states = downsampler(hidden_states) 529 | 530 | output_states += (hidden_states,) 531 | 532 | return hidden_states, output_states 533 | 534 | 535 | class CrossAttnUpBlock3D(nn.Module): 536 | def __init__( 537 | self, 538 | in_channels: int, 539 | out_channels: int, 540 | prev_output_channel: int, 541 | temb_channels: int, 542 | dropout: float = 0.0, 543 | num_layers: int = 1, 544 | resnet_eps: float = 1e-6, 545 | resnet_time_scale_shift: str = "default", 546 | resnet_act_fn: str = "swish", 547 | resnet_groups: int = 32, 548 | resnet_pre_norm: bool = True, 549 | attn_num_head_channels=1, 550 | cross_attention_dim=1280, 551 | output_scale_factor=1.0, 552 | add_upsample=True, 553 | dual_cross_attention=False, 554 | use_linear_projection=False, 555 | only_cross_attention=False, 556 | upcast_attention=False, 557 | add_temp_attn=False, 558 | prepend_first_frame=False, 559 | add_temp_embed=False, 560 | add_temp_conv=False, 561 | ): 562 | super().__init__() 563 | resnets = [] 564 | attentions = [] 565 | if add_temp_conv: 566 | self.temp_convs = None 567 | temp_convs = [] 568 | 569 | self.has_cross_attention = True 570 | self.attn_num_head_channels = attn_num_head_channels 571 | 572 | for i in range(num_layers): 573 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 574 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 575 | 576 | resnets.append( 577 | ResnetBlock3D( 578 | in_channels=resnet_in_channels + res_skip_channels, 579 | out_channels=out_channels, 580 | temb_channels=temb_channels, 581 | eps=resnet_eps, 582 | groups=resnet_groups, 583 | dropout=dropout, 584 | time_embedding_norm=resnet_time_scale_shift, 585 | non_linearity=resnet_act_fn, 586 | output_scale_factor=output_scale_factor, 587 | pre_norm=resnet_pre_norm, 588 | ) 589 | ) 590 | if dual_cross_attention: 591 | raise NotImplementedError 592 | attentions.append( 593 | Transformer3DModel( 594 | attn_num_head_channels, 595 | out_channels // attn_num_head_channels, 596 | in_channels=out_channels, 597 | num_layers=1, 598 | cross_attention_dim=cross_attention_dim, 599 | norm_num_groups=resnet_groups, 600 | use_linear_projection=use_linear_projection, 601 | only_cross_attention=only_cross_attention, 602 | upcast_attention=upcast_attention, 603 | add_temp_attn=add_temp_attn, 604 | prepend_first_frame=prepend_first_frame, 605 | add_temp_embed=add_temp_embed, 606 | ) 607 | ) 608 | if hasattr(self, "temp_convs"): 609 | temp_convs.append( 610 | TemporalConvLayer(out_channels, out_channels, dropout=0.1) 611 | ) 612 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 613 | 614 | self.attentions = nn.ModuleList(attentions) 615 | self.resnets = nn.ModuleList(resnets) 616 | if hasattr(self, "temp_convs"): 617 | self.temp_convs = nn.ModuleList(temp_convs) 618 | 619 | if add_upsample: 620 | self.upsamplers = nn.ModuleList( 621 | [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] 622 | ) 623 | else: 624 | self.upsamplers = None 625 | 626 | self.gradient_checkpointing = False 627 | 628 | def forward( 629 | self, 630 | hidden_states, 631 | res_hidden_states_tuple, 632 | temb=None, 633 | encoder_hidden_states=None, 634 | upsample_size=None, 635 | attention_mask=None, 636 | ): 637 | for layer_idx in range(len(self.resnets)): 638 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx] 639 | # pop res hidden states 640 | res_hidden_states = res_hidden_states_tuple[-1] 641 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 642 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 643 | 644 | is_checkpointing = self.training and self.gradient_checkpointing 645 | hidden_states = checkpoint( 646 | func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing 647 | ) 648 | if hasattr(self, "temp_convs"): 649 | temp_conv = self.temp_convs[layer_idx] 650 | hidden_states = checkpoint( 651 | func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing 652 | ) 653 | hidden_states = checkpoint( 654 | func=attn, 655 | inputs=(hidden_states, encoder_hidden_states), 656 | flag=is_checkpointing, 657 | )[0] 658 | 659 | if self.upsamplers is not None: 660 | for upsampler in self.upsamplers: 661 | hidden_states = upsampler(hidden_states, upsample_size) 662 | 663 | return hidden_states 664 | 665 | 666 | class UpBlock3D(nn.Module): 667 | def __init__( 668 | self, 669 | in_channels: int, 670 | prev_output_channel: int, 671 | out_channels: int, 672 | temb_channels: int, 673 | dropout: float = 0.0, 674 | num_layers: int = 1, 675 | resnet_eps: float = 1e-6, 676 | resnet_time_scale_shift: str = "default", 677 | resnet_act_fn: str = "swish", 678 | resnet_groups: int = 32, 679 | resnet_pre_norm: bool = True, 680 | output_scale_factor=1.0, 681 | add_upsample=True, 682 | add_temp_conv=False, 683 | ): 684 | super().__init__() 685 | resnets = [] 686 | 687 | if add_temp_conv: 688 | self.temp_convs = None 689 | temp_convs = [] 690 | 691 | for i in range(num_layers): 692 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 693 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 694 | 695 | resnets.append( 696 | ResnetBlock3D( 697 | in_channels=resnet_in_channels + res_skip_channels, 698 | out_channels=out_channels, 699 | temb_channels=temb_channels, 700 | eps=resnet_eps, 701 | groups=resnet_groups, 702 | dropout=dropout, 703 | time_embedding_norm=resnet_time_scale_shift, 704 | non_linearity=resnet_act_fn, 705 | output_scale_factor=output_scale_factor, 706 | pre_norm=resnet_pre_norm, 707 | ) 708 | ) 709 | if hasattr(self, "temp_convs"): 710 | temp_convs.append( 711 | TemporalConvLayer(out_channels, out_channels, dropout=0.1) 712 | ) 713 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 714 | 715 | self.resnets = nn.ModuleList(resnets) 716 | if hasattr(self, "temp_convs"): 717 | self.temp_convs = nn.ModuleList(temp_convs) 718 | 719 | if add_upsample: 720 | self.upsamplers = nn.ModuleList( 721 | [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] 722 | ) 723 | else: 724 | self.upsamplers = None 725 | 726 | self.gradient_checkpointing = False 727 | 728 | def forward( 729 | self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None 730 | ): 731 | for layer_idx in range(len(self.resnets)): 732 | resnet = self.resnets[layer_idx] 733 | # pop res hidden states 734 | res_hidden_states = res_hidden_states_tuple[-1] 735 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 736 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 737 | 738 | is_checkpointing = self.training and self.gradient_checkpointing 739 | hidden_states = checkpoint( 740 | func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing 741 | ) 742 | if hasattr(self, "temp_convs"): 743 | temp_conv = self.temp_convs[layer_idx] 744 | hidden_states = checkpoint( 745 | func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing 746 | ) 747 | 748 | if self.upsamplers is not None: 749 | for upsampler in self.upsamplers: 750 | hidden_states = upsampler(hidden_states, upsample_size) 751 | 752 | return hidden_states 753 | -------------------------------------------------------------------------------- /model/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def checkpoint(func, inputs, flag): 5 | """ 6 | Evaluate a function without caching intermediate activations, allowing for 7 | reduced memory at the expense of extra compute in the backward pass. 8 | :param func: the function to evaluate. 9 | :param inputs: the argument sequence to pass to `func`. 10 | :param flag: if False, disable gradient checkpointing. 11 | """ 12 | if flag: 13 | return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False) 14 | else: 15 | return func(*inputs) 16 | 17 | 18 | def zero_module(module): 19 | """ 20 | Zero out the parameters of a module and return it. 21 | """ 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | -------------------------------------------------------------------------------- /model/pipeline.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 2 | 3 | 4 | import inspect 5 | from dataclasses import dataclass 6 | from math import sqrt 7 | from typing import Callable, Optional, Union 8 | from typing import List 9 | 10 | import numpy as np 11 | import torch 12 | from diffusers.models import AutoencoderKL 13 | from diffusers.pipeline_utils import DiffusionPipeline 14 | from diffusers.schedulers import ( 15 | DDIMScheduler, 16 | DPMSolverMultistepScheduler, 17 | EulerAncestralDiscreteScheduler, 18 | EulerDiscreteScheduler, 19 | LMSDiscreteScheduler, 20 | PNDMScheduler, 21 | ) 22 | from diffusers.utils import is_accelerate_available 23 | from diffusers.utils import is_accelerate_version, replace_example_docstring 24 | from diffusers.utils import logging, BaseOutput 25 | from einops import rearrange 26 | from transformers import CLIPTextModel, CLIPTokenizer 27 | 28 | from .modules.unet import UNet3DConditionModel 29 | from .utils import randn_progressive, randn_base 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | EXAMPLE_DOC_STRING = """ 34 | Examples: 35 | ```py 36 | >>> import torch 37 | >>> from diffusers import TextToVideoSDPipeline 38 | >>> from diffusers.utils import export_to_video 39 | 40 | >>> pipe = TextToVideoSDPipeline.from_pretrained( 41 | ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" 42 | ... ) 43 | >>> pipe.enable_model_cpu_offload() 44 | 45 | >>> prompt = "Spiderman is surfing" 46 | >>> video_frames = pipe(prompt).frames 47 | >>> video_path = export_to_video(video_frames) 48 | >>> video_path 49 | ``` 50 | """ 51 | 52 | 53 | @dataclass 54 | class SDVideoPipelineOutput(BaseOutput): 55 | """ 56 | Output class for text to video pipelines. 57 | Args: 58 | frames (`List[np.ndarray]` or `torch.FloatTensor`) 59 | List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as 60 | a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list 61 | denotes the video length i.e., the number of frames. 62 | """ 63 | 64 | videos: Union[torch.Tensor, np.ndarray] 65 | 66 | 67 | class SDVideoPipeline(DiffusionPipeline): 68 | r""" 69 | Pipeline for text-to-video generation. 70 | 71 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 72 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 73 | 74 | Args: 75 | vae ([`AutoencoderKL`]): 76 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 77 | text_encoder ([`CLIPTextModel`]): 78 | Frozen text-encoder. Same as Stable Diffusion 2. 79 | tokenizer (`CLIPTokenizer`): 80 | Tokenizer of class 81 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 82 | unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. 83 | scheduler ([`SchedulerMixin`]): 84 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 85 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | vae: AutoencoderKL, 91 | text_encoder: CLIPTextModel, 92 | tokenizer: CLIPTokenizer, 93 | unet: UNet3DConditionModel, 94 | scheduler: Union[ 95 | DDIMScheduler, 96 | PNDMScheduler, 97 | LMSDiscreteScheduler, 98 | EulerDiscreteScheduler, 99 | EulerAncestralDiscreteScheduler, 100 | DPMSolverMultistepScheduler, 101 | ], 102 | ): 103 | super().__init__() 104 | 105 | self.register_modules( 106 | vae=vae, 107 | text_encoder=text_encoder, 108 | tokenizer=tokenizer, 109 | unet=unet, 110 | scheduler=scheduler, 111 | ) 112 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 113 | 114 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing 115 | def enable_vae_slicing(self): 116 | r""" 117 | Enable sliced VAE decoding. 118 | 119 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 120 | steps. This is useful to save some memory and allow larger batch sizes. 121 | """ 122 | self.vae.enable_slicing() 123 | 124 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing 125 | def disable_vae_slicing(self): 126 | r""" 127 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 128 | computing decoding in one step. 129 | """ 130 | self.vae.disable_slicing() 131 | 132 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling 133 | def enable_vae_tiling(self): 134 | r""" 135 | Enable tiled VAE decoding. 136 | 137 | When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in 138 | several steps. This is useful to save a large amount of memory and to allow the processing of larger images. 139 | """ 140 | self.vae.enable_tiling() 141 | 142 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling 143 | def disable_vae_tiling(self): 144 | r""" 145 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to 146 | computing decoding in one step. 147 | """ 148 | self.vae.disable_tiling() 149 | 150 | def enable_sequential_cpu_offload(self, gpu_id=0): 151 | r""" 152 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 153 | text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded 154 | to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a 155 | submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. 156 | """ 157 | if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): 158 | from accelerate import cpu_offload 159 | else: 160 | raise ImportError( 161 | "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher" 162 | ) 163 | 164 | device = torch.device(f"cuda:{gpu_id}") 165 | 166 | if self.device.type != "cpu": 167 | self.to("cpu", silence_dtype_warnings=True) 168 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 169 | 170 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 171 | cpu_offload(cpu_offloaded_model, device) 172 | 173 | def enable_model_cpu_offload(self, gpu_id=0): 174 | r""" 175 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 176 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 177 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 178 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 179 | """ 180 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 181 | from accelerate import cpu_offload_with_hook 182 | else: 183 | raise ImportError( 184 | "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher." 185 | ) 186 | 187 | device = torch.device(f"cuda:{gpu_id}") 188 | 189 | if self.device.type != "cpu": 190 | self.to("cpu", silence_dtype_warnings=True) 191 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 192 | 193 | hook = None 194 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: 195 | _, hook = cpu_offload_with_hook( 196 | cpu_offloaded_model, device, prev_module_hook=hook 197 | ) 198 | 199 | # We'll offload the last model manually. 200 | self.final_offload_hook = hook 201 | 202 | @property 203 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device 204 | def _execution_device(self): 205 | r""" 206 | Returns the device on which the pipeline's models will be executed. After calling 207 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 208 | hooks. 209 | """ 210 | if not hasattr(self.unet, "_hf_hook"): 211 | return self.device 212 | for module in self.unet.modules(): 213 | if ( 214 | hasattr(module, "_hf_hook") 215 | and hasattr(module._hf_hook, "execution_device") 216 | and module._hf_hook.execution_device is not None 217 | ): 218 | return torch.device(module._hf_hook.execution_device) 219 | return self.device 220 | 221 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 222 | def _encode_prompt( 223 | self, 224 | prompt, 225 | device, 226 | num_images_per_prompt, 227 | do_classifier_free_guidance, 228 | negative_prompt=None, 229 | prompt_embeds: Optional[torch.FloatTensor] = None, 230 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 231 | ): 232 | r""" 233 | Encodes the prompt into text encoder hidden states. 234 | 235 | Args: 236 | prompt (`str` or `List[str]`, *optional*): 237 | prompt to be encoded 238 | device: (`torch.device`): 239 | torch device 240 | num_images_per_prompt (`int`): 241 | number of images that should be generated per prompt 242 | do_classifier_free_guidance (`bool`): 243 | whether to use classifier free guidance or not 244 | negative_prompt (`str` or `List[str]`, *optional*): 245 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 246 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 247 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 248 | prompt_embeds (`torch.FloatTensor`, *optional*): 249 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 250 | provided, text embeddings will be generated from `prompt` input argument. 251 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 252 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 253 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 254 | argument. 255 | """ 256 | if prompt is not None and isinstance(prompt, str): 257 | batch_size = 1 258 | elif prompt is not None and isinstance(prompt, list): 259 | batch_size = len(prompt) 260 | else: 261 | batch_size = prompt_embeds.shape[0] 262 | 263 | if prompt_embeds is None: 264 | text_inputs = self.tokenizer( 265 | prompt, 266 | padding="max_length", 267 | max_length=self.tokenizer.model_max_length, 268 | truncation=True, 269 | return_tensors="pt", 270 | ) 271 | text_input_ids = text_inputs.input_ids 272 | untruncated_ids = self.tokenizer( 273 | prompt, padding="longest", return_tensors="pt" 274 | ).input_ids 275 | 276 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 277 | -1 278 | ] and not torch.equal(text_input_ids, untruncated_ids): 279 | removed_text = self.tokenizer.batch_decode( 280 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 281 | ) 282 | logger.warning( 283 | "The following part of your input was truncated because CLIP can only handle sequences up to" 284 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 285 | ) 286 | 287 | if ( 288 | hasattr(self.text_encoder.config, "use_attention_mask") 289 | and self.text_encoder.config.use_attention_mask 290 | ): 291 | attention_mask = text_inputs.attention_mask.to(device) 292 | else: 293 | attention_mask = None 294 | 295 | prompt_embeds = self.text_encoder( 296 | text_input_ids.to(device), 297 | attention_mask=attention_mask, 298 | ) 299 | prompt_embeds = prompt_embeds[0] 300 | 301 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 302 | 303 | bs_embed, seq_len, _ = prompt_embeds.shape 304 | # duplicate text embeddings for each generation per prompt, using mps friendly method 305 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 306 | prompt_embeds = prompt_embeds.view( 307 | bs_embed * num_images_per_prompt, seq_len, -1 308 | ) 309 | 310 | # get unconditional embeddings for classifier free guidance 311 | if do_classifier_free_guidance and negative_prompt_embeds is None: 312 | uncond_tokens: List[str] 313 | if negative_prompt is None: 314 | uncond_tokens = [""] * batch_size 315 | elif type(prompt) is not type(negative_prompt): 316 | raise TypeError( 317 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 318 | f" {type(prompt)}." 319 | ) 320 | elif isinstance(negative_prompt, str): 321 | uncond_tokens = [negative_prompt] 322 | elif batch_size != len(negative_prompt): 323 | raise ValueError( 324 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 325 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 326 | " the batch size of `prompt`." 327 | ) 328 | else: 329 | uncond_tokens = negative_prompt 330 | 331 | max_length = prompt_embeds.shape[1] 332 | uncond_input = self.tokenizer( 333 | uncond_tokens, 334 | padding="max_length", 335 | max_length=max_length, 336 | truncation=True, 337 | return_tensors="pt", 338 | ) 339 | 340 | if ( 341 | hasattr(self.text_encoder.config, "use_attention_mask") 342 | and self.text_encoder.config.use_attention_mask 343 | ): 344 | attention_mask = uncond_input.attention_mask.to(device) 345 | else: 346 | attention_mask = None 347 | 348 | negative_prompt_embeds = self.text_encoder( 349 | uncond_input.input_ids.to(device), 350 | attention_mask=attention_mask, 351 | ) 352 | negative_prompt_embeds = negative_prompt_embeds[0] 353 | 354 | if do_classifier_free_guidance: 355 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 356 | seq_len = negative_prompt_embeds.shape[1] 357 | 358 | negative_prompt_embeds = negative_prompt_embeds.to( 359 | dtype=self.text_encoder.dtype, device=device 360 | ) 361 | 362 | negative_prompt_embeds = negative_prompt_embeds.repeat( 363 | 1, num_images_per_prompt, 1 364 | ) 365 | negative_prompt_embeds = negative_prompt_embeds.view( 366 | batch_size * num_images_per_prompt, seq_len, -1 367 | ) 368 | 369 | # For classifier free guidance, we need to do two forward passes. 370 | # Here we concatenate the unconditional and text embeddings into a single batch 371 | # to avoid doing two forward passes 372 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 373 | 374 | return prompt_embeds 375 | 376 | def decode_latents(self, latents): 377 | scaling_factor = self.vae.config.get("scaling_factor", 0.18215) 378 | video_length = latents.shape[2] 379 | latents = 1 / scaling_factor * latents 380 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 381 | video = self.vae.decode(latents).sample 382 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 383 | video = (video / 2 + 0.5).clamp(0, 1) 384 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 385 | video = video.float() 386 | return video 387 | 388 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 389 | def prepare_extra_step_kwargs(self, generator, eta): 390 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 391 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 392 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 393 | # and should be between [0, 1] 394 | 395 | accepts_eta = "eta" in set( 396 | inspect.signature(self.scheduler.step).parameters.keys() 397 | ) 398 | extra_step_kwargs = {} 399 | if accepts_eta: 400 | extra_step_kwargs["eta"] = eta 401 | 402 | # check if the scheduler accepts generator 403 | accepts_generator = "generator" in set( 404 | inspect.signature(self.scheduler.step).parameters.keys() 405 | ) 406 | if accepts_generator: 407 | extra_step_kwargs["generator"] = generator 408 | return extra_step_kwargs 409 | 410 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs 411 | def check_inputs( 412 | self, 413 | prompt, 414 | height, 415 | width, 416 | callback_steps, 417 | negative_prompt=None, 418 | prompt_embeds=None, 419 | negative_prompt_embeds=None, 420 | ): 421 | if height % 8 != 0 or width % 8 != 0: 422 | raise ValueError( 423 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." 424 | ) 425 | 426 | if (callback_steps is None) or ( 427 | callback_steps is not None 428 | and (not isinstance(callback_steps, int) or callback_steps <= 0) 429 | ): 430 | raise ValueError( 431 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 432 | f" {type(callback_steps)}." 433 | ) 434 | 435 | if prompt is not None and prompt_embeds is not None: 436 | raise ValueError( 437 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 438 | " only forward one of the two." 439 | ) 440 | elif prompt is None and prompt_embeds is None: 441 | raise ValueError( 442 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 443 | ) 444 | elif prompt is not None and ( 445 | not isinstance(prompt, str) and not isinstance(prompt, list) 446 | ): 447 | raise ValueError( 448 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" 449 | ) 450 | 451 | if negative_prompt is not None and negative_prompt_embeds is not None: 452 | raise ValueError( 453 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 454 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 455 | ) 456 | 457 | if prompt_embeds is not None and negative_prompt_embeds is not None: 458 | if prompt_embeds.shape != negative_prompt_embeds.shape: 459 | raise ValueError( 460 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 461 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 462 | f" {negative_prompt_embeds.shape}." 463 | ) 464 | 465 | def prepare_latents( 466 | self, 467 | batch_size, 468 | num_channels_latents, 469 | num_frames, 470 | height, 471 | width, 472 | dtype, 473 | device, 474 | generator, 475 | noise_alpha=0.0, 476 | latents=None, 477 | ): 478 | shape = ( 479 | batch_size, 480 | num_channels_latents, 481 | num_frames, 482 | height // self.vae_scale_factor, 483 | width // self.vae_scale_factor, 484 | ) 485 | if isinstance(generator, list) and len(generator) != batch_size: 486 | raise ValueError( 487 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 488 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 489 | ) 490 | 491 | if latents is None: 492 | latents = randn_progressive( 493 | shape, 494 | dim=2, 495 | alpha=noise_alpha, 496 | generator=generator, 497 | device=device, 498 | dtype=dtype, 499 | ) 500 | else: 501 | latents = latents.to(device) 502 | 503 | # scale the initial noise by the standard deviation required by the scheduler 504 | latents = latents * self.scheduler.init_noise_sigma 505 | return latents 506 | 507 | @torch.no_grad() 508 | @replace_example_docstring(EXAMPLE_DOC_STRING) 509 | def __call__( 510 | self, 511 | prompt: Union[str, List[str]] = None, 512 | height: Optional[int] = None, 513 | width: Optional[int] = None, 514 | num_frames: int = 8, 515 | num_inference_steps: int = 50, 516 | num_generated_clips: int = 1, 517 | extend_overlap_frames: int = 0, 518 | extend_noise_recycle: int = 0, 519 | extend_denoise_guidance: float = 0.0, 520 | guidance_scale: float = 7.5, 521 | negative_prompt: Optional[Union[str, List[str]]] = None, 522 | eta: float = 0.0, 523 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 524 | noise_alpha: float = 0.0, 525 | latents: Optional[torch.FloatTensor] = None, 526 | prompt_embeds: Optional[torch.FloatTensor] = None, 527 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 528 | return_dict: bool = True, 529 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 530 | callback_steps: int = 1, 531 | ): 532 | r""" 533 | Function invoked when calling the pipeline for generation. 534 | 535 | Args: 536 | prompt (`str` or `List[str]`, *optional*): 537 | The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. 538 | instead. 539 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 540 | The height in pixels of the generated video. 541 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 542 | The width in pixels of the generated video. 543 | num_frames (`int`, *optional*, defaults to 16): 544 | The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds 545 | amounts to 2 seconds of video. 546 | num_inference_steps (`int`, *optional*, defaults to 50): 547 | The number of denoising steps. More denoising steps usually lead to a higher quality videos at the 548 | expense of slower inference. 549 | extend_clip_num (`int`, *optional*, defaults to 0): 550 | The number of extended clips following the generated base video clip. 551 | extend_overlap_frames (`int`, *optional*, defaults to 0): 552 | The number of overlapped frames between two clips in the range 553 | of `[0, N]` where N is the number of frames in a single clip. 554 | extend_noise_recycle (`int`, *optional*, defaults to 0): 555 | The strength of reusing the initial noise of the previous clip on 556 | the newly extended non-overlapped frames. The range of its value 557 | is `[0, +inf)` and smaller brings more randomness. Setting to 0 558 | means totally random noise. 559 | extend_denoise_guidance (`float`, *optional*, defaults to .0): 560 | The guidance strength of denoising on the overlapped frames. The 561 | range of its value is `[0, 1]`. Smaller means weaker guidance and 562 | 0 means no steps of guidance are used. 563 | num_inference_iteration (`int`, *optional*, defaults to 1): 564 | The number of iterations for generating long videos. 565 | guidance_scale (`float`, *optional*, defaults to 7.5): 566 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 567 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 568 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 569 | 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, 570 | usually at the expense of lower video quality. 571 | negative_prompt (`str` or `List[str]`, *optional*): 572 | The prompt or prompts not to guide the video generation. If not defined, one has to pass 573 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 574 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 575 | eta (`float`, *optional*, defaults to 0.0): 576 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 577 | [`schedulers.DDIMScheduler`], will be ignored for others. 578 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 579 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 580 | to make generation deterministic. 581 | noise_alpha (`float`, *optional*, defaults to 0.0): 582 | Factor for generating initial noise. Ref: [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474) 583 | latents (`torch.FloatTensor`, *optional*): 584 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video 585 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 586 | tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape 587 | `(batch_size, num_channel, num_frames, height, width)`. 588 | prompt_embeds (`torch.FloatTensor`, *optional*): 589 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 590 | provided, text embeddings will be generated from `prompt` input argument. 591 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 592 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 593 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 594 | argument. 595 | return_dict (`bool`, *optional*, defaults to `True`): 596 | Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a 597 | plain tuple. 598 | callback (`Callable`, *optional*): 599 | A function that will be called every `callback_steps` steps during inference. The function will be 600 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 601 | callback_steps (`int`, *optional*, defaults to 1): 602 | The frequency at which the `callback` function will be called. If not specified, the callback will be 603 | called at every step. 604 | 605 | Examples: 606 | 607 | Returns: 608 | [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: 609 | [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 610 | When returning a tuple, the first element is a list with the generated frames. 611 | """ 612 | # 0. Default height and width to unet 613 | height = height or self.unet.config.sample_size * self.vae_scale_factor 614 | width = width or self.unet.config.sample_size * self.vae_scale_factor 615 | 616 | num_images_per_prompt = 1 617 | 618 | # 1. Check inputs. Raise error if not correct 619 | self.check_inputs( 620 | prompt, 621 | height, 622 | width, 623 | callback_steps, 624 | negative_prompt, 625 | prompt_embeds, 626 | negative_prompt_embeds, 627 | ) 628 | 629 | # 2. Define call parameters 630 | if prompt is not None and isinstance(prompt, str): 631 | batch_size = 1 632 | elif prompt is not None and isinstance(prompt, list): 633 | batch_size = len(prompt) 634 | else: 635 | batch_size = prompt_embeds.shape[0] 636 | 637 | device = self._execution_device 638 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 639 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 640 | # corresponds to doing no classifier free guidance. 641 | do_classifier_free_guidance = guidance_scale > 1.0 642 | 643 | # 3. Encode input prompt 644 | prompt_embeds = self._encode_prompt( 645 | prompt, 646 | device, 647 | num_images_per_prompt, 648 | do_classifier_free_guidance, 649 | negative_prompt, 650 | prompt_embeds=prompt_embeds, 651 | negative_prompt_embeds=negative_prompt_embeds, 652 | ) 653 | 654 | # 4. Prepare timesteps 655 | self.scheduler.set_timesteps(num_inference_steps, device=device) 656 | timesteps = self.scheduler.timesteps 657 | 658 | # 5. Prepare latent variables 659 | num_channels_latents = self.unet.in_channels 660 | latents = self.prepare_latents( 661 | batch_size * num_images_per_prompt, 662 | num_channels_latents, 663 | num_frames, 664 | height, 665 | width, 666 | prompt_embeds.dtype, 667 | device, 668 | generator, 669 | noise_alpha=noise_alpha, 670 | latents=latents, 671 | ) 672 | 673 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 674 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 675 | 676 | # 7. Denoising loop 677 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 678 | 679 | latents_past_rev = dict() 680 | for clip_idx in range(num_generated_clips): 681 | if clip_idx > 0: 682 | latents[:, :, :extend_overlap_frames, ...] = latents_past_rev[0][ 683 | :, :, :extend_overlap_frames, ... 684 | ] 685 | more_noise = randn_base( 686 | latents.shape, 687 | mean=0.0, 688 | std=1 / sqrt(1 + extend_noise_recycle**2), 689 | generator=generator, 690 | device=device, 691 | ) 692 | latents[:, :, extend_overlap_frames - num_frames :, ...] = ( 693 | extend_noise_recycle / sqrt(1 + extend_noise_recycle**2) 694 | ) * latents_past_rev[0][ 695 | :, :, extend_overlap_frames - num_frames :, ... 696 | ] + more_noise[ 697 | :, :, extend_overlap_frames - num_frames :, ... 698 | ] 699 | for i, t in enumerate(timesteps): 700 | # store latest latents 701 | if clip_idx > 0: 702 | for frm in range(extend_overlap_frames): 703 | if ( 704 | i 705 | < (extend_overlap_frames - frm) 706 | * (len(timesteps) / extend_overlap_frames) 707 | * extend_denoise_guidance 708 | ): 709 | latents[:, :, frm, ...] = latents_past_rev[i][ 710 | :, :, frm, ... 711 | ] 712 | latents_past_rev[i] = latents.flip(dims=[2]) 713 | 714 | # expand the latents if we are doing classifier free guidance 715 | latent_model_input = ( 716 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 717 | ) 718 | latent_model_input = self.scheduler.scale_model_input( 719 | latent_model_input, t 720 | ) 721 | 722 | # predict the noise residual 723 | noise_pred = self.unet( 724 | latent_model_input, 725 | t, 726 | encoder_hidden_states=prompt_embeds, 727 | ).sample 728 | 729 | # perform guidance 730 | if do_classifier_free_guidance: 731 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 732 | noise_pred = noise_pred_uncond + guidance_scale * ( 733 | noise_pred_text - noise_pred_uncond 734 | ) 735 | 736 | # reshape latents 737 | bsz, channel, frames, width, height = latents.shape 738 | latents = latents.permute(0, 2, 1, 3, 4).reshape( 739 | bsz * frames, channel, width, height 740 | ) 741 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape( 742 | bsz * frames, channel, width, height 743 | ) 744 | 745 | # compute the previous noisy sample x_t -> x_t-1 746 | latents = self.scheduler.step( 747 | noise_pred, t, latents, **extra_step_kwargs 748 | ).prev_sample 749 | 750 | # reshape latents back 751 | latents = ( 752 | latents[None, :] 753 | .reshape(bsz, frames, channel, width, height) 754 | .permute(0, 2, 1, 3, 4) 755 | ) 756 | 757 | # call the callback, if provided 758 | if i == len(timesteps) - 1 or ( 759 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 760 | ): 761 | if callback is not None and i % callback_steps == 0: 762 | callback(i, t, latents) 763 | video_sub = self.decode_latents(latents) 764 | video = torch.cat([video, video_sub], dim=2) if clip_idx > 0 else video_sub 765 | 766 | # Offload last model to CPU 767 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 768 | self.final_offload_hook.offload() 769 | 770 | if not return_dict: 771 | return (video,) 772 | 773 | return SDVideoPipelineOutput(videos=video) 774 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import sqrt 3 | from pathlib import Path 4 | from typing import Union, Tuple, List, Optional 5 | 6 | import imageio 7 | import torch 8 | import torchvision 9 | from einops import rearrange 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def randn_base( 15 | shape: Union[Tuple, List], 16 | mean: float = 0.0, 17 | std: float = 1.0, 18 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 19 | device: Optional["torch.device"] = None, 20 | dtype: Optional["torch.dtype"] = None, 21 | ): 22 | if isinstance(generator, list): 23 | shape = (1,) + shape[1:] 24 | tensor = [ 25 | torch.normal( 26 | mean=mean, 27 | std=std, 28 | size=shape, 29 | generator=generator[i], 30 | device=device, 31 | dtype=dtype, 32 | ) 33 | for i in range(len(generator)) 34 | ] 35 | tensor = torch.cat(tensor, dim=0).to(device) 36 | else: 37 | tensor = torch.normal( 38 | mean=mean, 39 | std=std, 40 | size=shape, 41 | generator=generator, 42 | device=device, 43 | dtype=dtype, 44 | ) 45 | return tensor 46 | 47 | 48 | def randn_mixed( 49 | shape: Union[Tuple, List], 50 | dim: int, 51 | alpha: float = 0.0, 52 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 53 | device: Optional["torch.device"] = None, 54 | dtype: Optional["torch.dtype"] = None, 55 | ): 56 | """Refer to Section 4 of Preserve Your Own Correlation: 57 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474) 58 | """ 59 | shape_shared = shape[:dim] + (1,) + shape[dim + 1 :] 60 | 61 | # shared random tensor 62 | shared_std = alpha**2 / (1.0 + alpha**2) 63 | shared_tensor = randn_base( 64 | shape=shape_shared, 65 | mean=0.0, 66 | std=shared_std, 67 | generator=generator, 68 | device=device, 69 | dtype=dtype, 70 | ) 71 | 72 | # individual random tensor 73 | indv_std = 1.0 / (1.0 + alpha**2) 74 | indv_tensor = randn_base( 75 | shape=shape, 76 | mean=0.0, 77 | std=indv_std, 78 | generator=generator, 79 | device=device, 80 | dtype=dtype, 81 | ) 82 | 83 | return shared_tensor + indv_tensor 84 | 85 | 86 | def randn_progressive( 87 | shape: Union[Tuple, List], 88 | dim: int, 89 | alpha: float = 0.0, 90 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 91 | device: Optional["torch.device"] = None, 92 | dtype: Optional["torch.dtype"] = None, 93 | ): 94 | """Refer to Section 4 of Preserve Your Own Correlation: 95 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474) 96 | """ 97 | num_prog = shape[dim] 98 | shape_slice = shape[:dim] + (1,) + shape[dim + 1 :] 99 | tensors = [ 100 | randn_base( 101 | shape=shape_slice, 102 | mean=0.0, 103 | std=1.0, 104 | generator=generator, 105 | device=device, 106 | dtype=dtype, 107 | ) 108 | ] 109 | beta = alpha / sqrt(1.0 + alpha**2) 110 | std = 1.0 / (1.0 + alpha**2) 111 | for i in range(1, num_prog): 112 | tensor_i = beta * tensors[-1] + randn_base( 113 | shape=shape_slice, 114 | mean=0.0, 115 | std=std, 116 | generator=generator, 117 | device=device, 118 | dtype=dtype, 119 | ) 120 | tensors.append(tensor_i) 121 | tensors = torch.cat(tensors, dim=dim) 122 | return tensors 123 | 124 | 125 | def save_videos_grid(videos, path, rescale=False, n_rows=4, fps=4): 126 | if videos.dim() == 4: 127 | videos = videos.unsqueeze(0) 128 | videos = rearrange(videos, "b c t h w -> t b c h w") 129 | outputs = [] 130 | for x in videos: 131 | x = torchvision.utils.make_grid(x, nrow=n_rows) 132 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 133 | if rescale: 134 | x = (x + 1.0) / 2.0 # [-1, 1) -> [0, 1) 135 | x = (x * 255).to(dtype=torch.uint8, device="cpu") 136 | outputs.append(x) 137 | Path(path).parent.mkdir(parents=True, exist_ok=True) 138 | imageio.mimwrite(Path(path).as_posix(), outputs, duration=1000 / fps, loop=0) 139 | 140 | 141 | @torch.no_grad() 142 | def compute_clip_score( 143 | model, model_processor, images, texts, local_bs=32, rescale=False 144 | ): 145 | if rescale: 146 | images = (images + 1.0) / 2.0 # -1,1 -> 0,1 147 | images = (images * 255).to(torch.uint8) 148 | clip_scores = [] 149 | for start_idx in range(0, images.shape[0], local_bs): 150 | img_batch = images[start_idx : start_idx + local_bs] 151 | batch_size = img_batch.shape[0] # shape: [b c t h w] 152 | img_batch = rearrange(img_batch, "b c t h w -> (b t) c h w") 153 | outputs = [] 154 | for i in range(len(img_batch)): 155 | images_part = img_batch[i : i + 1] 156 | model_inputs = model_processor( 157 | text=texts, images=list(images_part), return_tensors="pt", padding=True 158 | ) 159 | model_inputs = { 160 | k: v.to(device=model.device, dtype=model.dtype) 161 | if k in ["pixel_values"] 162 | else v.to(device=model.device) 163 | for k, v in model_inputs.items() 164 | } 165 | logits = model(**model_inputs)["logits_per_image"] 166 | # For consistency with `torchmetrics.functional.multimodal.clip_score`. 167 | logits = logits / model.logit_scale.exp() 168 | outputs.append(logits) 169 | logits = torch.cat(outputs) 170 | logits = rearrange(logits, "(b t) p -> t b p", b=batch_size) 171 | frame_sims = [] 172 | for logit in logits: 173 | frame_sims.append(logit.diagonal()) 174 | frame_sims = torch.stack(frame_sims) # [t, b] 175 | clip_scores.append(frame_sims.mean(dim=0)) 176 | return torch.cat(clip_scores) 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | Pillow 3 | decord 4 | diffusers[torch]==0.14.0 5 | einops 6 | ftfy 7 | hydra-core 8 | imageio 9 | omegaconf 10 | pandas 11 | pytorch-lightning==1.9.5 12 | scipy 13 | tensorboard 14 | torch 15 | torchvision 16 | transformers==4.30.2 17 | -------------------------------------------------------------------------------- /vae/__init__.py: -------------------------------------------------------------------------------- 1 | from .vae import TemporalAutoencoderKL 2 | -------------------------------------------------------------------------------- /vae/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from diffusers.models.unet_2d_blocks import ResnetBlock2D, Upsample2D 3 | 4 | from model.modules.utils import zero_module 5 | 6 | 7 | class TemporalConvLayer(nn.Module): 8 | """ 9 | Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: 10 | https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 11 | """ 12 | 13 | def __init__(self, in_dim, out_dim=None, num_layers=3, dropout=0.0): 14 | super().__init__() 15 | out_dim = out_dim or in_dim 16 | 17 | # conv layers 18 | convs = [] 19 | prev_dim, next_dim = in_dim, out_dim 20 | for i in range(num_layers): 21 | if i == num_layers - 1: 22 | next_dim = out_dim 23 | convs.extend( 24 | [ 25 | nn.GroupNorm(32, prev_dim), 26 | nn.SiLU(), 27 | nn.Dropout(dropout), 28 | nn.Conv3d(prev_dim, next_dim, (3, 1, 1), padding=(1, 0, 0)), 29 | ] 30 | ) 31 | prev_dim, next_dim = next_dim, prev_dim 32 | self.convs = nn.ModuleList(convs) 33 | 34 | def forward(self, hidden_states): 35 | video_length = hidden_states.shape[2] 36 | identity = hidden_states 37 | 38 | for conv in self.convs: 39 | hidden_states = conv(hidden_states) 40 | 41 | # ignore these convolution layers on image input 42 | hidden_states = identity + hidden_states if video_length > 1 else identity 43 | 44 | return hidden_states 45 | 46 | 47 | class UpDecoderBlock2D(nn.Module): 48 | def __init__( 49 | self, 50 | in_channels: int, 51 | out_channels: int, 52 | dropout: float = 0.0, 53 | num_layers: int = 1, 54 | resnet_eps: float = 1e-6, 55 | resnet_time_scale_shift: str = "default", 56 | resnet_act_fn: str = "swish", 57 | resnet_groups: int = 32, 58 | resnet_pre_norm: bool = True, 59 | output_scale_factor=1.0, 60 | add_upsample=True, 61 | add_temp_conv=False, 62 | ): 63 | super().__init__() 64 | resnets = [] 65 | if add_temp_conv: 66 | self.temp_convs = None 67 | temp_convs = [] 68 | 69 | for i in range(num_layers): 70 | input_channels = in_channels if i == 0 else out_channels 71 | 72 | resnets.append( 73 | ResnetBlock2D( 74 | in_channels=input_channels, 75 | out_channels=out_channels, 76 | temb_channels=None, 77 | eps=resnet_eps, 78 | groups=resnet_groups, 79 | dropout=dropout, 80 | time_embedding_norm=resnet_time_scale_shift, 81 | non_linearity=resnet_act_fn, 82 | output_scale_factor=output_scale_factor, 83 | pre_norm=resnet_pre_norm, 84 | ) 85 | ) 86 | if add_temp_conv: 87 | temp_convs.append( 88 | TemporalConvLayer(out_channels, out_channels, dropout=0.1) 89 | ) 90 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 91 | 92 | self.resnets = nn.ModuleList(resnets) 93 | if add_temp_conv: 94 | self.temp_convs = nn.ModuleList(temp_convs) 95 | 96 | if add_upsample: 97 | self.upsamplers = nn.ModuleList( 98 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 99 | ) 100 | else: 101 | self.upsamplers = None 102 | 103 | def forward(self, hidden_states, num_frames=1): 104 | for layer_idx in range(len(self.resnets)): 105 | hidden_states = self.resnets[layer_idx](hidden_states, temb=None) 106 | if hasattr(self, "temp_convs"): 107 | hidden_states = hidden_states.reshape( 108 | hidden_states.shape[0] // num_frames, 109 | num_frames, 110 | *hidden_states.shape[1:], 111 | ) 112 | hidden_states = hidden_states.swapaxes(1, 2) 113 | hidden_states = self.temp_convs[layer_idx](hidden_states) 114 | hidden_states = hidden_states.swapaxes(1, 2) 115 | hidden_states = hidden_states.reshape( 116 | hidden_states.shape[0] * num_frames, *hidden_states.shape[2:] 117 | ) 118 | 119 | if self.upsamplers is not None: 120 | for upsampler in self.upsamplers: 121 | hidden_states = upsampler(hidden_states) 122 | 123 | return hidden_states 124 | 125 | 126 | def get_up_block( 127 | up_block_type, 128 | num_layers, 129 | in_channels, 130 | out_channels, 131 | prev_output_channel, 132 | temb_channels, 133 | add_upsample, 134 | resnet_eps, 135 | resnet_act_fn, 136 | attn_num_head_channels, 137 | resnet_groups=None, 138 | cross_attention_dim=None, 139 | dual_cross_attention=False, 140 | use_linear_projection=False, 141 | only_cross_attention=False, 142 | upcast_attention=False, 143 | resnet_time_scale_shift="default", 144 | add_temp_conv=False, 145 | ): 146 | up_block_type = ( 147 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 148 | ) 149 | if up_block_type == "UpDecoderBlock2D": 150 | return UpDecoderBlock2D( 151 | num_layers=num_layers, 152 | in_channels=in_channels, 153 | out_channels=out_channels, 154 | add_upsample=add_upsample, 155 | resnet_eps=resnet_eps, 156 | resnet_act_fn=resnet_act_fn, 157 | resnet_groups=resnet_groups, 158 | resnet_time_scale_shift=resnet_time_scale_shift, 159 | add_temp_conv=add_temp_conv, 160 | ) 161 | 162 | raise ValueError( 163 | f"{up_block_type} does not exist. Please refer to: `from diffusers.models.vae import get_up_block'" 164 | ) 165 | -------------------------------------------------------------------------------- /vae/vae.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Union, Optional, Tuple 4 | 5 | import torch 6 | from diffusers.configuration_utils import register_to_config 7 | from diffusers.models import AutoencoderKL 8 | from diffusers.models.autoencoder_kl import AutoencoderKLOutput 9 | from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D 10 | from diffusers.models.vae import DecoderOutput, UNetMidBlock2D 11 | from torch import nn 12 | from diffusers.utils import apply_forward_hook 13 | from .modules import get_up_block 14 | 15 | 16 | class TemporalDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | in_channels=3, 20 | out_channels=3, 21 | up_block_types=("UpDecoderBlock2D",), 22 | block_out_channels=(64,), 23 | layers_per_block=2, 24 | norm_num_groups=32, 25 | act_fn="silu", 26 | add_temp_conv=False, 27 | ): 28 | super().__init__() 29 | self.layers_per_block = layers_per_block 30 | 31 | self.conv_in = nn.Conv2d( 32 | in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1 33 | ) 34 | 35 | self.mid_block = None 36 | self.up_blocks = nn.ModuleList([]) 37 | 38 | # mid 39 | self.mid_block = UNetMidBlock2D( 40 | in_channels=block_out_channels[-1], 41 | resnet_eps=1e-6, 42 | resnet_act_fn=act_fn, 43 | output_scale_factor=1, 44 | resnet_time_scale_shift="default", 45 | attn_num_head_channels=None, 46 | resnet_groups=norm_num_groups, 47 | temb_channels=None, 48 | ) 49 | 50 | # up 51 | reversed_block_out_channels = list(reversed(block_out_channels)) 52 | output_channel = reversed_block_out_channels[0] 53 | for i, up_block_type in enumerate(up_block_types): 54 | prev_output_channel = output_channel 55 | output_channel = reversed_block_out_channels[i] 56 | 57 | is_final_block = i == len(block_out_channels) - 1 58 | 59 | up_block = get_up_block( 60 | up_block_type, 61 | num_layers=self.layers_per_block + 1, 62 | in_channels=prev_output_channel, 63 | out_channels=output_channel, 64 | prev_output_channel=None, 65 | add_upsample=not is_final_block, 66 | resnet_eps=1e-6, 67 | resnet_act_fn=act_fn, 68 | resnet_groups=norm_num_groups, 69 | attn_num_head_channels=None, 70 | temb_channels=None, 71 | add_temp_conv=add_temp_conv, 72 | ) 73 | self.up_blocks.append(up_block) 74 | 75 | # out 76 | self.conv_norm_out = nn.GroupNorm( 77 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 78 | ) 79 | self.conv_act = nn.SiLU() 80 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 81 | 82 | def forward(self, z, num_frames=1): 83 | sample = z 84 | sample = self.conv_in(sample) 85 | 86 | # middle 87 | sample = self.mid_block(sample) 88 | 89 | # up 90 | for up_block in self.up_blocks: 91 | sample = up_block(sample, num_frames) 92 | 93 | # post-process 94 | sample = self.conv_norm_out(sample) 95 | sample = self.conv_act(sample) 96 | sample = self.conv_out(sample) 97 | 98 | return sample 99 | 100 | 101 | class TemporalAutoencoderKL(AutoencoderKL): 102 | r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma 103 | and Max Welling. 104 | 105 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library 106 | implements for all the model (such as downloading or saving, etc.) 107 | 108 | Parameters: 109 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image. 110 | out_channels (int, *optional*, defaults to 3): Number of channels in the output. 111 | down_block_types (`Tuple[str]`, *optional*, defaults to : 112 | obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. 113 | up_block_types (`Tuple[str]`, *optional*, defaults to : 114 | obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. 115 | block_out_channels (`Tuple[int]`, *optional*, defaults to : 116 | obj:`(64,)`): Tuple of block output channels. 117 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 118 | latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. 119 | sample_size (`int`, *optional*, defaults to `32`): TODO 120 | scaling_factor (`float`, *optional*, defaults to 0.18215): 121 | The component-wise standard deviation of the trained latent space computed using the first batch of the 122 | training set. This is used to scale the latent space to have unit variance when training the diffusion 123 | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the 124 | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 125 | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image 126 | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. 127 | """ 128 | _supports_gradient_checkpointing = True 129 | 130 | @register_to_config 131 | def __init__( 132 | self, 133 | in_channels: int = 3, 134 | out_channels: int = 3, 135 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), 136 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), 137 | block_out_channels: Tuple[int] = (64,), 138 | layers_per_block: int = 1, 139 | act_fn: str = "silu", 140 | latent_channels: int = 4, 141 | norm_num_groups: int = 32, 142 | sample_size: int = 32, 143 | scaling_factor: float = 0.18215, 144 | add_temp_conv: bool = False, 145 | ): 146 | super().__init__( 147 | in_channels=in_channels, 148 | out_channels=out_channels, 149 | down_block_types=down_block_types, 150 | up_block_types=up_block_types, 151 | block_out_channels=block_out_channels, 152 | layers_per_block=layers_per_block, 153 | act_fn=act_fn, 154 | latent_channels=latent_channels, 155 | norm_num_groups=norm_num_groups, 156 | sample_size=sample_size, 157 | scaling_factor=scaling_factor, 158 | ) 159 | 160 | # pass init params to Decoder 161 | self.decoder = TemporalDecoder( 162 | in_channels=latent_channels, 163 | out_channels=out_channels, 164 | up_block_types=up_block_types, 165 | block_out_channels=block_out_channels, 166 | layers_per_block=layers_per_block, 167 | norm_num_groups=norm_num_groups, 168 | act_fn=act_fn, 169 | add_temp_conv=add_temp_conv, 170 | ) 171 | 172 | def get_last_layer(self): 173 | if ( 174 | hasattr(self.decoder.up_blocks[-1], "temp_convs") 175 | and not self.decoder.conv_out.weight.requires_grad 176 | ): 177 | return self.decoder.up_blocks[-1].temp_convs[-1].convs[-1].weight 178 | else: 179 | return self.decoder.conv_out.weight 180 | 181 | def _set_gradient_checkpointing(self, module, value=False): 182 | if isinstance(module, (DownEncoderBlock2D, UpDecoderBlock2D)): 183 | module.gradient_checkpointing = value 184 | 185 | def tiled_decode( 186 | self, z: torch.FloatTensor, num_frames: int = 1, return_dict: bool = True 187 | ) -> Union[DecoderOutput, torch.FloatTensor]: 188 | r"""Decode a batch of images using a tiled decoder. 189 | Args: 190 | When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several 191 | steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: 192 | different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the 193 | tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the 194 | look of the output, but they should be much less noticeable. 195 | z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to 196 | `True`): 197 | Whether or not to return a [`DecoderOutput`] instead of a plain tuple. 198 | """ 199 | overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) 200 | blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) 201 | row_limit = self.tile_sample_min_size - blend_extent 202 | 203 | # Split z into overlapping 64x64 tiles and decode them separately. 204 | # The tiles have an overlap to avoid seams between tiles. 205 | rows = [] 206 | for i in range(0, z.shape[2], overlap_size): 207 | row = [] 208 | for j in range(0, z.shape[3], overlap_size): 209 | tile = z[ 210 | :, 211 | :, 212 | i : i + self.tile_latent_min_size, 213 | j : j + self.tile_latent_min_size, 214 | ] 215 | tile = self.post_quant_conv(tile) 216 | decoded = self.decoder(tile, num_frames) 217 | row.append(decoded) 218 | rows.append(row) 219 | result_rows = [] 220 | for i, row in enumerate(rows): 221 | result_row = [] 222 | for j, tile in enumerate(row): 223 | # blend the above tile and the left tile 224 | # to the current tile and add the current tile to the result row 225 | if i > 0: 226 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 227 | if j > 0: 228 | tile = self.blend_h(row[j - 1], tile, blend_extent) 229 | result_row.append(tile[:, :, :row_limit, :row_limit]) 230 | result_rows.append(torch.cat(result_row, dim=3)) 231 | 232 | dec = torch.cat(result_rows, dim=2) 233 | if not return_dict: 234 | return (dec,) 235 | 236 | return DecoderOutput(sample=dec) 237 | 238 | def _decode( 239 | self, z: torch.FloatTensor, num_frames: int = 1, return_dict: bool = True 240 | ) -> Union[DecoderOutput, torch.FloatTensor]: 241 | if self.use_tiling and ( 242 | z.shape[-1] > self.tile_latent_min_size 243 | or z.shape[-2] > self.tile_latent_min_size 244 | ): 245 | return self.tiled_decode(z, num_frames=num_frames, return_dict=return_dict) 246 | 247 | z = self.post_quant_conv(z) 248 | dec = self.decoder(z, num_frames=num_frames) 249 | 250 | if not return_dict: 251 | return (dec,) 252 | 253 | return DecoderOutput(sample=dec) 254 | 255 | @apply_forward_hook 256 | def decode( 257 | self, z: torch.FloatTensor, num_frames: int = 1, return_dict: bool = True 258 | ) -> Union[DecoderOutput, torch.FloatTensor]: 259 | if self.use_slicing and z.shape[0] > 1: 260 | decoded_slices = [ 261 | self._decode(z_slice, num_frames=num_frames).sample 262 | for z_slice in z.split(1) 263 | ] 264 | decoded = torch.cat(decoded_slices) 265 | else: 266 | decoded = self._decode(z, num_frames=num_frames).sample 267 | 268 | if not return_dict: 269 | return (decoded,) 270 | 271 | return DecoderOutput(sample=decoded) 272 | 273 | def forward( 274 | self, 275 | sample: torch.FloatTensor, 276 | num_frames: int, 277 | sample_posterior: bool = False, 278 | return_dict: bool = True, 279 | generator: Optional[torch.Generator] = None, 280 | ) -> ( 281 | Union[DecoderOutput, torch.FloatTensor], 282 | Union[AutoencoderKLOutput, torch.FloatTensor], 283 | ): 284 | r""" 285 | Args: 286 | sample (`torch.FloatTensor`): Input sample. 287 | sample_posterior (`bool`, *optional*, defaults to `False`): 288 | Whether to sample from the posterior. 289 | return_dict (`bool`, *optional*, defaults to `True`): 290 | Whether or not to return a [`DecoderOutput`] instead of a plain tuple. 291 | """ 292 | x = sample 293 | posterior = self.encode(x).latent_dist 294 | if sample_posterior: 295 | z = posterior.sample(generator=generator) 296 | else: 297 | z = posterior.mode() 298 | dec = self.decode(z, num_frames=num_frames).sample 299 | 300 | if not return_dict: 301 | return dec, posterior 302 | 303 | return DecoderOutput(sample=dec) 304 | 305 | @classmethod 306 | def from_pretrained(cls, pretrained_model_path, subfolder=None, **kwargs): 307 | if subfolder is not None: 308 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 309 | 310 | config_file = os.path.join(pretrained_model_path, "config.json") 311 | if not os.path.isfile(config_file): 312 | raise RuntimeError(f"{config_file} does not exist") 313 | with open(config_file, "r") as f: 314 | config = json.load(f) 315 | config["_class_name"] = cls.__name__ 316 | 317 | from diffusers.utils import WEIGHTS_NAME 318 | 319 | model = cls.from_config(config, **kwargs) 320 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 321 | if not os.path.isfile(model_file): 322 | raise RuntimeError(f"{model_file} does not exist") 323 | state_dict = torch.load(model_file, map_location="cpu") 324 | for k, v in model.state_dict().items(): 325 | if "temp_" in k and k not in state_dict: 326 | state_dict.update({k: v}) 327 | model.load_state_dict(state_dict) 328 | 329 | return model 330 | --------------------------------------------------------------------------------