├── .DS_Store ├── README.md ├── animatediff ├── .DS_Store ├── data │ ├── __pycache__ │ │ └── dataset.cpython-310.pyc │ ├── dataset.py │ └── processor.py ├── models │ ├── __pycache__ │ │ ├── attention.cpython-310.pyc │ │ ├── controlnet.cpython-310.pyc │ │ ├── motion_module.cpython-310.pyc │ │ ├── resnet.cpython-310.pyc │ │ ├── unet.cpython-310.pyc │ │ └── unet_blocks.cpython-310.pyc │ ├── attention.py │ ├── controlnet.py │ ├── motion_module.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ ├── __pycache__ │ │ └── pipeline_animation.cpython-310.pyc │ └── pipeline_animation.py └── utils │ ├── __pycache__ │ ├── convert_from_ckpt.cpython-310.pyc │ ├── convert_lora_safetensor_to_diffusers.cpython-310.pyc │ └── util.cpython-310.pyc │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── animatetest.py ├── configs ├── inference │ ├── inference-v1.yaml │ └── inference-v2.yaml ├── prompts │ └── v2 │ │ └── 5-RealisticVision.yaml └── training │ ├── image_finetune.yaml │ └── training.yaml ├── download_data.py ├── imgs ├── .DS_Store ├── 0.gif ├── 1.gif ├── 2.gif ├── 3.gif ├── 4.gif └── 5.gif ├── init_images ├── .DS_Store ├── 0.jpg ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg └── 5.jpg ├── newanimate.yaml ├── requirements.txt ├── scripts ├── __pycache__ │ └── animate.cpython-310.pyc └── animate.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Animatediff with controlnet 2 | ### Descirption: Add a controlnet to animatediff to animate a given image. 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 29 | [Animatediff](https://github.com/guoyww/AnimateDiff) is a recent animation project based on SD, which produces excellent results. This repository aims to enhance Animatediff in two ways: 30 | 31 | 1. Animating a specific image: Starting from a given image and utilizing controlnet, it maintains the appearance of the image while animating it. 32 | 33 | 2. Upgrading the previous code's diffusers version: The previous code used diffusers version 0.11.1, and the upgraded version now uses diffusers version 0.21.4. This allows for the extension of Animatediff to include more features from diffusers, such as controlnet. 34 | 35 | #### TODO: 36 | 37 | - [x] Release the train and inference code 38 | - [x] Release the controlnet [checkpoint](https://huggingface.co/crishhh/animatediff_controlnet) 39 | - [ ] Reduce the GPU memory usage of controlnet in the code 40 | - [ ] Others 41 | 42 | #### How to start (inference) 43 | 44 | 1. Prepare the environment 45 | 46 | ```python 47 | conda env create -f newanimate.yaml 48 | # Or 49 | conda create --name newanimate python=3.10 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | 2. Download the models according to [AnimateDiff](https://github.com/guoyww/AnimateDiff), put them in ./models. Download the controlnet [checkpoint](https://huggingface.co/crishhh/animatediff_controlnet), put them in ./checkpoints. 54 | 55 | 3. Prepare the prompts and initial image(Prepare the prompts and initial image) 56 | 57 | Note that the prompts are important for the animation, here I use the MiniGPT-4, and the prompt to [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) is "Please output the perfect description prompt of this picture into the StableDiffusion model, and separate the description into multiple keywords with commas" 58 | 59 | 4. Modify the YAML file (location: ./configs/prompts/v2/5-RealisticVision.yaml) 60 | 61 | 5. Run the demo 62 | 63 | ```python 64 | python animatetest.py 65 | ``` 66 | 67 | #### How to train 68 | 69 | 1. Download the datasets (WebVid-10M) 70 | 71 | ```python 72 | python download_data.py 73 | ``` 74 | 75 | 2. Run the train 76 | 77 | ```python 78 | python train.py 79 | ``` 80 | 81 | #### Limitations 82 | 83 | 1. The current ControlNet version has been trained on a subset of WebVid-10M, comprising approximately 5,000 video-caption pairs. As a result, its performance is not very satisfactory, and work is underway to train ControlNet on larger datasets. 84 | 2. Some images are proving challenging to animate effectively, even when prompted with corresponding instructions. These difficulties persist when attempting to manipulate them using Animatediff without the use of ControlNet. 85 | 3. It is preferable for the image and its corresponding prompts to have a stronger alignment for better results. 86 | 87 | #### Future 88 | 89 | 1. Currently, the ControlNet in use is 2D level, and our plan is to expand it to 3D while incorporating the motion module into the ControlNet. 90 | 2. We aim to incorporate a trajectory encoder into the ControlNet branch to control the motion module. Even though this might appear to potentially conflict with the existing motion module, we still want to give it a try. 91 | 92 | #### Some Failed Attempts (Possibly Due to Missteps): 93 | 94 | 1. Injecting the encoded image by VAE into the initial latent space doesn't seem to work, it generates videos with similar styles but inconsistent appearances. 95 | 2. Performing DDIM inversion on the image to obtain noise and then denoising it, while seemingly drawing inspiration from common image editing methods, doesn't yield effective results based on our observations. 96 | 97 | The code in this repository is intended solely as an experimental demo. If you have any feedback or questions, please feel free to open an issue or contact me via email at crystallee0418@gmail.com. 98 | 99 | The code in this repository is derived from Animatediff and Diffusers. 100 | -------------------------------------------------------------------------------- /animatediff/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/.DS_Store -------------------------------------------------------------------------------- /animatediff/data/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/data/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/data/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, io, csv, math, random 3 | import numpy as np 4 | from einops import rearrange 5 | from decord import VideoReader 6 | 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | from torch.utils.data.dataset import Dataset 11 | from animatediff.utils.util import zero_rank_print 12 | 13 | 14 | 15 | 16 | class WebVid10M(Dataset): 17 | def __init__( 18 | self, 19 | csv_path, video_folder, opticalflow_folder, 20 | sample_size=256, sample_stride=4, sample_n_frames=16, 21 | is_image=False, 22 | ): 23 | zero_rank_print(f"loading annotations from {csv_path} ...") 24 | with open(csv_path, 'r') as csvfile: 25 | self.dataset = list(csv.DictReader(csvfile)) 26 | self.length = len(self.dataset) 27 | zero_rank_print(f"data scale: {self.length}") 28 | 29 | self.video_folder = video_folder 30 | self.opticalflow_folder = opticalflow_folder 31 | self.sample_stride = sample_stride 32 | self.sample_n_frames = sample_n_frames 33 | self.is_image = is_image 34 | 35 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 36 | self.pixel_transforms = transforms.Compose([ 37 | transforms.RandomHorizontalFlip(), 38 | transforms.Resize(sample_size[0]), 39 | transforms.CenterCrop(sample_size), 40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 41 | ]) 42 | 43 | 44 | 45 | def get_batch(self, idx): 46 | video_dict = self.dataset[idx] 47 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] 48 | 49 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 50 | video_reader = VideoReader(video_dir) 51 | video_length = len(video_reader) 52 | 53 | if not self.is_image: 54 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 55 | start_idx = random.randint(0, video_length - clip_length) 56 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 57 | else: 58 | batch_index = [random.randint(0, video_length - 1)] 59 | 60 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 61 | pixel_values = pixel_values / 255. 62 | del video_reader 63 | 64 | if self.is_image: 65 | pixel_values = pixel_values[0] 66 | 67 | return pixel_values, pixel_values[0], name 68 | 69 | def __len__(self): 70 | return self.length 71 | 72 | def __getitem__(self, idx): 73 | while True: 74 | try: 75 | pixel_values, image, name = self.get_batch(idx) 76 | break 77 | 78 | except Exception as e: 79 | idx = random.randint(0, self.length-1) 80 | 81 | pixel_values = self.pixel_transforms(pixel_values) # shape [16,3,256,256] 82 | sample = dict(pixel_values = pixel_values, image = pixel_values[0, :, :, :], text = name) 83 | return sample 84 | 85 | 86 | 87 | if __name__ == "__main__": 88 | from animatediff.utils.util import save_videos_grid 89 | 90 | dataset = WebVid10M( 91 | csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_train.csv", 92 | video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", 93 | sample_size=256, 94 | sample_stride=4, sample_n_frames=16, 95 | is_image=True, 96 | ) 97 | import pdb 98 | pdb.set_trace() 99 | 100 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) 101 | for idx, batch in enumerate(dataloader): 102 | print(batch["pixel_values"].shape, len(batch["text"])) 103 | # for i in range(batch["pixel_values"].shape[0]): 104 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) 105 | -------------------------------------------------------------------------------- /animatediff/data/processor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from tqdm import tqdm 4 | sys.path.append('./core') 5 | from raft import RAFT 6 | from utils import flow_viz 7 | from utils.utils import InputPadder 8 | 9 | import argparse 10 | import os, io, csv, math, random 11 | import numpy as np 12 | from einops import rearrange 13 | from decord import VideoReader 14 | import torch 15 | import torchvision.transforms as transforms 16 | from decord._ffi.base import DECORDError 17 | 18 | 19 | 20 | def extract_optical_flow(csv_path, output_dir, video_folder, sample_stride, sample_n_frames, sample_size): 21 | 22 | parser = argparse.ArgumentParser() 23 | args = parser.parse_args() 24 | model = torch.nn.DataParallel(RAFT(args)) 25 | state_dict = torch.load("/root/lh/RAFT-master/models/raft-things.pth") 26 | model.load_state_dict(state_dict) 27 | 28 | model = model.module 29 | model.to("cuda") 30 | model.eval() 31 | 32 | with open(csv_path, 'r') as csvfile: 33 | dataset = list(csv.DictReader(csvfile)) 34 | length = len(dataset) 35 | video_folder = video_folder 36 | sample_stride = sample_stride 37 | sample_n_frames = sample_n_frames 38 | 39 | pixel_transforms = transforms.Compose([ 40 | transforms.RandomHorizontalFlip(), 41 | transforms.Resize((sample_size, sample_size)), 42 | transforms.CenterCrop(sample_size), 43 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 44 | ]) 45 | 46 | with tqdm(total=length) as pbar: 47 | pbar.set_description("Steps") 48 | for idx in range(length): 49 | video_dict = dataset[idx] 50 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] 51 | output_path = output_dir+f"/{videoid}.npy" 52 | video_dir = os.path.join(video_folder, f"{videoid}.mp4") 53 | if os.path.exists(output_path): 54 | print(f"{output_path} already exists, continue") 55 | pbar.update(1) 56 | continue 57 | try: 58 | video_reader = VideoReader(video_dir) 59 | except Exception as e: 60 | print(f"Error reading video at {video_dir}, error: {e}") 61 | pbar.update(1) 62 | continue 63 | video_length = len(video_reader) 64 | 65 | # if not os.path.exists(output_path): 66 | # os.mkdir(output_path) 67 | 68 | clip_length = min(video_length, (sample_n_frames - 1) * sample_stride + 1) 69 | start_idx = random.randint(0, video_length - clip_length) 70 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_n_frames, dtype=int) 71 | 72 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 73 | pixel_values = pixel_values / 255. 74 | del video_reader 75 | 76 | 77 | 78 | pixel_values = pixel_transforms(pixel_values) # shape [16,3,256,256] 79 | #---------------------------------------------- 80 | flow_ls = [] 81 | with torch.no_grad(): 82 | padder = InputPadder(pixel_values[0].shape) 83 | for j in range(pixel_values.shape[0]-1): 84 | 85 | image1, image2 = padder.pad(pixel_values[j], pixel_values[j+1]) 86 | image1 = image1.unsqueeze(0).cuda() 87 | image2 = image2.unsqueeze(0).cuda() 88 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 89 | # extra_channel = torch.ones((1, 1, flow_up.shape[2],flow_up.shape[3])) 90 | # flow = torch.concatenate([flow_up.cpu(),extra_channel], dim=1).squeeze(0) 91 | flow = flow_up.cpu().squeeze(0) # shape [2, 256, 256] 92 | 93 | flow_ls.append(flow) 94 | flow_ls = np.array(flow_ls) # shape [15, 2, 256, 256] 95 | np.save(output_path, flow_ls) 96 | pbar.update(1) 97 | 98 | 99 | extract_optical_flow("/root/lh/AnimateDiff-main/results_2M_val.csv",\ 100 | "/root/lh/AnimateDiff-main/dataset_opticalflow" ,\ 101 | "/root/lh/AnimateDiff-main/datasets", 4, 16, 256) -------------------------------------------------------------------------------- /animatediff/models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/motion_module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/motion_module.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/__pycache__/unet_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/unet_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | # from diffusers.modeling_utils import ModelMixin 12 | from diffusers import ModelMixin 13 | 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | # from diffusers.models.attention import CrossAttention 17 | from diffusers.models.attention_processor import Attention 18 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 19 | 20 | 21 | from einops import rearrange, repeat 22 | import pdb 23 | 24 | @dataclass 25 | class Transformer3DModelOutput(BaseOutput): 26 | sample: torch.FloatTensor 27 | 28 | 29 | if is_xformers_available(): 30 | import xformers 31 | import xformers.ops 32 | else: 33 | xformers = None 34 | 35 | 36 | class Transformer3DModel(ModelMixin, ConfigMixin): 37 | @register_to_config 38 | def __init__( 39 | self, 40 | num_attention_heads: int = 16, 41 | attention_head_dim: int = 88, 42 | in_channels: Optional[int] = None, 43 | num_layers: int = 1, 44 | dropout: float = 0.0, 45 | norm_num_groups: int = 32, 46 | cross_attention_dim: Optional[int] = None, 47 | attention_bias: bool = False, 48 | activation_fn: str = "geglu", 49 | num_embeds_ada_norm: Optional[int] = None, 50 | use_linear_projection: bool = False, 51 | only_cross_attention: bool = False, 52 | upcast_attention: bool = False, 53 | 54 | unet_use_cross_frame_attention=None, 55 | unet_use_temporal_attention=None, 56 | ): 57 | super().__init__() 58 | self.use_linear_projection = use_linear_projection 59 | self.num_attention_heads = num_attention_heads 60 | self.attention_head_dim = attention_head_dim 61 | inner_dim = num_attention_heads * attention_head_dim 62 | 63 | # Define input layers 64 | self.in_channels = in_channels 65 | 66 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 67 | if use_linear_projection: 68 | self.proj_in = nn.Linear(in_channels, inner_dim) 69 | else: 70 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 71 | 72 | # Define transformers blocks 73 | self.transformer_blocks = nn.ModuleList( 74 | [ 75 | BasicTransformerBlock( 76 | inner_dim, 77 | num_attention_heads, 78 | attention_head_dim, 79 | dropout=dropout, 80 | cross_attention_dim=cross_attention_dim, 81 | activation_fn=activation_fn, 82 | num_embeds_ada_norm=num_embeds_ada_norm, 83 | attention_bias=attention_bias, 84 | only_cross_attention=only_cross_attention, 85 | upcast_attention=upcast_attention, 86 | 87 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 88 | unet_use_temporal_attention=unet_use_temporal_attention, 89 | ) 90 | for d in range(num_layers) 91 | ] 92 | ) 93 | 94 | # 4. Define output layers 95 | if use_linear_projection: 96 | self.proj_out = nn.Linear(in_channels, inner_dim) 97 | else: 98 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 99 | 100 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 101 | # Input 102 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 103 | video_length = hidden_states.shape[2] 104 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 105 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 106 | 107 | batch, channel, height, weight = hidden_states.shape 108 | residual = hidden_states 109 | 110 | hidden_states = self.norm(hidden_states) 111 | if not self.use_linear_projection: 112 | hidden_states = self.proj_in(hidden_states) 113 | inner_dim = hidden_states.shape[1] 114 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 115 | else: 116 | inner_dim = hidden_states.shape[1] 117 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 118 | hidden_states = self.proj_in(hidden_states) 119 | 120 | # Blocks 121 | for block in self.transformer_blocks: 122 | hidden_states = block( 123 | hidden_states, 124 | encoder_hidden_states=encoder_hidden_states, 125 | timestep=timestep, 126 | # video_length=video_length 127 | ) 128 | 129 | # Output 130 | if not self.use_linear_projection: 131 | hidden_states = ( 132 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 133 | ) 134 | hidden_states = self.proj_out(hidden_states) 135 | else: 136 | hidden_states = self.proj_out(hidden_states) 137 | hidden_states = ( 138 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 139 | ) 140 | 141 | output = hidden_states + residual 142 | 143 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 144 | if not return_dict: 145 | return (output,) 146 | 147 | return Transformer3DModelOutput(sample=output) 148 | 149 | 150 | class AdaLayerNorm(nn.Module): 151 | """ 152 | Norm layer modified to incorporate timestep embeddings. 153 | """ 154 | 155 | def __init__(self, embedding_dim, num_embeddings): 156 | super().__init__() 157 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 158 | self.silu = nn.SiLU() 159 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2) 160 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) 161 | 162 | def forward(self, x, timestep): 163 | emb = self.linear(self.silu(self.emb(timestep))) 164 | scale, shift = torch.chunk(emb, 2) 165 | x = self.norm(x) * (1 + scale) + shift 166 | return x 167 | 168 | class GEGLU(nn.Module): 169 | 170 | def __init__(self, dim_in: int, dim_out: int): 171 | super().__init__() 172 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) 173 | 174 | def gelu(self, gate): 175 | if gate.device.type != "mps": 176 | return F.gelu(gate) 177 | # mps: gelu is not implemented for float16 178 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 179 | 180 | def forward(self, hidden_states, scale: float = 1.0): 181 | hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) 182 | return hidden_states * self.gelu(gate) 183 | 184 | 185 | class GELU(nn.Module): 186 | 187 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 188 | super().__init__() 189 | self.proj = nn.Linear(dim_in, dim_out) 190 | self.approximate = approximate 191 | 192 | def gelu(self, gate): 193 | if gate.device.type != "mps": 194 | return F.gelu(gate, approximate=self.approximate) 195 | # mps: gelu is not implemented for float16 196 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) 197 | 198 | def forward(self, hidden_states): 199 | hidden_states = self.proj(hidden_states) 200 | hidden_states = self.gelu(hidden_states) 201 | return hidden_states 202 | 203 | 204 | class FeedForward(nn.Module): 205 | 206 | def __init__( 207 | self, 208 | dim: int, 209 | dim_out: Optional[int] = None, 210 | mult: int = 4, 211 | dropout: float = 0.0, 212 | activation_fn: str = "geglu", 213 | final_dropout: bool = False, 214 | ): 215 | super().__init__() 216 | inner_dim = int(dim * mult) 217 | dim_out = dim_out if dim_out is not None else dim 218 | 219 | if activation_fn == "gelu": 220 | act_fn = GELU(dim, inner_dim) 221 | if activation_fn == "gelu-approximate": 222 | act_fn = GELU(dim, inner_dim, approximate="tanh") 223 | elif activation_fn == "geglu": 224 | act_fn = GEGLU(dim, inner_dim) 225 | elif activation_fn == "geglu-approximate": 226 | act_fn = ApproximateGELU(dim, inner_dim) 227 | 228 | self.net = nn.ModuleList([]) 229 | # project in 230 | self.net.append(act_fn) 231 | # project dropout 232 | self.net.append(nn.Dropout(dropout)) 233 | # project out 234 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 235 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 236 | if final_dropout: 237 | self.net.append(nn.Dropout(dropout)) 238 | 239 | def forward(self, hidden_states, scale: float = 1.0): 240 | for module in self.net: 241 | if isinstance(module, (LoRACompatibleLinear, GEGLU)): 242 | hidden_states = module(hidden_states, scale) 243 | else: 244 | hidden_states = module(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BasicTransformerBlock(nn.Module): 249 | 250 | def __init__( 251 | self, 252 | dim: int, 253 | num_attention_heads: int, 254 | attention_head_dim: int, 255 | dropout=0.0, 256 | cross_attention_dim: Optional[int] = None, 257 | activation_fn: str = "geglu", 258 | num_embeds_ada_norm: Optional[int] = None, 259 | attention_bias: bool = False, 260 | only_cross_attention: bool = False, 261 | double_self_attention: bool = False, 262 | upcast_attention: bool = False, 263 | norm_elementwise_affine: bool = True, 264 | norm_type: str = "layer_norm", 265 | final_dropout: bool = False, 266 | attention_type: str = "default", 267 | unet_use_cross_frame_attention: bool = False, 268 | unet_use_temporal_attention: bool = False 269 | ): 270 | super().__init__() 271 | self.only_cross_attention = only_cross_attention 272 | 273 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 274 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 275 | 276 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 277 | raise ValueError( 278 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 279 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 280 | ) 281 | 282 | # Define 3 blocks. Each block has its own normalization layer. 283 | # 1. Self-Attn 284 | if self.use_ada_layer_norm: 285 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 286 | elif self.use_ada_layer_norm_zero: 287 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 288 | else: 289 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 290 | self.attn1 = Attention( 291 | query_dim=dim, 292 | heads=num_attention_heads, 293 | dim_head=attention_head_dim, 294 | dropout=dropout, 295 | bias=attention_bias, 296 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 297 | upcast_attention=upcast_attention, 298 | ) 299 | 300 | # 2. Cross-Attn 301 | if cross_attention_dim is not None or double_self_attention: 302 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 303 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 304 | # the second cross attention block. 305 | self.norm2 = ( 306 | AdaLayerNorm(dim, num_embeds_ada_norm) 307 | if self.use_ada_layer_norm 308 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 309 | ) 310 | self.attn2 = Attention( 311 | query_dim=dim, 312 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 313 | heads=num_attention_heads, 314 | dim_head=attention_head_dim, 315 | dropout=dropout, 316 | bias=attention_bias, 317 | upcast_attention=upcast_attention, 318 | ) # is self-attn if encoder_hidden_states is none 319 | else: 320 | self.norm2 = None 321 | self.attn2 = None 322 | 323 | # 3. Feed-forward 324 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 325 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 326 | # let chunk size default to None 327 | self._chunk_size = None 328 | self._chunk_dim = 0 329 | 330 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 331 | # Sets chunk feed-forward 332 | self._chunk_size = chunk_size 333 | self._chunk_dim = dim 334 | 335 | def forward( 336 | self, 337 | hidden_states: torch.FloatTensor, 338 | attention_mask: Optional[torch.FloatTensor] = None, 339 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 340 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 341 | timestep: Optional[torch.LongTensor] = None, 342 | class_labels: Optional[torch.LongTensor] = None, 343 | ): 344 | # Notice that normalization is always applied before the real computation in the following blocks. 345 | # 0. Self-Attention 346 | cross_attention_kwargs = None 347 | if self.use_ada_layer_norm: 348 | norm_hidden_states = self.norm1(hidden_states, timestep) 349 | elif self.use_ada_layer_norm_zero: 350 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 351 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 352 | ) 353 | else: 354 | norm_hidden_states = self.norm1(hidden_states) 355 | 356 | # 1. Retrieve lora scale. 357 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 358 | 359 | # 2. Prepare GLIGEN inputs 360 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 361 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 362 | 363 | attn_output = self.attn1( 364 | norm_hidden_states, 365 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 366 | attention_mask=attention_mask, 367 | **cross_attention_kwargs, 368 | ) 369 | if self.use_ada_layer_norm_zero: 370 | attn_output = gate_msa.unsqueeze(1) * attn_output 371 | hidden_states = attn_output + hidden_states 372 | 373 | # 2.5 GLIGEN Control 374 | if gligen_kwargs is not None: 375 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 376 | # 2.5 ends 377 | 378 | # 3. Cross-Attention 379 | if self.attn2 is not None: 380 | norm_hidden_states = ( 381 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 382 | ) 383 | 384 | attn_output = self.attn2( 385 | norm_hidden_states, 386 | encoder_hidden_states=encoder_hidden_states, 387 | attention_mask=encoder_attention_mask, 388 | **cross_attention_kwargs, 389 | ) 390 | hidden_states = attn_output + hidden_states 391 | 392 | # 4. Feed-forward 393 | norm_hidden_states = self.norm3(hidden_states) 394 | 395 | if self.use_ada_layer_norm_zero: 396 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 397 | 398 | if self._chunk_size is not None: 399 | # "feed_forward_chunk_size" can be used to save memory 400 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 401 | raise ValueError( 402 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 403 | ) 404 | 405 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 406 | ff_output = torch.cat( 407 | [ 408 | self.ff(hid_slice, scale=lora_scale) 409 | for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) 410 | ], 411 | dim=self._chunk_dim, 412 | ) 413 | else: 414 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 415 | 416 | if self.use_ada_layer_norm_zero: 417 | ff_output = gate_mlp.unsqueeze(1) * ff_output 418 | 419 | hidden_states = ff_output + hidden_states 420 | 421 | return hidden_states 422 | -------------------------------------------------------------------------------- /animatediff/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | # from diffusers.modeling_utils import ModelMixin 13 | from diffusers.utils import BaseOutput 14 | from diffusers.utils.import_utils import is_xformers_available 15 | from diffusers.models.attention_processor import Attention 16 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 17 | 18 | 19 | from einops import rearrange, repeat 20 | import math 21 | 22 | 23 | def zero_module(module): 24 | # Zero out the parameters of a module and return it. 25 | for p in module.parameters(): 26 | p.detach().zero_() 27 | return module 28 | 29 | 30 | @dataclass 31 | class TemporalTransformer3DModelOutput(BaseOutput): 32 | sample: torch.FloatTensor 33 | 34 | 35 | if is_xformers_available(): 36 | import xformers 37 | import xformers.ops 38 | else: 39 | xformers = None 40 | 41 | 42 | def get_motion_module( 43 | in_channels, 44 | motion_module_kwargs: dict, 45 | motion_module_type: str = "Vanilla", 46 | 47 | ): 48 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 49 | 50 | 51 | class VanillaTemporalModule(nn.Module): 52 | def __init__( 53 | self, 54 | in_channels, 55 | num_attention_heads = 8, 56 | num_transformer_block = 2, 57 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 58 | cross_frame_attention_mode = None, 59 | temporal_position_encoding = False, 60 | temporal_position_encoding_max_len = 24, 61 | temporal_attention_dim_div = 1, 62 | zero_initialize = True, 63 | ): 64 | super().__init__() 65 | 66 | self.temporal_transformer = TemporalTransformer3DModel( 67 | in_channels=in_channels, 68 | num_attention_heads=num_attention_heads, 69 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 70 | num_layers=num_transformer_block, 71 | attention_block_types=attention_block_types, 72 | cross_frame_attention_mode=cross_frame_attention_mode, 73 | temporal_position_encoding=temporal_position_encoding, 74 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 75 | ) 76 | 77 | if zero_initialize: 78 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 79 | 80 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 81 | hidden_states = input_tensor 82 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 83 | 84 | output = hidden_states 85 | return output 86 | 87 | 88 | class TemporalTransformer3DModel(nn.Module): 89 | def __init__( 90 | self, 91 | in_channels, 92 | num_attention_heads, 93 | attention_head_dim, 94 | 95 | num_layers, 96 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 97 | dropout = 0.0, 98 | norm_num_groups = 32, 99 | cross_attention_dim = 768, 100 | activation_fn = "geglu", 101 | attention_bias = False, 102 | upcast_attention = False, 103 | 104 | cross_frame_attention_mode = None, 105 | temporal_position_encoding = False, 106 | temporal_position_encoding_max_len = 24, 107 | ): 108 | super().__init__() 109 | 110 | inner_dim = num_attention_heads * attention_head_dim 111 | 112 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 113 | self.proj_in = nn.Linear(in_channels, inner_dim) 114 | 115 | self.transformer_blocks = nn.ModuleList( 116 | [ 117 | TemporalTransformerBlock( 118 | dim=inner_dim, 119 | num_attention_heads=num_attention_heads, 120 | attention_head_dim=attention_head_dim, 121 | attention_block_types=attention_block_types, 122 | dropout=dropout, 123 | norm_num_groups=norm_num_groups, 124 | cross_attention_dim=cross_attention_dim, 125 | activation_fn=activation_fn, 126 | attention_bias=attention_bias, 127 | upcast_attention=upcast_attention, 128 | cross_frame_attention_mode=cross_frame_attention_mode, 129 | temporal_position_encoding=temporal_position_encoding, 130 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 131 | ) 132 | for d in range(num_layers) 133 | ] 134 | ) 135 | self.proj_out = nn.Linear(inner_dim, in_channels) 136 | 137 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 138 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 139 | video_length = hidden_states.shape[2] 140 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 141 | 142 | batch, channel, height, weight = hidden_states.shape 143 | residual = hidden_states 144 | 145 | hidden_states = self.norm(hidden_states) 146 | inner_dim = hidden_states.shape[1] 147 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 148 | hidden_states = self.proj_in(hidden_states) 149 | 150 | # Transformer Blocks 151 | for block in self.transformer_blocks: 152 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 153 | 154 | # output 155 | hidden_states = self.proj_out(hidden_states) 156 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 157 | 158 | output = hidden_states + residual 159 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 160 | 161 | return output 162 | 163 | 164 | class FeedForward(nn.Module): 165 | 166 | def __init__( 167 | self, 168 | dim: int, 169 | dim_out: Optional[int] = None, 170 | mult: int = 4, 171 | dropout: float = 0.0, 172 | activation_fn: str = "geglu", 173 | final_dropout: bool = False, 174 | ): 175 | super().__init__() 176 | inner_dim = int(dim * mult) 177 | dim_out = dim_out if dim_out is not None else dim 178 | 179 | if activation_fn == "gelu": 180 | act_fn = GELU(dim, inner_dim) 181 | if activation_fn == "gelu-approximate": 182 | act_fn = GELU(dim, inner_dim, approximate="tanh") 183 | elif activation_fn == "geglu": 184 | act_fn = GEGLU(dim, inner_dim) 185 | elif activation_fn == "geglu-approximate": 186 | act_fn = ApproximateGELU(dim, inner_dim) 187 | 188 | self.net = nn.ModuleList([]) 189 | # project in 190 | self.net.append(act_fn) 191 | # project dropout 192 | self.net.append(nn.Dropout(dropout)) 193 | # project out 194 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 195 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 196 | if final_dropout: 197 | self.net.append(nn.Dropout(dropout)) 198 | 199 | def forward(self, hidden_states, scale: float = 1.0): 200 | for module in self.net: 201 | if isinstance(module, (LoRACompatibleLinear, GEGLU)): 202 | hidden_states = module(hidden_states, scale) 203 | else: 204 | hidden_states = module(hidden_states) 205 | return hidden_states 206 | 207 | 208 | class TemporalTransformerBlock(nn.Module): 209 | def __init__( 210 | self, 211 | dim, 212 | num_attention_heads, 213 | attention_head_dim, 214 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 215 | dropout = 0.0, 216 | norm_num_groups = 32, 217 | cross_attention_dim = 768, 218 | activation_fn = "geglu", 219 | attention_bias = False, 220 | upcast_attention = False, 221 | cross_frame_attention_mode = None, 222 | temporal_position_encoding = False, 223 | temporal_position_encoding_max_len = 24, 224 | ): 225 | super().__init__() 226 | 227 | attention_blocks = [] 228 | norms = [] 229 | 230 | for block_name in attention_block_types: 231 | attention_blocks.append( 232 | VersatileAttention( 233 | attention_mode=block_name.split("_")[0], 234 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 235 | 236 | query_dim=dim, 237 | heads=num_attention_heads, 238 | dim_head=attention_head_dim, 239 | dropout=dropout, 240 | bias=attention_bias, 241 | upcast_attention=upcast_attention, 242 | 243 | cross_frame_attention_mode=cross_frame_attention_mode, 244 | temporal_position_encoding=temporal_position_encoding, 245 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 246 | ) 247 | ) 248 | norms.append(nn.LayerNorm(dim)) 249 | 250 | self.attention_blocks = nn.ModuleList(attention_blocks) 251 | self.norms = nn.ModuleList(norms) 252 | 253 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 254 | self.ff_norm = nn.LayerNorm(dim) 255 | 256 | 257 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 258 | for attention_block, norm in zip(self.attention_blocks, self.norms): 259 | norm_hidden_states = norm(hidden_states) 260 | hidden_states = attention_block( 261 | norm_hidden_states, 262 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 263 | video_length=video_length, 264 | ) + hidden_states 265 | 266 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 267 | 268 | output = hidden_states 269 | return output 270 | 271 | 272 | class PositionalEncoding(nn.Module): 273 | def __init__( 274 | self, 275 | d_model, 276 | dropout = 0., 277 | max_len = 24 278 | ): 279 | super().__init__() 280 | self.dropout = nn.Dropout(p=dropout) 281 | position = torch.arange(max_len).unsqueeze(1) 282 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 283 | pe = torch.zeros(1, max_len, d_model) 284 | pe[0, :, 0::2] = torch.sin(position * div_term) 285 | pe[0, :, 1::2] = torch.cos(position * div_term) 286 | self.register_buffer('pe', pe) 287 | 288 | def forward(self, x): 289 | x = x + self.pe[:, :x.size(1)] 290 | return self.dropout(x) 291 | 292 | 293 | class VersatileAttention(Attention): 294 | def __init__( 295 | self, 296 | attention_mode = None, 297 | cross_frame_attention_mode = None, 298 | temporal_position_encoding = False, 299 | temporal_position_encoding_max_len = 24, 300 | *args, **kwargs 301 | ): 302 | super().__init__(*args, **kwargs) 303 | assert attention_mode == "Temporal" 304 | 305 | self.attention_mode = attention_mode 306 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 307 | 308 | self.pos_encoder = PositionalEncoding( 309 | kwargs["query_dim"], 310 | dropout=0., 311 | max_len=temporal_position_encoding_max_len 312 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 313 | 314 | def extra_repr(self): 315 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 316 | 317 | def _attention(self, query, key, value, attention_mask=None): 318 | # if self.upcast_attention: 319 | # query = query.float() 320 | # key = key.float() 321 | 322 | attention_scores = torch.baddbmm( 323 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 324 | query, 325 | key.transpose(-1, -2), 326 | beta=0, 327 | alpha=self.scale, 328 | ) 329 | 330 | if attention_mask is not None: 331 | attention_scores = attention_scores + attention_mask 332 | 333 | # if self.upcast_softmax: 334 | # attention_scores = attention_scores.float() 335 | 336 | attention_probs = attention_scores.softmax(dim=-1) 337 | 338 | # cast back to the original dtype 339 | attention_probs = attention_probs.to(value.dtype) 340 | 341 | # compute attention output 342 | hidden_states = torch.bmm(attention_probs, value) 343 | 344 | # reshape hidden_states 345 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 346 | return hidden_states 347 | 348 | def reshape_batch_dim_to_heads(self, tensor): 349 | batch_size, seq_len, dim = tensor.shape 350 | head_size = self.heads 351 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 352 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 353 | return tensor 354 | 355 | def reshape_heads_to_batch_dim(self, tensor): 356 | batch_size, seq_len, dim = tensor.shape # 4096 16 320 357 | head_size = self.heads 358 | # head_size = 8 359 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 360 | # [4096, 16, 8, 40] 361 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 362 | return tensor # [32768, 16, 40] 363 | 364 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 365 | batch_size, sequence_length, _ = hidden_states.shape 366 | 367 | if self.attention_mode == "Temporal": 368 | d = hidden_states.shape[1] 369 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 370 | 371 | if self.pos_encoder is not None: 372 | hidden_states = self.pos_encoder(hidden_states) 373 | 374 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 375 | else: 376 | raise NotImplementedError 377 | 378 | encoder_hidden_states = encoder_hidden_states 379 | 380 | if self.group_norm is not None: 381 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 382 | 383 | query = self.to_q(hidden_states) 384 | dim = query.shape[-1] # [4096, 16, 320] 385 | query = self.reshape_heads_to_batch_dim(query) 386 | 387 | if self.added_kv_proj_dim is not None: 388 | raise NotImplementedError 389 | 390 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 391 | key = self.to_k(encoder_hidden_states) 392 | value = self.to_v(encoder_hidden_states) 393 | 394 | key = self.reshape_heads_to_batch_dim(key) 395 | value = self.reshape_heads_to_batch_dim(value) 396 | 397 | if attention_mask is not None: 398 | if attention_mask.shape[-1] != query.shape[1]: 399 | target_length = query.shape[1] 400 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 401 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 402 | 403 | # attention, what we cannot get enough of 404 | # if self.set_use_memory_efficient_attention_xformers: 405 | # hidden_states = self.set_use_memory_efficient_attention_xformers(query, key, value) 406 | # # self.set_use_memory_efficient_attention_xformers() 407 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input 408 | # hidden_states = hidden_states.to(query.dtype) 409 | # else: 410 | # if self._slice_size is None or query.shape[0] // self._slice_size == 1: 411 | hidden_states = self._attention(query, key, value, attention_mask) 412 | # else: 413 | # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 414 | 415 | # linear proj 416 | hidden_states = self.to_out[0](hidden_states) 417 | 418 | # dropout 419 | hidden_states = self.to_out[1](hidden_states) 420 | 421 | if self.attention_mode == "Temporal": 422 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 423 | 424 | return hidden_states 425 | 426 | class GEGLU(nn.Module): 427 | 428 | def __init__(self, dim_in: int, dim_out: int): 429 | super().__init__() 430 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) 431 | 432 | def gelu(self, gate): 433 | if gate.device.type != "mps": 434 | return F.gelu(gate) 435 | # mps: gelu is not implemented for float16 436 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 437 | 438 | def forward(self, hidden_states, scale: float = 1.0): 439 | hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) 440 | return hidden_states * self.gelu(gate) -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | from typing import Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | from einops import rearrange 23 | 24 | from diffusers.models.activations import get_activation 25 | from diffusers.models.attention import AdaGroupNorm 26 | from diffusers.models.attention_processor import SpatialNorm 27 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 28 | 29 | 30 | class InflatedConv3d(nn.Conv2d): 31 | def forward(self, x): 32 | video_length = x.shape[2] 33 | 34 | x = rearrange(x, "b c f h w -> (b f) c h w") 35 | x = super().forward(x) 36 | x = rearrange(x, "(b f) c h w -> b c f h w", f = video_length) 37 | 38 | return x 39 | 40 | class InflatedGroupNorm(nn.GroupNorm): 41 | def froward(self, x): 42 | video_length = x.shape[2] 43 | 44 | x = rearrange(x, "b c f h w -> (b f) c h w") 45 | x = super().forward(x) 46 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 47 | 48 | return x 49 | 50 | 51 | class Upsample3D(nn.Module): 52 | """A 2D upsampling layer with an optional convolution. 53 | 54 | Parameters: 55 | channels (`int`): 56 | number of channels in the inputs and outputs. 57 | use_conv (`bool`, default `False`): 58 | option to use a convolution. 59 | use_conv_transpose (`bool`, default `False`): 60 | option to use a convolution transpose. 61 | out_channels (`int`, optional): 62 | number of output channels. Defaults to `channels`. 63 | """ 64 | 65 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 66 | super().__init__() 67 | self.channels = channels 68 | self.out_channels = out_channels or channels 69 | self.use_conv = use_conv 70 | self.use_conv_transpose = use_conv_transpose 71 | self.name = name 72 | 73 | conv = None 74 | if use_conv_transpose: 75 | conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) 76 | elif use_conv: 77 | # conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) 78 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 79 | 80 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 81 | if name == "conv": 82 | self.conv = conv 83 | else: 84 | self.Conv2d_0 = conv 85 | 86 | def forward(self, hidden_states, output_size=None, scale: float = 1.0): 87 | assert hidden_states.shape[1] == self.channels 88 | 89 | if self.use_conv_transpose: 90 | return self.conv(hidden_states) 91 | 92 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 93 | # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch 94 | # https://github.com/pytorch/pytorch/issues/86679 95 | dtype = hidden_states.dtype 96 | if dtype == torch.bfloat16: 97 | hidden_states = hidden_states.to(torch.float32) 98 | 99 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 100 | if hidden_states.shape[0] >= 64: 101 | hidden_states = hidden_states.contiguous() 102 | 103 | # if `output_size` is passed we force the interpolation output 104 | # size and do not make use of `scale_factor=2` 105 | if output_size is None: 106 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 107 | else: 108 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 109 | 110 | # If the input is bfloat16, we cast back to bfloat16 111 | if dtype == torch.bfloat16: 112 | hidden_states = hidden_states.to(dtype) 113 | 114 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 115 | if self.use_conv: 116 | if self.name == "conv": 117 | if isinstance(self.conv, LoRACompatibleConv): 118 | hidden_states = self.conv(hidden_states, scale) 119 | else: 120 | hidden_states = self.conv(hidden_states) 121 | else: 122 | if isinstance(self.Conv2d_0, LoRACompatibleConv): 123 | hidden_states = self.Conv2d_0(hidden_states, scale) 124 | else: 125 | hidden_states = self.Conv2d_0(hidden_states) 126 | 127 | return hidden_states 128 | 129 | 130 | class Downsample3D(nn.Module): 131 | """A 2D downsampling layer with an optional convolution. 132 | 133 | Parameters: 134 | channels (`int`): 135 | number of channels in the inputs and outputs. 136 | use_conv (`bool`, default `False`): 137 | option to use a convolution. 138 | out_channels (`int`, optional): 139 | number of output channels. Defaults to `channels`. 140 | padding (`int`, default `1`): 141 | padding for the convolution. 142 | """ 143 | 144 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 145 | super().__init__() 146 | self.channels = channels 147 | self.out_channels = out_channels or channels 148 | self.use_conv = use_conv 149 | self.padding = padding 150 | stride = 2 151 | self.name = name 152 | 153 | if use_conv: 154 | # conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) 155 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 156 | 157 | else: 158 | assert self.channels == self.out_channels 159 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride) 160 | 161 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 162 | if name == "conv": 163 | self.Conv2d_0 = conv 164 | self.conv = conv 165 | elif name == "Conv2d_0": 166 | self.conv = conv 167 | else: 168 | self.conv = conv 169 | 170 | def forward(self, hidden_states, scale: float = 1.0): 171 | assert hidden_states.shape[1] == self.channels 172 | if self.use_conv and self.padding == 0: 173 | pad = (0, 1, 0, 1) 174 | hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) 175 | 176 | assert hidden_states.shape[1] == self.channels 177 | if isinstance(self.conv, LoRACompatibleConv): 178 | hidden_states = self.conv(hidden_states, scale) 179 | else: 180 | hidden_states = self.conv(hidden_states) 181 | 182 | return hidden_states 183 | 184 | 185 | class ResnetBlock3D(nn.Module): 186 | 187 | def __init__( 188 | self, 189 | *, 190 | in_channels, 191 | out_channels=None, 192 | conv_shortcut=False, 193 | dropout=0.0, 194 | temb_channels=512, 195 | groups=32, 196 | groups_out=None, 197 | pre_norm=True, 198 | eps=1e-6, 199 | non_linearity="swish", 200 | skip_time_act=False, 201 | time_embedding_norm="default", # default, scale_shift, ada_group, spatial 202 | kernel=None, 203 | output_scale_factor=1.0, 204 | use_in_shortcut=None, 205 | up=False, 206 | down=False, 207 | conv_shortcut_bias: bool = True, 208 | conv_2d_out_channels: Optional[int] = None, 209 | use_inflated_groupnorm: bool = True, 210 | ): 211 | super().__init__() 212 | self.pre_norm = pre_norm 213 | self.pre_norm = True 214 | self.in_channels = in_channels 215 | out_channels = in_channels if out_channels is None else out_channels 216 | self.out_channels = out_channels 217 | self.use_conv_shortcut = conv_shortcut 218 | self.up = up 219 | self.down = down 220 | self.output_scale_factor = output_scale_factor 221 | self.time_embedding_norm = time_embedding_norm 222 | self.skip_time_act = skip_time_act 223 | 224 | if groups_out is None: 225 | groups_out = groups 226 | 227 | if self.time_embedding_norm == "ada_group": 228 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) 229 | elif self.time_embedding_norm == "spatial": 230 | self.norm1 = SpatialNorm(in_channels, temb_channels) 231 | else: 232 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 233 | 234 | # self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 235 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 236 | 237 | if temb_channels is not None: 238 | if self.time_embedding_norm == "default": 239 | # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) 240 | self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) 241 | elif self.time_embedding_norm == "scale_shift": 242 | self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) 243 | elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 244 | self.time_emb_proj = None 245 | else: 246 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 247 | else: 248 | self.time_emb_proj = None 249 | 250 | if self.time_embedding_norm == "ada_group": 251 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) 252 | elif self.time_embedding_norm == "spatial": 253 | self.norm2 = SpatialNorm(out_channels, temb_channels) 254 | else: 255 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 256 | 257 | self.dropout = torch.nn.Dropout(dropout) 258 | conv_2d_out_channels = conv_2d_out_channels or out_channels 259 | # self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) 260 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 261 | 262 | self.nonlinearity = get_activation(non_linearity) 263 | 264 | self.upsample = self.downsample = None 265 | if self.up: 266 | if kernel == "fir": 267 | fir_kernel = (1, 3, 3, 1) 268 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) 269 | elif kernel == "sde_vp": 270 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 271 | else: 272 | self.upsample = Upsample2D(in_channels, use_conv=False) 273 | elif self.down: 274 | if kernel == "fir": 275 | fir_kernel = (1, 3, 3, 1) 276 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) 277 | elif kernel == "sde_vp": 278 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) 279 | else: 280 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") 281 | 282 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 283 | 284 | self.conv_shortcut = None 285 | if self.use_in_shortcut: 286 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 287 | 288 | 289 | 290 | def forward(self, input_tensor, temb, scale: float = 1.0): 291 | hidden_states = input_tensor 292 | 293 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 294 | hidden_states = self.norm1(hidden_states, temb) 295 | else: 296 | hidden_states = self.norm1(hidden_states) 297 | 298 | hidden_states = self.nonlinearity(hidden_states) 299 | 300 | # if self.upsample is not None: 301 | # # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 302 | # if hidden_states.shape[0] >= 64: 303 | # input_tensor = input_tensor.contiguous() 304 | # hidden_states = hidden_states.contiguous() 305 | # input_tensor = ( 306 | # self.upsample(input_tensor, scale=scale) 307 | # if isinstance(self.upsample, Upsample2D) 308 | # else self.upsample(input_tensor) 309 | # ) 310 | # hidden_states = ( 311 | # self.upsample(hidden_states, scale=scale) 312 | # if isinstance(self.upsample, Upsample2D) 313 | # else self.upsample(hidden_states) 314 | # ) 315 | # elif self.downsample is not None: 316 | # input_tensor = ( 317 | # self.downsample(input_tensor, scale=scale) 318 | # if isinstance(self.downsample, Downsample2D) 319 | # else self.downsample(input_tensor) 320 | # ) 321 | # hidden_states = ( 322 | # self.downsample(hidden_states, scale=scale) 323 | # if isinstance(self.downsample, Downsample2D) 324 | # else self.downsample(hidden_states) 325 | # ) 326 | 327 | hidden_states = self.conv1(hidden_states) 328 | 329 | if self.time_emb_proj is not None: 330 | if not self.skip_time_act: 331 | temb = self.nonlinearity(temb) 332 | # temb = self.time_emb_proj(temb, scale)[:, :, None, None, None] 333 | temb = self.time_emb_proj(temb)[:, :, None, None, None] 334 | 335 | 336 | if temb is not None and self.time_embedding_norm == "default": 337 | hidden_states = hidden_states + temb 338 | 339 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": 340 | hidden_states = self.norm2(hidden_states, temb) 341 | else: 342 | hidden_states = self.norm2(hidden_states) 343 | 344 | if temb is not None and self.time_embedding_norm == "scale_shift": 345 | scale, shift = torch.chunk(temb, 2, dim=1) 346 | hidden_states = hidden_states * (1 + scale) + shift 347 | 348 | hidden_states = self.nonlinearity(hidden_states) 349 | 350 | hidden_states = self.dropout(hidden_states) 351 | hidden_states = self.conv2(hidden_states) 352 | 353 | if self.conv_shortcut is not None: 354 | input_tensor = self.conv_shortcut(input_tensor) 355 | 356 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 357 | 358 | return output_tensor 359 | 360 | 361 | -------------------------------------------------------------------------------- /animatediff/models/unet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/root/autodl-tmp/code/animatediff/modelshigh") 3 | from dataclasses import dataclass 4 | import os 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | import json 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers.loaders import UNet2DConditionLoadersMixin 14 | from diffusers.utils import BaseOutput, logging 15 | from diffusers.models.attention_processor import ( 16 | ADDED_KV_ATTENTION_PROCESSORS, 17 | CROSS_ATTENTION_PROCESSORS, 18 | AttentionProcessor, 19 | AttnAddedKVProcessor, 20 | AttnProcessor, 21 | ) 22 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 23 | from diffusers.models.modeling_utils import ModelMixin 24 | # from diffusers.models.transformer_temporal import TransformerTemporalModel 25 | from animatediff.models.unet_blocks import ( 26 | CrossAttnDownBlock3D, 27 | CrossAttnUpBlock3D, 28 | DownBlock3D, 29 | UNetMidBlock3DCrossAttn, 30 | UpBlock3D, 31 | get_down_block, 32 | get_up_block, 33 | ) 34 | 35 | from .resnet import InflatedConv3d, InflatedGroupNorm 36 | 37 | 38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 39 | 40 | 41 | @dataclass 42 | class UNet3DConditionOutput(BaseOutput): 43 | """ 44 | The output of [`UNet3DConditionModel`]. 45 | 46 | Args: 47 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): 48 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 49 | """ 50 | 51 | sample: torch.FloatTensor 52 | 53 | 54 | class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 55 | 56 | _supports_gradient_checkpointing = False 57 | 58 | @register_to_config 59 | def __init__( 60 | self, 61 | sample_size: Optional[int] = None, 62 | in_channels: int = 4, 63 | out_channels: int = 4, 64 | down_block_types: Tuple[str] = ( 65 | "CrossAttnDownBlock3D", 66 | "CrossAttnDownBlock3D", 67 | "CrossAttnDownBlock3D", 68 | "DownBlock3D", 69 | ), 70 | #----- 71 | mid_block_type: str = "UnetMidBlock3DCrossAttn", 72 | #----- 73 | up_block_types: Tuple[str] = ( 74 | "UpBlock3D", 75 | "CrossAttnUpBlock3D", 76 | "CrossAttnUpBlock3D", 77 | "CrossAttnUpBlock3D" 78 | ), 79 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 80 | layers_per_block: int = 2, 81 | downsample_padding: int = 1, 82 | mid_block_scale_factor: float = 1, 83 | act_fn: str = "silu", 84 | norm_num_groups: Optional[int] = 32, 85 | norm_eps: float = 1e-5, 86 | # cross_attention_dim: int = 1024, 87 | cross_attention_dim: int = 1280, 88 | # attention_head_dim: Union[int, Tuple[int]] = 64, 89 | attention_head_dim: Union[int, Tuple[int]] = 8, 90 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 91 | 92 | use_inflated_groupnorm=False, 93 | # Additional 94 | use_motion_module = False, 95 | motion_module_resolutions = ( 1,2,4,8 ), 96 | motion_module_mid_block = False, 97 | motion_module_decoder_only = False, 98 | motion_module_type = None, 99 | motion_module_kwargs = {}, 100 | unet_use_cross_frame_attention = None, 101 | unet_use_temporal_attention = None, 102 | ): 103 | super().__init__() 104 | 105 | self.sample_size = sample_size 106 | # time_embed_dim = block_out_channels[0] * 4 107 | 108 | if num_attention_heads is not None: 109 | raise NotImplementedError( 110 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." 111 | ) 112 | 113 | # If `num_attention_heads` is not defined (which is the case for most models) 114 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 115 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 116 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 117 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 118 | # which is why we correct for the naming here. 119 | num_attention_heads = num_attention_heads or attention_head_dim 120 | 121 | # Check inputs 122 | if len(down_block_types) != len(up_block_types): 123 | raise ValueError( 124 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 125 | ) 126 | 127 | if len(block_out_channels) != len(down_block_types): 128 | raise ValueError( 129 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 130 | ) 131 | 132 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 133 | raise ValueError( 134 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 135 | ) 136 | 137 | # input 138 | conv_in_kernel = 3 139 | conv_out_kernel = 3 140 | conv_in_padding = (conv_in_kernel - 1) // 2 141 | # self.conv_in = nn.Conv2d( 142 | # in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 143 | # ) 144 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 145 | 146 | # time 147 | time_embed_dim = block_out_channels[0] * 4 148 | self.time_proj = Timesteps(block_out_channels[0], True, 0) 149 | timestep_input_dim = block_out_channels[0] 150 | 151 | self.time_embedding = TimestepEmbedding( 152 | timestep_input_dim, 153 | time_embed_dim, 154 | act_fn=act_fn, 155 | ) 156 | 157 | # self.transformer_in = TransformerTemporalModel( 158 | # num_attention_heads=8, 159 | # attention_head_dim=attention_head_dim, 160 | # in_channels=block_out_channels[0], 161 | # num_layers=1, 162 | # ) 163 | 164 | # class embedding 165 | 166 | 167 | self.down_blocks = nn.ModuleList([]) 168 | self.mid_block = None 169 | self.up_blocks = nn.ModuleList([]) 170 | 171 | if isinstance(num_attention_heads, int): 172 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 173 | 174 | # down 175 | output_channel = block_out_channels[0] 176 | for i, down_block_type in enumerate(down_block_types): 177 | res = 2 ** i 178 | input_channel = output_channel 179 | output_channel = block_out_channels[i] 180 | is_final_block = i == len(block_out_channels) - 1 181 | 182 | down_block = get_down_block( 183 | down_block_type, 184 | num_layers=layers_per_block, 185 | in_channels=input_channel, 186 | out_channels=output_channel, 187 | temb_channels=time_embed_dim, 188 | add_downsample=not is_final_block, 189 | resnet_eps=norm_eps, 190 | resnet_act_fn=act_fn, 191 | resnet_groups=norm_num_groups, 192 | cross_attention_dim=cross_attention_dim, 193 | num_attention_heads=num_attention_heads[i], 194 | downsample_padding=downsample_padding, 195 | 196 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 197 | unet_use_temporal_attention=unet_use_temporal_attention, 198 | use_inflated_groupnorm=use_inflated_groupnorm, 199 | 200 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 201 | motion_module_type=motion_module_type, 202 | motion_module_kwargs=motion_module_kwargs, 203 | ) 204 | self.down_blocks.append(down_block) 205 | 206 | # mid 207 | self.mid_block = UNetMidBlock3DCrossAttn( 208 | in_channels=block_out_channels[-1], 209 | temb_channels=time_embed_dim, 210 | resnet_eps=norm_eps, 211 | resnet_act_fn=act_fn, 212 | output_scale_factor=mid_block_scale_factor, 213 | cross_attention_dim=cross_attention_dim, 214 | num_attention_heads=num_attention_heads[-1], 215 | resnet_groups=norm_num_groups, 216 | 217 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 218 | unet_use_temporal_attention=unet_use_temporal_attention, 219 | use_inflated_groupnorm=use_inflated_groupnorm, 220 | 221 | use_motion_module=use_motion_module and motion_module_mid_block, 222 | motion_module_type=motion_module_type, 223 | motion_module_kwargs=motion_module_kwargs, 224 | ) 225 | 226 | # count how many layers upsample the images 227 | self.num_upsamplers = 0 228 | 229 | # up 230 | reversed_block_out_channels = list(reversed(block_out_channels)) 231 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 232 | 233 | output_channel = reversed_block_out_channels[0] 234 | for i, up_block_type in enumerate(up_block_types): 235 | res = 2 ** (3 - i) 236 | is_final_block = i == len(block_out_channels) - 1 237 | 238 | prev_output_channel = output_channel 239 | output_channel = reversed_block_out_channels[i] 240 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 241 | 242 | # add upsample block for all BUT final layer 243 | if not is_final_block: 244 | add_upsample = True 245 | self.num_upsamplers += 1 246 | else: 247 | add_upsample = False 248 | 249 | up_block = get_up_block( 250 | up_block_type, 251 | num_layers=layers_per_block + 1, 252 | in_channels=input_channel, 253 | out_channels=output_channel, 254 | prev_output_channel=prev_output_channel, 255 | temb_channels=time_embed_dim, 256 | add_upsample=add_upsample, 257 | resnet_eps=norm_eps, 258 | resnet_act_fn=act_fn, 259 | resnet_groups=norm_num_groups, 260 | cross_attention_dim=cross_attention_dim, 261 | num_attention_heads=reversed_num_attention_heads[i], 262 | 263 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 264 | unet_use_temporal_attention=unet_use_temporal_attention, 265 | use_inflated_groupnorm=use_inflated_groupnorm, 266 | 267 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 268 | motion_module_type=motion_module_type, 269 | motion_module_kwargs=motion_module_kwargs, 270 | ) 271 | self.up_blocks.append(up_block) 272 | prev_output_channel = output_channel 273 | 274 | # out 275 | if norm_num_groups is not None: 276 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 277 | else: 278 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 279 | self.conv_act = nn.SiLU() 280 | 281 | conv_out_padding = (conv_out_kernel - 1) // 2 282 | 283 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 284 | 285 | 286 | @property 287 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors 288 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 289 | r""" 290 | Returns: 291 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 292 | indexed by its weight name. 293 | """ 294 | # set recursively 295 | processors = {} 296 | 297 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 298 | if hasattr(module, "get_processor"): 299 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 300 | 301 | for sub_name, child in module.named_children(): 302 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 303 | 304 | return processors 305 | 306 | for name, module in self.named_children(): 307 | fn_recursive_add_processors(name, module, processors) 308 | 309 | return processors 310 | 311 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 312 | def set_attention_slice(self, slice_size): 313 | r""" 314 | Enable sliced attention computation. 315 | 316 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 317 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 318 | 319 | Args: 320 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 321 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 322 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 323 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 324 | must be a multiple of `slice_size`. 325 | """ 326 | sliceable_head_dims = [] 327 | 328 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 329 | if hasattr(module, "set_attention_slice"): 330 | sliceable_head_dims.append(module.sliceable_head_dim) 331 | 332 | for child in module.children(): 333 | fn_recursive_retrieve_sliceable_dims(child) 334 | 335 | # retrieve number of attention layers 336 | for module in self.children(): 337 | fn_recursive_retrieve_sliceable_dims(module) 338 | 339 | num_sliceable_layers = len(sliceable_head_dims) 340 | 341 | if slice_size == "auto": 342 | # half the attention head size is usually a good trade-off between 343 | # speed and memory 344 | slice_size = [dim // 2 for dim in sliceable_head_dims] 345 | elif slice_size == "max": 346 | # make smallest slice possible 347 | slice_size = num_sliceable_layers * [1] 348 | 349 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 350 | 351 | if len(slice_size) != len(sliceable_head_dims): 352 | raise ValueError( 353 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 354 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 355 | ) 356 | 357 | for i in range(len(slice_size)): 358 | size = slice_size[i] 359 | dim = sliceable_head_dims[i] 360 | if size is not None and size > dim: 361 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 362 | 363 | # Recursively walk through all the children. 364 | # Any children which exposes the set_attention_slice method 365 | # gets the message 366 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 367 | if hasattr(module, "set_attention_slice"): 368 | module.set_attention_slice(slice_size.pop()) 369 | 370 | for child in module.children(): 371 | fn_recursive_set_attention_slice(child, slice_size) 372 | 373 | reversed_slice_size = list(reversed(slice_size)) 374 | for module in self.children(): 375 | fn_recursive_set_attention_slice(module, reversed_slice_size) 376 | 377 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor 378 | def set_attn_processor( 379 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False 380 | ): 381 | r""" 382 | Sets the attention processor to use to compute attention. 383 | 384 | Parameters: 385 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 386 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 387 | for **all** `Attention` layers. 388 | 389 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 390 | processor. This is strongly recommended when setting trainable attention processors. 391 | 392 | """ 393 | count = len(self.attn_processors.keys()) 394 | 395 | if isinstance(processor, dict) and len(processor) != count: 396 | raise ValueError( 397 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 398 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 399 | ) 400 | 401 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 402 | if hasattr(module, "set_processor"): 403 | if not isinstance(processor, dict): 404 | module.set_processor(processor, _remove_lora=_remove_lora) 405 | else: 406 | module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) 407 | 408 | for sub_name, child in module.named_children(): 409 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 410 | 411 | for name, module in self.named_children(): 412 | fn_recursive_attn_processor(name, module, processor) 413 | 414 | def enable_forward_chunking(self, chunk_size=None, dim=0): 415 | """ 416 | Sets the attention processor to use [feed forward 417 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 418 | 419 | Parameters: 420 | chunk_size (`int`, *optional*): 421 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 422 | over each tensor of dim=`dim`. 423 | dim (`int`, *optional*, defaults to `0`): 424 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 425 | or dim=1 (sequence length). 426 | """ 427 | if dim not in [0, 1]: 428 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 429 | 430 | # By default chunk size is 1 431 | chunk_size = chunk_size or 1 432 | 433 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 434 | if hasattr(module, "set_chunk_feed_forward"): 435 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 436 | 437 | for child in module.children(): 438 | fn_recursive_feed_forward(child, chunk_size, dim) 439 | 440 | for module in self.children(): 441 | fn_recursive_feed_forward(module, chunk_size, dim) 442 | 443 | def disable_forward_chunking(self): 444 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 445 | if hasattr(module, "set_chunk_feed_forward"): 446 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 447 | 448 | for child in module.children(): 449 | fn_recursive_feed_forward(child, chunk_size, dim) 450 | 451 | for module in self.children(): 452 | fn_recursive_feed_forward(module, None, 0) 453 | 454 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 455 | def set_default_attn_processor(self): 456 | """ 457 | Disables custom attention processors and sets the default attention implementation. 458 | """ 459 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 460 | processor = AttnAddedKVProcessor() 461 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 462 | processor = AttnProcessor() 463 | else: 464 | raise ValueError( 465 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 466 | ) 467 | 468 | self.set_attn_processor(processor, _remove_lora=True) 469 | 470 | def _set_gradient_checkpointing(self, module, value=False): 471 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 472 | module.gradient_checkpointing = value 473 | 474 | def forward( 475 | self, 476 | sample: torch.FloatTensor, 477 | timestep: Union[torch.Tensor, float, int], 478 | encoder_hidden_states: torch.Tensor, 479 | class_labels: Optional[torch.Tensor] = None, 480 | timestep_cond: Optional[torch.Tensor] = None, 481 | attention_mask: Optional[torch.Tensor] = None, 482 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 483 | mid_block_additional_residual: Optional[torch.Tensor] = None, 484 | return_dict: bool = True, 485 | ) -> Union[UNet3DConditionOutput, Tuple]: 486 | r""" 487 | The [`UNet3DConditionModel`] forward method. 488 | 489 | Args: 490 | sample (`torch.FloatTensor`): 491 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. 492 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 493 | encoder_hidden_states (`torch.FloatTensor`): 494 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 495 | return_dict (`bool`, *optional*, defaults to `True`): 496 | Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain 497 | tuple. 498 | cross_attention_kwargs (`dict`, *optional*): 499 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. 500 | 501 | Returns: 502 | [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: 503 | If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise 504 | a `tuple` is returned where the first element is the sample tensor. 505 | """ 506 | # By default samples have to be AT least a multiple of the overall upsampling factor. 507 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 508 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 509 | # on the fly if necessary. 510 | default_overall_up_factor = 2**self.num_upsamplers 511 | 512 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 513 | forward_upsample_size = False 514 | upsample_size = None 515 | 516 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 517 | logger.info("Forward upsample size to force interpolation output size.") 518 | forward_upsample_size = True 519 | 520 | # prepare attention_mask 521 | if attention_mask is not None: 522 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 523 | attention_mask = attention_mask.unsqueeze(1) 524 | 525 | # 1. time 526 | timesteps = timestep 527 | if not torch.is_tensor(timesteps): 528 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 529 | # This would be a good case for the `match` statement (Python 3.10+) 530 | is_mps = sample.device.type == "mps" 531 | if isinstance(timestep, float): 532 | dtype = torch.float32 if is_mps else torch.float64 533 | else: 534 | dtype = torch.int32 if is_mps else torch.int64 535 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 536 | elif len(timesteps.shape) == 0: 537 | timesteps = timesteps[None].to(sample.device) 538 | 539 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 540 | num_frames = sample.shape[2] 541 | timesteps = timesteps.expand(sample.shape[0]) 542 | 543 | t_emb = self.time_proj(timesteps) 544 | 545 | # timesteps does not contain any weights and will always return f32 tensors 546 | # but time_embedding might actually be running in fp16. so we need to cast here. 547 | # there might be better ways to encapsulate this. 548 | t_emb = t_emb.to(dtype=self.dtype) 549 | 550 | emb = self.time_embedding(t_emb, timestep_cond) 551 | # emb = emb.repeat_interleave(repeats=num_frames, dim=0) 552 | # encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) 553 | 554 | # 2. pre-process 555 | # sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) 556 | sample = self.conv_in(sample) 557 | 558 | # sample = self.transformer_in( 559 | # sample, 560 | # num_frames=num_frames, 561 | # cross_attention_kwargs=cross_attention_kwargs, 562 | # return_dict=False, 563 | # )[0] 564 | 565 | # 3. down 566 | down_block_res_samples = (sample,) 567 | for downsample_block in self.down_blocks: 568 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 569 | sample, res_samples = downsample_block( 570 | hidden_states=sample, 571 | temb=emb, 572 | encoder_hidden_states=encoder_hidden_states, 573 | attention_mask=attention_mask, 574 | num_frames=num_frames, 575 | ) 576 | else: 577 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 578 | 579 | down_block_res_samples += res_samples 580 | 581 | if down_block_additional_residuals is not None: 582 | new_down_block_res_samples = () 583 | 584 | for down_block_res_sample, down_block_additional_residual in zip( 585 | down_block_res_samples, down_block_additional_residuals 586 | ): 587 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2).repeat(1,1,16,1,1) 588 | down_block_res_sample = down_block_res_sample + down_block_additional_residual * 1.0 589 | new_down_block_res_samples += (down_block_res_sample,) 590 | 591 | down_block_res_samples = new_down_block_res_samples 592 | 593 | # 4. mid 594 | if self.mid_block is not None: 595 | sample = self.mid_block( 596 | sample, 597 | emb, 598 | encoder_hidden_states=encoder_hidden_states, 599 | attention_mask=attention_mask, 600 | ) 601 | 602 | if mid_block_additional_residual is not None: 603 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2).repeat(1,1,16,1,1) 604 | sample = sample + mid_block_additional_residual * 1.0 605 | 606 | # 5. up 607 | for i, upsample_block in enumerate(self.up_blocks): 608 | is_final_block = i == len(self.up_blocks) - 1 609 | 610 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 611 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 612 | 613 | # if we have not reached the final block and need to forward the 614 | # upsample size, we do it here 615 | if not is_final_block and forward_upsample_size: 616 | upsample_size = down_block_res_samples[-1].shape[2:] 617 | 618 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 619 | sample = upsample_block( 620 | hidden_states=sample, 621 | temb=emb, 622 | res_hidden_states_tuple=res_samples, 623 | encoder_hidden_states=encoder_hidden_states, 624 | upsample_size=upsample_size, 625 | attention_mask=attention_mask, 626 | ) 627 | else: 628 | sample = upsample_block( 629 | hidden_states=sample, 630 | temb=emb, 631 | res_hidden_states_tuple=res_samples, 632 | upsample_size=upsample_size, 633 | ) 634 | 635 | # 6. post-process 636 | if self.conv_norm_out: 637 | sample = self.conv_norm_out(sample) 638 | sample = self.conv_act(sample) 639 | 640 | sample = self.conv_out(sample) 641 | 642 | # reshape to (batch, channel, framerate, width, height) 643 | # sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) 644 | 645 | if not return_dict: 646 | return (sample,) 647 | 648 | return UNet3DConditionOutput(sample=sample) 649 | 650 | 651 | @classmethod 652 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 653 | if subfolder is not None: 654 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 655 | print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") 656 | 657 | config_file = os.path.join(pretrained_model_path, 'config.json') 658 | if not os.path.isfile(config_file): 659 | raise RuntimeError(f"{config_file} does not exist") 660 | with open(config_file, "r") as f: 661 | config = json.load(f) 662 | config["_class_name"] = cls.__name__ 663 | config["down_block_types"] = [ 664 | "CrossAttnDownBlock3D", 665 | "CrossAttnDownBlock3D", 666 | "CrossAttnDownBlock3D", 667 | "DownBlock3D" 668 | ] 669 | config["up_block_types"] = [ 670 | "UpBlock3D", 671 | "CrossAttnUpBlock3D", 672 | "CrossAttnUpBlock3D", 673 | "CrossAttnUpBlock3D" 674 | ] 675 | 676 | from diffusers.utils import WEIGHTS_NAME 677 | model = cls.from_config(config, **unet_additional_kwargs) 678 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 679 | if not os.path.isfile(model_file): 680 | raise RuntimeError(f"{model_file} does not exist") 681 | state_dict = torch.load(model_file, map_location="cpu") 682 | 683 | m, u = model.load_state_dict(state_dict, strict=False) 684 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 685 | # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") 686 | 687 | params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] 688 | print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") 689 | 690 | return model 691 | 692 | -------------------------------------------------------------------------------- /animatediff/models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | 18 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 19 | from .attention import Transformer3DModel 20 | from .motion_module import get_motion_module 21 | 22 | 23 | def get_down_block( 24 | down_block_type, 25 | num_layers, 26 | in_channels, 27 | out_channels, 28 | temb_channels, 29 | add_downsample, 30 | resnet_eps, 31 | resnet_act_fn, 32 | num_attention_heads, 33 | resnet_groups=None, 34 | cross_attention_dim=None, 35 | downsample_padding=None, 36 | dual_cross_attention=False, 37 | use_linear_projection=True, 38 | only_cross_attention=False, 39 | upcast_attention=False, 40 | resnet_time_scale_shift="default", 41 | 42 | unet_use_cross_frame_attention=None, 43 | unet_use_temporal_attention=None, 44 | use_inflated_groupnorm=None, 45 | 46 | use_motion_module=None, 47 | motion_module_type=None, 48 | motion_module_kwargs=None, 49 | ): 50 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 51 | 52 | if down_block_type == "DownBlock3D": 53 | return DownBlock3D( 54 | num_layers=num_layers, 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | temb_channels=temb_channels, 58 | add_downsample=add_downsample, 59 | resnet_eps=resnet_eps, 60 | resnet_act_fn=resnet_act_fn, 61 | resnet_groups=resnet_groups, 62 | downsample_padding=downsample_padding, 63 | resnet_time_scale_shift=resnet_time_scale_shift, 64 | 65 | use_inflated_groupnorm=use_inflated_groupnorm, 66 | 67 | use_motion_module=use_motion_module, 68 | motion_module_type=motion_module_type, 69 | motion_module_kwargs=motion_module_kwargs, 70 | ) 71 | elif down_block_type == "CrossAttnDownBlock3D": 72 | if cross_attention_dim is None: 73 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 74 | return CrossAttnDownBlock3D( 75 | num_layers=num_layers, 76 | in_channels=in_channels, 77 | out_channels=out_channels, 78 | temb_channels=temb_channels, 79 | add_downsample=add_downsample, 80 | resnet_eps=resnet_eps, 81 | resnet_act_fn=resnet_act_fn, 82 | resnet_groups=resnet_groups, 83 | downsample_padding=downsample_padding, 84 | cross_attention_dim=cross_attention_dim, 85 | num_attention_heads=num_attention_heads, 86 | dual_cross_attention=dual_cross_attention, 87 | use_linear_projection=use_linear_projection, 88 | only_cross_attention=only_cross_attention, 89 | upcast_attention=upcast_attention, 90 | resnet_time_scale_shift=resnet_time_scale_shift, 91 | 92 | use_inflated_groupnorm=use_inflated_groupnorm, 93 | 94 | use_motion_module=use_motion_module, 95 | motion_module_type=motion_module_type, 96 | motion_module_kwargs=motion_module_kwargs, 97 | ) 98 | raise ValueError(f"{down_block_type} does not exist.") 99 | 100 | 101 | def get_up_block( 102 | up_block_type, 103 | num_layers, 104 | in_channels, 105 | out_channels, 106 | prev_output_channel, 107 | temb_channels, 108 | add_upsample, 109 | resnet_eps, 110 | resnet_act_fn, 111 | num_attention_heads, 112 | use_motion_module, 113 | motion_module_type, 114 | motion_module_kwargs, 115 | resnet_groups=None, 116 | cross_attention_dim=None, 117 | unet_use_cross_frame_attention=False, 118 | unet_use_temporal_attention=False, 119 | dual_cross_attention=False, 120 | use_linear_projection=True, 121 | only_cross_attention=False, 122 | upcast_attention=False, 123 | resnet_time_scale_shift="default", 124 | 125 | use_inflated_groupnorm=False, 126 | 127 | ): 128 | if up_block_type == "UpBlock3D": 129 | return UpBlock3D( 130 | num_layers=num_layers, 131 | in_channels=in_channels, 132 | out_channels=out_channels, 133 | prev_output_channel=prev_output_channel, 134 | temb_channels=temb_channels, 135 | add_upsample=add_upsample, 136 | resnet_eps=resnet_eps, 137 | resnet_act_fn=resnet_act_fn, 138 | resnet_groups=resnet_groups, 139 | resnet_time_scale_shift=resnet_time_scale_shift, 140 | 141 | use_inflated_groupnorm=use_inflated_groupnorm, 142 | 143 | use_motion_module=use_motion_module, 144 | motion_module_type=motion_module_type, 145 | motion_module_kwargs=motion_module_kwargs, 146 | ) 147 | elif up_block_type == "CrossAttnUpBlock3D": 148 | if cross_attention_dim is None: 149 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 150 | return CrossAttnUpBlock3D( 151 | num_layers=num_layers, 152 | in_channels=in_channels, 153 | out_channels=out_channels, 154 | prev_output_channel=prev_output_channel, 155 | temb_channels=temb_channels, 156 | add_upsample=add_upsample, 157 | resnet_eps=resnet_eps, 158 | resnet_act_fn=resnet_act_fn, 159 | resnet_groups=resnet_groups, 160 | cross_attention_dim=cross_attention_dim, 161 | num_attention_heads=num_attention_heads, 162 | dual_cross_attention=dual_cross_attention, 163 | use_linear_projection=use_linear_projection, 164 | only_cross_attention=only_cross_attention, 165 | upcast_attention=upcast_attention, 166 | resnet_time_scale_shift=resnet_time_scale_shift, 167 | 168 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 169 | unet_use_temporal_attention=unet_use_temporal_attention, 170 | use_inflated_groupnorm=use_inflated_groupnorm, 171 | 172 | use_motion_module=use_motion_module, 173 | motion_module_type=motion_module_type, 174 | motion_module_kwargs=motion_module_kwargs, 175 | ) 176 | raise ValueError(f"{up_block_type} does not exist.") 177 | 178 | 179 | class UNetMidBlock3DCrossAttn(nn.Module): 180 | def __init__( 181 | self, 182 | in_channels: int, 183 | temb_channels: int, 184 | dropout: float = 0.0, 185 | num_layers: int = 1, 186 | resnet_eps: float = 1e-6, 187 | resnet_time_scale_shift: str = "default", 188 | resnet_act_fn: str = "swish", 189 | resnet_groups: int = 32, 190 | resnet_pre_norm: bool = True, 191 | num_attention_heads=1, 192 | output_scale_factor=1.0, 193 | cross_attention_dim=1280, 194 | dual_cross_attention=False, 195 | use_linear_projection=True, 196 | upcast_attention=False, 197 | 198 | unet_use_cross_frame_attention=None, 199 | unet_use_temporal_attention=None, 200 | use_inflated_groupnorm=None, 201 | 202 | use_motion_module=None, 203 | 204 | motion_module_type=None, 205 | motion_module_kwargs=None, 206 | ): 207 | super().__init__() 208 | 209 | self.has_cross_attention = True 210 | self.num_attention_heads = num_attention_heads 211 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 212 | 213 | # there is always at least one resnet 214 | resnets = [ 215 | ResnetBlock3D( 216 | in_channels=in_channels, 217 | out_channels=in_channels, 218 | temb_channels=temb_channels, 219 | eps=resnet_eps, 220 | groups=resnet_groups, 221 | dropout=dropout, 222 | time_embedding_norm=resnet_time_scale_shift, 223 | non_linearity=resnet_act_fn, 224 | output_scale_factor=output_scale_factor, 225 | pre_norm=resnet_pre_norm, 226 | use_inflated_groupnorm=use_inflated_groupnorm, 227 | 228 | ) 229 | ] 230 | 231 | attentions = [] 232 | motion_modules = [] 233 | 234 | for _ in range(num_layers): 235 | attentions.append( 236 | Transformer3DModel( 237 | in_channels // num_attention_heads, 238 | num_attention_heads, 239 | in_channels=in_channels, 240 | num_layers=1, 241 | cross_attention_dim=cross_attention_dim, 242 | norm_num_groups=resnet_groups, 243 | # use_linear_projection=use_linear_projection, 244 | use_linear_projection=False, 245 | upcast_attention=upcast_attention, 246 | 247 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 248 | unet_use_temporal_attention=unet_use_temporal_attention, 249 | ) 250 | ) 251 | motion_modules.append( 252 | get_motion_module( 253 | in_channels=in_channels, 254 | motion_module_type=motion_module_type, 255 | motion_module_kwargs=motion_module_kwargs, 256 | ) if use_motion_module else None 257 | ) 258 | 259 | resnets.append( 260 | ResnetBlock3D( 261 | in_channels=in_channels, 262 | out_channels=in_channels, 263 | temb_channels=temb_channels, 264 | eps=resnet_eps, 265 | groups=resnet_groups, 266 | dropout=dropout, 267 | time_embedding_norm=resnet_time_scale_shift, 268 | non_linearity=resnet_act_fn, 269 | output_scale_factor=output_scale_factor, 270 | pre_norm=resnet_pre_norm, 271 | 272 | use_inflated_groupnorm=use_inflated_groupnorm, 273 | 274 | ) 275 | ) 276 | 277 | self.attentions = nn.ModuleList(attentions) 278 | self.resnets = nn.ModuleList(resnets) 279 | self.motion_modules = nn.ModuleList(motion_modules) 280 | 281 | def forward( 282 | self, 283 | hidden_states, 284 | temb=None, 285 | encoder_hidden_states=None, 286 | attention_mask=None, 287 | ): 288 | hidden_states = self.resnets[0](hidden_states, temb) 289 | for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): 290 | hidden_states = attn( 291 | hidden_states, 292 | encoder_hidden_states=encoder_hidden_states, 293 | return_dict=False, 294 | )[0] 295 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 296 | hidden_states = resnet(hidden_states, temb) 297 | 298 | return hidden_states 299 | 300 | 301 | class CrossAttnDownBlock3D(nn.Module): 302 | def __init__( 303 | self, 304 | in_channels: int, 305 | out_channels: int, 306 | temb_channels: int, 307 | dropout: float = 0.0, 308 | num_layers: int = 1, 309 | resnet_eps: float = 1e-6, 310 | resnet_time_scale_shift: str = "default", 311 | resnet_act_fn: str = "swish", 312 | resnet_groups: int = 32, 313 | resnet_pre_norm: bool = True, 314 | num_attention_heads=1, 315 | cross_attention_dim=1280, 316 | output_scale_factor=1.0, 317 | downsample_padding=1, 318 | add_downsample=True, 319 | dual_cross_attention=False, 320 | use_linear_projection=False, 321 | only_cross_attention=False, 322 | upcast_attention=False, 323 | 324 | unet_use_cross_frame_attention=None, 325 | unet_use_temporal_attention=None, 326 | use_inflated_groupnorm=None, 327 | 328 | use_motion_module=None, 329 | 330 | motion_module_type=None, 331 | motion_module_kwargs=None, 332 | ): 333 | super().__init__() 334 | resnets = [] 335 | attentions = [] 336 | motion_modules = [] 337 | 338 | self.has_cross_attention = True 339 | self.attn_num_attention_heads = num_attention_heads 340 | 341 | for i in range(num_layers): 342 | in_channels = in_channels if i == 0 else out_channels 343 | resnets.append( 344 | ResnetBlock3D( 345 | in_channels=in_channels, 346 | out_channels=out_channels, 347 | temb_channels=temb_channels, 348 | eps=resnet_eps, 349 | groups=resnet_groups, 350 | dropout=dropout, 351 | time_embedding_norm=resnet_time_scale_shift, 352 | non_linearity=resnet_act_fn, 353 | output_scale_factor=output_scale_factor, 354 | pre_norm=resnet_pre_norm, 355 | 356 | use_inflated_groupnorm=use_inflated_groupnorm, 357 | ) 358 | ) 359 | if dual_cross_attention: 360 | raise NotImplementedError 361 | 362 | attentions.append( 363 | Transformer3DModel( 364 | self.attn_num_attention_heads, 365 | out_channels // self.attn_num_attention_heads, 366 | in_channels=out_channels, 367 | num_layers=1, 368 | cross_attention_dim=cross_attention_dim, 369 | norm_num_groups=resnet_groups, 370 | # use_linear_projection=use_linear_projection, 371 | use_linear_projection=False, 372 | only_cross_attention=only_cross_attention, 373 | upcast_attention=upcast_attention, 374 | 375 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 376 | unet_use_temporal_attention=unet_use_temporal_attention, 377 | ) 378 | ) 379 | 380 | motion_modules.append( 381 | get_motion_module( 382 | in_channels=out_channels, 383 | motion_module_type=motion_module_type, 384 | motion_module_kwargs=motion_module_kwargs, 385 | ) 386 | ) 387 | 388 | self.resnets = nn.ModuleList(resnets) 389 | self.attentions = nn.ModuleList(attentions) 390 | self.motion_modules = nn.ModuleList(motion_modules) 391 | 392 | 393 | if add_downsample: 394 | self.downsamplers = nn.ModuleList( 395 | [ 396 | Downsample3D( 397 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 398 | ) 399 | ] 400 | ) 401 | else: 402 | self.downsamplers = None 403 | 404 | self.gradient_checkpointing = False 405 | 406 | def forward( 407 | self, 408 | hidden_states, 409 | temb=None, 410 | encoder_hidden_states=None, 411 | attention_mask=None, 412 | num_frames=1, 413 | ): 414 | # TODO(Patrick, William) - attention mask is not used 415 | output_states = () 416 | 417 | for resnet, attn, motion_module in zip( 418 | self.resnets, self.attentions, self.motion_modules 419 | ): 420 | hidden_states = resnet(hidden_states, temb) 421 | hidden_states = attn( 422 | hidden_states, 423 | encoder_hidden_states=encoder_hidden_states, 424 | return_dict=False, 425 | )[0] 426 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 427 | 428 | 429 | output_states += (hidden_states,) 430 | 431 | if self.downsamplers is not None: 432 | for downsampler in self.downsamplers: 433 | hidden_states = downsampler(hidden_states) 434 | 435 | output_states += (hidden_states,) 436 | 437 | return hidden_states, output_states 438 | 439 | 440 | class DownBlock3D(nn.Module): 441 | def __init__( 442 | self, 443 | in_channels: int, 444 | out_channels: int, 445 | temb_channels: int, 446 | dropout: float = 0.0, 447 | num_layers: int = 1, 448 | resnet_eps: float = 1e-6, 449 | resnet_time_scale_shift: str = "default", 450 | resnet_act_fn: str = "swish", 451 | resnet_groups: int = 32, 452 | resnet_pre_norm: bool = True, 453 | output_scale_factor=1.0, 454 | add_downsample=True, 455 | downsample_padding=1, 456 | use_inflated_groupnorm=False, 457 | use_motion_module=True, 458 | motion_module_type=None, 459 | motion_module_kwargs=None, 460 | ): 461 | super().__init__() 462 | resnets = [] 463 | motion_modules = [] 464 | 465 | for i in range(num_layers): 466 | in_channels = in_channels if i == 0 else out_channels 467 | resnets.append( 468 | ResnetBlock3D( 469 | in_channels=in_channels, 470 | out_channels=out_channels, 471 | temb_channels=temb_channels, 472 | eps=resnet_eps, 473 | groups=resnet_groups, 474 | dropout=dropout, 475 | time_embedding_norm=resnet_time_scale_shift, 476 | non_linearity=resnet_act_fn, 477 | output_scale_factor=output_scale_factor, 478 | pre_norm=resnet_pre_norm, 479 | 480 | use_inflated_groupnorm=use_inflated_groupnorm, 481 | ) 482 | ) 483 | motion_modules.append( 484 | get_motion_module( 485 | in_channels=out_channels, 486 | motion_module_type=motion_module_type, 487 | motion_module_kwargs=motion_module_kwargs, 488 | ) if use_motion_module else None 489 | ) 490 | 491 | self.resnets = nn.ModuleList(resnets) 492 | self.motion_modules = nn.ModuleList(motion_modules) 493 | 494 | if add_downsample: 495 | self.downsamplers = nn.ModuleList( 496 | [ 497 | Downsample3D( 498 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 499 | ) 500 | ] 501 | ) 502 | else: 503 | self.downsamplers = None 504 | 505 | self.gradient_checkpointing = False 506 | 507 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 508 | output_states = () 509 | 510 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 511 | hidden_states = resnet(hidden_states, temb) 512 | # hidden_states = temp_conv(hidden_states, num_frames=num_frames) 513 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 514 | 515 | output_states += (hidden_states,) 516 | 517 | if self.downsamplers is not None: 518 | for downsampler in self.downsamplers: 519 | hidden_states = downsampler(hidden_states) 520 | 521 | output_states += (hidden_states,) 522 | 523 | return hidden_states, output_states 524 | 525 | 526 | class CrossAttnUpBlock3D(nn.Module): 527 | def __init__( 528 | self, 529 | in_channels: int, 530 | out_channels: int, 531 | prev_output_channel: int, 532 | temb_channels: int, 533 | dropout: float = 0.0, 534 | num_layers: int = 1, 535 | resnet_eps: float = 1e-6, 536 | resnet_time_scale_shift: str = "default", 537 | resnet_act_fn: str = "swish", 538 | resnet_groups: int = 32, 539 | resnet_pre_norm: bool = True, 540 | num_attention_heads=1, 541 | cross_attention_dim=1280, 542 | output_scale_factor=1.0, 543 | add_upsample=True, 544 | dual_cross_attention=False, 545 | use_linear_projection=False, 546 | only_cross_attention=False, 547 | upcast_attention=False, 548 | 549 | unet_use_cross_frame_attention=None, 550 | unet_use_temporal_attention=None, 551 | use_inflated_groupnorm=None, 552 | 553 | use_motion_module=None, 554 | 555 | motion_module_type=None, 556 | motion_module_kwargs=None, 557 | ): 558 | super().__init__() 559 | resnets = [] 560 | attentions = [] 561 | motion_modules = [] 562 | 563 | self.has_cross_attention = True 564 | self.attn_num_attention_heads = num_attention_heads 565 | 566 | for i in range(num_layers): 567 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 568 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 569 | 570 | resnets.append( 571 | ResnetBlock3D( 572 | in_channels=resnet_in_channels + res_skip_channels, 573 | out_channels=out_channels, 574 | temb_channels=temb_channels, 575 | eps=resnet_eps, 576 | groups=resnet_groups, 577 | dropout=dropout, 578 | time_embedding_norm=resnet_time_scale_shift, 579 | non_linearity=resnet_act_fn, 580 | output_scale_factor=output_scale_factor, 581 | pre_norm=resnet_pre_norm, 582 | 583 | use_inflated_groupnorm=use_inflated_groupnorm, 584 | ) 585 | ) 586 | attentions.append( 587 | Transformer3DModel( 588 | self.attn_num_attention_heads, 589 | out_channels // self.attn_num_attention_heads, 590 | in_channels=out_channels, 591 | num_layers=1, 592 | cross_attention_dim=cross_attention_dim, 593 | norm_num_groups=resnet_groups, 594 | # use_linear_projection=use_linear_projection, 595 | use_linear_projection=False, 596 | only_cross_attention=only_cross_attention, 597 | upcast_attention=upcast_attention, 598 | 599 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 600 | unet_use_temporal_attention=unet_use_temporal_attention, 601 | ) 602 | ) 603 | motion_modules.append( 604 | get_motion_module( 605 | in_channels=out_channels, 606 | motion_module_type=motion_module_type, 607 | motion_module_kwargs=motion_module_kwargs, 608 | ) if use_motion_module else None 609 | ) 610 | self.resnets = nn.ModuleList(resnets) 611 | self.attentions = nn.ModuleList(attentions) 612 | self.motion_modules = nn.ModuleList(motion_modules) 613 | 614 | if add_upsample: 615 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 616 | else: 617 | self.upsamplers = None 618 | 619 | self.gradient_checkpointing = False 620 | 621 | def forward( 622 | self, 623 | hidden_states, 624 | res_hidden_states_tuple, 625 | temb=None, 626 | encoder_hidden_states=None, 627 | upsample_size=None, 628 | attention_mask=None, 629 | ): 630 | # TODO(Patrick, William) - attention mask is not used 631 | for resnet, attn, motion_module in zip( 632 | self.resnets, self.attentions, self.motion_modules 633 | ): 634 | # pop res hidden states 635 | res_hidden_states = res_hidden_states_tuple[-1] 636 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 637 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 638 | 639 | hidden_states = resnet(hidden_states, temb) 640 | hidden_states = attn( 641 | hidden_states, 642 | encoder_hidden_states=encoder_hidden_states, 643 | return_dict=False, 644 | )[0] 645 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 646 | 647 | 648 | if self.upsamplers is not None: 649 | for upsampler in self.upsamplers: 650 | hidden_states = upsampler(hidden_states, upsample_size) 651 | 652 | return hidden_states 653 | 654 | 655 | class UpBlock3D(nn.Module): 656 | def __init__( 657 | self, 658 | in_channels: int, 659 | prev_output_channel: int, 660 | out_channels: int, 661 | temb_channels: int, 662 | dropout: float = 0.0, 663 | num_layers: int = 1, 664 | resnet_eps: float = 1e-6, 665 | resnet_time_scale_shift: str = "default", 666 | resnet_act_fn: str = "swish", 667 | resnet_groups: int = 32, 668 | resnet_pre_norm: bool = True, 669 | output_scale_factor=1.0, 670 | add_upsample=True, 671 | 672 | use_inflated_groupnorm=None, 673 | 674 | use_motion_module=None, 675 | motion_module_type=None, 676 | motion_module_kwargs=None, 677 | ): 678 | super().__init__() 679 | resnets = [] 680 | motion_modules = [] 681 | 682 | for i in range(num_layers): 683 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 684 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 685 | 686 | resnets.append( 687 | ResnetBlock3D( 688 | in_channels=resnet_in_channels + res_skip_channels, 689 | out_channels=out_channels, 690 | temb_channels=temb_channels, 691 | eps=resnet_eps, 692 | groups=resnet_groups, 693 | dropout=dropout, 694 | time_embedding_norm=resnet_time_scale_shift, 695 | non_linearity=resnet_act_fn, 696 | output_scale_factor=output_scale_factor, 697 | pre_norm=resnet_pre_norm, 698 | 699 | use_inflated_groupnorm=use_inflated_groupnorm, 700 | 701 | ) 702 | ) 703 | motion_modules.append( 704 | get_motion_module( 705 | in_channels=out_channels, 706 | motion_module_type=motion_module_type, 707 | motion_module_kwargs=motion_module_kwargs, 708 | ) if use_motion_module else None 709 | ) 710 | 711 | self.resnets = nn.ModuleList(resnets) 712 | self.motion_modules = nn.ModuleList(motion_modules) 713 | 714 | if add_upsample: 715 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 716 | else: 717 | self.upsamplers = None 718 | 719 | self.gradient_checkpointing = False 720 | 721 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None): 722 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 723 | # pop res hidden states 724 | res_hidden_states = res_hidden_states_tuple[-1] 725 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 726 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 727 | 728 | hidden_states = resnet(hidden_states, temb) 729 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 730 | 731 | 732 | if self.upsamplers is not None: 733 | for upsampler in self.upsamplers: 734 | hidden_states = upsampler(hidden_states, upsample_size) 735 | 736 | return hidden_states 737 | -------------------------------------------------------------------------------- /animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/pipelines/pipeline_animation.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py 2 | 3 | import inspect 4 | from typing import Callable, List, Optional, Union 5 | from dataclasses import dataclass 6 | import cv2 7 | 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | from tqdm import tqdm 12 | 13 | from diffusers.utils import is_accelerate_available 14 | from packaging import version 15 | from transformers import CLIPTextModel, CLIPTokenizer 16 | 17 | from diffusers.configuration_utils import FrozenDict 18 | from diffusers.models import AutoencoderKL 19 | from diffusers.pipeline_utils import DiffusionPipeline 20 | from diffusers.schedulers import ( 21 | DDIMScheduler, 22 | DPMSolverMultistepScheduler, 23 | EulerAncestralDiscreteScheduler, 24 | EulerDiscreteScheduler, 25 | LMSDiscreteScheduler, 26 | PNDMScheduler, 27 | ) 28 | from diffusers.models import ControlNetModel 29 | from diffusers.image_processor import VaeImageProcessor 30 | from diffusers.utils import deprecate, logging, BaseOutput 31 | 32 | from einops import rearrange 33 | 34 | from ..models.unet import UNet3DConditionModel 35 | 36 | 37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 38 | 39 | 40 | 41 | def prepare_image( 42 | image, 43 | width, 44 | height, 45 | batch_size, 46 | num_images_per_prompt, 47 | device, 48 | dtype, 49 | do_classifier_free_guidance=False, 50 | guess_mode=False, 51 | ): 52 | control_image_processor = VaeImageProcessor() 53 | image = control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) 54 | image_batch_size = batch_size 55 | 56 | if image_batch_size == 1: 57 | repeat_by = batch_size 58 | else: 59 | # image batch size is the same as prompt batch size 60 | repeat_by = num_images_per_prompt 61 | 62 | image = image.repeat_interleave(repeat_by, dim=0) 63 | 64 | image = image.to(device=device, dtype=dtype) 65 | 66 | if do_classifier_free_guidance and not guess_mode: 67 | image = torch.cat([image] * 2) 68 | 69 | return image 70 | 71 | 72 | @dataclass 73 | class AnimationPipelineOutput(BaseOutput): 74 | videos: Union[torch.Tensor, np.ndarray] 75 | 76 | 77 | class AnimationPipeline(DiffusionPipeline): 78 | _optional_components = [] 79 | 80 | def __init__( 81 | self, 82 | vae: AutoencoderKL, 83 | text_encoder: CLIPTextModel, 84 | tokenizer: CLIPTokenizer, 85 | unet: UNet3DConditionModel, 86 | scheduler: Union[ 87 | DDIMScheduler, 88 | PNDMScheduler, 89 | LMSDiscreteScheduler, 90 | EulerDiscreteScheduler, 91 | EulerAncestralDiscreteScheduler, 92 | DPMSolverMultistepScheduler, 93 | ], 94 | controlnet: ControlNetModel, 95 | ): 96 | super().__init__() 97 | 98 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 99 | deprecation_message = ( 100 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 101 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 102 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 103 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 104 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 105 | " file" 106 | ) 107 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 108 | new_config = dict(scheduler.config) 109 | new_config["steps_offset"] = 1 110 | scheduler._internal_dict = FrozenDict(new_config) 111 | 112 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 113 | deprecation_message = ( 114 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 115 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 116 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 117 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 118 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 119 | ) 120 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 121 | new_config = dict(scheduler.config) 122 | new_config["clip_sample"] = False 123 | scheduler._internal_dict = FrozenDict(new_config) 124 | 125 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 126 | version.parse(unet.config._diffusers_version).base_version 127 | ) < version.parse("0.9.0.dev0") 128 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 129 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 130 | deprecation_message = ( 131 | "The configuration file of the unet has set the default `sample_size` to smaller than" 132 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 133 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 134 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 135 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 136 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 137 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 138 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 139 | " the `unet/config.json` file" 140 | ) 141 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 142 | new_config = dict(unet.config) 143 | new_config["sample_size"] = 64 144 | unet._internal_dict = FrozenDict(new_config) 145 | 146 | self.register_modules( 147 | vae=vae, 148 | text_encoder=text_encoder, 149 | tokenizer=tokenizer, 150 | unet=unet, 151 | scheduler=scheduler, 152 | controlnet=controlnet 153 | ) 154 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 155 | 156 | def enable_vae_slicing(self): 157 | self.vae.enable_slicing() 158 | 159 | def disable_vae_slicing(self): 160 | self.vae.disable_slicing() 161 | 162 | def enable_sequential_cpu_offload(self, gpu_id=0): 163 | if is_accelerate_available(): 164 | from accelerate import cpu_offload 165 | else: 166 | raise ImportError("Please install accelerate via `pip install accelerate`") 167 | 168 | device = torch.device(f"cuda:{gpu_id}") 169 | 170 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 171 | if cpu_offloaded_model is not None: 172 | cpu_offload(cpu_offloaded_model, device) 173 | 174 | 175 | @property 176 | def _execution_device(self): 177 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 178 | return self.device 179 | for module in self.unet.modules(): 180 | if ( 181 | hasattr(module, "_hf_hook") 182 | and hasattr(module._hf_hook, "execution_device") 183 | and module._hf_hook.execution_device is not None 184 | ): 185 | return torch.device(module._hf_hook.execution_device) 186 | return self.device 187 | 188 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 189 | batch_size = len(prompt) if isinstance(prompt, list) else 1 190 | 191 | text_inputs = self.tokenizer( 192 | prompt, 193 | padding="max_length", 194 | max_length=self.tokenizer.model_max_length, 195 | truncation=True, 196 | return_tensors="pt", 197 | ) 198 | text_input_ids = text_inputs.input_ids 199 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 200 | 201 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 202 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 203 | logger.warning( 204 | "The following part of your input was truncated because CLIP can only handle sequences up to" 205 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 206 | ) 207 | 208 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 209 | attention_mask = text_inputs.attention_mask.to(device) 210 | else: 211 | attention_mask = None 212 | 213 | text_embeddings = self.text_encoder( 214 | text_input_ids.to(device), 215 | attention_mask=attention_mask, 216 | ) 217 | text_embeddings = text_embeddings[0] 218 | 219 | # duplicate text embeddings for each generation per prompt, using mps friendly method 220 | bs_embed, seq_len, _ = text_embeddings.shape 221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 222 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 223 | 224 | # get unconditional embeddings for classifier free guidance 225 | if do_classifier_free_guidance: 226 | uncond_tokens: List[str] 227 | if negative_prompt is None: 228 | uncond_tokens = [""] * batch_size 229 | elif type(prompt) is not type(negative_prompt): 230 | raise TypeError( 231 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 232 | f" {type(prompt)}." 233 | ) 234 | elif isinstance(negative_prompt, str): 235 | uncond_tokens = [negative_prompt] 236 | elif batch_size != len(negative_prompt): 237 | raise ValueError( 238 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 239 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 240 | " the batch size of `prompt`." 241 | ) 242 | else: 243 | uncond_tokens = negative_prompt 244 | 245 | max_length = text_input_ids.shape[-1] 246 | uncond_input = self.tokenizer( 247 | uncond_tokens, 248 | padding="max_length", 249 | max_length=max_length, 250 | truncation=True, 251 | return_tensors="pt", 252 | ) 253 | 254 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 255 | attention_mask = uncond_input.attention_mask.to(device) 256 | else: 257 | attention_mask = None 258 | 259 | uncond_embeddings = self.text_encoder( 260 | uncond_input.input_ids.to(device), 261 | attention_mask=attention_mask, 262 | ) 263 | uncond_embeddings = uncond_embeddings[0] 264 | 265 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 266 | seq_len = uncond_embeddings.shape[1] 267 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 268 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 269 | 270 | # For classifier free guidance, we need to do two forward passes. 271 | # Here we concatenate the unconditional and text embeddings into a single batch 272 | # to avoid doing two forward passes 273 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 274 | 275 | return text_embeddings 276 | 277 | def decode_latents(self, latents): 278 | video_length = latents.shape[2] 279 | latents = 1 / 0.18215 * latents 280 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 281 | # video = self.vae.decode(latents).sample 282 | video = [] 283 | for frame_idx in tqdm(range(latents.shape[0])): 284 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) 285 | video = torch.cat(video) 286 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 287 | video = (video / 2 + 0.5).clamp(0, 1) 288 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 289 | video = video.cpu().float().numpy() 290 | return video 291 | 292 | def prepare_extra_step_kwargs(self, generator, eta): 293 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 294 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 295 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 296 | # and should be between [0, 1] 297 | 298 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 299 | extra_step_kwargs = {} 300 | if accepts_eta: 301 | extra_step_kwargs["eta"] = eta 302 | 303 | # check if the scheduler accepts generator 304 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 305 | if accepts_generator: 306 | extra_step_kwargs["generator"] = generator 307 | return extra_step_kwargs 308 | 309 | def check_inputs(self, prompt, height, width, callback_steps): 310 | if not isinstance(prompt, str) and not isinstance(prompt, list): 311 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 312 | 313 | if height % 8 != 0 or width % 8 != 0: 314 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 315 | 316 | if (callback_steps is None) or ( 317 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 318 | ): 319 | raise ValueError( 320 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 321 | f" {type(callback_steps)}." 322 | ) 323 | 324 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): 325 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 326 | if isinstance(generator, list) and len(generator) != batch_size: 327 | raise ValueError( 328 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 329 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 330 | ) 331 | if latents is None: 332 | rand_device = "cpu" if device.type == "mps" else device 333 | 334 | if isinstance(generator, list): 335 | shape = shape 336 | # shape = (1,) + shape[1:] 337 | latents = [ 338 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 339 | for i in range(batch_size) 340 | ] 341 | latents = torch.cat(latents, dim=0).to(device) 342 | else: 343 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 344 | else: 345 | rand_device = "cpu" if device.type == "mps" else device 346 | if latents.shape != shape: 347 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 348 | noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 349 | latents = noise 350 | latents = latents.to(device) 351 | 352 | # scale the initial noise by the standard deviation required by the scheduler 353 | latents = latents * self.scheduler.init_noise_sigma 354 | return latents 355 | 356 | @torch.no_grad() 357 | def __call__( 358 | self, 359 | prompt: Union[str, List[str]], 360 | video_length: Optional[int], 361 | height: Optional[int] = None, 362 | width: Optional[int] = None, 363 | num_inference_steps: int = 50, 364 | guidance_scale: float = 7.5, 365 | negative_prompt: Optional[Union[str, List[str]]] = None, 366 | num_videos_per_prompt: Optional[int] = 1, 367 | eta: float = 0.0, 368 | controlnet_image: torch.FloatTensor = None, 369 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 370 | latents: Optional[torch.FloatTensor] = None, 371 | output_type: Optional[str] = "tensor", 372 | return_dict: bool = True, 373 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 374 | callback_steps: Optional[int] = 1, 375 | **kwargs, 376 | ): 377 | # Default height and width to unet 378 | height = height or self.unet.config.sample_size * self.vae_scale_factor 379 | width = width or self.unet.config.sample_size * self.vae_scale_factor 380 | 381 | # Check inputs. Raise error if not correct 382 | self.check_inputs(prompt, height, width, callback_steps) 383 | 384 | # Define call parameters 385 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 386 | batch_size = 1 387 | if latents is not None: 388 | batch_size = latents.shape[0] 389 | if isinstance(prompt, list): 390 | batch_size = len(prompt) 391 | 392 | device = self._execution_device 393 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 394 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 395 | # corresponds to doing no classifier free guidance. 396 | do_classifier_free_guidance = guidance_scale > 1.0 397 | 398 | # Encode input prompt 399 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 400 | if negative_prompt is not None: 401 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 402 | text_embeddings = self._encode_prompt( 403 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 404 | ) 405 | 406 | # Prepare timesteps 407 | self.scheduler.set_timesteps(num_inference_steps, device=device) 408 | timesteps = self.scheduler.timesteps 409 | 410 | # Prepare latent variables 411 | num_channels_latents = self.unet.in_channels 412 | 413 | latents = self.prepare_latents( 414 | batch_size * num_videos_per_prompt, 415 | num_channels_latents, 416 | video_length, 417 | height, 418 | width, 419 | text_embeddings.dtype, 420 | device, 421 | generator, 422 | latents, 423 | ) 424 | latents_dtype = latents.dtype 425 | 426 | #--------------- 427 | image = prepare_image( 428 | image=controlnet_image, 429 | width=controlnet_image.shape[-1], 430 | height=controlnet_image.shape[-2], 431 | batch_size=1, 432 | num_images_per_prompt=1, 433 | device="cuda", 434 | dtype=self.controlnet.dtype, 435 | ) 436 | 437 | #--------------- 438 | 439 | # Prepare extra step kwargs. 440 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 441 | 442 | # Denoising loop 443 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 444 | with self.progress_bar(total=num_inference_steps) as progress_bar: 445 | for i, t in enumerate(timesteps): 446 | # expand the latents if we are doing classifier free guidance 447 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 448 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 449 | 450 | down_block_res_samples, mid_block_res_sample = self.controlnet( 451 | sample=latent_model_input[:,:,0,:,:], 452 | timestep=t, 453 | encoder_hidden_states=text_embeddings, # [4,77,768] 454 | controlnet_cond=image, 455 | return_dict=False, 456 | ) 457 | 458 | # predict the noise residual 459 | noise_pred = self.unet(latent_model_input, 460 | t, 461 | encoder_hidden_states=text_embeddings, 462 | down_block_additional_residuals=down_block_res_samples, 463 | mid_block_additional_residual=mid_block_res_sample, 464 | ).sample.to(dtype=latents_dtype) 465 | 466 | # perform guidance 467 | if do_classifier_free_guidance: 468 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 469 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 470 | 471 | # compute the previous noisy sample x_t -> x_t-1 472 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 473 | 474 | # call the callback, if provided 475 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 476 | progress_bar.update() 477 | if callback is not None and i % callback_steps == 0: 478 | callback(i, t, latents) 479 | 480 | # Post-processing 481 | video = self.decode_latents(latents) 482 | 483 | # Convert to tensor 484 | if output_type == "tensor": 485 | video = torch.from_numpy(video) 486 | 487 | if not return_dict: 488 | return video 489 | 490 | return AnimationPipelineOutput(videos=video) 491 | 492 | -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | 18 | import argparse 19 | 20 | import torch 21 | from safetensors.torch import load_file 22 | 23 | from diffusers import StableDiffusionPipeline 24 | import pdb 25 | 26 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 27 | # load base model 28 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 29 | 30 | # load LoRA weight from .safetensors 31 | # state_dict = load_file(checkpoint_path) 32 | 33 | visited = [] 34 | 35 | # directly update weight in diffusers model 36 | for key in state_dict: 37 | # it is suggested to print out the key, it usually will be something like below 38 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 39 | 40 | # as we have set the alpha beforehand, so just skip 41 | if ".alpha" in key or key in visited: 42 | continue 43 | 44 | if "text" in key: 45 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 46 | curr_layer = pipeline.text_encoder 47 | else: 48 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 49 | curr_layer = pipeline.unet 50 | 51 | # find the target layer 52 | temp_name = layer_infos.pop(0) 53 | while len(layer_infos) > -1: 54 | try: 55 | curr_layer = curr_layer.__getattr__(temp_name) 56 | if len(layer_infos) > 0: 57 | temp_name = layer_infos.pop(0) 58 | elif len(layer_infos) == 0: 59 | break 60 | except Exception: 61 | if len(temp_name) > 0: 62 | temp_name += "_" + layer_infos.pop(0) 63 | else: 64 | temp_name = layer_infos.pop(0) 65 | 66 | pair_keys = [] 67 | if "lora_down" in key: 68 | pair_keys.append(key.replace("lora_down", "lora_up")) 69 | pair_keys.append(key) 70 | else: 71 | pair_keys.append(key) 72 | pair_keys.append(key.replace("lora_up", "lora_down")) 73 | 74 | # update weight 75 | if len(state_dict[pair_keys[0]].shape) == 4: 76 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 77 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 78 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 79 | else: 80 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 81 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 82 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 83 | 84 | # update visited list 85 | for item in pair_keys: 86 | visited.append(item) 87 | 88 | return pipeline 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | 94 | parser.add_argument( 95 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 96 | ) 97 | parser.add_argument( 98 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 99 | ) 100 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 101 | parser.add_argument( 102 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 103 | ) 104 | parser.add_argument( 105 | "--lora_prefix_text_encoder", 106 | default="lora_te", 107 | type=str, 108 | help="The prefix of text encoder weight in safetensors", 109 | ) 110 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 111 | parser.add_argument( 112 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 113 | ) 114 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 115 | 116 | args = parser.parse_args() 117 | 118 | base_model_path = args.base_model_path 119 | checkpoint_path = args.checkpoint_path 120 | dump_path = args.dump_path 121 | lora_prefix_unet = args.lora_prefix_unet 122 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 123 | alpha = args.alpha 124 | 125 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 126 | 127 | pipe = pipe.to(args.device) 128 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 129 | -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from tqdm import tqdm 11 | from einops import rearrange 12 | 13 | 14 | def zero_rank_print(s): 15 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 16 | 17 | 18 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 19 | videos = rearrange(videos, "b c t h w -> t b c h w") 20 | outputs = [] 21 | for x in videos: 22 | x = torchvision.utils.make_grid(x, nrow=n_rows) 23 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 24 | if rescale: 25 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 26 | x = (x * 255).numpy().astype(np.uint8) 27 | outputs.append(x) 28 | 29 | os.makedirs(os.path.dirname(path), exist_ok=True) 30 | imageio.mimsave(path, outputs, fps=fps) 31 | 32 | 33 | # DDIM Inversion 34 | @torch.no_grad() 35 | def init_prompt(prompt, pipeline): 36 | uncond_input = pipeline.tokenizer( 37 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 38 | return_tensors="pt" 39 | ) 40 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 41 | text_input = pipeline.tokenizer( 42 | [prompt], 43 | padding="max_length", 44 | max_length=pipeline.tokenizer.model_max_length, 45 | truncation=True, 46 | return_tensors="pt", 47 | ) 48 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 49 | context = torch.cat([uncond_embeddings, text_embeddings]) 50 | 51 | return context 52 | 53 | 54 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 55 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 56 | timestep, next_timestep = min( 57 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 58 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 59 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 60 | beta_prod_t = 1 - alpha_prod_t 61 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 62 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 63 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 64 | return next_sample 65 | 66 | 67 | def get_noise_pred_single(latents, t, context, unet): 68 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 69 | return noise_pred 70 | 71 | 72 | @torch.no_grad() 73 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 74 | context = init_prompt(prompt, pipeline) 75 | uncond_embeddings, cond_embeddings = context.chunk(2) 76 | all_latent = [latent] 77 | latent = latent.clone().detach() 78 | for i in tqdm(range(num_inv_steps)): 79 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 80 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 81 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 82 | all_latent.append(latent) 83 | return all_latent 84 | 85 | 86 | @torch.no_grad() 87 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 88 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 89 | return ddim_latents 90 | -------------------------------------------------------------------------------- /animatetest.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import os 5 | from omegaconf import OmegaConf 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision import models 11 | from torch.nn import functional as F 12 | import torchvision.transforms as transforms 13 | 14 | import diffusers 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | import pickle 17 | 18 | from tqdm.auto import tqdm 19 | from transformers import CLIPTextModel, CLIPTokenizer 20 | 21 | # import sys 22 | # sys.path.append("/root/AnimateDiffcontrolnet-main/") 23 | 24 | from animatediff.models.unet import UNet3DConditionModel 25 | from animatediff.models.controlnet import ControlNetModel 26 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 27 | from animatediff.utils.util import save_videos_grid 28 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 29 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora 30 | from diffusers.utils.import_utils import is_xformers_available 31 | 32 | from einops import rearrange, repeat 33 | 34 | import csv, pdb, glob 35 | from safetensors import safe_open 36 | import math 37 | from pathlib import Path 38 | 39 | 40 | def main(args): 41 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 42 | func_args = dict(func_args) 43 | 44 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 45 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 46 | os.makedirs(savedir) 47 | 48 | config = OmegaConf.load(args.config) 49 | samples = [] 50 | 51 | sample_idx = 0 52 | for model_idx, (config_key, model_config) in enumerate(list(config.items())): 53 | 54 | motion_modules = model_config.motion_module 55 | motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules) 56 | for motion_module in motion_modules: 57 | inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config)) 58 | 59 | ### >>> create validation pipeline >>> ### 60 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") 61 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") 62 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") 63 | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 64 | controlnet = ControlNetModel() 65 | 66 | 67 | # if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() 68 | # else: assert False 69 | 70 | pipeline = AnimationPipeline( 71 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, 72 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 73 | ).to("cuda") 74 | 75 | # 0. controlnet ckpt 76 | controlnet_state_dict = torch.load("./checkpoints/controlnet_checkpoint-epoch-30.ckpt", map_location="cpu") 77 | missing, unexpected = pipeline.controlnet.load_state_dict(controlnet_state_dict["state_dict"], strict=False) 78 | assert len(unexpected) == 0 79 | 80 | 81 | # 1. unet ckpt 82 | # 1.1 motion module 83 | motion_module_state_dict = torch.load(motion_module, map_location="cpu") 84 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 85 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) 86 | assert len(unexpected) == 0 87 | 88 | # 1.2 T2I 用的其他微调过的模型 89 | if model_config.path != "": 90 | if model_config.path.endswith(".ckpt"): 91 | state_dict = torch.load(model_config.path) 92 | pipeline.unet.load_state_dict(state_dict) 93 | 94 | elif model_config.path.endswith(".safetensors"): 95 | state_dict = {} 96 | with safe_open(model_config.path, framework="pt", device="cpu") as f: 97 | for key in f.keys(): 98 | state_dict[key] = f.get_tensor(key) 99 | 100 | is_lora = all("lora" in k for k in state_dict.keys()) 101 | if not is_lora: 102 | base_state_dict = state_dict 103 | else: 104 | base_state_dict = {} 105 | with safe_open(model_config.base, framework="pt", device="cpu") as f: 106 | for key in f.keys(): 107 | base_state_dict[key] = f.get_tensor(key) 108 | 109 | # vae 110 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) 111 | pipeline.vae.load_state_dict(converted_vae_checkpoint) 112 | # unet 113 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) 114 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 115 | # text_model 116 | pipeline.text_encoder = convert_ldm_clip_checkpoint(pipeline.text_encoder, base_state_dict) 117 | 118 | # import pdb 119 | # pdb.set_trace() 120 | if is_lora: 121 | pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) 122 | 123 | pipeline.to("cuda") 124 | ### <<< create validation pipeline <<< ### 125 | 126 | prompts = model_config.prompt 127 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt 128 | 129 | random_seeds = model_config.get("seed", [-1]) 130 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 131 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 132 | 133 | config[config_key].random_seed = [] 134 | 135 | #------------------------------------------------ 136 | pixel_transforms = transforms.Compose([ 137 | # transforms.RandomHorizontalFlip(), 138 | transforms.Resize(512), 139 | transforms.CenterCrop(512), 140 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 141 | ]) 142 | # -------------------------------------------------- 143 | 144 | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): 145 | init_image = model_config.init_image[prompt_idx] 146 | 147 | pixel_values = Image.open(init_image) 148 | pixel_values = np.array(pixel_values) 149 | pixel_values = torch.from_numpy(pixel_values).permute(2,0,1).unsqueeze(0) 150 | pixel_values = pixel_values / 255. 151 | pixel_values = pixel_transforms(pixel_values) 152 | pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1) 153 | 154 | # manually set random seed for reproduction 155 | if random_seed != -1: torch.manual_seed(random_seed) 156 | else: torch.seed() 157 | config[config_key].random_seed.append(torch.initial_seed()) 158 | 159 | print(f"current seed: {torch.initial_seed()}") 160 | print(f"sampling {prompt} ...") 161 | sample = pipeline( 162 | prompt, 163 | negative_prompt = n_prompt, 164 | num_inference_steps = model_config.steps, 165 | guidance_scale = model_config.guidance_scale, 166 | width = args.W, 167 | height = args.H, 168 | video_length = args.L, 169 | controlnet_image = pixel_values, 170 | ).videos 171 | samples.append(sample) 172 | 173 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) 174 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") 175 | print(f"save to {savedir}/sample/{prompt}.gif") 176 | 177 | sample_idx += 1 178 | 179 | samples = torch.concat(samples) 180 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) 181 | 182 | OmegaConf.save(config, f"{savedir}/config.yaml") 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("--pretrained_model_path", type=str, default="./models/stable-diffusion-v1-5",) 188 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v2.yaml") 189 | parser.add_argument("--config", type=str, default="configs/prompts/v2/5-RealisticVision2.yaml") 190 | 191 | parser.add_argument("--L", type=int, default=16 ) 192 | parser.add_argument("--W", type=int, default=512) 193 | parser.add_argument("--H", type=int, default=512) 194 | 195 | args = parser.parse_args() 196 | main(args) 197 | -------------------------------------------------------------------------------- /configs/inference/inference-v1.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | unet_use_cross_frame_attention: false 3 | unet_use_temporal_attention: false 4 | use_motion_module: true 5 | motion_module_resolutions: 6 | - 1 7 | - 2 8 | - 4 9 | - 8 10 | motion_module_mid_block: false 11 | motion_module_decoder_only: false 12 | motion_module_type: Vanilla 13 | motion_module_kwargs: 14 | num_attention_heads: 8 15 | num_transformer_block: 1 16 | attention_block_types: 17 | - Temporal_Self 18 | - Temporal_Self 19 | temporal_position_encoding: true 20 | temporal_position_encoding_max_len: 24 21 | temporal_attention_dim_div: 1 22 | 23 | noise_scheduler_kwargs: 24 | beta_start: 0.00085 25 | beta_end: 0.012 26 | beta_schedule: "linear" 27 | -------------------------------------------------------------------------------- /configs/inference/inference-v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: 7 | - 1 8 | - 2 9 | - 4 10 | - 8 11 | motion_module_mid_block: true 12 | motion_module_decoder_only: false 13 | motion_module_type: Vanilla 14 | motion_module_kwargs: 15 | num_attention_heads: 8 16 | num_transformer_block: 1 17 | attention_block_types: 18 | - Temporal_Self 19 | - Temporal_Self 20 | temporal_position_encoding: true 21 | temporal_position_encoding_max_len: 32 22 | temporal_attention_dim_div: 1 23 | 24 | noise_scheduler_kwargs: 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | -------------------------------------------------------------------------------- /configs/prompts/v2/5-RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | RealisticVision: 2 | base: "" 3 | path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" 4 | 5 | inference_config: "configs/inference/inference-v2.yaml" 6 | motion_module: 7 | - "models/Motion_Module/mm_sd_v15_v2.ckpt" 8 | 9 | seed: [0] 10 | steps: 25 11 | guidance_scale: 7.5 12 | 13 | prompt: 14 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 15 | - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 16 | - "beach, large rocks, waves, cloudy sky, dark clouds, flowing water" 17 | - "scuba diving, coral reef, fish, sea anemones, starfish, sea turtles, clear water, sunlight, underwater world" 18 | - "fireworks, new year, 2023, night sky, stars, ring of fire" 19 | - "bird, small, brown, back, wings, white, chest, belly, black, beak, eyes, tree, deciduous, shrub, green, leaves, sky, clouds" 20 | 21 | n_prompt: 22 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 23 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 24 | - "" 25 | - "" 26 | - "" 27 | - "" 28 | 29 | init_image: 30 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/0.jpg" 31 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/1.jpg" 32 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/2.jpg" 33 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/3.jpg" 34 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/4.jpg" 35 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/5.jpg" 36 | 37 | -------------------------------------------------------------------------------- /configs/training/image_finetune.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: true 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" 5 | 6 | noise_scheduler_kwargs: 7 | num_train_timesteps: 1000 8 | beta_start: 0.00085 9 | beta_end: 0.012 10 | beta_schedule: "scaled_linear" 11 | steps_offset: 1 12 | clip_sample: false 13 | 14 | train_data: 15 | csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv" 16 | video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val" 17 | sample_size: 256 18 | 19 | validation_data: 20 | prompts: 21 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 22 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 23 | - "Robot dancing in times square." 24 | - "Pacific coast, carmel by the sea ocean and waves." 25 | num_inference_steps: 25 26 | guidance_scale: 8. 27 | 28 | trainable_modules: 29 | - "." 30 | 31 | unet_checkpoint_path: "" 32 | 33 | learning_rate: 1.e-5 34 | train_batch_size: 50 35 | 36 | max_train_epoch: -1 37 | max_train_steps: 100 38 | checkpointing_epochs: -1 39 | checkpointing_steps: 60 40 | 41 | validation_steps: 5000 42 | validation_steps_tuple: [2, 50] 43 | 44 | global_seed: 42 45 | mixed_precision_training: true 46 | enable_xformers_memory_efficient_attention: True 47 | 48 | is_debug: False 49 | -------------------------------------------------------------------------------- /configs/training/training.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: false 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/stable-diffusion-v1-5" 5 | 6 | unet_additional_kwargs: 7 | use_motion_module : true 8 | motion_module_resolutions : [ 1,2,4,8 ] 9 | unet_use_cross_frame_attention : false 10 | unet_use_temporal_attention : false 11 | 12 | motion_module_type: Vanilla 13 | motion_module_kwargs: 14 | num_attention_heads : 8 15 | num_transformer_block : 1 16 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 17 | temporal_position_encoding : true 18 | # temporal_position_encoding_max_len : 24 19 | temporal_position_encoding_max_len : 32 20 | temporal_attention_dim_div : 1 21 | zero_initialize : true 22 | 23 | noise_scheduler_kwargs: 24 | num_train_timesteps: 1000 25 | beta_start: 0.00085 26 | beta_end: 0.012 27 | beta_schedule: "linear" 28 | steps_offset: 1 29 | clip_sample: false 30 | 31 | train_data: 32 | csv_path: "./results_2M_train_new.csv" 33 | video_folder: "./datasets_train" 34 | sample_size: 256 35 | sample_stride: 4 36 | sample_n_frames: 16 37 | 38 | validation_data: 39 | prompts: 40 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 41 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 42 | - "Robot dancing in times square." 43 | - "Pacific coast, carmel by the sea ocean and waves." 44 | num_inference_steps: 25 45 | guidance_scale: 8. 46 | 47 | trainable_modules: 48 | - "motion_modules." 49 | 50 | # unet_checkpoint_path: "" 51 | unet_checkpoint_path: "models/Motion_Module/mm_sd_v15_v2.ckpt" 52 | 53 | 54 | learning_rate: 1.e-4 55 | train_batch_size: 12 56 | 57 | max_train_epoch: 30 58 | max_train_steps: -1 59 | checkpointing_epochs: -1 60 | # 存储checkpoints的step数 61 | checkpointing_steps: 2000 62 | 63 | validation_steps: 5000 64 | validation_steps_tuple: [1,1000, 5000, 10000] 65 | 66 | global_seed: 42 67 | mixed_precision_training: true 68 | enable_xformers_memory_efficient_attention: True 69 | 70 | is_debug: False 71 | -------------------------------------------------------------------------------- /download_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import requests 3 | import concurrent.futures 4 | import os 5 | 6 | def download_video(row): 7 | video_url = row['contentUrl'] 8 | video_name = row['videoid'] 9 | folder_path = './datasets_train' # 请将此路径替换为你的文件夹的路径 10 | video_file = os.path.join(folder_path, f'{video_name}.mp4') 11 | 12 | # 如果文件已经存在,就跳过下载 13 | if os.path.isfile(video_file): 14 | return 15 | 16 | response = requests.get(video_url) 17 | 18 | if response.status_code == 200: 19 | with open(video_file, 'wb') as f: 20 | f.write(response.content) 21 | else: 22 | print(f"Failed to download video {video_name} from url {video_url}") 23 | 24 | df = pd.read_csv('results_2M_train.csv') 25 | 26 | rows = df.to_dict('records') 27 | 28 | with concurrent.futures.ThreadPoolExecutor() as executor: 29 | for row in rows: 30 | executor.submit(download_video, row) -------------------------------------------------------------------------------- /imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/.DS_Store -------------------------------------------------------------------------------- /imgs/0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/0.gif -------------------------------------------------------------------------------- /imgs/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/1.gif -------------------------------------------------------------------------------- /imgs/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/2.gif -------------------------------------------------------------------------------- /imgs/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/3.gif -------------------------------------------------------------------------------- /imgs/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/4.gif -------------------------------------------------------------------------------- /imgs/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/5.gif -------------------------------------------------------------------------------- /init_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/.DS_Store -------------------------------------------------------------------------------- /init_images/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/0.jpg -------------------------------------------------------------------------------- /init_images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/1.jpg -------------------------------------------------------------------------------- /init_images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/2.jpg -------------------------------------------------------------------------------- /init_images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/3.jpg -------------------------------------------------------------------------------- /init_images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/4.jpg -------------------------------------------------------------------------------- /init_images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/5.jpg -------------------------------------------------------------------------------- /newanimate.yaml: -------------------------------------------------------------------------------- 1 | name: newanimate 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - bzip2=1.0.8=h7b6447c_0 8 | - ca-certificates=2023.08.22=h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.4=h6a678d5_0 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - libuuid=1.41.5=h5eee18b_0 15 | - ncurses=6.4=h6a678d5_0 16 | - openssl=3.0.10=h7f8727e_2 17 | - pip=23.2.1=py310h06a4308_0 18 | - python=3.10.13=h955ad1f_0 19 | - readline=8.2=h5eee18b_0 20 | - setuptools=68.0.0=py310h06a4308_0 21 | - sqlite=3.41.2=h5eee18b_0 22 | - tk=8.6.12=h1ccaba5_0 23 | - wheel=0.38.4=py310h06a4308_0 24 | - xz=5.4.2=h5eee18b_0 25 | - zlib=1.2.13=h5eee18b_0 26 | - pip: 27 | - accelerate==0.23.0 28 | - aiofiles==23.2.1 29 | - altair==5.1.1 30 | - annotated-types==0.5.0 31 | - antlr4-python3-runtime==4.9.3 32 | - anyio==3.7.1 33 | - appdirs==1.4.4 34 | - attrs==23.1.0 35 | - beautifulsoup4==4.12.2 36 | - certifi==2023.7.22 37 | - charset-normalizer==3.2.0 38 | - click==8.1.7 39 | - contourpy==1.1.1 40 | - cycler==0.11.0 41 | - decord==0.6.0 42 | - diffusers==0.21.4 43 | - docker-pycreds==0.4.0 44 | - einops==0.6.1 45 | - exceptiongroup==1.1.3 46 | - fastapi==0.103.1 47 | - ffmpy==0.3.1 48 | - filelock==3.12.4 49 | - fonttools==4.42.1 50 | - fsspec==2023.9.1 51 | - gdown==4.7.1 52 | - gitdb==4.0.10 53 | - gitpython==3.1.37 54 | - gradio==3.44.3 55 | - gradio-client==0.5.0 56 | - h11==0.14.0 57 | - httpcore==0.18.0 58 | - httpx==0.25.0 59 | - huggingface-hub==0.17.2 60 | - idna==3.4 61 | - imageio==2.27.0 62 | - importlib-metadata==6.8.0 63 | - importlib-resources==6.0.1 64 | - jinja2==3.1.2 65 | - jsonschema==4.19.0 66 | - jsonschema-specifications==2023.7.1 67 | - kiwisolver==1.4.5 68 | - markupsafe==2.1.3 69 | - matplotlib==3.8.0 70 | - mypy-extensions==1.0.0 71 | - numpy==1.26.0 72 | - nvidia-cublas-cu11==11.10.3.66 73 | - nvidia-cuda-nvrtc-cu11==11.7.99 74 | - nvidia-cuda-runtime-cu11==11.7.99 75 | - nvidia-cudnn-cu11==8.5.0.96 76 | - omegaconf==2.3.0 77 | - opencv-python==4.8.0.76 78 | - orjson==3.9.7 79 | - packaging==23.1 80 | - pandas==2.1.0 81 | - pathtools==0.1.2 82 | - pillow==10.0.1 83 | - protobuf==4.24.4 84 | - psutil==5.9.5 85 | - pydantic==2.3.0 86 | - pydantic-core==2.6.3 87 | - pydub==0.25.1 88 | - pyparsing==3.1.1 89 | - pyre-extensions==0.0.23 90 | - pysocks==1.7.1 91 | - python-dateutil==2.8.2 92 | - python-multipart==0.0.6 93 | - pytz==2023.3.post1 94 | - pyyaml==6.0.1 95 | - referencing==0.30.2 96 | - regex==2023.8.8 97 | - requests==2.31.0 98 | - rpds-py==0.10.3 99 | - safetensors==0.3.3 100 | - semantic-version==2.10.0 101 | - sentry-sdk==1.32.0 102 | - setproctitle==1.3.3 103 | - six==1.16.0 104 | - smmap==5.0.1 105 | - sniffio==1.3.0 106 | - soupsieve==2.5 107 | - starlette==0.27.0 108 | - tokenizers==0.13.3 109 | - toolz==0.12.0 110 | - torch==1.13.1 111 | - torchaudio==0.13.1 112 | - torchvision==0.14.1 113 | - tqdm==4.66.1 114 | - transformers==4.33.2 115 | - triton==2.1.0 116 | - typing-extensions==4.8.0 117 | - typing-inspect==0.9.0 118 | - tzdata==2023.3 119 | - urllib3==2.0.4 120 | - uvicorn==0.23.2 121 | - wandb==0.15.12 122 | - websockets==11.0.3 123 | - xformers==0.0.16 124 | - zipp==3.17.0 125 | prefix: /root/anaconda3/envs/newanimate 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | aiofiles==23.2.1 3 | altair==5.1.1 4 | annotated-types==0.5.0 5 | antlr4-python3-runtime==4.9.3 6 | anyio==3.7.1 7 | appdirs==1.4.4 8 | attrs==23.1.0 9 | beautifulsoup4==4.12.2 10 | certifi==2023.7.22 11 | charset-normalizer==3.2.0 12 | click==8.1.7 13 | contourpy==1.1.1 14 | cycler==0.11.0 15 | decord==0.6.0 16 | diffusers==0.21.4 17 | docker-pycreds==0.4.0 18 | einops==0.6.1 19 | exceptiongroup==1.1.3 20 | fastapi==0.103.1 21 | ffmpy==0.3.1 22 | filelock==3.12.4 23 | fonttools==4.42.1 24 | fsspec==2023.9.1 25 | gdown==4.7.1 26 | gitdb==4.0.10 27 | GitPython==3.1.37 28 | gradio==3.44.3 29 | gradio_client==0.5.0 30 | h11==0.14.0 31 | httpcore==0.18.0 32 | httpx==0.25.0 33 | huggingface-hub==0.17.2 34 | idna==3.4 35 | imageio==2.27.0 36 | importlib-metadata==6.8.0 37 | importlib-resources==6.0.1 38 | Jinja2==3.1.2 39 | jsonschema==4.19.0 40 | jsonschema-specifications==2023.7.1 41 | kiwisolver==1.4.5 42 | MarkupSafe==2.1.3 43 | matplotlib==3.8.0 44 | mypy-extensions==1.0.0 45 | numpy==1.26.0 46 | nvidia-cublas-cu11==11.10.3.66 47 | nvidia-cuda-nvrtc-cu11==11.7.99 48 | nvidia-cuda-runtime-cu11==11.7.99 49 | nvidia-cudnn-cu11==8.5.0.96 50 | omegaconf==2.3.0 51 | opencv-python==4.8.0.76 52 | orjson==3.9.7 53 | packaging==23.1 54 | pandas==2.1.0 55 | pathtools==0.1.2 56 | Pillow==10.0.1 57 | protobuf==4.24.4 58 | psutil==5.9.5 59 | pydantic==2.3.0 60 | pydantic_core==2.6.3 61 | pydub==0.25.1 62 | pyparsing==3.1.1 63 | pyre-extensions==0.0.23 64 | PySocks==1.7.1 65 | python-dateutil==2.8.2 66 | python-multipart==0.0.6 67 | pytz==2023.3.post1 68 | PyYAML==6.0.1 69 | referencing==0.30.2 70 | regex==2023.8.8 71 | requests==2.31.0 72 | rpds-py==0.10.3 73 | safetensors==0.3.3 74 | semantic-version==2.10.0 75 | sentry-sdk==1.32.0 76 | setproctitle==1.3.3 77 | six==1.16.0 78 | smmap==5.0.1 79 | sniffio==1.3.0 80 | soupsieve==2.5 81 | starlette==0.27.0 82 | tokenizers==0.13.3 83 | toolz==0.12.0 84 | torch==1.13.1 85 | torchaudio==0.13.1 86 | torchvision==0.14.1 87 | tqdm==4.66.1 88 | transformers==4.33.2 89 | triton==2.1.0 90 | typing-inspect==0.9.0 91 | typing_extensions==4.8.0 92 | tzdata==2023.3 93 | urllib3==2.0.4 94 | uvicorn==0.23.2 95 | wandb==0.15.12 96 | websockets==11.0.3 97 | xformers==0.0.16 98 | zipp==3.17.0 99 | -------------------------------------------------------------------------------- /scripts/__pycache__/animate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/scripts/__pycache__/animate.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/animate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import os 5 | from omegaconf import OmegaConf 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision import models 11 | from torch.nn import functional as F 12 | import torchvision.transforms as transforms 13 | 14 | import diffusers 15 | from diffusers import AutoencoderKL, DDIMScheduler 16 | import pickle 17 | 18 | from tqdm.auto import tqdm 19 | from transformers import CLIPTextModel, CLIPTokenizer 20 | 21 | import sys 22 | sys.path.append("/root/lh/AnimateDiff-main/") 23 | 24 | from animatediff.models.unet import UNet3DConditionModel 25 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 26 | from animatediff.utils.util import save_videos_grid 27 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 28 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora 29 | from diffusers.utils.import_utils import is_xformers_available 30 | 31 | from einops import rearrange, repeat 32 | 33 | import csv, pdb, glob 34 | from safetensors import safe_open 35 | import math 36 | from pathlib import Path 37 | 38 | 39 | def main(args): 40 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 41 | func_args = dict(func_args) 42 | 43 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 44 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 45 | os.makedirs(savedir) 46 | 47 | config = OmegaConf.load(args.config) 48 | samples = [] 49 | 50 | sample_idx = 0 51 | for model_idx, (config_key, model_config) in enumerate(list(config.items())): 52 | 53 | motion_modules = model_config.motion_module 54 | motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules) 55 | for motion_module in motion_modules: 56 | inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config)) 57 | 58 | ### >>> create validation pipeline >>> ### 59 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") 60 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") 61 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") 62 | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) 63 | 64 | # if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() 65 | # else: assert False 66 | 67 | pipeline = AnimationPipeline( 68 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 69 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 70 | ).to("cuda") 71 | 72 | # 1. unet ckpt 73 | # 1.1 motion module 74 | motion_module_state_dict = torch.load(motion_module, map_location="cpu") 75 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 76 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) 77 | assert len(unexpected) == 0 78 | 79 | # 1.2 T2I 80 | if model_config.path != "": 81 | if model_config.path.endswith(".ckpt"): 82 | state_dict = torch.load(model_config.path) 83 | pipeline.unet.load_state_dict(state_dict) 84 | 85 | elif model_config.path.endswith(".safetensors"): 86 | state_dict = {} 87 | with safe_open(model_config.path, framework="pt", device="cpu") as f: 88 | for key in f.keys(): 89 | state_dict[key] = f.get_tensor(key) 90 | 91 | is_lora = all("lora" in k for k in state_dict.keys()) 92 | if not is_lora: 93 | base_state_dict = state_dict 94 | else: 95 | base_state_dict = {} 96 | with safe_open(model_config.base, framework="pt", device="cpu") as f: 97 | for key in f.keys(): 98 | base_state_dict[key] = f.get_tensor(key) 99 | 100 | # vae 101 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) 102 | pipeline.vae.load_state_dict(converted_vae_checkpoint) 103 | # unet 104 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) 105 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 106 | # text_model 107 | pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) 108 | 109 | # import pdb 110 | # pdb.set_trace() 111 | if is_lora: 112 | pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) 113 | 114 | pipeline.to("cuda") 115 | ### <<< create validation pipeline <<< ### 116 | 117 | prompts = model_config.prompt 118 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt 119 | 120 | random_seeds = model_config.get("seed", [-1]) 121 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 122 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 123 | 124 | config[config_key].random_seed = [] 125 | 126 | #------------------------------------------------ 127 | pixel_transforms = transforms.Compose([ 128 | transforms.RandomHorizontalFlip(), 129 | transforms.Resize(512), 130 | transforms.CenterCrop(512), 131 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 132 | ]) 133 | pixel_values = Image.open("/root/lh/AnimateDiff-main/sample.jpg") 134 | pixel_values = np.array(pixel_values) 135 | pixel_values = torch.from_numpy(pixel_values).permute(2,0,1).unsqueeze(0) 136 | pixel_values = pixel_values / 255. 137 | pixel_values = pixel_transforms(pixel_values).cuda() 138 | # latents = pipeline.vae.encode(pixel_values).latent_dist 139 | # latents = latents.sample() 140 | 141 | # latents = latents * 0.18215 142 | # latents = latents.unsqueeze(2).repeat(1,1,16,1,1) 143 | 144 | 145 | # -------------------------------------------------- 146 | 147 | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): 148 | 149 | # manually set random seed for reproduction 150 | if random_seed != -1: torch.manual_seed(random_seed) 151 | else: torch.seed() 152 | config[config_key].random_seed.append(torch.initial_seed()) 153 | 154 | print(f"current seed: {torch.initial_seed()}") 155 | print(f"sampling {prompt} ...") 156 | sample = pipeline( 157 | prompt, 158 | negative_prompt = n_prompt, 159 | num_inference_steps = model_config.steps, 160 | guidance_scale = model_config.guidance_scale, 161 | width = args.W, 162 | height = args.H, 163 | video_length = args.L, 164 | # latents = pixel_values 165 | ).videos 166 | samples.append(sample) 167 | 168 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) 169 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") 170 | print(f"save to {savedir}/sample/{prompt}.gif") 171 | 172 | sample_idx += 1 173 | 174 | samples = torch.concat(samples) 175 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) 176 | 177 | OmegaConf.save(config, f"{savedir}/config.yaml") 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument("--pretrained_model_path", type=str, default="/root/lh/stable-diffusion-v1-5",) 183 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v2.yaml") 184 | parser.add_argument("--config", type=str, default="configs/prompts/v2/5-RealisticVision1.yaml") 185 | 186 | parser.add_argument("--L", type=int, default=16 ) 187 | parser.add_argument("--W", type=int, default=512) 188 | parser.add_argument("--H", type=int, default=512) 189 | 190 | args = parser.parse_args() 191 | main(args) 192 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import imageio 4 | import numpy as np 5 | import wandb 6 | import random 7 | import logging 8 | import inspect 9 | import argparse 10 | import datetime 11 | import subprocess 12 | import multiprocessing as mp 13 | 14 | 15 | from pathlib import Path 16 | from tqdm.auto import tqdm 17 | from einops import rearrange 18 | from omegaconf import OmegaConf 19 | from safetensors import safe_open 20 | from typing import Dict, Optional, Tuple 21 | 22 | import torch 23 | import torchvision 24 | import torch.nn.functional as F 25 | import torch.distributed as dist 26 | from torch.optim.swa_utils import AveragedModel 27 | from torch.utils.data.distributed import DistributedSampler 28 | from torch.nn.parallel import DistributedDataParallel as DDP 29 | 30 | import diffusers 31 | from diffusers import AutoencoderKL, DDIMScheduler 32 | from diffusers.models import UNet2DConditionModel 33 | from diffusers.pipelines import StableDiffusionPipeline 34 | from diffusers.optimization import get_scheduler 35 | from diffusers.utils import check_min_version 36 | from diffusers.utils.import_utils import is_xformers_available 37 | from diffusers.image_processor import VaeImageProcessor 38 | 39 | import transformers 40 | from transformers import CLIPTextModel, CLIPTokenizer 41 | 42 | from animatediff.data.dataset import WebVid10M 43 | from animatediff.models.unet import UNet3DConditionModel 44 | from animatediff.models.controlnet import ControlNetModel 45 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 46 | from animatediff.utils.util import save_videos_grid, zero_rank_print 47 | 48 | 49 | def prepare_image( 50 | image, 51 | width, 52 | height, 53 | batch_size, 54 | num_images_per_prompt, 55 | device, 56 | dtype, 57 | do_classifier_free_guidance=False, 58 | guess_mode=False, 59 | ): 60 | control_image_processor = VaeImageProcessor() 61 | image = control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) 62 | image_batch_size = image.shape[0] 63 | 64 | if image_batch_size == 1: 65 | repeat_by = batch_size 66 | else: 67 | # image batch size is the same as prompt batch size 68 | repeat_by = num_images_per_prompt 69 | 70 | image = image.repeat_interleave(repeat_by, dim=0) 71 | 72 | image = image.to(device=device, dtype=dtype) 73 | 74 | if do_classifier_free_guidance and not guess_mode: 75 | image = torch.cat([image] * 2) 76 | 77 | return image 78 | 79 | 80 | 81 | 82 | 83 | def main( 84 | image_finetune: bool, 85 | 86 | name: str, 87 | use_wandb: bool, 88 | launcher: str, 89 | 90 | output_dir: str, 91 | pretrained_model_path: str, 92 | 93 | train_data: Dict, 94 | validation_data: Dict, 95 | cfg_random_null_text: bool = True, 96 | cfg_random_null_text_ratio: float = 0.1, 97 | 98 | unet_checkpoint_path: str = "", 99 | unet_additional_kwargs: Dict = {}, 100 | ema_decay: float = 0.9999, 101 | noise_scheduler_kwargs = None, 102 | 103 | max_train_epoch: int = -1, 104 | max_train_steps: int = 100, 105 | validation_steps: int = 100, 106 | validation_steps_tuple: Tuple = (-1,), 107 | 108 | learning_rate: float = 3e-5, 109 | scale_lr: bool = False, 110 | lr_warmup_steps: int = 0, 111 | lr_scheduler: str = "constant", 112 | 113 | trainable_modules: Tuple[str] = (None, ), 114 | num_workers: int = 32, 115 | train_batch_size: int = 1, 116 | adam_beta1: float = 0.9, 117 | adam_beta2: float = 0.999, 118 | adam_weight_decay: float = 1e-2, 119 | adam_epsilon: float = 1e-08, 120 | max_grad_norm: float = 1.0, 121 | gradient_accumulation_steps: int = 1, 122 | gradient_checkpointing: bool = False, 123 | checkpointing_epochs: int = 5, 124 | checkpointing_steps: int = -1, 125 | 126 | mixed_precision_training: bool = True, 127 | enable_xformers_memory_efficient_attention: bool = True, 128 | 129 | global_seed: int = 42, 130 | is_debug: bool = False, 131 | ): 132 | check_min_version("0.10.0.dev0") 133 | 134 | # Initialize distributed training 135 | # local_rank = init_dist(launcher=launcher) 136 | # local_rank = 1 137 | # global_rank = dist.get_rank() 138 | # num_processes = dist.get_world_size() 139 | # is_main_process = global_rank == 0 140 | is_main_process = True 141 | 142 | # seed = global_seed + global_rank 143 | seed = 42 144 | torch.manual_seed(seed) 145 | 146 | # Logging folder 147 | folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") 148 | output_dir = os.path.join(output_dir, folder_name) 149 | if is_debug and os.path.exists(output_dir): 150 | os.system(f"rm -rf {output_dir}") 151 | 152 | *_, config = inspect.getargvalues(inspect.currentframe()) 153 | 154 | # Make one log on every process with the configuration for debugging. 155 | logging.basicConfig( 156 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 157 | datefmt="%m/%d/%Y %H:%M:%S", 158 | level=logging.INFO, 159 | ) 160 | 161 | # 需要设置wandb账号 162 | if is_main_process and (not is_debug) and use_wandb: 163 | run = wandb.init(project="animatediff_pics_controlnetonly", name=folder_name, config=config) 164 | 165 | # Handle the output folder creation 166 | if is_main_process: 167 | os.makedirs(output_dir, exist_ok=True) 168 | os.makedirs(f"{output_dir}/samples", exist_ok=True) 169 | os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) 170 | os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) 171 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 172 | 173 | #----------------------------------------------------------------------------------------------- 174 | # Load scheduler, tokenizer and models. 175 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 176 | 177 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 178 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 179 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 180 | unet2d = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") 181 | # controlnet = ControlNetModel.from_unet(unet2d) 182 | controlnet = ControlNetModel() 183 | # unet = UNet3DConditionModel() 184 | if not image_finetune: 185 | unet = UNet3DConditionModel.from_pretrained_2d( 186 | pretrained_model_path, subfolder="unet", 187 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) 188 | ) 189 | else: 190 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") 191 | 192 | # Load pretrained unet weights 193 | # if unet_checkpoint_path != "": 194 | # zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") 195 | # unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") 196 | # if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") 197 | # state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path 198 | 199 | # m, u = unet.load_state_dict(state_dict, strict=False) 200 | # zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 201 | # assert len(u) == 0 202 | motion_module_state_dict = torch.load("models/Motion_Module/mm_sd_v15_v2.ckpt", map_location="cpu") 203 | # # print(motion_module_state_dict) 204 | # # if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) 205 | missing, unexpected = unet.load_state_dict(motion_module_state_dict, strict=False) 206 | # print(f"### missing keys: {len(missing)}; \n### unexpected keys: {len(unexpected)};") 207 | # print(f"### missing keys:\n{missing}\n### unexpected keys:\n{unexpected}\n") 208 | assert len(unexpected) == 0 209 | 210 | controlnet_state_dict = torch.load("/root/lh/AnimateDiffcontrolnet-main/outputs/1/checkpoints/controlnet_checkpoint-epoch-30.ckpt", map_location="cpu") 211 | missing, unexpected = controlnet.load_state_dict(controlnet_state_dict["state_dict"], strict=False) 212 | assert len(unexpected) == 0 213 | #----------------------------------------------------------------------------------------------- 214 | 215 | # Freeze vae and text_encoder 216 | vae.requires_grad_(False) 217 | text_encoder.requires_grad_(False) 218 | # controlnet.requires_grad_(False) 219 | # for name, param in controlnet.named_parameters(): 220 | # print(name, ": ", param.requires_grad ) 221 | # print("---------------------------------------") 222 | 223 | 224 | # 把这里打上断点 看一下unet的结构 225 | # Set unet trainable parameters 226 | # print(unet) 227 | unet.requires_grad_(False) 228 | # unet.requires_grad_(True) 229 | # for name, param in unet.named_parameters(): 230 | # for trainable_module_name in trainable_modules: 231 | # if trainable_module_name in name: 232 | # param.requires_grad = True 233 | # break 234 | # trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) 235 | trainable_params = [] 236 | for name, param in controlnet.named_parameters(): 237 | if (param.requires_grad): 238 | trainable_params.append(param) 239 | # print(name, ": ", param.requires_grad ) 240 | # trainable_params.append(list(filter(lambda p: p.requires_grad, controlnet.parameters()))) 241 | optimizer = torch.optim.AdamW( 242 | trainable_params, 243 | lr=learning_rate, 244 | betas=(adam_beta1, adam_beta2), 245 | weight_decay=adam_weight_decay, 246 | eps=adam_epsilon, 247 | ) 248 | #----------------------------------------------------------------------------------------------- 249 | 250 | if is_main_process: 251 | # zero_rank_print(f"trainable params number: {len(trainable_params)}") 252 | # zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") 253 | print(f"trainable params number: {len(trainable_params)}") 254 | print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") 255 | #----------------------------------------------------------------------------------------------- 256 | 257 | # Enable xformers 258 | if enable_xformers_memory_efficient_attention: 259 | if is_xformers_available(): 260 | unet.enable_xformers_memory_efficient_attention() 261 | else: 262 | raise ValueError("xformers is not available. Make sure it is installed correctly") 263 | 264 | # Enable gradient checkpointing 265 | if gradient_checkpointing: 266 | unet.enable_gradient_checkpointing() 267 | 268 | # Move models to GPU 269 | # vae.to(local_rank) 270 | # text_encoder.to(local_rank) 271 | vae.to("cuda") 272 | controlnet.to("cuda") 273 | text_encoder.to("cuda") 274 | unet.to("cuda") 275 | #----------------------------------------------------------------------------------------------- 276 | 277 | # Get the training dataset 278 | train_dataset = WebVid10M(**train_data, is_image=image_finetune) 279 | # distributed_sampler = DistributedSampler( 280 | # train_dataset, 281 | # num_replicas=1, 282 | # rank=0, 283 | # shuffle=True, 284 | # seed=global_seed, 285 | # ) 286 | 287 | # DataLoaders creation: 288 | train_dataloader = torch.utils.data.DataLoader( 289 | train_dataset, 290 | batch_size=train_batch_size, 291 | shuffle=False, 292 | # sampler=distributed_sampler, 293 | num_workers=0, 294 | pin_memory=True, 295 | drop_last=True, 296 | ) 297 | #----------------------------------------------------------------------------------------------- 298 | 299 | # Get the training iteration 300 | if max_train_steps == -1: 301 | assert max_train_epoch != -1 302 | max_train_steps = max_train_epoch * len(train_dataloader) 303 | 304 | if checkpointing_steps == -1: 305 | assert checkpointing_epochs != -1 306 | checkpointing_steps = checkpointing_epochs * len(train_dataloader) 307 | 308 | if scale_lr: 309 | learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size ) 310 | 311 | # Scheduler 312 | lr_scheduler = get_scheduler( 313 | lr_scheduler, 314 | optimizer=optimizer, 315 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, 316 | num_training_steps=max_train_steps * gradient_accumulation_steps, 317 | ) 318 | #----------------------------------------------------------------------------------------------- 319 | 320 | # Validation pipeline 321 | if not image_finetune: 322 | validation_pipeline = AnimationPipeline( 323 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, controlnet=controlnet, 324 | ).to("cuda") 325 | else: 326 | validation_pipeline = StableDiffusionPipeline.from_pretrained( 327 | pretrained_model_path, 328 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None, 329 | ) 330 | validation_pipeline.enable_vae_slicing() 331 | #----------------------------------------------------------------------------------------------- 332 | 333 | # DDP warpper 334 | # unet = DDP(unet, device_ids=["cuda:0"], output_device="cuda:0") 335 | 336 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 337 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 338 | # Afterwards we recalculate our number of training epochs 339 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 340 | 341 | # Train! 342 | total_batch_size = train_batch_size * gradient_accumulation_steps 343 | 344 | if is_main_process: 345 | logging.info("***** Running training *****") 346 | logging.info(f" Num examples = {len(train_dataset)}") 347 | logging.info(f" Num Epochs = {num_train_epochs}") 348 | logging.info(f" Instantaneous batch size per device = {train_batch_size}") 349 | logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 350 | logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 351 | logging.info(f" Total optimization steps = {max_train_steps}") 352 | global_step = 0 353 | first_epoch = 0 354 | 355 | # Only show the progress bar once on each machine. 356 | progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) 357 | progress_bar.set_description("Steps") 358 | 359 | # Support mixed-precision training 360 | scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None 361 | # mp.set_start_method('spawn') 362 | for epoch in range(first_epoch, num_train_epochs): 363 | # train_dataloader.sampler.set_epoch(epoch) 364 | unet.train() 365 | # mp.set_start_method('spawn') 366 | for step, batch in enumerate(train_dataloader): 367 | if cfg_random_null_text: 368 | batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] 369 | 370 | # Data batch sanity check 371 | if epoch == first_epoch and step == 0: 372 | pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] 373 | if not image_finetune: 374 | pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") 375 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): 376 | pixel_value = pixel_value[None, ...] 377 | # save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_seed}-{idx}'}.gif", rescale=True) 378 | else: 379 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): 380 | pixel_value = pixel_value / 2. + 0.5 381 | torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_seed}-{idx}'}.png") 382 | 383 | ### >>>> Training >>>> ### 384 | 385 | # Convert videos to latent space 386 | orig_img = batch['image'].squeeze(1) 387 | orig_img = orig_img / 2. + 0.5 388 | pixel_values = batch["pixel_values"].to("cuda") 389 | video_length = pixel_values.shape[1] 390 | with torch.no_grad(): 391 | if not image_finetune: 392 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") 393 | latents = vae.encode(pixel_values).latent_dist 394 | latents = latents.sample() 395 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 396 | else: 397 | latents = vae.encode(pixel_values).latent_dist 398 | latents = latents.sample() 399 | 400 | latents = latents * 0.18215 401 | 402 | 403 | #-------------------------------------- 404 | image = prepare_image( 405 | image=orig_img, 406 | width=orig_img.shape[-1], 407 | height=orig_img.shape[-2], 408 | batch_size=orig_img.shape[0], 409 | num_images_per_prompt=1, 410 | device="cuda", 411 | dtype=controlnet.dtype, 412 | ) 413 | 414 | #-------------------------------------- 415 | 416 | 417 | # Sample noise that we'll add to the latents 如果要加原图信号的话 就是在这里加 418 | noise = torch.randn_like(latents) # latents shape为 [4, 4, 16, 32, 32] 419 | bsz = latents.shape[0] 420 | 421 | # Sample a random timestep for each video 422 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) # shape [4] 423 | timesteps = timesteps.long() 424 | 425 | # Add noise to the latents according to the noise magnitude at each timestep 426 | # (this is the forward diffusion process) 427 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # shape[4, 4, 16, 32, 32] 428 | 429 | # Get the text embedding for conditioning 430 | with torch.no_grad(): 431 | prompt_ids = tokenizer( 432 | batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 433 | ).input_ids.to(latents.device) # shape [4, 77] 434 | encoder_hidden_states = text_encoder(prompt_ids)[0] # shape [4, 77, 768] 435 | 436 | # Get the target for loss depending on the prediction type 437 | if noise_scheduler.config.prediction_type == "epsilon": 438 | target = noise 439 | elif noise_scheduler.config.prediction_type == "v_prediction": 440 | raise NotImplementedError 441 | else: 442 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 443 | 444 | 445 | #------------------------------------------------ 446 | down_block_res_samples, mid_block_res_sample = controlnet( 447 | sample=noisy_latents[:,:,0,:,:], 448 | timestep=timesteps, 449 | encoder_hidden_states=encoder_hidden_states, # [4,77,768] 450 | controlnet_cond=image, 451 | return_dict=False, 452 | ) 453 | 454 | # down_block_additional_residuals 455 | # mid_block_additional_residual 456 | #------------------------------------------------ 457 | 458 | # Predict the noise residual and compute loss 459 | # Mixed-precision training 460 | with torch.cuda.amp.autocast(enabled=mixed_precision_training): 461 | # noisy_latents shape [4, 4, 16, 32, 32] 462 | # encoder_hidden_states [4, 77, 768] 463 | model_pred = unet(sample=noisy_latents, 464 | timestep=timesteps, 465 | encoder_hidden_states=encoder_hidden_states, 466 | down_block_additional_residuals=down_block_res_samples, 467 | mid_block_additional_residual=mid_block_res_sample, 468 | ).sample 469 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 470 | 471 | optimizer.zero_grad() 472 | 473 | # Backpropagate 474 | if mixed_precision_training: 475 | scaler.scale(loss).backward() 476 | """ >>> gradient clipping >>> """ 477 | scaler.unscale_(optimizer) 478 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) 479 | """ <<< gradient clipping <<< """ 480 | scaler.step(optimizer) 481 | scaler.update() 482 | else: 483 | loss.backward() 484 | """ >>> gradient clipping >>> """ 485 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) 486 | """ <<< gradient clipping <<< """ 487 | optimizer.step() 488 | 489 | lr_scheduler.step() 490 | progress_bar.update(1) 491 | global_step += 1 492 | 493 | ### <<<< Training <<<< ### 494 | 495 | # Wandb logging 496 | if is_main_process and (not is_debug) and use_wandb: 497 | wandb.log({"train_loss": loss.item()}, step=global_step) 498 | 499 | # Save checkpoint 500 | if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1): 501 | save_path = os.path.join(output_dir, f"checkpoints") 502 | controlnet_state_dict = { 503 | "epoch": epoch, 504 | "global_step": global_step, 505 | "state_dict": controlnet.state_dict(), 506 | } 507 | if step == len(train_dataloader) - 1: 508 | torch.save(controlnet_state_dict, os.path.join(save_path, f"controlnet_checkpoint-epoch-{epoch+1}.ckpt")) 509 | else: 510 | torch.save(controlnet_state_dict, os.path.join(save_path, f"controlnet_checkpoint.ckpt")) 511 | logging.info(f"Saved state to {save_path} (global_step: {global_step})") 512 | 513 | # Periodically validation 514 | if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): 515 | samples = [] 516 | 517 | generator = torch.Generator(device=latents.device) 518 | generator.manual_seed(global_seed) 519 | 520 | height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size 521 | width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size 522 | 523 | # prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts 524 | prompts = batch['text'] 525 | 526 | init_images = batch['image'].squeeze(1) #[b, 1, c, h, w] 527 | controlnet_images = init_images / 2. + 0.5 528 | init_images = init_images.permute(0,2,3,1) # [b, 1, h, w, c] 529 | init_images = np.array(init_images.cpu()) 530 | for idx, prompt in enumerate(prompts): 531 | if not image_finetune: 532 | controlnet_image = controlnet_images[idx, :, :, :] 533 | sample = validation_pipeline( 534 | prompt, 535 | generator = generator, 536 | video_length = train_data.sample_n_frames, 537 | height = height, 538 | width = width, 539 | controlnet_image=controlnet_image, 540 | **validation_data, 541 | ).videos 542 | init_image = init_images[idx, :, :, :] 543 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") 544 | imageio.imsave(f"{output_dir}/samples/sample-{global_step}/{idx}.jpg", init_image) 545 | samples.append(sample) 546 | 547 | else: 548 | sample = validation_pipeline( 549 | prompt, 550 | generator = generator, 551 | height = height, 552 | width = width, 553 | num_inference_steps = validation_data.get("num_inference_steps", 25), 554 | guidance_scale = validation_data.get("guidance_scale", 8.), 555 | ).images[0] 556 | sample = torchvision.transforms.functional.to_tensor(sample) 557 | samples.append(sample) 558 | 559 | if not image_finetune: 560 | samples = torch.concat(samples) 561 | save_path = f"{output_dir}/samples/sample-{global_step}.gif" 562 | save_videos_grid(samples, save_path) 563 | 564 | else: 565 | samples = torch.stack(samples) 566 | save_path = f"{output_dir}/samples/sample-{global_step}.png" 567 | torchvision.utils.save_image(samples, save_path, nrow=4) 568 | 569 | logging.info(f"Saved samples to {save_path}") 570 | 571 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 572 | progress_bar.set_postfix(**logs) 573 | 574 | if global_step >= max_train_steps: 575 | break 576 | 577 | # dist.destroy_process_group() 578 | 579 | 580 | 581 | if __name__ == "__main__": 582 | parser = argparse.ArgumentParser() 583 | parser.add_argument("--config", type=str, default="./configs/training/training.yaml") 584 | parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch") 585 | # parser.add_argument("--wandb", action="store_true") 586 | parser.add_argument("--wandb", type=bool, default=True) 587 | 588 | args = parser.parse_args() 589 | 590 | name = Path(args.config).stem 591 | config = OmegaConf.load(args.config) 592 | 593 | main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config) 594 | --------------------------------------------------------------------------------