├── .DS_Store
├── README.md
├── animatediff
├── .DS_Store
├── data
│ ├── __pycache__
│ │ └── dataset.cpython-310.pyc
│ ├── dataset.py
│ └── processor.py
├── models
│ ├── __pycache__
│ │ ├── attention.cpython-310.pyc
│ │ ├── controlnet.cpython-310.pyc
│ │ ├── motion_module.cpython-310.pyc
│ │ ├── resnet.cpython-310.pyc
│ │ ├── unet.cpython-310.pyc
│ │ └── unet_blocks.cpython-310.pyc
│ ├── attention.py
│ ├── controlnet.py
│ ├── motion_module.py
│ ├── resnet.py
│ ├── unet.py
│ └── unet_blocks.py
├── pipelines
│ ├── __pycache__
│ │ └── pipeline_animation.cpython-310.pyc
│ └── pipeline_animation.py
└── utils
│ ├── __pycache__
│ ├── convert_from_ckpt.cpython-310.pyc
│ ├── convert_lora_safetensor_to_diffusers.cpython-310.pyc
│ └── util.cpython-310.pyc
│ ├── convert_from_ckpt.py
│ ├── convert_lora_safetensor_to_diffusers.py
│ └── util.py
├── animatetest.py
├── configs
├── inference
│ ├── inference-v1.yaml
│ └── inference-v2.yaml
├── prompts
│ └── v2
│ │ └── 5-RealisticVision.yaml
└── training
│ ├── image_finetune.yaml
│ └── training.yaml
├── download_data.py
├── imgs
├── .DS_Store
├── 0.gif
├── 1.gif
├── 2.gif
├── 3.gif
├── 4.gif
└── 5.gif
├── init_images
├── .DS_Store
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
└── 5.jpg
├── newanimate.yaml
├── requirements.txt
├── scripts
├── __pycache__
│ └── animate.cpython-310.pyc
└── animate.py
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Animatediff with controlnet
2 | ### Descirption: Add a controlnet to animatediff to animate a given image.
3 |
4 |
5 |
6 |  |
7 |  |
8 |  |
9 |  |
10 |
11 |
12 |
13 |
14 |  |
15 |  |
16 |  |
17 |  |
18 |
19 |
20 |
21 |
22 |  |
23 |  |
24 |  |
25 |  |
26 |
27 |
28 |
29 | [Animatediff](https://github.com/guoyww/AnimateDiff) is a recent animation project based on SD, which produces excellent results. This repository aims to enhance Animatediff in two ways:
30 |
31 | 1. Animating a specific image: Starting from a given image and utilizing controlnet, it maintains the appearance of the image while animating it.
32 |
33 | 2. Upgrading the previous code's diffusers version: The previous code used diffusers version 0.11.1, and the upgraded version now uses diffusers version 0.21.4. This allows for the extension of Animatediff to include more features from diffusers, such as controlnet.
34 |
35 | #### TODO:
36 |
37 | - [x] Release the train and inference code
38 | - [x] Release the controlnet [checkpoint](https://huggingface.co/crishhh/animatediff_controlnet)
39 | - [ ] Reduce the GPU memory usage of controlnet in the code
40 | - [ ] Others
41 |
42 | #### How to start (inference)
43 |
44 | 1. Prepare the environment
45 |
46 | ```python
47 | conda env create -f newanimate.yaml
48 | # Or
49 | conda create --name newanimate python=3.10
50 | pip install -r requirements.txt
51 | ```
52 |
53 | 2. Download the models according to [AnimateDiff](https://github.com/guoyww/AnimateDiff), put them in ./models. Download the controlnet [checkpoint](https://huggingface.co/crishhh/animatediff_controlnet), put them in ./checkpoints.
54 |
55 | 3. Prepare the prompts and initial image(Prepare the prompts and initial image)
56 |
57 | Note that the prompts are important for the animation, here I use the MiniGPT-4, and the prompt to [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) is "Please output the perfect description prompt of this picture into the StableDiffusion model, and separate the description into multiple keywords with commas"
58 |
59 | 4. Modify the YAML file (location: ./configs/prompts/v2/5-RealisticVision.yaml)
60 |
61 | 5. Run the demo
62 |
63 | ```python
64 | python animatetest.py
65 | ```
66 |
67 | #### How to train
68 |
69 | 1. Download the datasets (WebVid-10M)
70 |
71 | ```python
72 | python download_data.py
73 | ```
74 |
75 | 2. Run the train
76 |
77 | ```python
78 | python train.py
79 | ```
80 |
81 | #### Limitations
82 |
83 | 1. The current ControlNet version has been trained on a subset of WebVid-10M, comprising approximately 5,000 video-caption pairs. As a result, its performance is not very satisfactory, and work is underway to train ControlNet on larger datasets.
84 | 2. Some images are proving challenging to animate effectively, even when prompted with corresponding instructions. These difficulties persist when attempting to manipulate them using Animatediff without the use of ControlNet.
85 | 3. It is preferable for the image and its corresponding prompts to have a stronger alignment for better results.
86 |
87 | #### Future
88 |
89 | 1. Currently, the ControlNet in use is 2D level, and our plan is to expand it to 3D while incorporating the motion module into the ControlNet.
90 | 2. We aim to incorporate a trajectory encoder into the ControlNet branch to control the motion module. Even though this might appear to potentially conflict with the existing motion module, we still want to give it a try.
91 |
92 | #### Some Failed Attempts (Possibly Due to Missteps):
93 |
94 | 1. Injecting the encoded image by VAE into the initial latent space doesn't seem to work, it generates videos with similar styles but inconsistent appearances.
95 | 2. Performing DDIM inversion on the image to obtain noise and then denoising it, while seemingly drawing inspiration from common image editing methods, doesn't yield effective results based on our observations.
96 |
97 | The code in this repository is intended solely as an experimental demo. If you have any feedback or questions, please feel free to open an issue or contact me via email at crystallee0418@gmail.com.
98 |
99 | The code in this repository is derived from Animatediff and Diffusers.
100 |
--------------------------------------------------------------------------------
/animatediff/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/.DS_Store
--------------------------------------------------------------------------------
/animatediff/data/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/data/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/data/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, io, csv, math, random
3 | import numpy as np
4 | from einops import rearrange
5 | from decord import VideoReader
6 |
7 |
8 | import torch
9 | import torchvision.transforms as transforms
10 | from torch.utils.data.dataset import Dataset
11 | from animatediff.utils.util import zero_rank_print
12 |
13 |
14 |
15 |
16 | class WebVid10M(Dataset):
17 | def __init__(
18 | self,
19 | csv_path, video_folder, opticalflow_folder,
20 | sample_size=256, sample_stride=4, sample_n_frames=16,
21 | is_image=False,
22 | ):
23 | zero_rank_print(f"loading annotations from {csv_path} ...")
24 | with open(csv_path, 'r') as csvfile:
25 | self.dataset = list(csv.DictReader(csvfile))
26 | self.length = len(self.dataset)
27 | zero_rank_print(f"data scale: {self.length}")
28 |
29 | self.video_folder = video_folder
30 | self.opticalflow_folder = opticalflow_folder
31 | self.sample_stride = sample_stride
32 | self.sample_n_frames = sample_n_frames
33 | self.is_image = is_image
34 |
35 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
36 | self.pixel_transforms = transforms.Compose([
37 | transforms.RandomHorizontalFlip(),
38 | transforms.Resize(sample_size[0]),
39 | transforms.CenterCrop(sample_size),
40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41 | ])
42 |
43 |
44 |
45 | def get_batch(self, idx):
46 | video_dict = self.dataset[idx]
47 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
48 |
49 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
50 | video_reader = VideoReader(video_dir)
51 | video_length = len(video_reader)
52 |
53 | if not self.is_image:
54 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
55 | start_idx = random.randint(0, video_length - clip_length)
56 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
57 | else:
58 | batch_index = [random.randint(0, video_length - 1)]
59 |
60 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
61 | pixel_values = pixel_values / 255.
62 | del video_reader
63 |
64 | if self.is_image:
65 | pixel_values = pixel_values[0]
66 |
67 | return pixel_values, pixel_values[0], name
68 |
69 | def __len__(self):
70 | return self.length
71 |
72 | def __getitem__(self, idx):
73 | while True:
74 | try:
75 | pixel_values, image, name = self.get_batch(idx)
76 | break
77 |
78 | except Exception as e:
79 | idx = random.randint(0, self.length-1)
80 |
81 | pixel_values = self.pixel_transforms(pixel_values) # shape [16,3,256,256]
82 | sample = dict(pixel_values = pixel_values, image = pixel_values[0, :, :, :], text = name)
83 | return sample
84 |
85 |
86 |
87 | if __name__ == "__main__":
88 | from animatediff.utils.util import save_videos_grid
89 |
90 | dataset = WebVid10M(
91 | csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_train.csv",
92 | video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
93 | sample_size=256,
94 | sample_stride=4, sample_n_frames=16,
95 | is_image=True,
96 | )
97 | import pdb
98 | pdb.set_trace()
99 |
100 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
101 | for idx, batch in enumerate(dataloader):
102 | print(batch["pixel_values"].shape, len(batch["text"]))
103 | # for i in range(batch["pixel_values"].shape[0]):
104 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
105 |
--------------------------------------------------------------------------------
/animatediff/data/processor.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from tqdm import tqdm
4 | sys.path.append('./core')
5 | from raft import RAFT
6 | from utils import flow_viz
7 | from utils.utils import InputPadder
8 |
9 | import argparse
10 | import os, io, csv, math, random
11 | import numpy as np
12 | from einops import rearrange
13 | from decord import VideoReader
14 | import torch
15 | import torchvision.transforms as transforms
16 | from decord._ffi.base import DECORDError
17 |
18 |
19 |
20 | def extract_optical_flow(csv_path, output_dir, video_folder, sample_stride, sample_n_frames, sample_size):
21 |
22 | parser = argparse.ArgumentParser()
23 | args = parser.parse_args()
24 | model = torch.nn.DataParallel(RAFT(args))
25 | state_dict = torch.load("/root/lh/RAFT-master/models/raft-things.pth")
26 | model.load_state_dict(state_dict)
27 |
28 | model = model.module
29 | model.to("cuda")
30 | model.eval()
31 |
32 | with open(csv_path, 'r') as csvfile:
33 | dataset = list(csv.DictReader(csvfile))
34 | length = len(dataset)
35 | video_folder = video_folder
36 | sample_stride = sample_stride
37 | sample_n_frames = sample_n_frames
38 |
39 | pixel_transforms = transforms.Compose([
40 | transforms.RandomHorizontalFlip(),
41 | transforms.Resize((sample_size, sample_size)),
42 | transforms.CenterCrop(sample_size),
43 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
44 | ])
45 |
46 | with tqdm(total=length) as pbar:
47 | pbar.set_description("Steps")
48 | for idx in range(length):
49 | video_dict = dataset[idx]
50 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
51 | output_path = output_dir+f"/{videoid}.npy"
52 | video_dir = os.path.join(video_folder, f"{videoid}.mp4")
53 | if os.path.exists(output_path):
54 | print(f"{output_path} already exists, continue")
55 | pbar.update(1)
56 | continue
57 | try:
58 | video_reader = VideoReader(video_dir)
59 | except Exception as e:
60 | print(f"Error reading video at {video_dir}, error: {e}")
61 | pbar.update(1)
62 | continue
63 | video_length = len(video_reader)
64 |
65 | # if not os.path.exists(output_path):
66 | # os.mkdir(output_path)
67 |
68 | clip_length = min(video_length, (sample_n_frames - 1) * sample_stride + 1)
69 | start_idx = random.randint(0, video_length - clip_length)
70 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, sample_n_frames, dtype=int)
71 |
72 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
73 | pixel_values = pixel_values / 255.
74 | del video_reader
75 |
76 |
77 |
78 | pixel_values = pixel_transforms(pixel_values) # shape [16,3,256,256]
79 | #----------------------------------------------
80 | flow_ls = []
81 | with torch.no_grad():
82 | padder = InputPadder(pixel_values[0].shape)
83 | for j in range(pixel_values.shape[0]-1):
84 |
85 | image1, image2 = padder.pad(pixel_values[j], pixel_values[j+1])
86 | image1 = image1.unsqueeze(0).cuda()
87 | image2 = image2.unsqueeze(0).cuda()
88 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
89 | # extra_channel = torch.ones((1, 1, flow_up.shape[2],flow_up.shape[3]))
90 | # flow = torch.concatenate([flow_up.cpu(),extra_channel], dim=1).squeeze(0)
91 | flow = flow_up.cpu().squeeze(0) # shape [2, 256, 256]
92 |
93 | flow_ls.append(flow)
94 | flow_ls = np.array(flow_ls) # shape [15, 2, 256, 256]
95 | np.save(output_path, flow_ls)
96 | pbar.update(1)
97 |
98 |
99 | extract_optical_flow("/root/lh/AnimateDiff-main/results_2M_val.csv",\
100 | "/root/lh/AnimateDiff-main/dataset_opticalflow" ,\
101 | "/root/lh/AnimateDiff-main/datasets", 4, 16, 256)
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/controlnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/controlnet.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/motion_module.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/motion_module.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/resnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/resnet.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/unet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/unet.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/__pycache__/unet_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/models/__pycache__/unet_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config
11 | # from diffusers.modeling_utils import ModelMixin
12 | from diffusers import ModelMixin
13 |
14 | from diffusers.utils import BaseOutput
15 | from diffusers.utils.import_utils import is_xformers_available
16 | # from diffusers.models.attention import CrossAttention
17 | from diffusers.models.attention_processor import Attention
18 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
19 |
20 |
21 | from einops import rearrange, repeat
22 | import pdb
23 |
24 | @dataclass
25 | class Transformer3DModelOutput(BaseOutput):
26 | sample: torch.FloatTensor
27 |
28 |
29 | if is_xformers_available():
30 | import xformers
31 | import xformers.ops
32 | else:
33 | xformers = None
34 |
35 |
36 | class Transformer3DModel(ModelMixin, ConfigMixin):
37 | @register_to_config
38 | def __init__(
39 | self,
40 | num_attention_heads: int = 16,
41 | attention_head_dim: int = 88,
42 | in_channels: Optional[int] = None,
43 | num_layers: int = 1,
44 | dropout: float = 0.0,
45 | norm_num_groups: int = 32,
46 | cross_attention_dim: Optional[int] = None,
47 | attention_bias: bool = False,
48 | activation_fn: str = "geglu",
49 | num_embeds_ada_norm: Optional[int] = None,
50 | use_linear_projection: bool = False,
51 | only_cross_attention: bool = False,
52 | upcast_attention: bool = False,
53 |
54 | unet_use_cross_frame_attention=None,
55 | unet_use_temporal_attention=None,
56 | ):
57 | super().__init__()
58 | self.use_linear_projection = use_linear_projection
59 | self.num_attention_heads = num_attention_heads
60 | self.attention_head_dim = attention_head_dim
61 | inner_dim = num_attention_heads * attention_head_dim
62 |
63 | # Define input layers
64 | self.in_channels = in_channels
65 |
66 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
67 | if use_linear_projection:
68 | self.proj_in = nn.Linear(in_channels, inner_dim)
69 | else:
70 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
71 |
72 | # Define transformers blocks
73 | self.transformer_blocks = nn.ModuleList(
74 | [
75 | BasicTransformerBlock(
76 | inner_dim,
77 | num_attention_heads,
78 | attention_head_dim,
79 | dropout=dropout,
80 | cross_attention_dim=cross_attention_dim,
81 | activation_fn=activation_fn,
82 | num_embeds_ada_norm=num_embeds_ada_norm,
83 | attention_bias=attention_bias,
84 | only_cross_attention=only_cross_attention,
85 | upcast_attention=upcast_attention,
86 |
87 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
88 | unet_use_temporal_attention=unet_use_temporal_attention,
89 | )
90 | for d in range(num_layers)
91 | ]
92 | )
93 |
94 | # 4. Define output layers
95 | if use_linear_projection:
96 | self.proj_out = nn.Linear(in_channels, inner_dim)
97 | else:
98 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
99 |
100 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
101 | # Input
102 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
103 | video_length = hidden_states.shape[2]
104 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
105 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
106 |
107 | batch, channel, height, weight = hidden_states.shape
108 | residual = hidden_states
109 |
110 | hidden_states = self.norm(hidden_states)
111 | if not self.use_linear_projection:
112 | hidden_states = self.proj_in(hidden_states)
113 | inner_dim = hidden_states.shape[1]
114 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
115 | else:
116 | inner_dim = hidden_states.shape[1]
117 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
118 | hidden_states = self.proj_in(hidden_states)
119 |
120 | # Blocks
121 | for block in self.transformer_blocks:
122 | hidden_states = block(
123 | hidden_states,
124 | encoder_hidden_states=encoder_hidden_states,
125 | timestep=timestep,
126 | # video_length=video_length
127 | )
128 |
129 | # Output
130 | if not self.use_linear_projection:
131 | hidden_states = (
132 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
133 | )
134 | hidden_states = self.proj_out(hidden_states)
135 | else:
136 | hidden_states = self.proj_out(hidden_states)
137 | hidden_states = (
138 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
139 | )
140 |
141 | output = hidden_states + residual
142 |
143 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
144 | if not return_dict:
145 | return (output,)
146 |
147 | return Transformer3DModelOutput(sample=output)
148 |
149 |
150 | class AdaLayerNorm(nn.Module):
151 | """
152 | Norm layer modified to incorporate timestep embeddings.
153 | """
154 |
155 | def __init__(self, embedding_dim, num_embeddings):
156 | super().__init__()
157 | self.emb = nn.Embedding(num_embeddings, embedding_dim)
158 | self.silu = nn.SiLU()
159 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
160 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
161 |
162 | def forward(self, x, timestep):
163 | emb = self.linear(self.silu(self.emb(timestep)))
164 | scale, shift = torch.chunk(emb, 2)
165 | x = self.norm(x) * (1 + scale) + shift
166 | return x
167 |
168 | class GEGLU(nn.Module):
169 |
170 | def __init__(self, dim_in: int, dim_out: int):
171 | super().__init__()
172 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
173 |
174 | def gelu(self, gate):
175 | if gate.device.type != "mps":
176 | return F.gelu(gate)
177 | # mps: gelu is not implemented for float16
178 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
179 |
180 | def forward(self, hidden_states, scale: float = 1.0):
181 | hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
182 | return hidden_states * self.gelu(gate)
183 |
184 |
185 | class GELU(nn.Module):
186 |
187 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
188 | super().__init__()
189 | self.proj = nn.Linear(dim_in, dim_out)
190 | self.approximate = approximate
191 |
192 | def gelu(self, gate):
193 | if gate.device.type != "mps":
194 | return F.gelu(gate, approximate=self.approximate)
195 | # mps: gelu is not implemented for float16
196 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
197 |
198 | def forward(self, hidden_states):
199 | hidden_states = self.proj(hidden_states)
200 | hidden_states = self.gelu(hidden_states)
201 | return hidden_states
202 |
203 |
204 | class FeedForward(nn.Module):
205 |
206 | def __init__(
207 | self,
208 | dim: int,
209 | dim_out: Optional[int] = None,
210 | mult: int = 4,
211 | dropout: float = 0.0,
212 | activation_fn: str = "geglu",
213 | final_dropout: bool = False,
214 | ):
215 | super().__init__()
216 | inner_dim = int(dim * mult)
217 | dim_out = dim_out if dim_out is not None else dim
218 |
219 | if activation_fn == "gelu":
220 | act_fn = GELU(dim, inner_dim)
221 | if activation_fn == "gelu-approximate":
222 | act_fn = GELU(dim, inner_dim, approximate="tanh")
223 | elif activation_fn == "geglu":
224 | act_fn = GEGLU(dim, inner_dim)
225 | elif activation_fn == "geglu-approximate":
226 | act_fn = ApproximateGELU(dim, inner_dim)
227 |
228 | self.net = nn.ModuleList([])
229 | # project in
230 | self.net.append(act_fn)
231 | # project dropout
232 | self.net.append(nn.Dropout(dropout))
233 | # project out
234 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
235 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
236 | if final_dropout:
237 | self.net.append(nn.Dropout(dropout))
238 |
239 | def forward(self, hidden_states, scale: float = 1.0):
240 | for module in self.net:
241 | if isinstance(module, (LoRACompatibleLinear, GEGLU)):
242 | hidden_states = module(hidden_states, scale)
243 | else:
244 | hidden_states = module(hidden_states)
245 | return hidden_states
246 |
247 |
248 | class BasicTransformerBlock(nn.Module):
249 |
250 | def __init__(
251 | self,
252 | dim: int,
253 | num_attention_heads: int,
254 | attention_head_dim: int,
255 | dropout=0.0,
256 | cross_attention_dim: Optional[int] = None,
257 | activation_fn: str = "geglu",
258 | num_embeds_ada_norm: Optional[int] = None,
259 | attention_bias: bool = False,
260 | only_cross_attention: bool = False,
261 | double_self_attention: bool = False,
262 | upcast_attention: bool = False,
263 | norm_elementwise_affine: bool = True,
264 | norm_type: str = "layer_norm",
265 | final_dropout: bool = False,
266 | attention_type: str = "default",
267 | unet_use_cross_frame_attention: bool = False,
268 | unet_use_temporal_attention: bool = False
269 | ):
270 | super().__init__()
271 | self.only_cross_attention = only_cross_attention
272 |
273 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
274 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
275 |
276 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
277 | raise ValueError(
278 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
279 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
280 | )
281 |
282 | # Define 3 blocks. Each block has its own normalization layer.
283 | # 1. Self-Attn
284 | if self.use_ada_layer_norm:
285 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
286 | elif self.use_ada_layer_norm_zero:
287 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
288 | else:
289 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
290 | self.attn1 = Attention(
291 | query_dim=dim,
292 | heads=num_attention_heads,
293 | dim_head=attention_head_dim,
294 | dropout=dropout,
295 | bias=attention_bias,
296 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
297 | upcast_attention=upcast_attention,
298 | )
299 |
300 | # 2. Cross-Attn
301 | if cross_attention_dim is not None or double_self_attention:
302 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
303 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
304 | # the second cross attention block.
305 | self.norm2 = (
306 | AdaLayerNorm(dim, num_embeds_ada_norm)
307 | if self.use_ada_layer_norm
308 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
309 | )
310 | self.attn2 = Attention(
311 | query_dim=dim,
312 | cross_attention_dim=cross_attention_dim if not double_self_attention else None,
313 | heads=num_attention_heads,
314 | dim_head=attention_head_dim,
315 | dropout=dropout,
316 | bias=attention_bias,
317 | upcast_attention=upcast_attention,
318 | ) # is self-attn if encoder_hidden_states is none
319 | else:
320 | self.norm2 = None
321 | self.attn2 = None
322 |
323 | # 3. Feed-forward
324 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
325 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
326 | # let chunk size default to None
327 | self._chunk_size = None
328 | self._chunk_dim = 0
329 |
330 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
331 | # Sets chunk feed-forward
332 | self._chunk_size = chunk_size
333 | self._chunk_dim = dim
334 |
335 | def forward(
336 | self,
337 | hidden_states: torch.FloatTensor,
338 | attention_mask: Optional[torch.FloatTensor] = None,
339 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
340 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
341 | timestep: Optional[torch.LongTensor] = None,
342 | class_labels: Optional[torch.LongTensor] = None,
343 | ):
344 | # Notice that normalization is always applied before the real computation in the following blocks.
345 | # 0. Self-Attention
346 | cross_attention_kwargs = None
347 | if self.use_ada_layer_norm:
348 | norm_hidden_states = self.norm1(hidden_states, timestep)
349 | elif self.use_ada_layer_norm_zero:
350 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
351 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
352 | )
353 | else:
354 | norm_hidden_states = self.norm1(hidden_states)
355 |
356 | # 1. Retrieve lora scale.
357 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
358 |
359 | # 2. Prepare GLIGEN inputs
360 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
361 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
362 |
363 | attn_output = self.attn1(
364 | norm_hidden_states,
365 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
366 | attention_mask=attention_mask,
367 | **cross_attention_kwargs,
368 | )
369 | if self.use_ada_layer_norm_zero:
370 | attn_output = gate_msa.unsqueeze(1) * attn_output
371 | hidden_states = attn_output + hidden_states
372 |
373 | # 2.5 GLIGEN Control
374 | if gligen_kwargs is not None:
375 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
376 | # 2.5 ends
377 |
378 | # 3. Cross-Attention
379 | if self.attn2 is not None:
380 | norm_hidden_states = (
381 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
382 | )
383 |
384 | attn_output = self.attn2(
385 | norm_hidden_states,
386 | encoder_hidden_states=encoder_hidden_states,
387 | attention_mask=encoder_attention_mask,
388 | **cross_attention_kwargs,
389 | )
390 | hidden_states = attn_output + hidden_states
391 |
392 | # 4. Feed-forward
393 | norm_hidden_states = self.norm3(hidden_states)
394 |
395 | if self.use_ada_layer_norm_zero:
396 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
397 |
398 | if self._chunk_size is not None:
399 | # "feed_forward_chunk_size" can be used to save memory
400 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
401 | raise ValueError(
402 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
403 | )
404 |
405 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
406 | ff_output = torch.cat(
407 | [
408 | self.ff(hid_slice, scale=lora_scale)
409 | for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
410 | ],
411 | dim=self._chunk_dim,
412 | )
413 | else:
414 | ff_output = self.ff(norm_hidden_states, scale=lora_scale)
415 |
416 | if self.use_ada_layer_norm_zero:
417 | ff_output = gate_mlp.unsqueeze(1) * ff_output
418 |
419 | hidden_states = ff_output + hidden_states
420 |
421 | return hidden_states
422 |
--------------------------------------------------------------------------------
/animatediff/models/motion_module.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from torch import nn
8 | import torchvision
9 |
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config
11 | from diffusers import ModelMixin
12 | # from diffusers.modeling_utils import ModelMixin
13 | from diffusers.utils import BaseOutput
14 | from diffusers.utils.import_utils import is_xformers_available
15 | from diffusers.models.attention_processor import Attention
16 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
17 |
18 |
19 | from einops import rearrange, repeat
20 | import math
21 |
22 |
23 | def zero_module(module):
24 | # Zero out the parameters of a module and return it.
25 | for p in module.parameters():
26 | p.detach().zero_()
27 | return module
28 |
29 |
30 | @dataclass
31 | class TemporalTransformer3DModelOutput(BaseOutput):
32 | sample: torch.FloatTensor
33 |
34 |
35 | if is_xformers_available():
36 | import xformers
37 | import xformers.ops
38 | else:
39 | xformers = None
40 |
41 |
42 | def get_motion_module(
43 | in_channels,
44 | motion_module_kwargs: dict,
45 | motion_module_type: str = "Vanilla",
46 |
47 | ):
48 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
49 |
50 |
51 | class VanillaTemporalModule(nn.Module):
52 | def __init__(
53 | self,
54 | in_channels,
55 | num_attention_heads = 8,
56 | num_transformer_block = 2,
57 | attention_block_types =( "Temporal_Self", "Temporal_Self" ),
58 | cross_frame_attention_mode = None,
59 | temporal_position_encoding = False,
60 | temporal_position_encoding_max_len = 24,
61 | temporal_attention_dim_div = 1,
62 | zero_initialize = True,
63 | ):
64 | super().__init__()
65 |
66 | self.temporal_transformer = TemporalTransformer3DModel(
67 | in_channels=in_channels,
68 | num_attention_heads=num_attention_heads,
69 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
70 | num_layers=num_transformer_block,
71 | attention_block_types=attention_block_types,
72 | cross_frame_attention_mode=cross_frame_attention_mode,
73 | temporal_position_encoding=temporal_position_encoding,
74 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
75 | )
76 |
77 | if zero_initialize:
78 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
79 |
80 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
81 | hidden_states = input_tensor
82 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
83 |
84 | output = hidden_states
85 | return output
86 |
87 |
88 | class TemporalTransformer3DModel(nn.Module):
89 | def __init__(
90 | self,
91 | in_channels,
92 | num_attention_heads,
93 | attention_head_dim,
94 |
95 | num_layers,
96 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
97 | dropout = 0.0,
98 | norm_num_groups = 32,
99 | cross_attention_dim = 768,
100 | activation_fn = "geglu",
101 | attention_bias = False,
102 | upcast_attention = False,
103 |
104 | cross_frame_attention_mode = None,
105 | temporal_position_encoding = False,
106 | temporal_position_encoding_max_len = 24,
107 | ):
108 | super().__init__()
109 |
110 | inner_dim = num_attention_heads * attention_head_dim
111 |
112 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
113 | self.proj_in = nn.Linear(in_channels, inner_dim)
114 |
115 | self.transformer_blocks = nn.ModuleList(
116 | [
117 | TemporalTransformerBlock(
118 | dim=inner_dim,
119 | num_attention_heads=num_attention_heads,
120 | attention_head_dim=attention_head_dim,
121 | attention_block_types=attention_block_types,
122 | dropout=dropout,
123 | norm_num_groups=norm_num_groups,
124 | cross_attention_dim=cross_attention_dim,
125 | activation_fn=activation_fn,
126 | attention_bias=attention_bias,
127 | upcast_attention=upcast_attention,
128 | cross_frame_attention_mode=cross_frame_attention_mode,
129 | temporal_position_encoding=temporal_position_encoding,
130 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
131 | )
132 | for d in range(num_layers)
133 | ]
134 | )
135 | self.proj_out = nn.Linear(inner_dim, in_channels)
136 |
137 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
138 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
139 | video_length = hidden_states.shape[2]
140 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
141 |
142 | batch, channel, height, weight = hidden_states.shape
143 | residual = hidden_states
144 |
145 | hidden_states = self.norm(hidden_states)
146 | inner_dim = hidden_states.shape[1]
147 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
148 | hidden_states = self.proj_in(hidden_states)
149 |
150 | # Transformer Blocks
151 | for block in self.transformer_blocks:
152 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
153 |
154 | # output
155 | hidden_states = self.proj_out(hidden_states)
156 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
157 |
158 | output = hidden_states + residual
159 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
160 |
161 | return output
162 |
163 |
164 | class FeedForward(nn.Module):
165 |
166 | def __init__(
167 | self,
168 | dim: int,
169 | dim_out: Optional[int] = None,
170 | mult: int = 4,
171 | dropout: float = 0.0,
172 | activation_fn: str = "geglu",
173 | final_dropout: bool = False,
174 | ):
175 | super().__init__()
176 | inner_dim = int(dim * mult)
177 | dim_out = dim_out if dim_out is not None else dim
178 |
179 | if activation_fn == "gelu":
180 | act_fn = GELU(dim, inner_dim)
181 | if activation_fn == "gelu-approximate":
182 | act_fn = GELU(dim, inner_dim, approximate="tanh")
183 | elif activation_fn == "geglu":
184 | act_fn = GEGLU(dim, inner_dim)
185 | elif activation_fn == "geglu-approximate":
186 | act_fn = ApproximateGELU(dim, inner_dim)
187 |
188 | self.net = nn.ModuleList([])
189 | # project in
190 | self.net.append(act_fn)
191 | # project dropout
192 | self.net.append(nn.Dropout(dropout))
193 | # project out
194 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
195 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
196 | if final_dropout:
197 | self.net.append(nn.Dropout(dropout))
198 |
199 | def forward(self, hidden_states, scale: float = 1.0):
200 | for module in self.net:
201 | if isinstance(module, (LoRACompatibleLinear, GEGLU)):
202 | hidden_states = module(hidden_states, scale)
203 | else:
204 | hidden_states = module(hidden_states)
205 | return hidden_states
206 |
207 |
208 | class TemporalTransformerBlock(nn.Module):
209 | def __init__(
210 | self,
211 | dim,
212 | num_attention_heads,
213 | attention_head_dim,
214 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
215 | dropout = 0.0,
216 | norm_num_groups = 32,
217 | cross_attention_dim = 768,
218 | activation_fn = "geglu",
219 | attention_bias = False,
220 | upcast_attention = False,
221 | cross_frame_attention_mode = None,
222 | temporal_position_encoding = False,
223 | temporal_position_encoding_max_len = 24,
224 | ):
225 | super().__init__()
226 |
227 | attention_blocks = []
228 | norms = []
229 |
230 | for block_name in attention_block_types:
231 | attention_blocks.append(
232 | VersatileAttention(
233 | attention_mode=block_name.split("_")[0],
234 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
235 |
236 | query_dim=dim,
237 | heads=num_attention_heads,
238 | dim_head=attention_head_dim,
239 | dropout=dropout,
240 | bias=attention_bias,
241 | upcast_attention=upcast_attention,
242 |
243 | cross_frame_attention_mode=cross_frame_attention_mode,
244 | temporal_position_encoding=temporal_position_encoding,
245 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
246 | )
247 | )
248 | norms.append(nn.LayerNorm(dim))
249 |
250 | self.attention_blocks = nn.ModuleList(attention_blocks)
251 | self.norms = nn.ModuleList(norms)
252 |
253 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
254 | self.ff_norm = nn.LayerNorm(dim)
255 |
256 |
257 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
258 | for attention_block, norm in zip(self.attention_blocks, self.norms):
259 | norm_hidden_states = norm(hidden_states)
260 | hidden_states = attention_block(
261 | norm_hidden_states,
262 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
263 | video_length=video_length,
264 | ) + hidden_states
265 |
266 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
267 |
268 | output = hidden_states
269 | return output
270 |
271 |
272 | class PositionalEncoding(nn.Module):
273 | def __init__(
274 | self,
275 | d_model,
276 | dropout = 0.,
277 | max_len = 24
278 | ):
279 | super().__init__()
280 | self.dropout = nn.Dropout(p=dropout)
281 | position = torch.arange(max_len).unsqueeze(1)
282 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
283 | pe = torch.zeros(1, max_len, d_model)
284 | pe[0, :, 0::2] = torch.sin(position * div_term)
285 | pe[0, :, 1::2] = torch.cos(position * div_term)
286 | self.register_buffer('pe', pe)
287 |
288 | def forward(self, x):
289 | x = x + self.pe[:, :x.size(1)]
290 | return self.dropout(x)
291 |
292 |
293 | class VersatileAttention(Attention):
294 | def __init__(
295 | self,
296 | attention_mode = None,
297 | cross_frame_attention_mode = None,
298 | temporal_position_encoding = False,
299 | temporal_position_encoding_max_len = 24,
300 | *args, **kwargs
301 | ):
302 | super().__init__(*args, **kwargs)
303 | assert attention_mode == "Temporal"
304 |
305 | self.attention_mode = attention_mode
306 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
307 |
308 | self.pos_encoder = PositionalEncoding(
309 | kwargs["query_dim"],
310 | dropout=0.,
311 | max_len=temporal_position_encoding_max_len
312 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None
313 |
314 | def extra_repr(self):
315 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
316 |
317 | def _attention(self, query, key, value, attention_mask=None):
318 | # if self.upcast_attention:
319 | # query = query.float()
320 | # key = key.float()
321 |
322 | attention_scores = torch.baddbmm(
323 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
324 | query,
325 | key.transpose(-1, -2),
326 | beta=0,
327 | alpha=self.scale,
328 | )
329 |
330 | if attention_mask is not None:
331 | attention_scores = attention_scores + attention_mask
332 |
333 | # if self.upcast_softmax:
334 | # attention_scores = attention_scores.float()
335 |
336 | attention_probs = attention_scores.softmax(dim=-1)
337 |
338 | # cast back to the original dtype
339 | attention_probs = attention_probs.to(value.dtype)
340 |
341 | # compute attention output
342 | hidden_states = torch.bmm(attention_probs, value)
343 |
344 | # reshape hidden_states
345 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
346 | return hidden_states
347 |
348 | def reshape_batch_dim_to_heads(self, tensor):
349 | batch_size, seq_len, dim = tensor.shape
350 | head_size = self.heads
351 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
352 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
353 | return tensor
354 |
355 | def reshape_heads_to_batch_dim(self, tensor):
356 | batch_size, seq_len, dim = tensor.shape # 4096 16 320
357 | head_size = self.heads
358 | # head_size = 8
359 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
360 | # [4096, 16, 8, 40]
361 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
362 | return tensor # [32768, 16, 40]
363 |
364 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
365 | batch_size, sequence_length, _ = hidden_states.shape
366 |
367 | if self.attention_mode == "Temporal":
368 | d = hidden_states.shape[1]
369 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
370 |
371 | if self.pos_encoder is not None:
372 | hidden_states = self.pos_encoder(hidden_states)
373 |
374 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
375 | else:
376 | raise NotImplementedError
377 |
378 | encoder_hidden_states = encoder_hidden_states
379 |
380 | if self.group_norm is not None:
381 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
382 |
383 | query = self.to_q(hidden_states)
384 | dim = query.shape[-1] # [4096, 16, 320]
385 | query = self.reshape_heads_to_batch_dim(query)
386 |
387 | if self.added_kv_proj_dim is not None:
388 | raise NotImplementedError
389 |
390 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
391 | key = self.to_k(encoder_hidden_states)
392 | value = self.to_v(encoder_hidden_states)
393 |
394 | key = self.reshape_heads_to_batch_dim(key)
395 | value = self.reshape_heads_to_batch_dim(value)
396 |
397 | if attention_mask is not None:
398 | if attention_mask.shape[-1] != query.shape[1]:
399 | target_length = query.shape[1]
400 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
401 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
402 |
403 | # attention, what we cannot get enough of
404 | # if self.set_use_memory_efficient_attention_xformers:
405 | # hidden_states = self.set_use_memory_efficient_attention_xformers(query, key, value)
406 | # # self.set_use_memory_efficient_attention_xformers()
407 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input
408 | # hidden_states = hidden_states.to(query.dtype)
409 | # else:
410 | # if self._slice_size is None or query.shape[0] // self._slice_size == 1:
411 | hidden_states = self._attention(query, key, value, attention_mask)
412 | # else:
413 | # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
414 |
415 | # linear proj
416 | hidden_states = self.to_out[0](hidden_states)
417 |
418 | # dropout
419 | hidden_states = self.to_out[1](hidden_states)
420 |
421 | if self.attention_mode == "Temporal":
422 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
423 |
424 | return hidden_states
425 |
426 | class GEGLU(nn.Module):
427 |
428 | def __init__(self, dim_in: int, dim_out: int):
429 | super().__init__()
430 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
431 |
432 | def gelu(self, gate):
433 | if gate.device.type != "mps":
434 | return F.gelu(gate)
435 | # mps: gelu is not implemented for float16
436 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
437 |
438 | def forward(self, hidden_states, scale: float = 1.0):
439 | hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
440 | return hidden_states * self.gelu(gate)
--------------------------------------------------------------------------------
/animatediff/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from functools import partial
17 | from typing import Optional
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | from einops import rearrange
23 |
24 | from diffusers.models.activations import get_activation
25 | from diffusers.models.attention import AdaGroupNorm
26 | from diffusers.models.attention_processor import SpatialNorm
27 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28 |
29 |
30 | class InflatedConv3d(nn.Conv2d):
31 | def forward(self, x):
32 | video_length = x.shape[2]
33 |
34 | x = rearrange(x, "b c f h w -> (b f) c h w")
35 | x = super().forward(x)
36 | x = rearrange(x, "(b f) c h w -> b c f h w", f = video_length)
37 |
38 | return x
39 |
40 | class InflatedGroupNorm(nn.GroupNorm):
41 | def froward(self, x):
42 | video_length = x.shape[2]
43 |
44 | x = rearrange(x, "b c f h w -> (b f) c h w")
45 | x = super().forward(x)
46 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
47 |
48 | return x
49 |
50 |
51 | class Upsample3D(nn.Module):
52 | """A 2D upsampling layer with an optional convolution.
53 |
54 | Parameters:
55 | channels (`int`):
56 | number of channels in the inputs and outputs.
57 | use_conv (`bool`, default `False`):
58 | option to use a convolution.
59 | use_conv_transpose (`bool`, default `False`):
60 | option to use a convolution transpose.
61 | out_channels (`int`, optional):
62 | number of output channels. Defaults to `channels`.
63 | """
64 |
65 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
66 | super().__init__()
67 | self.channels = channels
68 | self.out_channels = out_channels or channels
69 | self.use_conv = use_conv
70 | self.use_conv_transpose = use_conv_transpose
71 | self.name = name
72 |
73 | conv = None
74 | if use_conv_transpose:
75 | conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
76 | elif use_conv:
77 | # conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
78 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
79 |
80 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
81 | if name == "conv":
82 | self.conv = conv
83 | else:
84 | self.Conv2d_0 = conv
85 |
86 | def forward(self, hidden_states, output_size=None, scale: float = 1.0):
87 | assert hidden_states.shape[1] == self.channels
88 |
89 | if self.use_conv_transpose:
90 | return self.conv(hidden_states)
91 |
92 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
93 | # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
94 | # https://github.com/pytorch/pytorch/issues/86679
95 | dtype = hidden_states.dtype
96 | if dtype == torch.bfloat16:
97 | hidden_states = hidden_states.to(torch.float32)
98 |
99 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
100 | if hidden_states.shape[0] >= 64:
101 | hidden_states = hidden_states.contiguous()
102 |
103 | # if `output_size` is passed we force the interpolation output
104 | # size and do not make use of `scale_factor=2`
105 | if output_size is None:
106 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
107 | else:
108 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
109 |
110 | # If the input is bfloat16, we cast back to bfloat16
111 | if dtype == torch.bfloat16:
112 | hidden_states = hidden_states.to(dtype)
113 |
114 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
115 | if self.use_conv:
116 | if self.name == "conv":
117 | if isinstance(self.conv, LoRACompatibleConv):
118 | hidden_states = self.conv(hidden_states, scale)
119 | else:
120 | hidden_states = self.conv(hidden_states)
121 | else:
122 | if isinstance(self.Conv2d_0, LoRACompatibleConv):
123 | hidden_states = self.Conv2d_0(hidden_states, scale)
124 | else:
125 | hidden_states = self.Conv2d_0(hidden_states)
126 |
127 | return hidden_states
128 |
129 |
130 | class Downsample3D(nn.Module):
131 | """A 2D downsampling layer with an optional convolution.
132 |
133 | Parameters:
134 | channels (`int`):
135 | number of channels in the inputs and outputs.
136 | use_conv (`bool`, default `False`):
137 | option to use a convolution.
138 | out_channels (`int`, optional):
139 | number of output channels. Defaults to `channels`.
140 | padding (`int`, default `1`):
141 | padding for the convolution.
142 | """
143 |
144 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
145 | super().__init__()
146 | self.channels = channels
147 | self.out_channels = out_channels or channels
148 | self.use_conv = use_conv
149 | self.padding = padding
150 | stride = 2
151 | self.name = name
152 |
153 | if use_conv:
154 | # conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
155 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
156 |
157 | else:
158 | assert self.channels == self.out_channels
159 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
160 |
161 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
162 | if name == "conv":
163 | self.Conv2d_0 = conv
164 | self.conv = conv
165 | elif name == "Conv2d_0":
166 | self.conv = conv
167 | else:
168 | self.conv = conv
169 |
170 | def forward(self, hidden_states, scale: float = 1.0):
171 | assert hidden_states.shape[1] == self.channels
172 | if self.use_conv and self.padding == 0:
173 | pad = (0, 1, 0, 1)
174 | hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
175 |
176 | assert hidden_states.shape[1] == self.channels
177 | if isinstance(self.conv, LoRACompatibleConv):
178 | hidden_states = self.conv(hidden_states, scale)
179 | else:
180 | hidden_states = self.conv(hidden_states)
181 |
182 | return hidden_states
183 |
184 |
185 | class ResnetBlock3D(nn.Module):
186 |
187 | def __init__(
188 | self,
189 | *,
190 | in_channels,
191 | out_channels=None,
192 | conv_shortcut=False,
193 | dropout=0.0,
194 | temb_channels=512,
195 | groups=32,
196 | groups_out=None,
197 | pre_norm=True,
198 | eps=1e-6,
199 | non_linearity="swish",
200 | skip_time_act=False,
201 | time_embedding_norm="default", # default, scale_shift, ada_group, spatial
202 | kernel=None,
203 | output_scale_factor=1.0,
204 | use_in_shortcut=None,
205 | up=False,
206 | down=False,
207 | conv_shortcut_bias: bool = True,
208 | conv_2d_out_channels: Optional[int] = None,
209 | use_inflated_groupnorm: bool = True,
210 | ):
211 | super().__init__()
212 | self.pre_norm = pre_norm
213 | self.pre_norm = True
214 | self.in_channels = in_channels
215 | out_channels = in_channels if out_channels is None else out_channels
216 | self.out_channels = out_channels
217 | self.use_conv_shortcut = conv_shortcut
218 | self.up = up
219 | self.down = down
220 | self.output_scale_factor = output_scale_factor
221 | self.time_embedding_norm = time_embedding_norm
222 | self.skip_time_act = skip_time_act
223 |
224 | if groups_out is None:
225 | groups_out = groups
226 |
227 | if self.time_embedding_norm == "ada_group":
228 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
229 | elif self.time_embedding_norm == "spatial":
230 | self.norm1 = SpatialNorm(in_channels, temb_channels)
231 | else:
232 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
233 |
234 | # self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
235 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
236 |
237 | if temb_channels is not None:
238 | if self.time_embedding_norm == "default":
239 | # self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
240 | self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
241 | elif self.time_embedding_norm == "scale_shift":
242 | self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
243 | elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
244 | self.time_emb_proj = None
245 | else:
246 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
247 | else:
248 | self.time_emb_proj = None
249 |
250 | if self.time_embedding_norm == "ada_group":
251 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
252 | elif self.time_embedding_norm == "spatial":
253 | self.norm2 = SpatialNorm(out_channels, temb_channels)
254 | else:
255 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
256 |
257 | self.dropout = torch.nn.Dropout(dropout)
258 | conv_2d_out_channels = conv_2d_out_channels or out_channels
259 | # self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
260 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
261 |
262 | self.nonlinearity = get_activation(non_linearity)
263 |
264 | self.upsample = self.downsample = None
265 | if self.up:
266 | if kernel == "fir":
267 | fir_kernel = (1, 3, 3, 1)
268 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
269 | elif kernel == "sde_vp":
270 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
271 | else:
272 | self.upsample = Upsample2D(in_channels, use_conv=False)
273 | elif self.down:
274 | if kernel == "fir":
275 | fir_kernel = (1, 3, 3, 1)
276 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
277 | elif kernel == "sde_vp":
278 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
279 | else:
280 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
281 |
282 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
283 |
284 | self.conv_shortcut = None
285 | if self.use_in_shortcut:
286 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
287 |
288 |
289 |
290 | def forward(self, input_tensor, temb, scale: float = 1.0):
291 | hidden_states = input_tensor
292 |
293 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
294 | hidden_states = self.norm1(hidden_states, temb)
295 | else:
296 | hidden_states = self.norm1(hidden_states)
297 |
298 | hidden_states = self.nonlinearity(hidden_states)
299 |
300 | # if self.upsample is not None:
301 | # # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
302 | # if hidden_states.shape[0] >= 64:
303 | # input_tensor = input_tensor.contiguous()
304 | # hidden_states = hidden_states.contiguous()
305 | # input_tensor = (
306 | # self.upsample(input_tensor, scale=scale)
307 | # if isinstance(self.upsample, Upsample2D)
308 | # else self.upsample(input_tensor)
309 | # )
310 | # hidden_states = (
311 | # self.upsample(hidden_states, scale=scale)
312 | # if isinstance(self.upsample, Upsample2D)
313 | # else self.upsample(hidden_states)
314 | # )
315 | # elif self.downsample is not None:
316 | # input_tensor = (
317 | # self.downsample(input_tensor, scale=scale)
318 | # if isinstance(self.downsample, Downsample2D)
319 | # else self.downsample(input_tensor)
320 | # )
321 | # hidden_states = (
322 | # self.downsample(hidden_states, scale=scale)
323 | # if isinstance(self.downsample, Downsample2D)
324 | # else self.downsample(hidden_states)
325 | # )
326 |
327 | hidden_states = self.conv1(hidden_states)
328 |
329 | if self.time_emb_proj is not None:
330 | if not self.skip_time_act:
331 | temb = self.nonlinearity(temb)
332 | # temb = self.time_emb_proj(temb, scale)[:, :, None, None, None]
333 | temb = self.time_emb_proj(temb)[:, :, None, None, None]
334 |
335 |
336 | if temb is not None and self.time_embedding_norm == "default":
337 | hidden_states = hidden_states + temb
338 |
339 | if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
340 | hidden_states = self.norm2(hidden_states, temb)
341 | else:
342 | hidden_states = self.norm2(hidden_states)
343 |
344 | if temb is not None and self.time_embedding_norm == "scale_shift":
345 | scale, shift = torch.chunk(temb, 2, dim=1)
346 | hidden_states = hidden_states * (1 + scale) + shift
347 |
348 | hidden_states = self.nonlinearity(hidden_states)
349 |
350 | hidden_states = self.dropout(hidden_states)
351 | hidden_states = self.conv2(hidden_states)
352 |
353 | if self.conv_shortcut is not None:
354 | input_tensor = self.conv_shortcut(input_tensor)
355 |
356 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
357 |
358 | return output_tensor
359 |
360 |
361 |
--------------------------------------------------------------------------------
/animatediff/models/unet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/root/autodl-tmp/code/animatediff/modelshigh")
3 | from dataclasses import dataclass
4 | import os
5 | from typing import Any, Dict, List, Optional, Tuple, Union
6 | import json
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.checkpoint
11 |
12 | from diffusers.configuration_utils import ConfigMixin, register_to_config
13 | from diffusers.loaders import UNet2DConditionLoadersMixin
14 | from diffusers.utils import BaseOutput, logging
15 | from diffusers.models.attention_processor import (
16 | ADDED_KV_ATTENTION_PROCESSORS,
17 | CROSS_ATTENTION_PROCESSORS,
18 | AttentionProcessor,
19 | AttnAddedKVProcessor,
20 | AttnProcessor,
21 | )
22 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
23 | from diffusers.models.modeling_utils import ModelMixin
24 | # from diffusers.models.transformer_temporal import TransformerTemporalModel
25 | from animatediff.models.unet_blocks import (
26 | CrossAttnDownBlock3D,
27 | CrossAttnUpBlock3D,
28 | DownBlock3D,
29 | UNetMidBlock3DCrossAttn,
30 | UpBlock3D,
31 | get_down_block,
32 | get_up_block,
33 | )
34 |
35 | from .resnet import InflatedConv3d, InflatedGroupNorm
36 |
37 |
38 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39 |
40 |
41 | @dataclass
42 | class UNet3DConditionOutput(BaseOutput):
43 | """
44 | The output of [`UNet3DConditionModel`].
45 |
46 | Args:
47 | sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
48 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
49 | """
50 |
51 | sample: torch.FloatTensor
52 |
53 |
54 | class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55 |
56 | _supports_gradient_checkpointing = False
57 |
58 | @register_to_config
59 | def __init__(
60 | self,
61 | sample_size: Optional[int] = None,
62 | in_channels: int = 4,
63 | out_channels: int = 4,
64 | down_block_types: Tuple[str] = (
65 | "CrossAttnDownBlock3D",
66 | "CrossAttnDownBlock3D",
67 | "CrossAttnDownBlock3D",
68 | "DownBlock3D",
69 | ),
70 | #-----
71 | mid_block_type: str = "UnetMidBlock3DCrossAttn",
72 | #-----
73 | up_block_types: Tuple[str] = (
74 | "UpBlock3D",
75 | "CrossAttnUpBlock3D",
76 | "CrossAttnUpBlock3D",
77 | "CrossAttnUpBlock3D"
78 | ),
79 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
80 | layers_per_block: int = 2,
81 | downsample_padding: int = 1,
82 | mid_block_scale_factor: float = 1,
83 | act_fn: str = "silu",
84 | norm_num_groups: Optional[int] = 32,
85 | norm_eps: float = 1e-5,
86 | # cross_attention_dim: int = 1024,
87 | cross_attention_dim: int = 1280,
88 | # attention_head_dim: Union[int, Tuple[int]] = 64,
89 | attention_head_dim: Union[int, Tuple[int]] = 8,
90 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
91 |
92 | use_inflated_groupnorm=False,
93 | # Additional
94 | use_motion_module = False,
95 | motion_module_resolutions = ( 1,2,4,8 ),
96 | motion_module_mid_block = False,
97 | motion_module_decoder_only = False,
98 | motion_module_type = None,
99 | motion_module_kwargs = {},
100 | unet_use_cross_frame_attention = None,
101 | unet_use_temporal_attention = None,
102 | ):
103 | super().__init__()
104 |
105 | self.sample_size = sample_size
106 | # time_embed_dim = block_out_channels[0] * 4
107 |
108 | if num_attention_heads is not None:
109 | raise NotImplementedError(
110 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
111 | )
112 |
113 | # If `num_attention_heads` is not defined (which is the case for most models)
114 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
115 | # The reason for this behavior is to correct for incorrectly named variables that were introduced
116 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
117 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
118 | # which is why we correct for the naming here.
119 | num_attention_heads = num_attention_heads or attention_head_dim
120 |
121 | # Check inputs
122 | if len(down_block_types) != len(up_block_types):
123 | raise ValueError(
124 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
125 | )
126 |
127 | if len(block_out_channels) != len(down_block_types):
128 | raise ValueError(
129 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
130 | )
131 |
132 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
133 | raise ValueError(
134 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
135 | )
136 |
137 | # input
138 | conv_in_kernel = 3
139 | conv_out_kernel = 3
140 | conv_in_padding = (conv_in_kernel - 1) // 2
141 | # self.conv_in = nn.Conv2d(
142 | # in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
143 | # )
144 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
145 |
146 | # time
147 | time_embed_dim = block_out_channels[0] * 4
148 | self.time_proj = Timesteps(block_out_channels[0], True, 0)
149 | timestep_input_dim = block_out_channels[0]
150 |
151 | self.time_embedding = TimestepEmbedding(
152 | timestep_input_dim,
153 | time_embed_dim,
154 | act_fn=act_fn,
155 | )
156 |
157 | # self.transformer_in = TransformerTemporalModel(
158 | # num_attention_heads=8,
159 | # attention_head_dim=attention_head_dim,
160 | # in_channels=block_out_channels[0],
161 | # num_layers=1,
162 | # )
163 |
164 | # class embedding
165 |
166 |
167 | self.down_blocks = nn.ModuleList([])
168 | self.mid_block = None
169 | self.up_blocks = nn.ModuleList([])
170 |
171 | if isinstance(num_attention_heads, int):
172 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
173 |
174 | # down
175 | output_channel = block_out_channels[0]
176 | for i, down_block_type in enumerate(down_block_types):
177 | res = 2 ** i
178 | input_channel = output_channel
179 | output_channel = block_out_channels[i]
180 | is_final_block = i == len(block_out_channels) - 1
181 |
182 | down_block = get_down_block(
183 | down_block_type,
184 | num_layers=layers_per_block,
185 | in_channels=input_channel,
186 | out_channels=output_channel,
187 | temb_channels=time_embed_dim,
188 | add_downsample=not is_final_block,
189 | resnet_eps=norm_eps,
190 | resnet_act_fn=act_fn,
191 | resnet_groups=norm_num_groups,
192 | cross_attention_dim=cross_attention_dim,
193 | num_attention_heads=num_attention_heads[i],
194 | downsample_padding=downsample_padding,
195 |
196 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
197 | unet_use_temporal_attention=unet_use_temporal_attention,
198 | use_inflated_groupnorm=use_inflated_groupnorm,
199 |
200 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
201 | motion_module_type=motion_module_type,
202 | motion_module_kwargs=motion_module_kwargs,
203 | )
204 | self.down_blocks.append(down_block)
205 |
206 | # mid
207 | self.mid_block = UNetMidBlock3DCrossAttn(
208 | in_channels=block_out_channels[-1],
209 | temb_channels=time_embed_dim,
210 | resnet_eps=norm_eps,
211 | resnet_act_fn=act_fn,
212 | output_scale_factor=mid_block_scale_factor,
213 | cross_attention_dim=cross_attention_dim,
214 | num_attention_heads=num_attention_heads[-1],
215 | resnet_groups=norm_num_groups,
216 |
217 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
218 | unet_use_temporal_attention=unet_use_temporal_attention,
219 | use_inflated_groupnorm=use_inflated_groupnorm,
220 |
221 | use_motion_module=use_motion_module and motion_module_mid_block,
222 | motion_module_type=motion_module_type,
223 | motion_module_kwargs=motion_module_kwargs,
224 | )
225 |
226 | # count how many layers upsample the images
227 | self.num_upsamplers = 0
228 |
229 | # up
230 | reversed_block_out_channels = list(reversed(block_out_channels))
231 | reversed_num_attention_heads = list(reversed(num_attention_heads))
232 |
233 | output_channel = reversed_block_out_channels[0]
234 | for i, up_block_type in enumerate(up_block_types):
235 | res = 2 ** (3 - i)
236 | is_final_block = i == len(block_out_channels) - 1
237 |
238 | prev_output_channel = output_channel
239 | output_channel = reversed_block_out_channels[i]
240 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
241 |
242 | # add upsample block for all BUT final layer
243 | if not is_final_block:
244 | add_upsample = True
245 | self.num_upsamplers += 1
246 | else:
247 | add_upsample = False
248 |
249 | up_block = get_up_block(
250 | up_block_type,
251 | num_layers=layers_per_block + 1,
252 | in_channels=input_channel,
253 | out_channels=output_channel,
254 | prev_output_channel=prev_output_channel,
255 | temb_channels=time_embed_dim,
256 | add_upsample=add_upsample,
257 | resnet_eps=norm_eps,
258 | resnet_act_fn=act_fn,
259 | resnet_groups=norm_num_groups,
260 | cross_attention_dim=cross_attention_dim,
261 | num_attention_heads=reversed_num_attention_heads[i],
262 |
263 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
264 | unet_use_temporal_attention=unet_use_temporal_attention,
265 | use_inflated_groupnorm=use_inflated_groupnorm,
266 |
267 | use_motion_module=use_motion_module and (res in motion_module_resolutions),
268 | motion_module_type=motion_module_type,
269 | motion_module_kwargs=motion_module_kwargs,
270 | )
271 | self.up_blocks.append(up_block)
272 | prev_output_channel = output_channel
273 |
274 | # out
275 | if norm_num_groups is not None:
276 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
277 | else:
278 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
279 | self.conv_act = nn.SiLU()
280 |
281 | conv_out_padding = (conv_out_kernel - 1) // 2
282 |
283 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
284 |
285 |
286 | @property
287 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
288 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
289 | r"""
290 | Returns:
291 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
292 | indexed by its weight name.
293 | """
294 | # set recursively
295 | processors = {}
296 |
297 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
298 | if hasattr(module, "get_processor"):
299 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
300 |
301 | for sub_name, child in module.named_children():
302 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
303 |
304 | return processors
305 |
306 | for name, module in self.named_children():
307 | fn_recursive_add_processors(name, module, processors)
308 |
309 | return processors
310 |
311 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
312 | def set_attention_slice(self, slice_size):
313 | r"""
314 | Enable sliced attention computation.
315 |
316 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in
317 | several steps. This is useful for saving some memory in exchange for a small decrease in speed.
318 |
319 | Args:
320 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
321 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
322 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
323 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
324 | must be a multiple of `slice_size`.
325 | """
326 | sliceable_head_dims = []
327 |
328 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
329 | if hasattr(module, "set_attention_slice"):
330 | sliceable_head_dims.append(module.sliceable_head_dim)
331 |
332 | for child in module.children():
333 | fn_recursive_retrieve_sliceable_dims(child)
334 |
335 | # retrieve number of attention layers
336 | for module in self.children():
337 | fn_recursive_retrieve_sliceable_dims(module)
338 |
339 | num_sliceable_layers = len(sliceable_head_dims)
340 |
341 | if slice_size == "auto":
342 | # half the attention head size is usually a good trade-off between
343 | # speed and memory
344 | slice_size = [dim // 2 for dim in sliceable_head_dims]
345 | elif slice_size == "max":
346 | # make smallest slice possible
347 | slice_size = num_sliceable_layers * [1]
348 |
349 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
350 |
351 | if len(slice_size) != len(sliceable_head_dims):
352 | raise ValueError(
353 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
354 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
355 | )
356 |
357 | for i in range(len(slice_size)):
358 | size = slice_size[i]
359 | dim = sliceable_head_dims[i]
360 | if size is not None and size > dim:
361 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
362 |
363 | # Recursively walk through all the children.
364 | # Any children which exposes the set_attention_slice method
365 | # gets the message
366 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
367 | if hasattr(module, "set_attention_slice"):
368 | module.set_attention_slice(slice_size.pop())
369 |
370 | for child in module.children():
371 | fn_recursive_set_attention_slice(child, slice_size)
372 |
373 | reversed_slice_size = list(reversed(slice_size))
374 | for module in self.children():
375 | fn_recursive_set_attention_slice(module, reversed_slice_size)
376 |
377 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
378 | def set_attn_processor(
379 | self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
380 | ):
381 | r"""
382 | Sets the attention processor to use to compute attention.
383 |
384 | Parameters:
385 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
386 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
387 | for **all** `Attention` layers.
388 |
389 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
390 | processor. This is strongly recommended when setting trainable attention processors.
391 |
392 | """
393 | count = len(self.attn_processors.keys())
394 |
395 | if isinstance(processor, dict) and len(processor) != count:
396 | raise ValueError(
397 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
398 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
399 | )
400 |
401 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
402 | if hasattr(module, "set_processor"):
403 | if not isinstance(processor, dict):
404 | module.set_processor(processor, _remove_lora=_remove_lora)
405 | else:
406 | module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
407 |
408 | for sub_name, child in module.named_children():
409 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
410 |
411 | for name, module in self.named_children():
412 | fn_recursive_attn_processor(name, module, processor)
413 |
414 | def enable_forward_chunking(self, chunk_size=None, dim=0):
415 | """
416 | Sets the attention processor to use [feed forward
417 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
418 |
419 | Parameters:
420 | chunk_size (`int`, *optional*):
421 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
422 | over each tensor of dim=`dim`.
423 | dim (`int`, *optional*, defaults to `0`):
424 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
425 | or dim=1 (sequence length).
426 | """
427 | if dim not in [0, 1]:
428 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
429 |
430 | # By default chunk size is 1
431 | chunk_size = chunk_size or 1
432 |
433 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
434 | if hasattr(module, "set_chunk_feed_forward"):
435 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
436 |
437 | for child in module.children():
438 | fn_recursive_feed_forward(child, chunk_size, dim)
439 |
440 | for module in self.children():
441 | fn_recursive_feed_forward(module, chunk_size, dim)
442 |
443 | def disable_forward_chunking(self):
444 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
445 | if hasattr(module, "set_chunk_feed_forward"):
446 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
447 |
448 | for child in module.children():
449 | fn_recursive_feed_forward(child, chunk_size, dim)
450 |
451 | for module in self.children():
452 | fn_recursive_feed_forward(module, None, 0)
453 |
454 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
455 | def set_default_attn_processor(self):
456 | """
457 | Disables custom attention processors and sets the default attention implementation.
458 | """
459 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
460 | processor = AttnAddedKVProcessor()
461 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
462 | processor = AttnProcessor()
463 | else:
464 | raise ValueError(
465 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
466 | )
467 |
468 | self.set_attn_processor(processor, _remove_lora=True)
469 |
470 | def _set_gradient_checkpointing(self, module, value=False):
471 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
472 | module.gradient_checkpointing = value
473 |
474 | def forward(
475 | self,
476 | sample: torch.FloatTensor,
477 | timestep: Union[torch.Tensor, float, int],
478 | encoder_hidden_states: torch.Tensor,
479 | class_labels: Optional[torch.Tensor] = None,
480 | timestep_cond: Optional[torch.Tensor] = None,
481 | attention_mask: Optional[torch.Tensor] = None,
482 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
483 | mid_block_additional_residual: Optional[torch.Tensor] = None,
484 | return_dict: bool = True,
485 | ) -> Union[UNet3DConditionOutput, Tuple]:
486 | r"""
487 | The [`UNet3DConditionModel`] forward method.
488 |
489 | Args:
490 | sample (`torch.FloatTensor`):
491 | The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
492 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
493 | encoder_hidden_states (`torch.FloatTensor`):
494 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
495 | return_dict (`bool`, *optional*, defaults to `True`):
496 | Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
497 | tuple.
498 | cross_attention_kwargs (`dict`, *optional*):
499 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
500 |
501 | Returns:
502 | [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
503 | If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
504 | a `tuple` is returned where the first element is the sample tensor.
505 | """
506 | # By default samples have to be AT least a multiple of the overall upsampling factor.
507 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
508 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
509 | # on the fly if necessary.
510 | default_overall_up_factor = 2**self.num_upsamplers
511 |
512 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
513 | forward_upsample_size = False
514 | upsample_size = None
515 |
516 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
517 | logger.info("Forward upsample size to force interpolation output size.")
518 | forward_upsample_size = True
519 |
520 | # prepare attention_mask
521 | if attention_mask is not None:
522 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
523 | attention_mask = attention_mask.unsqueeze(1)
524 |
525 | # 1. time
526 | timesteps = timestep
527 | if not torch.is_tensor(timesteps):
528 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
529 | # This would be a good case for the `match` statement (Python 3.10+)
530 | is_mps = sample.device.type == "mps"
531 | if isinstance(timestep, float):
532 | dtype = torch.float32 if is_mps else torch.float64
533 | else:
534 | dtype = torch.int32 if is_mps else torch.int64
535 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
536 | elif len(timesteps.shape) == 0:
537 | timesteps = timesteps[None].to(sample.device)
538 |
539 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
540 | num_frames = sample.shape[2]
541 | timesteps = timesteps.expand(sample.shape[0])
542 |
543 | t_emb = self.time_proj(timesteps)
544 |
545 | # timesteps does not contain any weights and will always return f32 tensors
546 | # but time_embedding might actually be running in fp16. so we need to cast here.
547 | # there might be better ways to encapsulate this.
548 | t_emb = t_emb.to(dtype=self.dtype)
549 |
550 | emb = self.time_embedding(t_emb, timestep_cond)
551 | # emb = emb.repeat_interleave(repeats=num_frames, dim=0)
552 | # encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
553 |
554 | # 2. pre-process
555 | # sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
556 | sample = self.conv_in(sample)
557 |
558 | # sample = self.transformer_in(
559 | # sample,
560 | # num_frames=num_frames,
561 | # cross_attention_kwargs=cross_attention_kwargs,
562 | # return_dict=False,
563 | # )[0]
564 |
565 | # 3. down
566 | down_block_res_samples = (sample,)
567 | for downsample_block in self.down_blocks:
568 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
569 | sample, res_samples = downsample_block(
570 | hidden_states=sample,
571 | temb=emb,
572 | encoder_hidden_states=encoder_hidden_states,
573 | attention_mask=attention_mask,
574 | num_frames=num_frames,
575 | )
576 | else:
577 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
578 |
579 | down_block_res_samples += res_samples
580 |
581 | if down_block_additional_residuals is not None:
582 | new_down_block_res_samples = ()
583 |
584 | for down_block_res_sample, down_block_additional_residual in zip(
585 | down_block_res_samples, down_block_additional_residuals
586 | ):
587 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2).repeat(1,1,16,1,1)
588 | down_block_res_sample = down_block_res_sample + down_block_additional_residual * 1.0
589 | new_down_block_res_samples += (down_block_res_sample,)
590 |
591 | down_block_res_samples = new_down_block_res_samples
592 |
593 | # 4. mid
594 | if self.mid_block is not None:
595 | sample = self.mid_block(
596 | sample,
597 | emb,
598 | encoder_hidden_states=encoder_hidden_states,
599 | attention_mask=attention_mask,
600 | )
601 |
602 | if mid_block_additional_residual is not None:
603 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2).repeat(1,1,16,1,1)
604 | sample = sample + mid_block_additional_residual * 1.0
605 |
606 | # 5. up
607 | for i, upsample_block in enumerate(self.up_blocks):
608 | is_final_block = i == len(self.up_blocks) - 1
609 |
610 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
611 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
612 |
613 | # if we have not reached the final block and need to forward the
614 | # upsample size, we do it here
615 | if not is_final_block and forward_upsample_size:
616 | upsample_size = down_block_res_samples[-1].shape[2:]
617 |
618 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
619 | sample = upsample_block(
620 | hidden_states=sample,
621 | temb=emb,
622 | res_hidden_states_tuple=res_samples,
623 | encoder_hidden_states=encoder_hidden_states,
624 | upsample_size=upsample_size,
625 | attention_mask=attention_mask,
626 | )
627 | else:
628 | sample = upsample_block(
629 | hidden_states=sample,
630 | temb=emb,
631 | res_hidden_states_tuple=res_samples,
632 | upsample_size=upsample_size,
633 | )
634 |
635 | # 6. post-process
636 | if self.conv_norm_out:
637 | sample = self.conv_norm_out(sample)
638 | sample = self.conv_act(sample)
639 |
640 | sample = self.conv_out(sample)
641 |
642 | # reshape to (batch, channel, framerate, width, height)
643 | # sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4)
644 |
645 | if not return_dict:
646 | return (sample,)
647 |
648 | return UNet3DConditionOutput(sample=sample)
649 |
650 |
651 | @classmethod
652 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
653 | if subfolder is not None:
654 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
655 | print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
656 |
657 | config_file = os.path.join(pretrained_model_path, 'config.json')
658 | if not os.path.isfile(config_file):
659 | raise RuntimeError(f"{config_file} does not exist")
660 | with open(config_file, "r") as f:
661 | config = json.load(f)
662 | config["_class_name"] = cls.__name__
663 | config["down_block_types"] = [
664 | "CrossAttnDownBlock3D",
665 | "CrossAttnDownBlock3D",
666 | "CrossAttnDownBlock3D",
667 | "DownBlock3D"
668 | ]
669 | config["up_block_types"] = [
670 | "UpBlock3D",
671 | "CrossAttnUpBlock3D",
672 | "CrossAttnUpBlock3D",
673 | "CrossAttnUpBlock3D"
674 | ]
675 |
676 | from diffusers.utils import WEIGHTS_NAME
677 | model = cls.from_config(config, **unet_additional_kwargs)
678 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
679 | if not os.path.isfile(model_file):
680 | raise RuntimeError(f"{model_file} does not exist")
681 | state_dict = torch.load(model_file, map_location="cpu")
682 |
683 | m, u = model.load_state_dict(state_dict, strict=False)
684 | # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
685 | # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
686 |
687 | params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
688 | print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
689 |
690 | return model
691 |
692 |
--------------------------------------------------------------------------------
/animatediff/models/unet_blocks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | from torch import nn
17 |
18 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
19 | from .attention import Transformer3DModel
20 | from .motion_module import get_motion_module
21 |
22 |
23 | def get_down_block(
24 | down_block_type,
25 | num_layers,
26 | in_channels,
27 | out_channels,
28 | temb_channels,
29 | add_downsample,
30 | resnet_eps,
31 | resnet_act_fn,
32 | num_attention_heads,
33 | resnet_groups=None,
34 | cross_attention_dim=None,
35 | downsample_padding=None,
36 | dual_cross_attention=False,
37 | use_linear_projection=True,
38 | only_cross_attention=False,
39 | upcast_attention=False,
40 | resnet_time_scale_shift="default",
41 |
42 | unet_use_cross_frame_attention=None,
43 | unet_use_temporal_attention=None,
44 | use_inflated_groupnorm=None,
45 |
46 | use_motion_module=None,
47 | motion_module_type=None,
48 | motion_module_kwargs=None,
49 | ):
50 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
51 |
52 | if down_block_type == "DownBlock3D":
53 | return DownBlock3D(
54 | num_layers=num_layers,
55 | in_channels=in_channels,
56 | out_channels=out_channels,
57 | temb_channels=temb_channels,
58 | add_downsample=add_downsample,
59 | resnet_eps=resnet_eps,
60 | resnet_act_fn=resnet_act_fn,
61 | resnet_groups=resnet_groups,
62 | downsample_padding=downsample_padding,
63 | resnet_time_scale_shift=resnet_time_scale_shift,
64 |
65 | use_inflated_groupnorm=use_inflated_groupnorm,
66 |
67 | use_motion_module=use_motion_module,
68 | motion_module_type=motion_module_type,
69 | motion_module_kwargs=motion_module_kwargs,
70 | )
71 | elif down_block_type == "CrossAttnDownBlock3D":
72 | if cross_attention_dim is None:
73 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
74 | return CrossAttnDownBlock3D(
75 | num_layers=num_layers,
76 | in_channels=in_channels,
77 | out_channels=out_channels,
78 | temb_channels=temb_channels,
79 | add_downsample=add_downsample,
80 | resnet_eps=resnet_eps,
81 | resnet_act_fn=resnet_act_fn,
82 | resnet_groups=resnet_groups,
83 | downsample_padding=downsample_padding,
84 | cross_attention_dim=cross_attention_dim,
85 | num_attention_heads=num_attention_heads,
86 | dual_cross_attention=dual_cross_attention,
87 | use_linear_projection=use_linear_projection,
88 | only_cross_attention=only_cross_attention,
89 | upcast_attention=upcast_attention,
90 | resnet_time_scale_shift=resnet_time_scale_shift,
91 |
92 | use_inflated_groupnorm=use_inflated_groupnorm,
93 |
94 | use_motion_module=use_motion_module,
95 | motion_module_type=motion_module_type,
96 | motion_module_kwargs=motion_module_kwargs,
97 | )
98 | raise ValueError(f"{down_block_type} does not exist.")
99 |
100 |
101 | def get_up_block(
102 | up_block_type,
103 | num_layers,
104 | in_channels,
105 | out_channels,
106 | prev_output_channel,
107 | temb_channels,
108 | add_upsample,
109 | resnet_eps,
110 | resnet_act_fn,
111 | num_attention_heads,
112 | use_motion_module,
113 | motion_module_type,
114 | motion_module_kwargs,
115 | resnet_groups=None,
116 | cross_attention_dim=None,
117 | unet_use_cross_frame_attention=False,
118 | unet_use_temporal_attention=False,
119 | dual_cross_attention=False,
120 | use_linear_projection=True,
121 | only_cross_attention=False,
122 | upcast_attention=False,
123 | resnet_time_scale_shift="default",
124 |
125 | use_inflated_groupnorm=False,
126 |
127 | ):
128 | if up_block_type == "UpBlock3D":
129 | return UpBlock3D(
130 | num_layers=num_layers,
131 | in_channels=in_channels,
132 | out_channels=out_channels,
133 | prev_output_channel=prev_output_channel,
134 | temb_channels=temb_channels,
135 | add_upsample=add_upsample,
136 | resnet_eps=resnet_eps,
137 | resnet_act_fn=resnet_act_fn,
138 | resnet_groups=resnet_groups,
139 | resnet_time_scale_shift=resnet_time_scale_shift,
140 |
141 | use_inflated_groupnorm=use_inflated_groupnorm,
142 |
143 | use_motion_module=use_motion_module,
144 | motion_module_type=motion_module_type,
145 | motion_module_kwargs=motion_module_kwargs,
146 | )
147 | elif up_block_type == "CrossAttnUpBlock3D":
148 | if cross_attention_dim is None:
149 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
150 | return CrossAttnUpBlock3D(
151 | num_layers=num_layers,
152 | in_channels=in_channels,
153 | out_channels=out_channels,
154 | prev_output_channel=prev_output_channel,
155 | temb_channels=temb_channels,
156 | add_upsample=add_upsample,
157 | resnet_eps=resnet_eps,
158 | resnet_act_fn=resnet_act_fn,
159 | resnet_groups=resnet_groups,
160 | cross_attention_dim=cross_attention_dim,
161 | num_attention_heads=num_attention_heads,
162 | dual_cross_attention=dual_cross_attention,
163 | use_linear_projection=use_linear_projection,
164 | only_cross_attention=only_cross_attention,
165 | upcast_attention=upcast_attention,
166 | resnet_time_scale_shift=resnet_time_scale_shift,
167 |
168 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
169 | unet_use_temporal_attention=unet_use_temporal_attention,
170 | use_inflated_groupnorm=use_inflated_groupnorm,
171 |
172 | use_motion_module=use_motion_module,
173 | motion_module_type=motion_module_type,
174 | motion_module_kwargs=motion_module_kwargs,
175 | )
176 | raise ValueError(f"{up_block_type} does not exist.")
177 |
178 |
179 | class UNetMidBlock3DCrossAttn(nn.Module):
180 | def __init__(
181 | self,
182 | in_channels: int,
183 | temb_channels: int,
184 | dropout: float = 0.0,
185 | num_layers: int = 1,
186 | resnet_eps: float = 1e-6,
187 | resnet_time_scale_shift: str = "default",
188 | resnet_act_fn: str = "swish",
189 | resnet_groups: int = 32,
190 | resnet_pre_norm: bool = True,
191 | num_attention_heads=1,
192 | output_scale_factor=1.0,
193 | cross_attention_dim=1280,
194 | dual_cross_attention=False,
195 | use_linear_projection=True,
196 | upcast_attention=False,
197 |
198 | unet_use_cross_frame_attention=None,
199 | unet_use_temporal_attention=None,
200 | use_inflated_groupnorm=None,
201 |
202 | use_motion_module=None,
203 |
204 | motion_module_type=None,
205 | motion_module_kwargs=None,
206 | ):
207 | super().__init__()
208 |
209 | self.has_cross_attention = True
210 | self.num_attention_heads = num_attention_heads
211 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
212 |
213 | # there is always at least one resnet
214 | resnets = [
215 | ResnetBlock3D(
216 | in_channels=in_channels,
217 | out_channels=in_channels,
218 | temb_channels=temb_channels,
219 | eps=resnet_eps,
220 | groups=resnet_groups,
221 | dropout=dropout,
222 | time_embedding_norm=resnet_time_scale_shift,
223 | non_linearity=resnet_act_fn,
224 | output_scale_factor=output_scale_factor,
225 | pre_norm=resnet_pre_norm,
226 | use_inflated_groupnorm=use_inflated_groupnorm,
227 |
228 | )
229 | ]
230 |
231 | attentions = []
232 | motion_modules = []
233 |
234 | for _ in range(num_layers):
235 | attentions.append(
236 | Transformer3DModel(
237 | in_channels // num_attention_heads,
238 | num_attention_heads,
239 | in_channels=in_channels,
240 | num_layers=1,
241 | cross_attention_dim=cross_attention_dim,
242 | norm_num_groups=resnet_groups,
243 | # use_linear_projection=use_linear_projection,
244 | use_linear_projection=False,
245 | upcast_attention=upcast_attention,
246 |
247 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
248 | unet_use_temporal_attention=unet_use_temporal_attention,
249 | )
250 | )
251 | motion_modules.append(
252 | get_motion_module(
253 | in_channels=in_channels,
254 | motion_module_type=motion_module_type,
255 | motion_module_kwargs=motion_module_kwargs,
256 | ) if use_motion_module else None
257 | )
258 |
259 | resnets.append(
260 | ResnetBlock3D(
261 | in_channels=in_channels,
262 | out_channels=in_channels,
263 | temb_channels=temb_channels,
264 | eps=resnet_eps,
265 | groups=resnet_groups,
266 | dropout=dropout,
267 | time_embedding_norm=resnet_time_scale_shift,
268 | non_linearity=resnet_act_fn,
269 | output_scale_factor=output_scale_factor,
270 | pre_norm=resnet_pre_norm,
271 |
272 | use_inflated_groupnorm=use_inflated_groupnorm,
273 |
274 | )
275 | )
276 |
277 | self.attentions = nn.ModuleList(attentions)
278 | self.resnets = nn.ModuleList(resnets)
279 | self.motion_modules = nn.ModuleList(motion_modules)
280 |
281 | def forward(
282 | self,
283 | hidden_states,
284 | temb=None,
285 | encoder_hidden_states=None,
286 | attention_mask=None,
287 | ):
288 | hidden_states = self.resnets[0](hidden_states, temb)
289 | for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
290 | hidden_states = attn(
291 | hidden_states,
292 | encoder_hidden_states=encoder_hidden_states,
293 | return_dict=False,
294 | )[0]
295 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
296 | hidden_states = resnet(hidden_states, temb)
297 |
298 | return hidden_states
299 |
300 |
301 | class CrossAttnDownBlock3D(nn.Module):
302 | def __init__(
303 | self,
304 | in_channels: int,
305 | out_channels: int,
306 | temb_channels: int,
307 | dropout: float = 0.0,
308 | num_layers: int = 1,
309 | resnet_eps: float = 1e-6,
310 | resnet_time_scale_shift: str = "default",
311 | resnet_act_fn: str = "swish",
312 | resnet_groups: int = 32,
313 | resnet_pre_norm: bool = True,
314 | num_attention_heads=1,
315 | cross_attention_dim=1280,
316 | output_scale_factor=1.0,
317 | downsample_padding=1,
318 | add_downsample=True,
319 | dual_cross_attention=False,
320 | use_linear_projection=False,
321 | only_cross_attention=False,
322 | upcast_attention=False,
323 |
324 | unet_use_cross_frame_attention=None,
325 | unet_use_temporal_attention=None,
326 | use_inflated_groupnorm=None,
327 |
328 | use_motion_module=None,
329 |
330 | motion_module_type=None,
331 | motion_module_kwargs=None,
332 | ):
333 | super().__init__()
334 | resnets = []
335 | attentions = []
336 | motion_modules = []
337 |
338 | self.has_cross_attention = True
339 | self.attn_num_attention_heads = num_attention_heads
340 |
341 | for i in range(num_layers):
342 | in_channels = in_channels if i == 0 else out_channels
343 | resnets.append(
344 | ResnetBlock3D(
345 | in_channels=in_channels,
346 | out_channels=out_channels,
347 | temb_channels=temb_channels,
348 | eps=resnet_eps,
349 | groups=resnet_groups,
350 | dropout=dropout,
351 | time_embedding_norm=resnet_time_scale_shift,
352 | non_linearity=resnet_act_fn,
353 | output_scale_factor=output_scale_factor,
354 | pre_norm=resnet_pre_norm,
355 |
356 | use_inflated_groupnorm=use_inflated_groupnorm,
357 | )
358 | )
359 | if dual_cross_attention:
360 | raise NotImplementedError
361 |
362 | attentions.append(
363 | Transformer3DModel(
364 | self.attn_num_attention_heads,
365 | out_channels // self.attn_num_attention_heads,
366 | in_channels=out_channels,
367 | num_layers=1,
368 | cross_attention_dim=cross_attention_dim,
369 | norm_num_groups=resnet_groups,
370 | # use_linear_projection=use_linear_projection,
371 | use_linear_projection=False,
372 | only_cross_attention=only_cross_attention,
373 | upcast_attention=upcast_attention,
374 |
375 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
376 | unet_use_temporal_attention=unet_use_temporal_attention,
377 | )
378 | )
379 |
380 | motion_modules.append(
381 | get_motion_module(
382 | in_channels=out_channels,
383 | motion_module_type=motion_module_type,
384 | motion_module_kwargs=motion_module_kwargs,
385 | )
386 | )
387 |
388 | self.resnets = nn.ModuleList(resnets)
389 | self.attentions = nn.ModuleList(attentions)
390 | self.motion_modules = nn.ModuleList(motion_modules)
391 |
392 |
393 | if add_downsample:
394 | self.downsamplers = nn.ModuleList(
395 | [
396 | Downsample3D(
397 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
398 | )
399 | ]
400 | )
401 | else:
402 | self.downsamplers = None
403 |
404 | self.gradient_checkpointing = False
405 |
406 | def forward(
407 | self,
408 | hidden_states,
409 | temb=None,
410 | encoder_hidden_states=None,
411 | attention_mask=None,
412 | num_frames=1,
413 | ):
414 | # TODO(Patrick, William) - attention mask is not used
415 | output_states = ()
416 |
417 | for resnet, attn, motion_module in zip(
418 | self.resnets, self.attentions, self.motion_modules
419 | ):
420 | hidden_states = resnet(hidden_states, temb)
421 | hidden_states = attn(
422 | hidden_states,
423 | encoder_hidden_states=encoder_hidden_states,
424 | return_dict=False,
425 | )[0]
426 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
427 |
428 |
429 | output_states += (hidden_states,)
430 |
431 | if self.downsamplers is not None:
432 | for downsampler in self.downsamplers:
433 | hidden_states = downsampler(hidden_states)
434 |
435 | output_states += (hidden_states,)
436 |
437 | return hidden_states, output_states
438 |
439 |
440 | class DownBlock3D(nn.Module):
441 | def __init__(
442 | self,
443 | in_channels: int,
444 | out_channels: int,
445 | temb_channels: int,
446 | dropout: float = 0.0,
447 | num_layers: int = 1,
448 | resnet_eps: float = 1e-6,
449 | resnet_time_scale_shift: str = "default",
450 | resnet_act_fn: str = "swish",
451 | resnet_groups: int = 32,
452 | resnet_pre_norm: bool = True,
453 | output_scale_factor=1.0,
454 | add_downsample=True,
455 | downsample_padding=1,
456 | use_inflated_groupnorm=False,
457 | use_motion_module=True,
458 | motion_module_type=None,
459 | motion_module_kwargs=None,
460 | ):
461 | super().__init__()
462 | resnets = []
463 | motion_modules = []
464 |
465 | for i in range(num_layers):
466 | in_channels = in_channels if i == 0 else out_channels
467 | resnets.append(
468 | ResnetBlock3D(
469 | in_channels=in_channels,
470 | out_channels=out_channels,
471 | temb_channels=temb_channels,
472 | eps=resnet_eps,
473 | groups=resnet_groups,
474 | dropout=dropout,
475 | time_embedding_norm=resnet_time_scale_shift,
476 | non_linearity=resnet_act_fn,
477 | output_scale_factor=output_scale_factor,
478 | pre_norm=resnet_pre_norm,
479 |
480 | use_inflated_groupnorm=use_inflated_groupnorm,
481 | )
482 | )
483 | motion_modules.append(
484 | get_motion_module(
485 | in_channels=out_channels,
486 | motion_module_type=motion_module_type,
487 | motion_module_kwargs=motion_module_kwargs,
488 | ) if use_motion_module else None
489 | )
490 |
491 | self.resnets = nn.ModuleList(resnets)
492 | self.motion_modules = nn.ModuleList(motion_modules)
493 |
494 | if add_downsample:
495 | self.downsamplers = nn.ModuleList(
496 | [
497 | Downsample3D(
498 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
499 | )
500 | ]
501 | )
502 | else:
503 | self.downsamplers = None
504 |
505 | self.gradient_checkpointing = False
506 |
507 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
508 | output_states = ()
509 |
510 | for resnet, motion_module in zip(self.resnets, self.motion_modules):
511 | hidden_states = resnet(hidden_states, temb)
512 | # hidden_states = temp_conv(hidden_states, num_frames=num_frames)
513 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
514 |
515 | output_states += (hidden_states,)
516 |
517 | if self.downsamplers is not None:
518 | for downsampler in self.downsamplers:
519 | hidden_states = downsampler(hidden_states)
520 |
521 | output_states += (hidden_states,)
522 |
523 | return hidden_states, output_states
524 |
525 |
526 | class CrossAttnUpBlock3D(nn.Module):
527 | def __init__(
528 | self,
529 | in_channels: int,
530 | out_channels: int,
531 | prev_output_channel: int,
532 | temb_channels: int,
533 | dropout: float = 0.0,
534 | num_layers: int = 1,
535 | resnet_eps: float = 1e-6,
536 | resnet_time_scale_shift: str = "default",
537 | resnet_act_fn: str = "swish",
538 | resnet_groups: int = 32,
539 | resnet_pre_norm: bool = True,
540 | num_attention_heads=1,
541 | cross_attention_dim=1280,
542 | output_scale_factor=1.0,
543 | add_upsample=True,
544 | dual_cross_attention=False,
545 | use_linear_projection=False,
546 | only_cross_attention=False,
547 | upcast_attention=False,
548 |
549 | unet_use_cross_frame_attention=None,
550 | unet_use_temporal_attention=None,
551 | use_inflated_groupnorm=None,
552 |
553 | use_motion_module=None,
554 |
555 | motion_module_type=None,
556 | motion_module_kwargs=None,
557 | ):
558 | super().__init__()
559 | resnets = []
560 | attentions = []
561 | motion_modules = []
562 |
563 | self.has_cross_attention = True
564 | self.attn_num_attention_heads = num_attention_heads
565 |
566 | for i in range(num_layers):
567 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
568 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
569 |
570 | resnets.append(
571 | ResnetBlock3D(
572 | in_channels=resnet_in_channels + res_skip_channels,
573 | out_channels=out_channels,
574 | temb_channels=temb_channels,
575 | eps=resnet_eps,
576 | groups=resnet_groups,
577 | dropout=dropout,
578 | time_embedding_norm=resnet_time_scale_shift,
579 | non_linearity=resnet_act_fn,
580 | output_scale_factor=output_scale_factor,
581 | pre_norm=resnet_pre_norm,
582 |
583 | use_inflated_groupnorm=use_inflated_groupnorm,
584 | )
585 | )
586 | attentions.append(
587 | Transformer3DModel(
588 | self.attn_num_attention_heads,
589 | out_channels // self.attn_num_attention_heads,
590 | in_channels=out_channels,
591 | num_layers=1,
592 | cross_attention_dim=cross_attention_dim,
593 | norm_num_groups=resnet_groups,
594 | # use_linear_projection=use_linear_projection,
595 | use_linear_projection=False,
596 | only_cross_attention=only_cross_attention,
597 | upcast_attention=upcast_attention,
598 |
599 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
600 | unet_use_temporal_attention=unet_use_temporal_attention,
601 | )
602 | )
603 | motion_modules.append(
604 | get_motion_module(
605 | in_channels=out_channels,
606 | motion_module_type=motion_module_type,
607 | motion_module_kwargs=motion_module_kwargs,
608 | ) if use_motion_module else None
609 | )
610 | self.resnets = nn.ModuleList(resnets)
611 | self.attentions = nn.ModuleList(attentions)
612 | self.motion_modules = nn.ModuleList(motion_modules)
613 |
614 | if add_upsample:
615 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
616 | else:
617 | self.upsamplers = None
618 |
619 | self.gradient_checkpointing = False
620 |
621 | def forward(
622 | self,
623 | hidden_states,
624 | res_hidden_states_tuple,
625 | temb=None,
626 | encoder_hidden_states=None,
627 | upsample_size=None,
628 | attention_mask=None,
629 | ):
630 | # TODO(Patrick, William) - attention mask is not used
631 | for resnet, attn, motion_module in zip(
632 | self.resnets, self.attentions, self.motion_modules
633 | ):
634 | # pop res hidden states
635 | res_hidden_states = res_hidden_states_tuple[-1]
636 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
637 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
638 |
639 | hidden_states = resnet(hidden_states, temb)
640 | hidden_states = attn(
641 | hidden_states,
642 | encoder_hidden_states=encoder_hidden_states,
643 | return_dict=False,
644 | )[0]
645 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
646 |
647 |
648 | if self.upsamplers is not None:
649 | for upsampler in self.upsamplers:
650 | hidden_states = upsampler(hidden_states, upsample_size)
651 |
652 | return hidden_states
653 |
654 |
655 | class UpBlock3D(nn.Module):
656 | def __init__(
657 | self,
658 | in_channels: int,
659 | prev_output_channel: int,
660 | out_channels: int,
661 | temb_channels: int,
662 | dropout: float = 0.0,
663 | num_layers: int = 1,
664 | resnet_eps: float = 1e-6,
665 | resnet_time_scale_shift: str = "default",
666 | resnet_act_fn: str = "swish",
667 | resnet_groups: int = 32,
668 | resnet_pre_norm: bool = True,
669 | output_scale_factor=1.0,
670 | add_upsample=True,
671 |
672 | use_inflated_groupnorm=None,
673 |
674 | use_motion_module=None,
675 | motion_module_type=None,
676 | motion_module_kwargs=None,
677 | ):
678 | super().__init__()
679 | resnets = []
680 | motion_modules = []
681 |
682 | for i in range(num_layers):
683 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
684 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
685 |
686 | resnets.append(
687 | ResnetBlock3D(
688 | in_channels=resnet_in_channels + res_skip_channels,
689 | out_channels=out_channels,
690 | temb_channels=temb_channels,
691 | eps=resnet_eps,
692 | groups=resnet_groups,
693 | dropout=dropout,
694 | time_embedding_norm=resnet_time_scale_shift,
695 | non_linearity=resnet_act_fn,
696 | output_scale_factor=output_scale_factor,
697 | pre_norm=resnet_pre_norm,
698 |
699 | use_inflated_groupnorm=use_inflated_groupnorm,
700 |
701 | )
702 | )
703 | motion_modules.append(
704 | get_motion_module(
705 | in_channels=out_channels,
706 | motion_module_type=motion_module_type,
707 | motion_module_kwargs=motion_module_kwargs,
708 | ) if use_motion_module else None
709 | )
710 |
711 | self.resnets = nn.ModuleList(resnets)
712 | self.motion_modules = nn.ModuleList(motion_modules)
713 |
714 | if add_upsample:
715 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
716 | else:
717 | self.upsamplers = None
718 |
719 | self.gradient_checkpointing = False
720 |
721 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None):
722 | for resnet, motion_module in zip(self.resnets, self.motion_modules):
723 | # pop res hidden states
724 | res_hidden_states = res_hidden_states_tuple[-1]
725 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
726 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
727 |
728 | hidden_states = resnet(hidden_states, temb)
729 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
730 |
731 |
732 | if self.upsamplers is not None:
733 | for upsampler in self.upsamplers:
734 | hidden_states = upsampler(hidden_states, upsample_size)
735 |
736 | return hidden_states
737 |
--------------------------------------------------------------------------------
/animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/pipelines/__pycache__/pipeline_animation.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/pipelines/pipeline_animation.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2 |
3 | import inspect
4 | from typing import Callable, List, Optional, Union
5 | from dataclasses import dataclass
6 | import cv2
7 |
8 | from PIL import Image
9 | import numpy as np
10 | import torch
11 | from tqdm import tqdm
12 |
13 | from diffusers.utils import is_accelerate_available
14 | from packaging import version
15 | from transformers import CLIPTextModel, CLIPTokenizer
16 |
17 | from diffusers.configuration_utils import FrozenDict
18 | from diffusers.models import AutoencoderKL
19 | from diffusers.pipeline_utils import DiffusionPipeline
20 | from diffusers.schedulers import (
21 | DDIMScheduler,
22 | DPMSolverMultistepScheduler,
23 | EulerAncestralDiscreteScheduler,
24 | EulerDiscreteScheduler,
25 | LMSDiscreteScheduler,
26 | PNDMScheduler,
27 | )
28 | from diffusers.models import ControlNetModel
29 | from diffusers.image_processor import VaeImageProcessor
30 | from diffusers.utils import deprecate, logging, BaseOutput
31 |
32 | from einops import rearrange
33 |
34 | from ..models.unet import UNet3DConditionModel
35 |
36 |
37 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38 |
39 |
40 |
41 | def prepare_image(
42 | image,
43 | width,
44 | height,
45 | batch_size,
46 | num_images_per_prompt,
47 | device,
48 | dtype,
49 | do_classifier_free_guidance=False,
50 | guess_mode=False,
51 | ):
52 | control_image_processor = VaeImageProcessor()
53 | image = control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
54 | image_batch_size = batch_size
55 |
56 | if image_batch_size == 1:
57 | repeat_by = batch_size
58 | else:
59 | # image batch size is the same as prompt batch size
60 | repeat_by = num_images_per_prompt
61 |
62 | image = image.repeat_interleave(repeat_by, dim=0)
63 |
64 | image = image.to(device=device, dtype=dtype)
65 |
66 | if do_classifier_free_guidance and not guess_mode:
67 | image = torch.cat([image] * 2)
68 |
69 | return image
70 |
71 |
72 | @dataclass
73 | class AnimationPipelineOutput(BaseOutput):
74 | videos: Union[torch.Tensor, np.ndarray]
75 |
76 |
77 | class AnimationPipeline(DiffusionPipeline):
78 | _optional_components = []
79 |
80 | def __init__(
81 | self,
82 | vae: AutoencoderKL,
83 | text_encoder: CLIPTextModel,
84 | tokenizer: CLIPTokenizer,
85 | unet: UNet3DConditionModel,
86 | scheduler: Union[
87 | DDIMScheduler,
88 | PNDMScheduler,
89 | LMSDiscreteScheduler,
90 | EulerDiscreteScheduler,
91 | EulerAncestralDiscreteScheduler,
92 | DPMSolverMultistepScheduler,
93 | ],
94 | controlnet: ControlNetModel,
95 | ):
96 | super().__init__()
97 |
98 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
99 | deprecation_message = (
100 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
101 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
102 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
103 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
104 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
105 | " file"
106 | )
107 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
108 | new_config = dict(scheduler.config)
109 | new_config["steps_offset"] = 1
110 | scheduler._internal_dict = FrozenDict(new_config)
111 |
112 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
113 | deprecation_message = (
114 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
115 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
116 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
117 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
118 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
119 | )
120 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
121 | new_config = dict(scheduler.config)
122 | new_config["clip_sample"] = False
123 | scheduler._internal_dict = FrozenDict(new_config)
124 |
125 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
126 | version.parse(unet.config._diffusers_version).base_version
127 | ) < version.parse("0.9.0.dev0")
128 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
129 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
130 | deprecation_message = (
131 | "The configuration file of the unet has set the default `sample_size` to smaller than"
132 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
133 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
134 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
135 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
136 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
137 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
138 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
139 | " the `unet/config.json` file"
140 | )
141 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
142 | new_config = dict(unet.config)
143 | new_config["sample_size"] = 64
144 | unet._internal_dict = FrozenDict(new_config)
145 |
146 | self.register_modules(
147 | vae=vae,
148 | text_encoder=text_encoder,
149 | tokenizer=tokenizer,
150 | unet=unet,
151 | scheduler=scheduler,
152 | controlnet=controlnet
153 | )
154 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
155 |
156 | def enable_vae_slicing(self):
157 | self.vae.enable_slicing()
158 |
159 | def disable_vae_slicing(self):
160 | self.vae.disable_slicing()
161 |
162 | def enable_sequential_cpu_offload(self, gpu_id=0):
163 | if is_accelerate_available():
164 | from accelerate import cpu_offload
165 | else:
166 | raise ImportError("Please install accelerate via `pip install accelerate`")
167 |
168 | device = torch.device(f"cuda:{gpu_id}")
169 |
170 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
171 | if cpu_offloaded_model is not None:
172 | cpu_offload(cpu_offloaded_model, device)
173 |
174 |
175 | @property
176 | def _execution_device(self):
177 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
178 | return self.device
179 | for module in self.unet.modules():
180 | if (
181 | hasattr(module, "_hf_hook")
182 | and hasattr(module._hf_hook, "execution_device")
183 | and module._hf_hook.execution_device is not None
184 | ):
185 | return torch.device(module._hf_hook.execution_device)
186 | return self.device
187 |
188 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
189 | batch_size = len(prompt) if isinstance(prompt, list) else 1
190 |
191 | text_inputs = self.tokenizer(
192 | prompt,
193 | padding="max_length",
194 | max_length=self.tokenizer.model_max_length,
195 | truncation=True,
196 | return_tensors="pt",
197 | )
198 | text_input_ids = text_inputs.input_ids
199 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
200 |
201 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
202 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
203 | logger.warning(
204 | "The following part of your input was truncated because CLIP can only handle sequences up to"
205 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
206 | )
207 |
208 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
209 | attention_mask = text_inputs.attention_mask.to(device)
210 | else:
211 | attention_mask = None
212 |
213 | text_embeddings = self.text_encoder(
214 | text_input_ids.to(device),
215 | attention_mask=attention_mask,
216 | )
217 | text_embeddings = text_embeddings[0]
218 |
219 | # duplicate text embeddings for each generation per prompt, using mps friendly method
220 | bs_embed, seq_len, _ = text_embeddings.shape
221 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
223 |
224 | # get unconditional embeddings for classifier free guidance
225 | if do_classifier_free_guidance:
226 | uncond_tokens: List[str]
227 | if negative_prompt is None:
228 | uncond_tokens = [""] * batch_size
229 | elif type(prompt) is not type(negative_prompt):
230 | raise TypeError(
231 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
232 | f" {type(prompt)}."
233 | )
234 | elif isinstance(negative_prompt, str):
235 | uncond_tokens = [negative_prompt]
236 | elif batch_size != len(negative_prompt):
237 | raise ValueError(
238 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
239 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
240 | " the batch size of `prompt`."
241 | )
242 | else:
243 | uncond_tokens = negative_prompt
244 |
245 | max_length = text_input_ids.shape[-1]
246 | uncond_input = self.tokenizer(
247 | uncond_tokens,
248 | padding="max_length",
249 | max_length=max_length,
250 | truncation=True,
251 | return_tensors="pt",
252 | )
253 |
254 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
255 | attention_mask = uncond_input.attention_mask.to(device)
256 | else:
257 | attention_mask = None
258 |
259 | uncond_embeddings = self.text_encoder(
260 | uncond_input.input_ids.to(device),
261 | attention_mask=attention_mask,
262 | )
263 | uncond_embeddings = uncond_embeddings[0]
264 |
265 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
266 | seq_len = uncond_embeddings.shape[1]
267 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
268 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
269 |
270 | # For classifier free guidance, we need to do two forward passes.
271 | # Here we concatenate the unconditional and text embeddings into a single batch
272 | # to avoid doing two forward passes
273 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
274 |
275 | return text_embeddings
276 |
277 | def decode_latents(self, latents):
278 | video_length = latents.shape[2]
279 | latents = 1 / 0.18215 * latents
280 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
281 | # video = self.vae.decode(latents).sample
282 | video = []
283 | for frame_idx in tqdm(range(latents.shape[0])):
284 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
285 | video = torch.cat(video)
286 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
287 | video = (video / 2 + 0.5).clamp(0, 1)
288 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
289 | video = video.cpu().float().numpy()
290 | return video
291 |
292 | def prepare_extra_step_kwargs(self, generator, eta):
293 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
294 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
295 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
296 | # and should be between [0, 1]
297 |
298 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
299 | extra_step_kwargs = {}
300 | if accepts_eta:
301 | extra_step_kwargs["eta"] = eta
302 |
303 | # check if the scheduler accepts generator
304 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
305 | if accepts_generator:
306 | extra_step_kwargs["generator"] = generator
307 | return extra_step_kwargs
308 |
309 | def check_inputs(self, prompt, height, width, callback_steps):
310 | if not isinstance(prompt, str) and not isinstance(prompt, list):
311 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
312 |
313 | if height % 8 != 0 or width % 8 != 0:
314 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
315 |
316 | if (callback_steps is None) or (
317 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
318 | ):
319 | raise ValueError(
320 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
321 | f" {type(callback_steps)}."
322 | )
323 |
324 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
325 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
326 | if isinstance(generator, list) and len(generator) != batch_size:
327 | raise ValueError(
328 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
329 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
330 | )
331 | if latents is None:
332 | rand_device = "cpu" if device.type == "mps" else device
333 |
334 | if isinstance(generator, list):
335 | shape = shape
336 | # shape = (1,) + shape[1:]
337 | latents = [
338 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
339 | for i in range(batch_size)
340 | ]
341 | latents = torch.cat(latents, dim=0).to(device)
342 | else:
343 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
344 | else:
345 | rand_device = "cpu" if device.type == "mps" else device
346 | if latents.shape != shape:
347 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
348 | noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
349 | latents = noise
350 | latents = latents.to(device)
351 |
352 | # scale the initial noise by the standard deviation required by the scheduler
353 | latents = latents * self.scheduler.init_noise_sigma
354 | return latents
355 |
356 | @torch.no_grad()
357 | def __call__(
358 | self,
359 | prompt: Union[str, List[str]],
360 | video_length: Optional[int],
361 | height: Optional[int] = None,
362 | width: Optional[int] = None,
363 | num_inference_steps: int = 50,
364 | guidance_scale: float = 7.5,
365 | negative_prompt: Optional[Union[str, List[str]]] = None,
366 | num_videos_per_prompt: Optional[int] = 1,
367 | eta: float = 0.0,
368 | controlnet_image: torch.FloatTensor = None,
369 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
370 | latents: Optional[torch.FloatTensor] = None,
371 | output_type: Optional[str] = "tensor",
372 | return_dict: bool = True,
373 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
374 | callback_steps: Optional[int] = 1,
375 | **kwargs,
376 | ):
377 | # Default height and width to unet
378 | height = height or self.unet.config.sample_size * self.vae_scale_factor
379 | width = width or self.unet.config.sample_size * self.vae_scale_factor
380 |
381 | # Check inputs. Raise error if not correct
382 | self.check_inputs(prompt, height, width, callback_steps)
383 |
384 | # Define call parameters
385 | # batch_size = 1 if isinstance(prompt, str) else len(prompt)
386 | batch_size = 1
387 | if latents is not None:
388 | batch_size = latents.shape[0]
389 | if isinstance(prompt, list):
390 | batch_size = len(prompt)
391 |
392 | device = self._execution_device
393 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
394 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
395 | # corresponds to doing no classifier free guidance.
396 | do_classifier_free_guidance = guidance_scale > 1.0
397 |
398 | # Encode input prompt
399 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
400 | if negative_prompt is not None:
401 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
402 | text_embeddings = self._encode_prompt(
403 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
404 | )
405 |
406 | # Prepare timesteps
407 | self.scheduler.set_timesteps(num_inference_steps, device=device)
408 | timesteps = self.scheduler.timesteps
409 |
410 | # Prepare latent variables
411 | num_channels_latents = self.unet.in_channels
412 |
413 | latents = self.prepare_latents(
414 | batch_size * num_videos_per_prompt,
415 | num_channels_latents,
416 | video_length,
417 | height,
418 | width,
419 | text_embeddings.dtype,
420 | device,
421 | generator,
422 | latents,
423 | )
424 | latents_dtype = latents.dtype
425 |
426 | #---------------
427 | image = prepare_image(
428 | image=controlnet_image,
429 | width=controlnet_image.shape[-1],
430 | height=controlnet_image.shape[-2],
431 | batch_size=1,
432 | num_images_per_prompt=1,
433 | device="cuda",
434 | dtype=self.controlnet.dtype,
435 | )
436 |
437 | #---------------
438 |
439 | # Prepare extra step kwargs.
440 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
441 |
442 | # Denoising loop
443 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
444 | with self.progress_bar(total=num_inference_steps) as progress_bar:
445 | for i, t in enumerate(timesteps):
446 | # expand the latents if we are doing classifier free guidance
447 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
448 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
449 |
450 | down_block_res_samples, mid_block_res_sample = self.controlnet(
451 | sample=latent_model_input[:,:,0,:,:],
452 | timestep=t,
453 | encoder_hidden_states=text_embeddings, # [4,77,768]
454 | controlnet_cond=image,
455 | return_dict=False,
456 | )
457 |
458 | # predict the noise residual
459 | noise_pred = self.unet(latent_model_input,
460 | t,
461 | encoder_hidden_states=text_embeddings,
462 | down_block_additional_residuals=down_block_res_samples,
463 | mid_block_additional_residual=mid_block_res_sample,
464 | ).sample.to(dtype=latents_dtype)
465 |
466 | # perform guidance
467 | if do_classifier_free_guidance:
468 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
469 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
470 |
471 | # compute the previous noisy sample x_t -> x_t-1
472 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
473 |
474 | # call the callback, if provided
475 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
476 | progress_bar.update()
477 | if callback is not None and i % callback_steps == 0:
478 | callback(i, t, latents)
479 |
480 | # Post-processing
481 | video = self.decode_latents(latents)
482 |
483 | # Convert to tensor
484 | if output_type == "tensor":
485 | video = torch.from_numpy(video)
486 |
487 | if not return_dict:
488 | return video
489 |
490 | return AnimationPipelineOutput(videos=video)
491 |
492 |
--------------------------------------------------------------------------------
/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/utils/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/animatediff/utils/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/animatediff/utils/convert_lora_safetensor_to_diffusers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ Conversion script for the LoRA's safetensors checkpoints. """
17 |
18 | import argparse
19 |
20 | import torch
21 | from safetensors.torch import load_file
22 |
23 | from diffusers import StableDiffusionPipeline
24 | import pdb
25 |
26 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
27 | # load base model
28 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
29 |
30 | # load LoRA weight from .safetensors
31 | # state_dict = load_file(checkpoint_path)
32 |
33 | visited = []
34 |
35 | # directly update weight in diffusers model
36 | for key in state_dict:
37 | # it is suggested to print out the key, it usually will be something like below
38 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
39 |
40 | # as we have set the alpha beforehand, so just skip
41 | if ".alpha" in key or key in visited:
42 | continue
43 |
44 | if "text" in key:
45 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
46 | curr_layer = pipeline.text_encoder
47 | else:
48 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
49 | curr_layer = pipeline.unet
50 |
51 | # find the target layer
52 | temp_name = layer_infos.pop(0)
53 | while len(layer_infos) > -1:
54 | try:
55 | curr_layer = curr_layer.__getattr__(temp_name)
56 | if len(layer_infos) > 0:
57 | temp_name = layer_infos.pop(0)
58 | elif len(layer_infos) == 0:
59 | break
60 | except Exception:
61 | if len(temp_name) > 0:
62 | temp_name += "_" + layer_infos.pop(0)
63 | else:
64 | temp_name = layer_infos.pop(0)
65 |
66 | pair_keys = []
67 | if "lora_down" in key:
68 | pair_keys.append(key.replace("lora_down", "lora_up"))
69 | pair_keys.append(key)
70 | else:
71 | pair_keys.append(key)
72 | pair_keys.append(key.replace("lora_up", "lora_down"))
73 |
74 | # update weight
75 | if len(state_dict[pair_keys[0]].shape) == 4:
76 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
77 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
78 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
79 | else:
80 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
81 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
82 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
83 |
84 | # update visited list
85 | for item in pair_keys:
86 | visited.append(item)
87 |
88 | return pipeline
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser()
93 |
94 | parser.add_argument(
95 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
96 | )
97 | parser.add_argument(
98 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
99 | )
100 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
101 | parser.add_argument(
102 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
103 | )
104 | parser.add_argument(
105 | "--lora_prefix_text_encoder",
106 | default="lora_te",
107 | type=str,
108 | help="The prefix of text encoder weight in safetensors",
109 | )
110 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
111 | parser.add_argument(
112 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
113 | )
114 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
115 |
116 | args = parser.parse_args()
117 |
118 | base_model_path = args.base_model_path
119 | checkpoint_path = args.checkpoint_path
120 | dump_path = args.dump_path
121 | lora_prefix_unet = args.lora_prefix_unet
122 | lora_prefix_text_encoder = args.lora_prefix_text_encoder
123 | alpha = args.alpha
124 |
125 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
126 |
127 | pipe = pipe.to(args.device)
128 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
129 |
--------------------------------------------------------------------------------
/animatediff/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 | import torch.distributed as dist
9 |
10 | from tqdm import tqdm
11 | from einops import rearrange
12 |
13 |
14 | def zero_rank_print(s):
15 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
16 |
17 |
18 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
19 | videos = rearrange(videos, "b c t h w -> t b c h w")
20 | outputs = []
21 | for x in videos:
22 | x = torchvision.utils.make_grid(x, nrow=n_rows)
23 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
24 | if rescale:
25 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
26 | x = (x * 255).numpy().astype(np.uint8)
27 | outputs.append(x)
28 |
29 | os.makedirs(os.path.dirname(path), exist_ok=True)
30 | imageio.mimsave(path, outputs, fps=fps)
31 |
32 |
33 | # DDIM Inversion
34 | @torch.no_grad()
35 | def init_prompt(prompt, pipeline):
36 | uncond_input = pipeline.tokenizer(
37 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
38 | return_tensors="pt"
39 | )
40 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
41 | text_input = pipeline.tokenizer(
42 | [prompt],
43 | padding="max_length",
44 | max_length=pipeline.tokenizer.model_max_length,
45 | truncation=True,
46 | return_tensors="pt",
47 | )
48 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
49 | context = torch.cat([uncond_embeddings, text_embeddings])
50 |
51 | return context
52 |
53 |
54 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
55 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
56 | timestep, next_timestep = min(
57 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
58 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
59 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
60 | beta_prod_t = 1 - alpha_prod_t
61 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
62 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
63 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
64 | return next_sample
65 |
66 |
67 | def get_noise_pred_single(latents, t, context, unet):
68 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
69 | return noise_pred
70 |
71 |
72 | @torch.no_grad()
73 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
74 | context = init_prompt(prompt, pipeline)
75 | uncond_embeddings, cond_embeddings = context.chunk(2)
76 | all_latent = [latent]
77 | latent = latent.clone().detach()
78 | for i in tqdm(range(num_inv_steps)):
79 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
80 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
81 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
82 | all_latent.append(latent)
83 | return all_latent
84 |
85 |
86 | @torch.no_grad()
87 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
88 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
89 | return ddim_latents
90 |
--------------------------------------------------------------------------------
/animatetest.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import inspect
4 | import os
5 | from omegaconf import OmegaConf
6 | from PIL import Image
7 | import numpy as np
8 |
9 | import torch
10 | from torchvision import models
11 | from torch.nn import functional as F
12 | import torchvision.transforms as transforms
13 |
14 | import diffusers
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | import pickle
17 |
18 | from tqdm.auto import tqdm
19 | from transformers import CLIPTextModel, CLIPTokenizer
20 |
21 | # import sys
22 | # sys.path.append("/root/AnimateDiffcontrolnet-main/")
23 |
24 | from animatediff.models.unet import UNet3DConditionModel
25 | from animatediff.models.controlnet import ControlNetModel
26 | from animatediff.pipelines.pipeline_animation import AnimationPipeline
27 | from animatediff.utils.util import save_videos_grid
28 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
29 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
30 | from diffusers.utils.import_utils import is_xformers_available
31 |
32 | from einops import rearrange, repeat
33 |
34 | import csv, pdb, glob
35 | from safetensors import safe_open
36 | import math
37 | from pathlib import Path
38 |
39 |
40 | def main(args):
41 | *_, func_args = inspect.getargvalues(inspect.currentframe())
42 | func_args = dict(func_args)
43 |
44 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
45 | savedir = f"samples/{Path(args.config).stem}-{time_str}"
46 | os.makedirs(savedir)
47 |
48 | config = OmegaConf.load(args.config)
49 | samples = []
50 |
51 | sample_idx = 0
52 | for model_idx, (config_key, model_config) in enumerate(list(config.items())):
53 |
54 | motion_modules = model_config.motion_module
55 | motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
56 | for motion_module in motion_modules:
57 | inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
58 |
59 | ### >>> create validation pipeline >>> ###
60 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
61 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
62 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
63 | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
64 | controlnet = ControlNetModel()
65 |
66 |
67 | # if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
68 | # else: assert False
69 |
70 | pipeline = AnimationPipeline(
71 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
72 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
73 | ).to("cuda")
74 |
75 | # 0. controlnet ckpt
76 | controlnet_state_dict = torch.load("./checkpoints/controlnet_checkpoint-epoch-30.ckpt", map_location="cpu")
77 | missing, unexpected = pipeline.controlnet.load_state_dict(controlnet_state_dict["state_dict"], strict=False)
78 | assert len(unexpected) == 0
79 |
80 |
81 | # 1. unet ckpt
82 | # 1.1 motion module
83 | motion_module_state_dict = torch.load(motion_module, map_location="cpu")
84 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
85 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
86 | assert len(unexpected) == 0
87 |
88 | # 1.2 T2I 用的其他微调过的模型
89 | if model_config.path != "":
90 | if model_config.path.endswith(".ckpt"):
91 | state_dict = torch.load(model_config.path)
92 | pipeline.unet.load_state_dict(state_dict)
93 |
94 | elif model_config.path.endswith(".safetensors"):
95 | state_dict = {}
96 | with safe_open(model_config.path, framework="pt", device="cpu") as f:
97 | for key in f.keys():
98 | state_dict[key] = f.get_tensor(key)
99 |
100 | is_lora = all("lora" in k for k in state_dict.keys())
101 | if not is_lora:
102 | base_state_dict = state_dict
103 | else:
104 | base_state_dict = {}
105 | with safe_open(model_config.base, framework="pt", device="cpu") as f:
106 | for key in f.keys():
107 | base_state_dict[key] = f.get_tensor(key)
108 |
109 | # vae
110 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
111 | pipeline.vae.load_state_dict(converted_vae_checkpoint)
112 | # unet
113 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
114 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
115 | # text_model
116 | pipeline.text_encoder = convert_ldm_clip_checkpoint(pipeline.text_encoder, base_state_dict)
117 |
118 | # import pdb
119 | # pdb.set_trace()
120 | if is_lora:
121 | pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
122 |
123 | pipeline.to("cuda")
124 | ### <<< create validation pipeline <<< ###
125 |
126 | prompts = model_config.prompt
127 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
128 |
129 | random_seeds = model_config.get("seed", [-1])
130 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
131 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
132 |
133 | config[config_key].random_seed = []
134 |
135 | #------------------------------------------------
136 | pixel_transforms = transforms.Compose([
137 | # transforms.RandomHorizontalFlip(),
138 | transforms.Resize(512),
139 | transforms.CenterCrop(512),
140 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
141 | ])
142 | # --------------------------------------------------
143 |
144 | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
145 | init_image = model_config.init_image[prompt_idx]
146 |
147 | pixel_values = Image.open(init_image)
148 | pixel_values = np.array(pixel_values)
149 | pixel_values = torch.from_numpy(pixel_values).permute(2,0,1).unsqueeze(0)
150 | pixel_values = pixel_values / 255.
151 | pixel_values = pixel_transforms(pixel_values)
152 | pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
153 |
154 | # manually set random seed for reproduction
155 | if random_seed != -1: torch.manual_seed(random_seed)
156 | else: torch.seed()
157 | config[config_key].random_seed.append(torch.initial_seed())
158 |
159 | print(f"current seed: {torch.initial_seed()}")
160 | print(f"sampling {prompt} ...")
161 | sample = pipeline(
162 | prompt,
163 | negative_prompt = n_prompt,
164 | num_inference_steps = model_config.steps,
165 | guidance_scale = model_config.guidance_scale,
166 | width = args.W,
167 | height = args.H,
168 | video_length = args.L,
169 | controlnet_image = pixel_values,
170 | ).videos
171 | samples.append(sample)
172 |
173 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
174 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
175 | print(f"save to {savedir}/sample/{prompt}.gif")
176 |
177 | sample_idx += 1
178 |
179 | samples = torch.concat(samples)
180 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
181 |
182 | OmegaConf.save(config, f"{savedir}/config.yaml")
183 |
184 |
185 | if __name__ == "__main__":
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument("--pretrained_model_path", type=str, default="./models/stable-diffusion-v1-5",)
188 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v2.yaml")
189 | parser.add_argument("--config", type=str, default="configs/prompts/v2/5-RealisticVision2.yaml")
190 |
191 | parser.add_argument("--L", type=int, default=16 )
192 | parser.add_argument("--W", type=int, default=512)
193 | parser.add_argument("--H", type=int, default=512)
194 |
195 | args = parser.parse_args()
196 | main(args)
197 |
--------------------------------------------------------------------------------
/configs/inference/inference-v1.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | unet_use_cross_frame_attention: false
3 | unet_use_temporal_attention: false
4 | use_motion_module: true
5 | motion_module_resolutions:
6 | - 1
7 | - 2
8 | - 4
9 | - 8
10 | motion_module_mid_block: false
11 | motion_module_decoder_only: false
12 | motion_module_type: Vanilla
13 | motion_module_kwargs:
14 | num_attention_heads: 8
15 | num_transformer_block: 1
16 | attention_block_types:
17 | - Temporal_Self
18 | - Temporal_Self
19 | temporal_position_encoding: true
20 | temporal_position_encoding_max_len: 24
21 | temporal_attention_dim_div: 1
22 |
23 | noise_scheduler_kwargs:
24 | beta_start: 0.00085
25 | beta_end: 0.012
26 | beta_schedule: "linear"
27 |
--------------------------------------------------------------------------------
/configs/inference/inference-v2.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_inflated_groupnorm: true
3 | unet_use_cross_frame_attention: false
4 | unet_use_temporal_attention: false
5 | use_motion_module: true
6 | motion_module_resolutions:
7 | - 1
8 | - 2
9 | - 4
10 | - 8
11 | motion_module_mid_block: true
12 | motion_module_decoder_only: false
13 | motion_module_type: Vanilla
14 | motion_module_kwargs:
15 | num_attention_heads: 8
16 | num_transformer_block: 1
17 | attention_block_types:
18 | - Temporal_Self
19 | - Temporal_Self
20 | temporal_position_encoding: true
21 | temporal_position_encoding_max_len: 32
22 | temporal_attention_dim_div: 1
23 |
24 | noise_scheduler_kwargs:
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 |
--------------------------------------------------------------------------------
/configs/prompts/v2/5-RealisticVision.yaml:
--------------------------------------------------------------------------------
1 | RealisticVision:
2 | base: ""
3 | path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
4 |
5 | inference_config: "configs/inference/inference-v2.yaml"
6 | motion_module:
7 | - "models/Motion_Module/mm_sd_v15_v2.ckpt"
8 |
9 | seed: [0]
10 | steps: 25
11 | guidance_scale: 7.5
12 |
13 | prompt:
14 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
15 | - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
16 | - "beach, large rocks, waves, cloudy sky, dark clouds, flowing water"
17 | - "scuba diving, coral reef, fish, sea anemones, starfish, sea turtles, clear water, sunlight, underwater world"
18 | - "fireworks, new year, 2023, night sky, stars, ring of fire"
19 | - "bird, small, brown, back, wings, white, chest, belly, black, beak, eyes, tree, deciduous, shrub, green, leaves, sky, clouds"
20 |
21 | n_prompt:
22 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
23 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
24 | - ""
25 | - ""
26 | - ""
27 | - ""
28 |
29 | init_image:
30 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/0.jpg"
31 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/1.jpg"
32 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/2.jpg"
33 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/3.jpg"
34 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/4.jpg"
35 | - "/root/lh/AnimateDiffcontrolnet-main/init_images/5.jpg"
36 |
37 |
--------------------------------------------------------------------------------
/configs/training/image_finetune.yaml:
--------------------------------------------------------------------------------
1 | image_finetune: true
2 |
3 | output_dir: "outputs"
4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
5 |
6 | noise_scheduler_kwargs:
7 | num_train_timesteps: 1000
8 | beta_start: 0.00085
9 | beta_end: 0.012
10 | beta_schedule: "scaled_linear"
11 | steps_offset: 1
12 | clip_sample: false
13 |
14 | train_data:
15 | csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
16 | video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
17 | sample_size: 256
18 |
19 | validation_data:
20 | prompts:
21 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
22 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
23 | - "Robot dancing in times square."
24 | - "Pacific coast, carmel by the sea ocean and waves."
25 | num_inference_steps: 25
26 | guidance_scale: 8.
27 |
28 | trainable_modules:
29 | - "."
30 |
31 | unet_checkpoint_path: ""
32 |
33 | learning_rate: 1.e-5
34 | train_batch_size: 50
35 |
36 | max_train_epoch: -1
37 | max_train_steps: 100
38 | checkpointing_epochs: -1
39 | checkpointing_steps: 60
40 |
41 | validation_steps: 5000
42 | validation_steps_tuple: [2, 50]
43 |
44 | global_seed: 42
45 | mixed_precision_training: true
46 | enable_xformers_memory_efficient_attention: True
47 |
48 | is_debug: False
49 |
--------------------------------------------------------------------------------
/configs/training/training.yaml:
--------------------------------------------------------------------------------
1 | image_finetune: false
2 |
3 | output_dir: "outputs"
4 | pretrained_model_path: "models/stable-diffusion-v1-5"
5 |
6 | unet_additional_kwargs:
7 | use_motion_module : true
8 | motion_module_resolutions : [ 1,2,4,8 ]
9 | unet_use_cross_frame_attention : false
10 | unet_use_temporal_attention : false
11 |
12 | motion_module_type: Vanilla
13 | motion_module_kwargs:
14 | num_attention_heads : 8
15 | num_transformer_block : 1
16 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
17 | temporal_position_encoding : true
18 | # temporal_position_encoding_max_len : 24
19 | temporal_position_encoding_max_len : 32
20 | temporal_attention_dim_div : 1
21 | zero_initialize : true
22 |
23 | noise_scheduler_kwargs:
24 | num_train_timesteps: 1000
25 | beta_start: 0.00085
26 | beta_end: 0.012
27 | beta_schedule: "linear"
28 | steps_offset: 1
29 | clip_sample: false
30 |
31 | train_data:
32 | csv_path: "./results_2M_train_new.csv"
33 | video_folder: "./datasets_train"
34 | sample_size: 256
35 | sample_stride: 4
36 | sample_n_frames: 16
37 |
38 | validation_data:
39 | prompts:
40 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
41 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
42 | - "Robot dancing in times square."
43 | - "Pacific coast, carmel by the sea ocean and waves."
44 | num_inference_steps: 25
45 | guidance_scale: 8.
46 |
47 | trainable_modules:
48 | - "motion_modules."
49 |
50 | # unet_checkpoint_path: ""
51 | unet_checkpoint_path: "models/Motion_Module/mm_sd_v15_v2.ckpt"
52 |
53 |
54 | learning_rate: 1.e-4
55 | train_batch_size: 12
56 |
57 | max_train_epoch: 30
58 | max_train_steps: -1
59 | checkpointing_epochs: -1
60 | # 存储checkpoints的step数
61 | checkpointing_steps: 2000
62 |
63 | validation_steps: 5000
64 | validation_steps_tuple: [1,1000, 5000, 10000]
65 |
66 | global_seed: 42
67 | mixed_precision_training: true
68 | enable_xformers_memory_efficient_attention: True
69 |
70 | is_debug: False
71 |
--------------------------------------------------------------------------------
/download_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import requests
3 | import concurrent.futures
4 | import os
5 |
6 | def download_video(row):
7 | video_url = row['contentUrl']
8 | video_name = row['videoid']
9 | folder_path = './datasets_train' # 请将此路径替换为你的文件夹的路径
10 | video_file = os.path.join(folder_path, f'{video_name}.mp4')
11 |
12 | # 如果文件已经存在,就跳过下载
13 | if os.path.isfile(video_file):
14 | return
15 |
16 | response = requests.get(video_url)
17 |
18 | if response.status_code == 200:
19 | with open(video_file, 'wb') as f:
20 | f.write(response.content)
21 | else:
22 | print(f"Failed to download video {video_name} from url {video_url}")
23 |
24 | df = pd.read_csv('results_2M_train.csv')
25 |
26 | rows = df.to_dict('records')
27 |
28 | with concurrent.futures.ThreadPoolExecutor() as executor:
29 | for row in rows:
30 | executor.submit(download_video, row)
--------------------------------------------------------------------------------
/imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/.DS_Store
--------------------------------------------------------------------------------
/imgs/0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/0.gif
--------------------------------------------------------------------------------
/imgs/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/1.gif
--------------------------------------------------------------------------------
/imgs/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/2.gif
--------------------------------------------------------------------------------
/imgs/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/3.gif
--------------------------------------------------------------------------------
/imgs/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/4.gif
--------------------------------------------------------------------------------
/imgs/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/imgs/5.gif
--------------------------------------------------------------------------------
/init_images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/.DS_Store
--------------------------------------------------------------------------------
/init_images/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/0.jpg
--------------------------------------------------------------------------------
/init_images/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/1.jpg
--------------------------------------------------------------------------------
/init_images/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/2.jpg
--------------------------------------------------------------------------------
/init_images/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/3.jpg
--------------------------------------------------------------------------------
/init_images/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/4.jpg
--------------------------------------------------------------------------------
/init_images/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/init_images/5.jpg
--------------------------------------------------------------------------------
/newanimate.yaml:
--------------------------------------------------------------------------------
1 | name: newanimate
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - bzip2=1.0.8=h7b6447c_0
8 | - ca-certificates=2023.08.22=h06a4308_0
9 | - ld_impl_linux-64=2.38=h1181459_1
10 | - libffi=3.4.4=h6a678d5_0
11 | - libgcc-ng=11.2.0=h1234567_1
12 | - libgomp=11.2.0=h1234567_1
13 | - libstdcxx-ng=11.2.0=h1234567_1
14 | - libuuid=1.41.5=h5eee18b_0
15 | - ncurses=6.4=h6a678d5_0
16 | - openssl=3.0.10=h7f8727e_2
17 | - pip=23.2.1=py310h06a4308_0
18 | - python=3.10.13=h955ad1f_0
19 | - readline=8.2=h5eee18b_0
20 | - setuptools=68.0.0=py310h06a4308_0
21 | - sqlite=3.41.2=h5eee18b_0
22 | - tk=8.6.12=h1ccaba5_0
23 | - wheel=0.38.4=py310h06a4308_0
24 | - xz=5.4.2=h5eee18b_0
25 | - zlib=1.2.13=h5eee18b_0
26 | - pip:
27 | - accelerate==0.23.0
28 | - aiofiles==23.2.1
29 | - altair==5.1.1
30 | - annotated-types==0.5.0
31 | - antlr4-python3-runtime==4.9.3
32 | - anyio==3.7.1
33 | - appdirs==1.4.4
34 | - attrs==23.1.0
35 | - beautifulsoup4==4.12.2
36 | - certifi==2023.7.22
37 | - charset-normalizer==3.2.0
38 | - click==8.1.7
39 | - contourpy==1.1.1
40 | - cycler==0.11.0
41 | - decord==0.6.0
42 | - diffusers==0.21.4
43 | - docker-pycreds==0.4.0
44 | - einops==0.6.1
45 | - exceptiongroup==1.1.3
46 | - fastapi==0.103.1
47 | - ffmpy==0.3.1
48 | - filelock==3.12.4
49 | - fonttools==4.42.1
50 | - fsspec==2023.9.1
51 | - gdown==4.7.1
52 | - gitdb==4.0.10
53 | - gitpython==3.1.37
54 | - gradio==3.44.3
55 | - gradio-client==0.5.0
56 | - h11==0.14.0
57 | - httpcore==0.18.0
58 | - httpx==0.25.0
59 | - huggingface-hub==0.17.2
60 | - idna==3.4
61 | - imageio==2.27.0
62 | - importlib-metadata==6.8.0
63 | - importlib-resources==6.0.1
64 | - jinja2==3.1.2
65 | - jsonschema==4.19.0
66 | - jsonschema-specifications==2023.7.1
67 | - kiwisolver==1.4.5
68 | - markupsafe==2.1.3
69 | - matplotlib==3.8.0
70 | - mypy-extensions==1.0.0
71 | - numpy==1.26.0
72 | - nvidia-cublas-cu11==11.10.3.66
73 | - nvidia-cuda-nvrtc-cu11==11.7.99
74 | - nvidia-cuda-runtime-cu11==11.7.99
75 | - nvidia-cudnn-cu11==8.5.0.96
76 | - omegaconf==2.3.0
77 | - opencv-python==4.8.0.76
78 | - orjson==3.9.7
79 | - packaging==23.1
80 | - pandas==2.1.0
81 | - pathtools==0.1.2
82 | - pillow==10.0.1
83 | - protobuf==4.24.4
84 | - psutil==5.9.5
85 | - pydantic==2.3.0
86 | - pydantic-core==2.6.3
87 | - pydub==0.25.1
88 | - pyparsing==3.1.1
89 | - pyre-extensions==0.0.23
90 | - pysocks==1.7.1
91 | - python-dateutil==2.8.2
92 | - python-multipart==0.0.6
93 | - pytz==2023.3.post1
94 | - pyyaml==6.0.1
95 | - referencing==0.30.2
96 | - regex==2023.8.8
97 | - requests==2.31.0
98 | - rpds-py==0.10.3
99 | - safetensors==0.3.3
100 | - semantic-version==2.10.0
101 | - sentry-sdk==1.32.0
102 | - setproctitle==1.3.3
103 | - six==1.16.0
104 | - smmap==5.0.1
105 | - sniffio==1.3.0
106 | - soupsieve==2.5
107 | - starlette==0.27.0
108 | - tokenizers==0.13.3
109 | - toolz==0.12.0
110 | - torch==1.13.1
111 | - torchaudio==0.13.1
112 | - torchvision==0.14.1
113 | - tqdm==4.66.1
114 | - transformers==4.33.2
115 | - triton==2.1.0
116 | - typing-extensions==4.8.0
117 | - typing-inspect==0.9.0
118 | - tzdata==2023.3
119 | - urllib3==2.0.4
120 | - uvicorn==0.23.2
121 | - wandb==0.15.12
122 | - websockets==11.0.3
123 | - xformers==0.0.16
124 | - zipp==3.17.0
125 | prefix: /root/anaconda3/envs/newanimate
126 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.23.0
2 | aiofiles==23.2.1
3 | altair==5.1.1
4 | annotated-types==0.5.0
5 | antlr4-python3-runtime==4.9.3
6 | anyio==3.7.1
7 | appdirs==1.4.4
8 | attrs==23.1.0
9 | beautifulsoup4==4.12.2
10 | certifi==2023.7.22
11 | charset-normalizer==3.2.0
12 | click==8.1.7
13 | contourpy==1.1.1
14 | cycler==0.11.0
15 | decord==0.6.0
16 | diffusers==0.21.4
17 | docker-pycreds==0.4.0
18 | einops==0.6.1
19 | exceptiongroup==1.1.3
20 | fastapi==0.103.1
21 | ffmpy==0.3.1
22 | filelock==3.12.4
23 | fonttools==4.42.1
24 | fsspec==2023.9.1
25 | gdown==4.7.1
26 | gitdb==4.0.10
27 | GitPython==3.1.37
28 | gradio==3.44.3
29 | gradio_client==0.5.0
30 | h11==0.14.0
31 | httpcore==0.18.0
32 | httpx==0.25.0
33 | huggingface-hub==0.17.2
34 | idna==3.4
35 | imageio==2.27.0
36 | importlib-metadata==6.8.0
37 | importlib-resources==6.0.1
38 | Jinja2==3.1.2
39 | jsonschema==4.19.0
40 | jsonschema-specifications==2023.7.1
41 | kiwisolver==1.4.5
42 | MarkupSafe==2.1.3
43 | matplotlib==3.8.0
44 | mypy-extensions==1.0.0
45 | numpy==1.26.0
46 | nvidia-cublas-cu11==11.10.3.66
47 | nvidia-cuda-nvrtc-cu11==11.7.99
48 | nvidia-cuda-runtime-cu11==11.7.99
49 | nvidia-cudnn-cu11==8.5.0.96
50 | omegaconf==2.3.0
51 | opencv-python==4.8.0.76
52 | orjson==3.9.7
53 | packaging==23.1
54 | pandas==2.1.0
55 | pathtools==0.1.2
56 | Pillow==10.0.1
57 | protobuf==4.24.4
58 | psutil==5.9.5
59 | pydantic==2.3.0
60 | pydantic_core==2.6.3
61 | pydub==0.25.1
62 | pyparsing==3.1.1
63 | pyre-extensions==0.0.23
64 | PySocks==1.7.1
65 | python-dateutil==2.8.2
66 | python-multipart==0.0.6
67 | pytz==2023.3.post1
68 | PyYAML==6.0.1
69 | referencing==0.30.2
70 | regex==2023.8.8
71 | requests==2.31.0
72 | rpds-py==0.10.3
73 | safetensors==0.3.3
74 | semantic-version==2.10.0
75 | sentry-sdk==1.32.0
76 | setproctitle==1.3.3
77 | six==1.16.0
78 | smmap==5.0.1
79 | sniffio==1.3.0
80 | soupsieve==2.5
81 | starlette==0.27.0
82 | tokenizers==0.13.3
83 | toolz==0.12.0
84 | torch==1.13.1
85 | torchaudio==0.13.1
86 | torchvision==0.14.1
87 | tqdm==4.66.1
88 | transformers==4.33.2
89 | triton==2.1.0
90 | typing-inspect==0.9.0
91 | typing_extensions==4.8.0
92 | tzdata==2023.3
93 | urllib3==2.0.4
94 | uvicorn==0.23.2
95 | wandb==0.15.12
96 | websockets==11.0.3
97 | xformers==0.0.16
98 | zipp==3.17.0
99 |
--------------------------------------------------------------------------------
/scripts/__pycache__/animate.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crystallee-ai/animatediff-controlnet/332b9310b2b395038c565524648270d977f60dc0/scripts/__pycache__/animate.cpython-310.pyc
--------------------------------------------------------------------------------
/scripts/animate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import inspect
4 | import os
5 | from omegaconf import OmegaConf
6 | from PIL import Image
7 | import numpy as np
8 |
9 | import torch
10 | from torchvision import models
11 | from torch.nn import functional as F
12 | import torchvision.transforms as transforms
13 |
14 | import diffusers
15 | from diffusers import AutoencoderKL, DDIMScheduler
16 | import pickle
17 |
18 | from tqdm.auto import tqdm
19 | from transformers import CLIPTextModel, CLIPTokenizer
20 |
21 | import sys
22 | sys.path.append("/root/lh/AnimateDiff-main/")
23 |
24 | from animatediff.models.unet import UNet3DConditionModel
25 | from animatediff.pipelines.pipeline_animation import AnimationPipeline
26 | from animatediff.utils.util import save_videos_grid
27 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
28 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
29 | from diffusers.utils.import_utils import is_xformers_available
30 |
31 | from einops import rearrange, repeat
32 |
33 | import csv, pdb, glob
34 | from safetensors import safe_open
35 | import math
36 | from pathlib import Path
37 |
38 |
39 | def main(args):
40 | *_, func_args = inspect.getargvalues(inspect.currentframe())
41 | func_args = dict(func_args)
42 |
43 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
44 | savedir = f"samples/{Path(args.config).stem}-{time_str}"
45 | os.makedirs(savedir)
46 |
47 | config = OmegaConf.load(args.config)
48 | samples = []
49 |
50 | sample_idx = 0
51 | for model_idx, (config_key, model_config) in enumerate(list(config.items())):
52 |
53 | motion_modules = model_config.motion_module
54 | motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
55 | for motion_module in motion_modules:
56 | inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
57 |
58 | ### >>> create validation pipeline >>> ###
59 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
60 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
61 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
62 | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
63 |
64 | # if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
65 | # else: assert False
66 |
67 | pipeline = AnimationPipeline(
68 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
69 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
70 | ).to("cuda")
71 |
72 | # 1. unet ckpt
73 | # 1.1 motion module
74 | motion_module_state_dict = torch.load(motion_module, map_location="cpu")
75 | if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
76 | missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
77 | assert len(unexpected) == 0
78 |
79 | # 1.2 T2I
80 | if model_config.path != "":
81 | if model_config.path.endswith(".ckpt"):
82 | state_dict = torch.load(model_config.path)
83 | pipeline.unet.load_state_dict(state_dict)
84 |
85 | elif model_config.path.endswith(".safetensors"):
86 | state_dict = {}
87 | with safe_open(model_config.path, framework="pt", device="cpu") as f:
88 | for key in f.keys():
89 | state_dict[key] = f.get_tensor(key)
90 |
91 | is_lora = all("lora" in k for k in state_dict.keys())
92 | if not is_lora:
93 | base_state_dict = state_dict
94 | else:
95 | base_state_dict = {}
96 | with safe_open(model_config.base, framework="pt", device="cpu") as f:
97 | for key in f.keys():
98 | base_state_dict[key] = f.get_tensor(key)
99 |
100 | # vae
101 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
102 | pipeline.vae.load_state_dict(converted_vae_checkpoint)
103 | # unet
104 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
105 | pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
106 | # text_model
107 | pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
108 |
109 | # import pdb
110 | # pdb.set_trace()
111 | if is_lora:
112 | pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
113 |
114 | pipeline.to("cuda")
115 | ### <<< create validation pipeline <<< ###
116 |
117 | prompts = model_config.prompt
118 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
119 |
120 | random_seeds = model_config.get("seed", [-1])
121 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
122 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
123 |
124 | config[config_key].random_seed = []
125 |
126 | #------------------------------------------------
127 | pixel_transforms = transforms.Compose([
128 | transforms.RandomHorizontalFlip(),
129 | transforms.Resize(512),
130 | transforms.CenterCrop(512),
131 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132 | ])
133 | pixel_values = Image.open("/root/lh/AnimateDiff-main/sample.jpg")
134 | pixel_values = np.array(pixel_values)
135 | pixel_values = torch.from_numpy(pixel_values).permute(2,0,1).unsqueeze(0)
136 | pixel_values = pixel_values / 255.
137 | pixel_values = pixel_transforms(pixel_values).cuda()
138 | # latents = pipeline.vae.encode(pixel_values).latent_dist
139 | # latents = latents.sample()
140 |
141 | # latents = latents * 0.18215
142 | # latents = latents.unsqueeze(2).repeat(1,1,16,1,1)
143 |
144 |
145 | # --------------------------------------------------
146 |
147 | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
148 |
149 | # manually set random seed for reproduction
150 | if random_seed != -1: torch.manual_seed(random_seed)
151 | else: torch.seed()
152 | config[config_key].random_seed.append(torch.initial_seed())
153 |
154 | print(f"current seed: {torch.initial_seed()}")
155 | print(f"sampling {prompt} ...")
156 | sample = pipeline(
157 | prompt,
158 | negative_prompt = n_prompt,
159 | num_inference_steps = model_config.steps,
160 | guidance_scale = model_config.guidance_scale,
161 | width = args.W,
162 | height = args.H,
163 | video_length = args.L,
164 | # latents = pixel_values
165 | ).videos
166 | samples.append(sample)
167 |
168 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
169 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
170 | print(f"save to {savedir}/sample/{prompt}.gif")
171 |
172 | sample_idx += 1
173 |
174 | samples = torch.concat(samples)
175 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
176 |
177 | OmegaConf.save(config, f"{savedir}/config.yaml")
178 |
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument("--pretrained_model_path", type=str, default="/root/lh/stable-diffusion-v1-5",)
183 | parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v2.yaml")
184 | parser.add_argument("--config", type=str, default="configs/prompts/v2/5-RealisticVision1.yaml")
185 |
186 | parser.add_argument("--L", type=int, default=16 )
187 | parser.add_argument("--W", type=int, default=512)
188 | parser.add_argument("--H", type=int, default=512)
189 |
190 | args = parser.parse_args()
191 | main(args)
192 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import imageio
4 | import numpy as np
5 | import wandb
6 | import random
7 | import logging
8 | import inspect
9 | import argparse
10 | import datetime
11 | import subprocess
12 | import multiprocessing as mp
13 |
14 |
15 | from pathlib import Path
16 | from tqdm.auto import tqdm
17 | from einops import rearrange
18 | from omegaconf import OmegaConf
19 | from safetensors import safe_open
20 | from typing import Dict, Optional, Tuple
21 |
22 | import torch
23 | import torchvision
24 | import torch.nn.functional as F
25 | import torch.distributed as dist
26 | from torch.optim.swa_utils import AveragedModel
27 | from torch.utils.data.distributed import DistributedSampler
28 | from torch.nn.parallel import DistributedDataParallel as DDP
29 |
30 | import diffusers
31 | from diffusers import AutoencoderKL, DDIMScheduler
32 | from diffusers.models import UNet2DConditionModel
33 | from diffusers.pipelines import StableDiffusionPipeline
34 | from diffusers.optimization import get_scheduler
35 | from diffusers.utils import check_min_version
36 | from diffusers.utils.import_utils import is_xformers_available
37 | from diffusers.image_processor import VaeImageProcessor
38 |
39 | import transformers
40 | from transformers import CLIPTextModel, CLIPTokenizer
41 |
42 | from animatediff.data.dataset import WebVid10M
43 | from animatediff.models.unet import UNet3DConditionModel
44 | from animatediff.models.controlnet import ControlNetModel
45 | from animatediff.pipelines.pipeline_animation import AnimationPipeline
46 | from animatediff.utils.util import save_videos_grid, zero_rank_print
47 |
48 |
49 | def prepare_image(
50 | image,
51 | width,
52 | height,
53 | batch_size,
54 | num_images_per_prompt,
55 | device,
56 | dtype,
57 | do_classifier_free_guidance=False,
58 | guess_mode=False,
59 | ):
60 | control_image_processor = VaeImageProcessor()
61 | image = control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
62 | image_batch_size = image.shape[0]
63 |
64 | if image_batch_size == 1:
65 | repeat_by = batch_size
66 | else:
67 | # image batch size is the same as prompt batch size
68 | repeat_by = num_images_per_prompt
69 |
70 | image = image.repeat_interleave(repeat_by, dim=0)
71 |
72 | image = image.to(device=device, dtype=dtype)
73 |
74 | if do_classifier_free_guidance and not guess_mode:
75 | image = torch.cat([image] * 2)
76 |
77 | return image
78 |
79 |
80 |
81 |
82 |
83 | def main(
84 | image_finetune: bool,
85 |
86 | name: str,
87 | use_wandb: bool,
88 | launcher: str,
89 |
90 | output_dir: str,
91 | pretrained_model_path: str,
92 |
93 | train_data: Dict,
94 | validation_data: Dict,
95 | cfg_random_null_text: bool = True,
96 | cfg_random_null_text_ratio: float = 0.1,
97 |
98 | unet_checkpoint_path: str = "",
99 | unet_additional_kwargs: Dict = {},
100 | ema_decay: float = 0.9999,
101 | noise_scheduler_kwargs = None,
102 |
103 | max_train_epoch: int = -1,
104 | max_train_steps: int = 100,
105 | validation_steps: int = 100,
106 | validation_steps_tuple: Tuple = (-1,),
107 |
108 | learning_rate: float = 3e-5,
109 | scale_lr: bool = False,
110 | lr_warmup_steps: int = 0,
111 | lr_scheduler: str = "constant",
112 |
113 | trainable_modules: Tuple[str] = (None, ),
114 | num_workers: int = 32,
115 | train_batch_size: int = 1,
116 | adam_beta1: float = 0.9,
117 | adam_beta2: float = 0.999,
118 | adam_weight_decay: float = 1e-2,
119 | adam_epsilon: float = 1e-08,
120 | max_grad_norm: float = 1.0,
121 | gradient_accumulation_steps: int = 1,
122 | gradient_checkpointing: bool = False,
123 | checkpointing_epochs: int = 5,
124 | checkpointing_steps: int = -1,
125 |
126 | mixed_precision_training: bool = True,
127 | enable_xformers_memory_efficient_attention: bool = True,
128 |
129 | global_seed: int = 42,
130 | is_debug: bool = False,
131 | ):
132 | check_min_version("0.10.0.dev0")
133 |
134 | # Initialize distributed training
135 | # local_rank = init_dist(launcher=launcher)
136 | # local_rank = 1
137 | # global_rank = dist.get_rank()
138 | # num_processes = dist.get_world_size()
139 | # is_main_process = global_rank == 0
140 | is_main_process = True
141 |
142 | # seed = global_seed + global_rank
143 | seed = 42
144 | torch.manual_seed(seed)
145 |
146 | # Logging folder
147 | folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
148 | output_dir = os.path.join(output_dir, folder_name)
149 | if is_debug and os.path.exists(output_dir):
150 | os.system(f"rm -rf {output_dir}")
151 |
152 | *_, config = inspect.getargvalues(inspect.currentframe())
153 |
154 | # Make one log on every process with the configuration for debugging.
155 | logging.basicConfig(
156 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
157 | datefmt="%m/%d/%Y %H:%M:%S",
158 | level=logging.INFO,
159 | )
160 |
161 | # 需要设置wandb账号
162 | if is_main_process and (not is_debug) and use_wandb:
163 | run = wandb.init(project="animatediff_pics_controlnetonly", name=folder_name, config=config)
164 |
165 | # Handle the output folder creation
166 | if is_main_process:
167 | os.makedirs(output_dir, exist_ok=True)
168 | os.makedirs(f"{output_dir}/samples", exist_ok=True)
169 | os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
170 | os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
171 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
172 |
173 | #-----------------------------------------------------------------------------------------------
174 | # Load scheduler, tokenizer and models.
175 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
176 |
177 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
178 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
179 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
180 | unet2d = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
181 | # controlnet = ControlNetModel.from_unet(unet2d)
182 | controlnet = ControlNetModel()
183 | # unet = UNet3DConditionModel()
184 | if not image_finetune:
185 | unet = UNet3DConditionModel.from_pretrained_2d(
186 | pretrained_model_path, subfolder="unet",
187 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
188 | )
189 | else:
190 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
191 |
192 | # Load pretrained unet weights
193 | # if unet_checkpoint_path != "":
194 | # zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
195 | # unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
196 | # if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
197 | # state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
198 |
199 | # m, u = unet.load_state_dict(state_dict, strict=False)
200 | # zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
201 | # assert len(u) == 0
202 | motion_module_state_dict = torch.load("models/Motion_Module/mm_sd_v15_v2.ckpt", map_location="cpu")
203 | # # print(motion_module_state_dict)
204 | # # if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
205 | missing, unexpected = unet.load_state_dict(motion_module_state_dict, strict=False)
206 | # print(f"### missing keys: {len(missing)}; \n### unexpected keys: {len(unexpected)};")
207 | # print(f"### missing keys:\n{missing}\n### unexpected keys:\n{unexpected}\n")
208 | assert len(unexpected) == 0
209 |
210 | controlnet_state_dict = torch.load("/root/lh/AnimateDiffcontrolnet-main/outputs/1/checkpoints/controlnet_checkpoint-epoch-30.ckpt", map_location="cpu")
211 | missing, unexpected = controlnet.load_state_dict(controlnet_state_dict["state_dict"], strict=False)
212 | assert len(unexpected) == 0
213 | #-----------------------------------------------------------------------------------------------
214 |
215 | # Freeze vae and text_encoder
216 | vae.requires_grad_(False)
217 | text_encoder.requires_grad_(False)
218 | # controlnet.requires_grad_(False)
219 | # for name, param in controlnet.named_parameters():
220 | # print(name, ": ", param.requires_grad )
221 | # print("---------------------------------------")
222 |
223 |
224 | # 把这里打上断点 看一下unet的结构
225 | # Set unet trainable parameters
226 | # print(unet)
227 | unet.requires_grad_(False)
228 | # unet.requires_grad_(True)
229 | # for name, param in unet.named_parameters():
230 | # for trainable_module_name in trainable_modules:
231 | # if trainable_module_name in name:
232 | # param.requires_grad = True
233 | # break
234 | # trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
235 | trainable_params = []
236 | for name, param in controlnet.named_parameters():
237 | if (param.requires_grad):
238 | trainable_params.append(param)
239 | # print(name, ": ", param.requires_grad )
240 | # trainable_params.append(list(filter(lambda p: p.requires_grad, controlnet.parameters())))
241 | optimizer = torch.optim.AdamW(
242 | trainable_params,
243 | lr=learning_rate,
244 | betas=(adam_beta1, adam_beta2),
245 | weight_decay=adam_weight_decay,
246 | eps=adam_epsilon,
247 | )
248 | #-----------------------------------------------------------------------------------------------
249 |
250 | if is_main_process:
251 | # zero_rank_print(f"trainable params number: {len(trainable_params)}")
252 | # zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
253 | print(f"trainable params number: {len(trainable_params)}")
254 | print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
255 | #-----------------------------------------------------------------------------------------------
256 |
257 | # Enable xformers
258 | if enable_xformers_memory_efficient_attention:
259 | if is_xformers_available():
260 | unet.enable_xformers_memory_efficient_attention()
261 | else:
262 | raise ValueError("xformers is not available. Make sure it is installed correctly")
263 |
264 | # Enable gradient checkpointing
265 | if gradient_checkpointing:
266 | unet.enable_gradient_checkpointing()
267 |
268 | # Move models to GPU
269 | # vae.to(local_rank)
270 | # text_encoder.to(local_rank)
271 | vae.to("cuda")
272 | controlnet.to("cuda")
273 | text_encoder.to("cuda")
274 | unet.to("cuda")
275 | #-----------------------------------------------------------------------------------------------
276 |
277 | # Get the training dataset
278 | train_dataset = WebVid10M(**train_data, is_image=image_finetune)
279 | # distributed_sampler = DistributedSampler(
280 | # train_dataset,
281 | # num_replicas=1,
282 | # rank=0,
283 | # shuffle=True,
284 | # seed=global_seed,
285 | # )
286 |
287 | # DataLoaders creation:
288 | train_dataloader = torch.utils.data.DataLoader(
289 | train_dataset,
290 | batch_size=train_batch_size,
291 | shuffle=False,
292 | # sampler=distributed_sampler,
293 | num_workers=0,
294 | pin_memory=True,
295 | drop_last=True,
296 | )
297 | #-----------------------------------------------------------------------------------------------
298 |
299 | # Get the training iteration
300 | if max_train_steps == -1:
301 | assert max_train_epoch != -1
302 | max_train_steps = max_train_epoch * len(train_dataloader)
303 |
304 | if checkpointing_steps == -1:
305 | assert checkpointing_epochs != -1
306 | checkpointing_steps = checkpointing_epochs * len(train_dataloader)
307 |
308 | if scale_lr:
309 | learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size )
310 |
311 | # Scheduler
312 | lr_scheduler = get_scheduler(
313 | lr_scheduler,
314 | optimizer=optimizer,
315 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
316 | num_training_steps=max_train_steps * gradient_accumulation_steps,
317 | )
318 | #-----------------------------------------------------------------------------------------------
319 |
320 | # Validation pipeline
321 | if not image_finetune:
322 | validation_pipeline = AnimationPipeline(
323 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, controlnet=controlnet,
324 | ).to("cuda")
325 | else:
326 | validation_pipeline = StableDiffusionPipeline.from_pretrained(
327 | pretrained_model_path,
328 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
329 | )
330 | validation_pipeline.enable_vae_slicing()
331 | #-----------------------------------------------------------------------------------------------
332 |
333 | # DDP warpper
334 | # unet = DDP(unet, device_ids=["cuda:0"], output_device="cuda:0")
335 |
336 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
337 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
338 | # Afterwards we recalculate our number of training epochs
339 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
340 |
341 | # Train!
342 | total_batch_size = train_batch_size * gradient_accumulation_steps
343 |
344 | if is_main_process:
345 | logging.info("***** Running training *****")
346 | logging.info(f" Num examples = {len(train_dataset)}")
347 | logging.info(f" Num Epochs = {num_train_epochs}")
348 | logging.info(f" Instantaneous batch size per device = {train_batch_size}")
349 | logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
350 | logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
351 | logging.info(f" Total optimization steps = {max_train_steps}")
352 | global_step = 0
353 | first_epoch = 0
354 |
355 | # Only show the progress bar once on each machine.
356 | progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
357 | progress_bar.set_description("Steps")
358 |
359 | # Support mixed-precision training
360 | scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
361 | # mp.set_start_method('spawn')
362 | for epoch in range(first_epoch, num_train_epochs):
363 | # train_dataloader.sampler.set_epoch(epoch)
364 | unet.train()
365 | # mp.set_start_method('spawn')
366 | for step, batch in enumerate(train_dataloader):
367 | if cfg_random_null_text:
368 | batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
369 |
370 | # Data batch sanity check
371 | if epoch == first_epoch and step == 0:
372 | pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
373 | if not image_finetune:
374 | pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
375 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
376 | pixel_value = pixel_value[None, ...]
377 | # save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_seed}-{idx}'}.gif", rescale=True)
378 | else:
379 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
380 | pixel_value = pixel_value / 2. + 0.5
381 | torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_seed}-{idx}'}.png")
382 |
383 | ### >>>> Training >>>> ###
384 |
385 | # Convert videos to latent space
386 | orig_img = batch['image'].squeeze(1)
387 | orig_img = orig_img / 2. + 0.5
388 | pixel_values = batch["pixel_values"].to("cuda")
389 | video_length = pixel_values.shape[1]
390 | with torch.no_grad():
391 | if not image_finetune:
392 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
393 | latents = vae.encode(pixel_values).latent_dist
394 | latents = latents.sample()
395 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
396 | else:
397 | latents = vae.encode(pixel_values).latent_dist
398 | latents = latents.sample()
399 |
400 | latents = latents * 0.18215
401 |
402 |
403 | #--------------------------------------
404 | image = prepare_image(
405 | image=orig_img,
406 | width=orig_img.shape[-1],
407 | height=orig_img.shape[-2],
408 | batch_size=orig_img.shape[0],
409 | num_images_per_prompt=1,
410 | device="cuda",
411 | dtype=controlnet.dtype,
412 | )
413 |
414 | #--------------------------------------
415 |
416 |
417 | # Sample noise that we'll add to the latents 如果要加原图信号的话 就是在这里加
418 | noise = torch.randn_like(latents) # latents shape为 [4, 4, 16, 32, 32]
419 | bsz = latents.shape[0]
420 |
421 | # Sample a random timestep for each video
422 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) # shape [4]
423 | timesteps = timesteps.long()
424 |
425 | # Add noise to the latents according to the noise magnitude at each timestep
426 | # (this is the forward diffusion process)
427 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # shape[4, 4, 16, 32, 32]
428 |
429 | # Get the text embedding for conditioning
430 | with torch.no_grad():
431 | prompt_ids = tokenizer(
432 | batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
433 | ).input_ids.to(latents.device) # shape [4, 77]
434 | encoder_hidden_states = text_encoder(prompt_ids)[0] # shape [4, 77, 768]
435 |
436 | # Get the target for loss depending on the prediction type
437 | if noise_scheduler.config.prediction_type == "epsilon":
438 | target = noise
439 | elif noise_scheduler.config.prediction_type == "v_prediction":
440 | raise NotImplementedError
441 | else:
442 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
443 |
444 |
445 | #------------------------------------------------
446 | down_block_res_samples, mid_block_res_sample = controlnet(
447 | sample=noisy_latents[:,:,0,:,:],
448 | timestep=timesteps,
449 | encoder_hidden_states=encoder_hidden_states, # [4,77,768]
450 | controlnet_cond=image,
451 | return_dict=False,
452 | )
453 |
454 | # down_block_additional_residuals
455 | # mid_block_additional_residual
456 | #------------------------------------------------
457 |
458 | # Predict the noise residual and compute loss
459 | # Mixed-precision training
460 | with torch.cuda.amp.autocast(enabled=mixed_precision_training):
461 | # noisy_latents shape [4, 4, 16, 32, 32]
462 | # encoder_hidden_states [4, 77, 768]
463 | model_pred = unet(sample=noisy_latents,
464 | timestep=timesteps,
465 | encoder_hidden_states=encoder_hidden_states,
466 | down_block_additional_residuals=down_block_res_samples,
467 | mid_block_additional_residual=mid_block_res_sample,
468 | ).sample
469 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
470 |
471 | optimizer.zero_grad()
472 |
473 | # Backpropagate
474 | if mixed_precision_training:
475 | scaler.scale(loss).backward()
476 | """ >>> gradient clipping >>> """
477 | scaler.unscale_(optimizer)
478 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
479 | """ <<< gradient clipping <<< """
480 | scaler.step(optimizer)
481 | scaler.update()
482 | else:
483 | loss.backward()
484 | """ >>> gradient clipping >>> """
485 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
486 | """ <<< gradient clipping <<< """
487 | optimizer.step()
488 |
489 | lr_scheduler.step()
490 | progress_bar.update(1)
491 | global_step += 1
492 |
493 | ### <<<< Training <<<< ###
494 |
495 | # Wandb logging
496 | if is_main_process and (not is_debug) and use_wandb:
497 | wandb.log({"train_loss": loss.item()}, step=global_step)
498 |
499 | # Save checkpoint
500 | if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
501 | save_path = os.path.join(output_dir, f"checkpoints")
502 | controlnet_state_dict = {
503 | "epoch": epoch,
504 | "global_step": global_step,
505 | "state_dict": controlnet.state_dict(),
506 | }
507 | if step == len(train_dataloader) - 1:
508 | torch.save(controlnet_state_dict, os.path.join(save_path, f"controlnet_checkpoint-epoch-{epoch+1}.ckpt"))
509 | else:
510 | torch.save(controlnet_state_dict, os.path.join(save_path, f"controlnet_checkpoint.ckpt"))
511 | logging.info(f"Saved state to {save_path} (global_step: {global_step})")
512 |
513 | # Periodically validation
514 | if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
515 | samples = []
516 |
517 | generator = torch.Generator(device=latents.device)
518 | generator.manual_seed(global_seed)
519 |
520 | height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
521 | width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
522 |
523 | # prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
524 | prompts = batch['text']
525 |
526 | init_images = batch['image'].squeeze(1) #[b, 1, c, h, w]
527 | controlnet_images = init_images / 2. + 0.5
528 | init_images = init_images.permute(0,2,3,1) # [b, 1, h, w, c]
529 | init_images = np.array(init_images.cpu())
530 | for idx, prompt in enumerate(prompts):
531 | if not image_finetune:
532 | controlnet_image = controlnet_images[idx, :, :, :]
533 | sample = validation_pipeline(
534 | prompt,
535 | generator = generator,
536 | video_length = train_data.sample_n_frames,
537 | height = height,
538 | width = width,
539 | controlnet_image=controlnet_image,
540 | **validation_data,
541 | ).videos
542 | init_image = init_images[idx, :, :, :]
543 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
544 | imageio.imsave(f"{output_dir}/samples/sample-{global_step}/{idx}.jpg", init_image)
545 | samples.append(sample)
546 |
547 | else:
548 | sample = validation_pipeline(
549 | prompt,
550 | generator = generator,
551 | height = height,
552 | width = width,
553 | num_inference_steps = validation_data.get("num_inference_steps", 25),
554 | guidance_scale = validation_data.get("guidance_scale", 8.),
555 | ).images[0]
556 | sample = torchvision.transforms.functional.to_tensor(sample)
557 | samples.append(sample)
558 |
559 | if not image_finetune:
560 | samples = torch.concat(samples)
561 | save_path = f"{output_dir}/samples/sample-{global_step}.gif"
562 | save_videos_grid(samples, save_path)
563 |
564 | else:
565 | samples = torch.stack(samples)
566 | save_path = f"{output_dir}/samples/sample-{global_step}.png"
567 | torchvision.utils.save_image(samples, save_path, nrow=4)
568 |
569 | logging.info(f"Saved samples to {save_path}")
570 |
571 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
572 | progress_bar.set_postfix(**logs)
573 |
574 | if global_step >= max_train_steps:
575 | break
576 |
577 | # dist.destroy_process_group()
578 |
579 |
580 |
581 | if __name__ == "__main__":
582 | parser = argparse.ArgumentParser()
583 | parser.add_argument("--config", type=str, default="./configs/training/training.yaml")
584 | parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
585 | # parser.add_argument("--wandb", action="store_true")
586 | parser.add_argument("--wandb", type=bool, default=True)
587 |
588 | args = parser.parse_args()
589 |
590 | name = Path(args.config).stem
591 | config = OmegaConf.load(args.config)
592 |
593 | main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
594 |
--------------------------------------------------------------------------------