├── LICENSE.txt
├── README.md
├── __assets__
├── gifs
│ ├── dog.gif
│ ├── girl.gif
│ └── swan.gif
└── imgs
│ ├── comparison.png
│ ├── demo_figure.png
│ └── examples.png
├── animatelcm_sd15
├── README.md
├── animatelcm
│ ├── .DS_Store
│ ├── models
│ │ ├── adapter.py
│ │ ├── attention.py
│ │ ├── embeddings.py
│ │ ├── motion_module.py
│ │ ├── resnet.py
│ │ ├── unet.py
│ │ └── unet_blocks.py
│ ├── pipelines
│ │ └── pipeline_animation.py
│ ├── scheduler
│ │ └── lcm_scheduler.py
│ └── utils
│ │ ├── convert_from_ckpt.py
│ │ ├── convert_lora_safetensor_to_diffusers.py
│ │ ├── lcm_utils.py
│ │ └── util.py
├── app-i2v.py
├── app.py
├── batch_inference.py
├── configs
│ ├── batch_inference_example.yaml
│ ├── inference-i2v.yaml
│ └── inference-t2v.yaml
├── models
│ ├── LCM_LoRA
│ │ └── put_spatial_lora_here.txt
│ ├── Motion_Module
│ │ └── put_motion_module_here.txt
│ ├── Personalized
│ │ └── put_personalized_weights_here.txt
│ └── StableDiffusion
│ │ └── put_stable_diffusion_v15_here.txt
├── requirements.txt
└── test_imgs
│ ├── cloud.jpeg
│ ├── dog.jpg
│ ├── fire.jpg
│ ├── fox.jpg
│ ├── girl.png
│ ├── girl_flower.jpg
│ ├── lighter.jpg
│ └── snow_man_fire.jpg
├── animatelcm_svd
├── README.md
├── animate_lcm_utils.py
├── animatelcm_scheduler.py
├── app.py
├── batch_inference.py
├── dataset.py
├── enviroment.yaml
├── outputs_gradio
│ ├── 000000.mp4
│ ├── 000001.mp4
│ ├── 000002.mp4
│ ├── 000003.mp4
│ ├── 000004.mp4
│ └── 000005.mp4
├── pipeline.py
├── requirements.txt
├── safetensors
│ ├── AnimateLCM-SVD-xt-1.1.safetensors
│ └── AnimateLCM-SVD-xt.safetensors
├── test_imgs
│ ├── .DS_Store
│ ├── ai-generated-8411866_1280.jpg
│ ├── ai-generated-8463496_1280.jpg
│ ├── ai-generated-8476858_1280.png
│ ├── ai-generated-8479572_1280.jpg
│ ├── ai-generated-8481641_1280.jpg
│ ├── ai-generated-8496135_1280.jpg
│ ├── ai-generated-8496952_1280.jpg
│ ├── ai-generated-8498844_1280.jpg
│ ├── bird-7411270_1280.jpg
│ ├── bird-7586857_1280.jpg
│ ├── bird-8014191_1280.jpg
│ ├── couple-8019370_1280.jpg
│ ├── cupcakes-380178_1280.jpg
│ ├── dog-7330712_1280.jpg
│ ├── dog-7396912_1280.jpg
│ ├── girl-4898696_1280.jpg
│ ├── grey-capped-flycatcher-8071233_1280.jpg
│ ├── halloween-4585684_1280.jpg
│ ├── leaf-7260246_1280.jpg
│ ├── meerkat-7465819_1280.jpg
│ ├── mobile-phone-1875813_1280.jpg
│ ├── mother-8097324_1280.jpg
│ ├── plane-8145957_1280.jpg
│ ├── power-station-6579092_1280.jpg
│ ├── ship-7833921_1280.jpg
│ ├── sleep-7871915_1280.jpg
│ ├── squirrel-7985502_1280.jpg
│ ├── squirrel-8211238_1280.jpg
│ ├── training-8122941_1280.jpg
│ ├── violin-8405558_1280.jpg
│ ├── weight-8246973_1280.jpg
│ ├── woman-4549327_1280.jpg
│ ├── woman-4757707_1280.jpg
│ └── woman-5667299_1280.jpg
└── train_svd_lcm.py
└── metrics
├── UCF101_prompts.yaml
├── clip_score.py
├── fvd.py
├── i3d_pretrained_400.pt
└── pytorch_i3d.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Fu-Yun Wang.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## ⚡️AnimateLCM: Computation-Efficient Personalized Style Video Generation without Personalized Video Data
4 |
5 | [[Paper]](https://arxiv.org/abs/2402.00769) [[Project Page ✨]](https://animatelcm.github.io/) [[Demo in 🤗Hugging Face]](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) [[Pre-trained Models]](https://huggingface.co/wangfuyun/AnimateLCM) [[Civitai]](https://civitai.com/models/290375/animatelcm-fast-video-generation) 
6 |
7 |
8 | by *[Fu-Yun Wang](https://g-u-n.github.io), Zhaoyang Huang📮, Weikang Bian, Xiaoyu Shi, Keqiang Sun, Guanglu Song, Yu Liu, Hongsheng Li📮*
9 |
10 |
11 |
12 | | Example 1 | Example 2 | Example 3 |
13 | |-----------------|-----------------|-----------------|
14 | |  |  |  |
15 |
16 | If you use any components of our work, please cite it.
17 |
18 | ```
19 | @article{wang2024animatelcm,
20 | title={AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning},
21 | author={Wang, Fu-Yun and Huang, Zhaoyang and Shi, Xiaoyu and Bian, Weikang and Song, Guanglu and Liu, Yu and Li, Hongsheng},
22 | journal={arXiv preprint arXiv:2402.00769},
23 | year={2024}
24 | }
25 |
26 | ```
27 | ### News
28 |
29 | - [2024.05]: 🔥🔥🔥 We release the [training script](https://github.com/G-U-N/AnimateLCM/blob/master/animatelcm_svd/train_svd_lcm.py) for accelerating Stable Video Diffusion.
30 | - [2024.03]: 😆😆😆 We release the AnimateLCM-I2V and AnimateLCM-SVD for fast image animation.
31 | - [2024.02]: 🤗🤗🤗 Release pretrained model weights and Huggingface Demo.
32 | - [2024.02]: 💡💡💡 Technical report is available on arXiv.
33 |
34 |
35 | Here is a screen recording of usage. Prompt:"river reflecting mountain"
36 |
37 | 
38 |
39 |
40 | ### Introduction
41 |
42 |
43 | Animate-LCM is **a pioneer work** and exploratory on fast animation generation following the consistency models, being able to generate animations in good quality with 4 inference steps.
44 |
45 | It relies on the **decoupled** learning paradigm, firstly learning image generation prior and then learning the temporal generation prior for fast sampling, greatly boosting the training efficiency.
46 |
47 | The High-level workflow of AnimateLCM can be
48 |
49 |
50 |
51 |

52 |
53 |
54 |
55 | ### Demos
56 |
57 | We have **launched lots of demo videos generated by Animate-LCM on the [Project Page](https://animatelcm.github.io/)**. Generally speaking, AnimateLCM works for fast, text-to-video, control-to-video, image-to-video, video-to-video stylization, and longer video generation.
58 |
59 |
60 |
61 |

62 |
63 |
64 |
65 |
66 |
67 |
68 | ### Models
69 |
70 | So far, we have released three models for usage
71 |
72 | - [Animate-LCM-T2V](https://huggingface.co/wangfuyun/AnimateLCM): A spatial LoRA weight and a motion module for personalized video generation. Some trying from the community point out that the motion module is also compatible with many personalized models tuned for LCM, for example [Dreamshaper-LCM](https://civitai.com/models/4384?modelVersionId=252914).
73 |
74 | - [AnimateLCM-SVD-xt](https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt). I provide AnimateLCM-SVD-xt and AnimateLCM-SVD-xt 1.1, which are tuned from [SVD-xt](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) and [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) respectively. They work for high-resolution image animation with 25 frames with 1~8 steps. You can try it with the Hugging Face [Demo](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD). Thanks to the Hugging Face team for providing the GPU grants.
75 |
76 | - [AnimateLCM-I2V](https://huggingface.co/wangfuyun/AnimateLCM-I2V). A spatial LoRA weight and a motion module with an additional image encoder for personalized image animation. It is our trying to directly train an image animation model for fast sampling without any teacher models. It can generate animations with a personalized image with 2~4 steps. Yet due to the training resources is very limited, it is not as stable as I would like (Just like most I2V models built on Stable-Diffusion-v1-5, they generally not very stable for generation).
77 |
78 | ### Install & Usage Instruction
79 |
80 | We split the animatelcm_sd15 and animatelcm_svd into two folders. They are based on different environments. Please refer to [README_animatelcm_sd15](./animatelcm_sd15/README.md) and [README_animatelcm_svd](./animatelcm_svd/README.md) for instructions.
81 |
82 | ### Usage Tips
83 |
84 | - **AnimateLCM-T2V**:
85 | - 4 steps can generally work well. For better quality, apply 6~8 inference steps to improve the generation quality.
86 | - CFG scale should be set between 1~2. Set CFG=1 can reduce the sampling cost by half. However, generally, I would prefer using CFG 1.5 and setting proper negative prompts for sampling to achieve better quality.
87 | - Set the video length to 16 frames for sampling. This is the length that the model trained with.
88 | - The models should work with IP-Adapter, ControlNet, and lots of adapters tuned for Stable Diffusion in a zero-shot manner. If you hope for better results of combination, you can try to tune them together by applying the teacher-free adaptation script I provide. It will not corrupt the sampling speed.
89 |
90 | - **AnimateLCM-I2V**:
91 | - 2-4 steps should work for personalized image animation.
92 | - In most cases, the model does not need CFG values. Just set the CFG=1 to reduce inference cost.
93 | - I additionally set a `motion scale` hyper-parameter. Set it to 0.8 as the default choice. If you set it to 0.0, you should always obtain static animations. You can increase the motion scale for larger motions, but that will sometimes cause generation failure.
94 |
95 | - The typical workflow can be:
96 | - Using your personalized image models to generate an image with good quality.
97 | - Applying the generated image as input and reusing the same prompt for image animation.
98 | - You can even further apply AnimateLCM-T2V to refine the final motion quality.
99 |
100 | - **AnimateLCM-SVD**:
101 | - 1-4 steps should work.
102 | - SVD requires two CFG values. `CFG_min` and `CFG_max`. By default, `CFG_min` is set to 1. Slightly adjusting `CFG_max` between [1, 1.5] will obtain good results. Again, just setting it to 1 to reduce the inference cost.
103 | - For other hyper-parameters of AnimateLCM-SVD-xt, please just follow the original SVD design.
104 |
105 | ### Related Notes
106 |
107 | - 🎉 Tutorial video of AnimateLCM on ComfyUI: [Tutorial Video](https://www.youtube.com/watch?v=HxlZHsd6xAk&feature=youtu.be)
108 | - 🎉 ComfyUI for AnimateLCM: [AnimateLCM-ComfyUI](https://github.com/dezi-ai/ComfyUI-AnimateLCM) & [ComfyUI-Reddit](https://www.reddit.com/r/comfyui/comments/1ajjp9v/animatelcm_support_just_dropped/)
109 |
110 |
111 | ### Comparison
112 |
113 | Screen recording of AnimateLCM-T2V. Prompt: "dog with sunglasses".
114 |
115 | 
116 |
117 |
118 |
119 |

120 |
121 |
122 |
123 | ### Contact & Collaboration
124 |
125 | I am open to collaboration, but not to a full-time intern. If you find some of my work interesting and hope for collaboration/discussion in any format, please do not hesitate to contact me.
126 |
127 | 📧 Email: fywang@link.cuhk.edu.hk
128 |
129 | ### Acknowledge
130 |
131 | I would thank **[AK](https://twitter.com/_akhaliq)** for broadcasting our work and the hugging face team for providing help in building the gradio demo and storing the models. Would thank the [Dhruv Nair](https://twitter.com/_DhruvNair_) for providing help in diffusers.
132 |
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/__assets__/gifs/dog.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/dog.gif
--------------------------------------------------------------------------------
/__assets__/gifs/girl.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/girl.gif
--------------------------------------------------------------------------------
/__assets__/gifs/swan.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/swan.gif
--------------------------------------------------------------------------------
/__assets__/imgs/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/comparison.png
--------------------------------------------------------------------------------
/__assets__/imgs/demo_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/demo_figure.png
--------------------------------------------------------------------------------
/__assets__/imgs/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/examples.png
--------------------------------------------------------------------------------
/animatelcm_sd15/README.md:
--------------------------------------------------------------------------------
1 | ## AnimateLCM SD15
2 |
3 |
4 | ### Enviroment
5 |
6 | ```
7 | conda create -n animatelcm python=3.9
8 | conda activate animatelcm
9 | pip install -r requirements.txt
10 | ```
11 |
12 | ### Models
13 |
14 | 1. stable diffusion
15 | ```
16 | cd models/StableDiffusion/
17 | git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
18 | ```
19 | 2. motion_module
20 | ```
21 | cd ..
22 | cd Motion_Module
23 | wget -c https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v.ckpt
24 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-I2V/resolve/main/AnimateLCM_sd15_i2v.ckpt
25 | ```
26 |
27 | 3. spatial_lora
28 | ```
29 | cd ..
30 | cd LCM_LoRA
31 | wget -c https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v_lora.safetensors
32 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-I2V/resolve/main/AnimateLCM_sd15_i2v_lora.safetensors
33 | ```
34 |
35 | 4. personalized models
36 |
37 | You can either download from the civitai page or apply this [civitai downloader](https://github.com/ashleykleynhans/civitai-downloader). Then put your downloaded models on the Personalized folder
38 |
39 |
40 | ### Inference
41 |
42 | ```
43 | python app.py
44 |
45 | python app-i2v.py
46 | ```
47 |
48 | ### Batch Inference
49 |
50 | ```
51 | python batch_inference.py --config=./configs/batch_inference_example.yaml
52 | ```
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/animatelcm/.DS_Store
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/models/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | import math
6 | import numpy as np
7 |
8 |
9 | def zero_module(module):
10 | # Zero out the parameters of a module and return it.
11 | for p in module.parameters():
12 | p.detach().zero_()
13 | return module
14 |
15 |
16 | def conv_nd(dims, in_channels, out_channels, kernel_size, **kwargs):
17 | """
18 | Create a 1D, 2D, or 3D convolution module.
19 | """
20 | if dims == 1:
21 | return nn.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
22 | elif dims == 2:
23 | return nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
24 | elif dims == 3:
25 | if isinstance(kernel_size, int):
26 | kernel_size = (1, *((kernel_size,) * 2))
27 | if 'stride' in kwargs.keys():
28 | if isinstance(kwargs['stride'], int):
29 | kwargs['stride'] = (1, *((kwargs['stride'],) * 2))
30 | if 'padding' in kwargs.keys():
31 | if isinstance(kwargs['padding'], int):
32 | kwargs['padding'] = (0, *((kwargs['padding'],) * 2))
33 | return nn.Conv3d(in_channels, out_channels, kernel_size, **kwargs)
34 | raise ValueError(f"unsupported dimensions: {dims}")
35 |
36 |
37 | def avg_pool_nd(dims, *args, **kwargs):
38 | """
39 | Create a 1D, 2D, or 3D average pooling module.
40 | """
41 | if dims == 1:
42 | return nn.AvgPool1d(*args, **kwargs)
43 | elif dims == 2:
44 | return nn.AvgPool2d(*args, **kwargs)
45 | elif dims == 3:
46 | return nn.AvgPool3d(*args, **kwargs)
47 | raise ValueError(f"unsupported dimensions: {dims}")
48 |
49 |
50 | def fixed_positional_embedding(t, d_model):
51 | position = torch.arange(0, t, dtype=torch.float).unsqueeze(1)
52 | div_term = torch.exp(torch.arange(0, d_model, 2).float()
53 | * (-np.log(10000.0) / d_model))
54 | pos_embedding = torch.zeros(t, d_model)
55 | pos_embedding[:, 0::2] = torch.sin(position * div_term)
56 | pos_embedding[:, 1::2] = torch.cos(position * div_term)
57 | return pos_embedding
58 |
59 |
60 | class Adapter(nn.Module):
61 | def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
62 | super(Adapter, self).__init__()
63 | self.channels = channels
64 | self.nums_rb = nums_rb
65 | self.body = []
66 | for i in range(len(channels)):
67 | for j in range(nums_rb):
68 | if (i != 0) and (j == 0):
69 | self.body.append(ResnetBlock(
70 | channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
71 | else:
72 | self.body.append(ResnetBlock(
73 | channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
74 | self.body = nn.ModuleList(self.body)
75 | self.conv_in = zero_module(nn.Conv2d(cin, channels[0], 3, 1, 1))
76 | self.motion_scale = 0.8
77 | self.insertion_weights = [1., 1., 1., 1.]
78 |
79 | self.d_model = channels[0]
80 |
81 | def forward(self, x):
82 | b, c, t, h, w = x.shape
83 | x = rearrange(x, 'b c t h w -> (b t) c h w')
84 |
85 | features = []
86 | x = self.conv_in(x)
87 |
88 | pos_embedding = fixed_positional_embedding(
89 | t, self.d_model).to(x.device)
90 | pos_embedding = pos_embedding.unsqueeze(-1).unsqueeze(-1)
91 | pos_embedding = pos_embedding.expand(-1, -1, h, w)
92 |
93 | x_pos = pos_embedding.repeat(b, 1, 1, 1)
94 |
95 | x = self.motion_scale*x_pos + x
96 |
97 | for i in range(len(self.channels)):
98 | for j in range(self.nums_rb):
99 | idx = i*self.nums_rb + j
100 | x = self.body[idx](x)
101 | features.append(x)
102 | features = [weight*rearrange(fn, '(b t) c h w -> b c t h w', b=b, t=t)
103 | for fn, weight in zip(features, self.insertion_weights)]
104 | return features
105 |
106 |
107 | class ResnetBlock(nn.Module):
108 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
109 | super().__init__()
110 | ps = ksize//2
111 | if in_c != out_c or sk == False:
112 | self.in_conv = zero_module(nn.Conv2d(in_c, out_c, ksize, 1, ps))
113 | else:
114 | self.in_conv = None
115 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
116 | self.act = nn.ReLU()
117 | self.block2 = zero_module(nn.Conv2d(out_c, out_c, ksize, 1, ps))
118 | if sk == False:
119 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
120 | else:
121 | self.skep = None
122 |
123 | self.down = down
124 | if self.down == True:
125 | self.down_opt = Downsample(in_c, use_conv=use_conv)
126 |
127 | def forward(self, x):
128 | if self.down == True:
129 | x = self.down_opt(x)
130 |
131 | if self.in_conv is not None:
132 | x = self.in_conv(x)
133 |
134 | h = self.block1(x)
135 | h = self.act(h)
136 | h = self.block2(h)
137 |
138 | if self.skep is not None:
139 | return h + self.skep(x)
140 | else:
141 | return h + x
142 |
143 |
144 | class Downsample(nn.Module):
145 | """
146 | A downsampling layer with an optional convolution.
147 | :param channels: channels in the inputs and outputs.
148 | :param use_conv: a bool determining if a convolution is applied.
149 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
150 | downsampling occurs in the inner-two dimensions.
151 | """
152 |
153 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
154 | super().__init__()
155 | self.channels = channels
156 | self.out_channels = out_channels or channels
157 | self.use_conv = use_conv
158 | self.dims = dims
159 | stride = 2 if dims != 3 else (1, 2, 2)
160 | if use_conv:
161 | self.op = conv_nd(
162 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
163 | )
164 | else:
165 | assert self.channels == self.out_channels
166 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
167 |
168 | def forward(self, x):
169 | assert x.shape[1] == self.channels
170 |
171 | kernel_size = (2, 2)
172 |
173 | input_height, input_width = x.size(2), x.size(3)
174 |
175 | padding_height = (
176 | math.ceil(input_height / kernel_size[0]) * kernel_size[0]) - input_height
177 | padding_width = (
178 | math.ceil(input_width / kernel_size[1]) * kernel_size[1]) - input_width
179 |
180 | x = F.pad(x, (0, padding_width, 0, padding_height), mode='replicate')
181 |
182 | return self.op(x)
183 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/models/attention.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.modeling_utils import ModelMixin
10 | from diffusers.utils import BaseOutput
11 | from diffusers.utils.import_utils import is_xformers_available
12 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
13 |
14 | from einops import rearrange, repeat
15 |
16 | @dataclass
17 | class Transformer3DModelOutput(BaseOutput):
18 | sample: torch.FloatTensor
19 |
20 |
21 | if is_xformers_available():
22 | import xformers
23 | import xformers.ops
24 | else:
25 | xformers = None
26 |
27 |
28 | class Transformer3DModel(ModelMixin, ConfigMixin):
29 | @register_to_config
30 | def __init__(
31 | self,
32 | num_attention_heads: int = 16,
33 | attention_head_dim: int = 88,
34 | in_channels: Optional[int] = None,
35 | num_layers: int = 1,
36 | dropout: float = 0.0,
37 | norm_num_groups: int = 32,
38 | cross_attention_dim: Optional[int] = None,
39 | attention_bias: bool = False,
40 | activation_fn: str = "geglu",
41 | num_embeds_ada_norm: Optional[int] = None,
42 | use_linear_projection: bool = False,
43 | only_cross_attention: bool = False,
44 | upcast_attention: bool = False,
45 |
46 | unet_use_cross_frame_attention=None,
47 | unet_use_temporal_attention=None,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59 | if use_linear_projection:
60 | self.proj_in = nn.Linear(in_channels, inner_dim)
61 | else:
62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63 |
64 | # Define transformers blocks
65 | self.transformer_blocks = nn.ModuleList(
66 | [
67 | BasicTransformerBlock(
68 | inner_dim,
69 | num_attention_heads,
70 | attention_head_dim,
71 | dropout=dropout,
72 | cross_attention_dim=cross_attention_dim,
73 | activation_fn=activation_fn,
74 | num_embeds_ada_norm=num_embeds_ada_norm,
75 | attention_bias=attention_bias,
76 | only_cross_attention=only_cross_attention,
77 | upcast_attention=upcast_attention,
78 |
79 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
80 | unet_use_temporal_attention=unet_use_temporal_attention,
81 | )
82 | for d in range(num_layers)
83 | ]
84 | )
85 |
86 | # 4. Define output layers
87 | if use_linear_projection:
88 | self.proj_out = nn.Linear(in_channels, inner_dim)
89 | else:
90 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
91 |
92 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
93 | # Input
94 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
95 | video_length = hidden_states.shape[2]
96 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
97 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
98 |
99 | batch, channel, height, weight = hidden_states.shape
100 | residual = hidden_states
101 |
102 | hidden_states = self.norm(hidden_states)
103 | if not self.use_linear_projection:
104 | hidden_states = self.proj_in(hidden_states)
105 | inner_dim = hidden_states.shape[1]
106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107 | else:
108 | inner_dim = hidden_states.shape[1]
109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110 | hidden_states = self.proj_in(hidden_states)
111 |
112 | # Blocks
113 | for block in self.transformer_blocks:
114 | hidden_states = block(
115 | hidden_states,
116 | encoder_hidden_states=encoder_hidden_states,
117 | timestep=timestep,
118 | video_length=video_length
119 | )
120 |
121 | # Output
122 | if not self.use_linear_projection:
123 | hidden_states = (
124 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
125 | )
126 | hidden_states = self.proj_out(hidden_states)
127 | else:
128 | hidden_states = self.proj_out(hidden_states)
129 | hidden_states = (
130 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
131 | )
132 |
133 | output = hidden_states + residual
134 |
135 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
136 | if not return_dict:
137 | return (output,)
138 |
139 | return Transformer3DModelOutput(sample=output)
140 |
141 |
142 | class BasicTransformerBlock(nn.Module):
143 | def __init__(
144 | self,
145 | dim: int,
146 | num_attention_heads: int,
147 | attention_head_dim: int,
148 | dropout=0.0,
149 | cross_attention_dim: Optional[int] = None,
150 | activation_fn: str = "geglu",
151 | num_embeds_ada_norm: Optional[int] = None,
152 | attention_bias: bool = False,
153 | only_cross_attention: bool = False,
154 | upcast_attention: bool = False,
155 |
156 | unet_use_cross_frame_attention = None,
157 | unet_use_temporal_attention = None,
158 | ):
159 | super().__init__()
160 | self.only_cross_attention = only_cross_attention
161 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
162 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
163 | self.unet_use_temporal_attention = unet_use_temporal_attention
164 |
165 | # SC-Attn
166 | assert unet_use_cross_frame_attention is not None
167 | if unet_use_cross_frame_attention:
168 | self.attn1 = SparseCausalAttention2D(
169 | query_dim=dim,
170 | heads=num_attention_heads,
171 | dim_head=attention_head_dim,
172 | dropout=dropout,
173 | bias=attention_bias,
174 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
175 | upcast_attention=upcast_attention,
176 | )
177 | else:
178 | self.attn1 = CrossAttention(
179 | query_dim=dim,
180 | heads=num_attention_heads,
181 | dim_head=attention_head_dim,
182 | dropout=dropout,
183 | bias=attention_bias,
184 | upcast_attention=upcast_attention,
185 | )
186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
187 |
188 | # Cross-Attn
189 | if cross_attention_dim is not None:
190 | self.attn2 = CrossAttention(
191 | query_dim=dim,
192 | cross_attention_dim=cross_attention_dim,
193 | heads=num_attention_heads,
194 | dim_head=attention_head_dim,
195 | dropout=dropout,
196 | bias=attention_bias,
197 | upcast_attention=upcast_attention,
198 | )
199 | else:
200 | self.attn2 = None
201 |
202 | if cross_attention_dim is not None:
203 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
204 | else:
205 | self.norm2 = None
206 |
207 | # Feed-forward
208 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
209 | self.norm3 = nn.LayerNorm(dim)
210 |
211 | # Temp-Attn
212 | assert unet_use_temporal_attention is not None
213 | if unet_use_temporal_attention:
214 | self.attn_temp = CrossAttention(
215 | query_dim=dim,
216 | heads=num_attention_heads,
217 | dim_head=attention_head_dim,
218 | dropout=dropout,
219 | bias=attention_bias,
220 | upcast_attention=upcast_attention,
221 | )
222 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
223 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
224 |
225 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
226 | if not is_xformers_available():
227 | raise ModuleNotFoundError(
228 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
229 | " xformers",
230 | name="xformers",
231 | )
232 | elif not torch.cuda.is_available():
233 | raise ValueError(
234 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
235 | " available for GPU "
236 | )
237 | else:
238 | try:
239 | # Make sure we can run the memory efficient attention
240 | _ = xformers.ops.memory_efficient_attention(
241 | torch.randn((1, 2, 40), device="cuda"),
242 | torch.randn((1, 2, 40), device="cuda"),
243 | torch.randn((1, 2, 40), device="cuda"),
244 | )
245 | except Exception as e:
246 | raise e
247 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
248 | if self.attn2 is not None:
249 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
250 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
251 |
252 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
253 | # SparseCausal-Attention
254 | norm_hidden_states = (
255 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
256 | )
257 |
258 | # if self.only_cross_attention:
259 | # hidden_states = (
260 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
261 | # )
262 | # else:
263 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
264 |
265 | # pdb.set_trace()
266 | if self.unet_use_cross_frame_attention:
267 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
268 | else:
269 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
270 |
271 | if self.attn2 is not None:
272 | # Cross-Attention
273 | norm_hidden_states = (
274 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
275 | )
276 | hidden_states = (
277 | self.attn2(
278 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
279 | )
280 | + hidden_states
281 | )
282 |
283 | # Feed-forward
284 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
285 |
286 | # Temporal-Attention
287 | if self.unet_use_temporal_attention:
288 | d = hidden_states.shape[1]
289 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
290 | norm_hidden_states = (
291 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
292 | )
293 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
294 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
295 |
296 | return hidden_states
297 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/models/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 |
16 | import numpy as np
17 | import torch
18 | from torch import nn
19 |
20 |
21 | def get_timestep_embedding(
22 | timesteps: torch.Tensor,
23 | embedding_dim: int,
24 | flip_sin_to_cos: bool = False,
25 | downscale_freq_shift: float = 1,
26 | scale: float = 1,
27 | max_period: int = 10000,
28 | ):
29 | """
30 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
31 |
32 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
33 | These may be fractional.
34 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
35 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
36 | """
37 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
38 |
39 | half_dim = embedding_dim // 2
40 | exponent = -math.log(max_period) * torch.arange(
41 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
42 | )
43 | exponent = exponent / (half_dim - downscale_freq_shift)
44 |
45 | emb = torch.exp(exponent)
46 | emb = timesteps[:, None].float() * emb[None, :]
47 |
48 | # scale embeddings
49 | emb = scale * emb
50 |
51 | # concat sine and cosine embeddings
52 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
53 |
54 | # flip sine and cosine embeddings
55 | if flip_sin_to_cos:
56 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
57 |
58 | # zero pad
59 | if embedding_dim % 2 == 1:
60 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
61 | return emb
62 |
63 | def zero_module(module):
64 | # Zero out the parameters of a module and return it.
65 | for p in module.parameters():
66 | p.detach().zero_()
67 | return module
68 |
69 | class TimestepEmbedding(nn.Module):
70 | def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None):
71 | super().__init__()
72 |
73 | self.linear_1 = nn.Linear(in_channels, time_embed_dim)
74 | self.act = None
75 | if act_fn == "silu":
76 | self.act = nn.SiLU()
77 | elif act_fn == "mish":
78 | self.act = nn.Mish()
79 |
80 | if time_cond_proj_dim is not None:
81 | self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False))
82 | else:
83 | self.cond_proj = None
84 |
85 |
86 | if out_dim is not None:
87 | time_embed_dim_out = out_dim
88 | else:
89 | time_embed_dim_out = time_embed_dim
90 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
91 |
92 | def forward(self, sample, condition=None):
93 | if condition is not None:
94 | sample = sample + self.cond_proj(condition)
95 | sample = self.linear_1(sample)
96 |
97 | if self.act is not None:
98 | sample = self.act(sample)
99 |
100 | sample = self.linear_2(sample)
101 | return sample
102 |
103 |
104 | class Timesteps(nn.Module):
105 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
106 | super().__init__()
107 | self.num_channels = num_channels
108 | self.flip_sin_to_cos = flip_sin_to_cos
109 | self.downscale_freq_shift = downscale_freq_shift
110 |
111 | def forward(self, timesteps):
112 | t_emb = get_timestep_embedding(
113 | timesteps,
114 | self.num_channels,
115 | flip_sin_to_cos=self.flip_sin_to_cos,
116 | downscale_freq_shift=self.downscale_freq_shift,
117 | )
118 | return t_emb
119 |
120 |
121 | class GaussianFourierProjection(nn.Module):
122 | """Gaussian Fourier embeddings for noise levels."""
123 |
124 | def __init__(
125 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
126 | ):
127 | super().__init__()
128 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
129 | self.log = log
130 | self.flip_sin_to_cos = flip_sin_to_cos
131 |
132 | if set_W_to_weight:
133 | # to delete later
134 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
135 |
136 | self.weight = self.W
137 |
138 | def forward(self, x):
139 | if self.log:
140 | x = torch.log(x)
141 |
142 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
143 |
144 | if self.flip_sin_to_cos:
145 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
146 | else:
147 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
148 | return out
149 |
150 |
151 | class ImagePositionalEmbeddings(nn.Module):
152 | """
153 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
154 | height and width of the latent space.
155 |
156 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
157 |
158 | For VQ-diffusion:
159 |
160 | Output vector embeddings are used as input for the transformer.
161 |
162 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
163 |
164 | Args:
165 | num_embed (`int`):
166 | Number of embeddings for the latent pixels embeddings.
167 | height (`int`):
168 | Height of the latent image i.e. the number of height embeddings.
169 | width (`int`):
170 | Width of the latent image i.e. the number of width embeddings.
171 | embed_dim (`int`):
172 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
173 | """
174 |
175 | def __init__(
176 | self,
177 | num_embed: int,
178 | height: int,
179 | width: int,
180 | embed_dim: int,
181 | ):
182 | super().__init__()
183 |
184 | self.height = height
185 | self.width = width
186 | self.num_embed = num_embed
187 | self.embed_dim = embed_dim
188 |
189 | self.emb = nn.Embedding(self.num_embed, embed_dim)
190 | self.height_emb = nn.Embedding(self.height, embed_dim)
191 | self.width_emb = nn.Embedding(self.width, embed_dim)
192 |
193 | def forward(self, index):
194 | emb = self.emb(index)
195 |
196 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
197 |
198 | # 1 x H x D -> 1 x H x 1 x D
199 | height_emb = height_emb.unsqueeze(2)
200 |
201 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
202 |
203 | # 1 x W x D -> 1 x 1 x W x D
204 | width_emb = width_emb.unsqueeze(1)
205 |
206 | pos_emb = height_emb + width_emb
207 |
208 | # 1 x H x W x D -> 1 x L xD
209 | pos_emb = pos_emb.view(1, self.height * self.width, -1)
210 |
211 | emb = emb + pos_emb[:, : emb.shape[1], :]
212 |
213 | return emb
214 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/models/motion_module.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.modeling_utils import ModelMixin
10 | from diffusers.utils import BaseOutput
11 | from diffusers.utils.import_utils import is_xformers_available
12 | from diffusers.models.attention import CrossAttention, FeedForward
13 |
14 | from einops import rearrange, repeat
15 | import math
16 |
17 |
18 | def zero_module(module):
19 | for p in module.parameters():
20 | p.detach().zero_()
21 | return module
22 |
23 |
24 | @dataclass
25 | class TemporalTransformer3DModelOutput(BaseOutput):
26 | sample: torch.FloatTensor
27 |
28 |
29 | if is_xformers_available():
30 | import xformers
31 | import xformers.ops
32 | else:
33 | xformers = None
34 |
35 |
36 | def get_motion_module(
37 | in_channels,
38 | motion_module_type: str,
39 | motion_module_kwargs: dict
40 | ):
41 | if motion_module_type == "Vanilla":
42 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
43 | else:
44 | raise ValueError
45 |
46 |
47 | class VanillaTemporalModule(nn.Module):
48 | def __init__(
49 | self,
50 | in_channels,
51 | num_attention_heads=8,
52 | num_transformer_block=2,
53 | attention_block_types=("Temporal_Self", "Temporal_Self"),
54 | cross_frame_attention_mode=None,
55 | temporal_position_encoding=False,
56 | temporal_attention_dim_div=1,
57 | zero_initialize=True,
58 | ):
59 | super().__init__()
60 |
61 | self.temporal_transformer = TemporalTransformer3DModel(
62 | in_channels=in_channels,
63 | num_attention_heads=num_attention_heads,
64 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
65 | num_layers=num_transformer_block,
66 | attention_block_types=attention_block_types,
67 | cross_frame_attention_mode=cross_frame_attention_mode,
68 | temporal_position_encoding=temporal_position_encoding,
69 | )
70 |
71 | if zero_initialize:
72 | self.temporal_transformer.proj_out = zero_module(
73 | self.temporal_transformer.proj_out)
74 |
75 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
76 | hidden_states = input_tensor
77 | hidden_states = self.temporal_transformer(
78 | hidden_states, encoder_hidden_states, attention_mask)
79 |
80 | output = hidden_states
81 | return output
82 |
83 |
84 | class TemporalTransformer3DModel(nn.Module):
85 | def __init__(
86 | self,
87 | in_channels,
88 | num_attention_heads,
89 | attention_head_dim,
90 |
91 | num_layers,
92 | attention_block_types=("Temporal_Self", "Temporal_Self", ),
93 | dropout=0.0,
94 | norm_num_groups=32,
95 | cross_attention_dim=768,
96 | activation_fn="geglu",
97 | attention_bias=False,
98 | upcast_attention=False,
99 |
100 | cross_frame_attention_mode=None,
101 | temporal_position_encoding=False,
102 | ):
103 | super().__init__()
104 |
105 | inner_dim = num_attention_heads * attention_head_dim
106 |
107 | self.norm = torch.nn.GroupNorm(
108 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
109 | self.proj_in = nn.Linear(in_channels, inner_dim)
110 |
111 | self.transformer_blocks = nn.ModuleList(
112 | [
113 | TemporalTransformerBlock(
114 | dim=inner_dim,
115 | num_attention_heads=num_attention_heads,
116 | attention_head_dim=attention_head_dim,
117 | attention_block_types=attention_block_types,
118 | dropout=dropout,
119 | norm_num_groups=norm_num_groups,
120 | cross_attention_dim=cross_attention_dim,
121 | activation_fn=activation_fn,
122 | attention_bias=attention_bias,
123 | upcast_attention=upcast_attention,
124 | cross_frame_attention_mode=cross_frame_attention_mode,
125 | temporal_position_encoding=temporal_position_encoding,
126 | )
127 | for d in range(num_layers)
128 | ]
129 | )
130 | self.proj_out = nn.Linear(inner_dim, in_channels)
131 |
132 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
133 | assert hidden_states.dim(
134 | ) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
135 | video_length = hidden_states.shape[2]
136 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
137 |
138 | batch, channel, height, weight = hidden_states.shape
139 | residual = hidden_states
140 |
141 | hidden_states = self.norm(hidden_states)
142 | inner_dim = hidden_states.shape[1]
143 | hidden_states = hidden_states.permute(
144 | 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
145 | hidden_states = self.proj_in(hidden_states)
146 |
147 | # Transformer Blocks
148 | for block in self.transformer_blocks:
149 | hidden_states = block(
150 | hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
151 |
152 | # output
153 | hidden_states = self.proj_out(hidden_states)
154 | hidden_states = hidden_states.reshape(
155 | batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
156 |
157 | output = hidden_states + residual
158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
159 |
160 | return output
161 |
162 |
163 | class TemporalTransformerBlock(nn.Module):
164 | def __init__(
165 | self,
166 | dim,
167 | num_attention_heads,
168 | attention_head_dim,
169 | attention_block_types=("Temporal_Self", "Temporal_Self", ),
170 | dropout=0.0,
171 | norm_num_groups=32,
172 | cross_attention_dim=768,
173 | activation_fn="geglu",
174 | attention_bias=False,
175 | upcast_attention=False,
176 | cross_frame_attention_mode=None,
177 | temporal_position_encoding=False,
178 | ):
179 | super().__init__()
180 |
181 | attention_blocks = []
182 | norms = []
183 |
184 | for block_name in attention_block_types:
185 | attention_blocks.append(
186 | VersatileAttention(
187 | attention_mode=block_name.split("_")[0],
188 | cross_attention_dim=cross_attention_dim if block_name.endswith(
189 | "_Cross") else None,
190 |
191 | query_dim=dim,
192 | heads=num_attention_heads,
193 | dim_head=attention_head_dim,
194 | dropout=dropout,
195 | bias=attention_bias,
196 | upcast_attention=upcast_attention,
197 |
198 | cross_frame_attention_mode=cross_frame_attention_mode,
199 | temporal_position_encoding=temporal_position_encoding,
200 | )
201 | )
202 | norms.append(nn.LayerNorm(dim))
203 |
204 | self.attention_blocks = nn.ModuleList(attention_blocks)
205 | self.norms = nn.ModuleList(norms)
206 |
207 | self.ff = FeedForward(dim, dropout=dropout,
208 | activation_fn=activation_fn)
209 | self.ff_norm = nn.LayerNorm(dim)
210 |
211 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
212 | for attention_block, norm in zip(self.attention_blocks, self.norms):
213 | norm_hidden_states = norm(hidden_states)
214 | hidden_states = attention_block(
215 | norm_hidden_states,
216 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
217 | video_length=video_length,
218 | ) + hidden_states
219 |
220 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
221 |
222 | output = hidden_states
223 | return output
224 |
225 |
226 | class PositionalEncoding(nn.Module):
227 | def __init__(
228 | self,
229 | d_model,
230 | dropout=0.,
231 | ):
232 | super().__init__()
233 |
234 | max_length = 64
235 | self.dropout = nn.Dropout(p=dropout)
236 | position = torch.arange(max_length).unsqueeze(1)
237 | div_term = torch.exp(torch.arange(0, d_model, 2)
238 | * (-math.log(10000.0) / d_model))
239 | pe = torch.zeros(1, max_length, d_model)
240 | pe[0, :, 0::2] = torch.sin(position * div_term)
241 | pe[0, :, 1::2] = torch.cos(position * div_term)
242 | self.register_buffer('pos_encoding', pe)
243 |
244 | def forward(self, x):
245 | x = x + self.pos_encoding[:, :x.size(1)]
246 | return self.dropout(x)
247 |
248 |
249 | class VersatileAttention(CrossAttention):
250 | def __init__(
251 | self,
252 | attention_mode=None,
253 | cross_frame_attention_mode=None,
254 | temporal_position_encoding=False,
255 | *args, **kwargs
256 | ):
257 | super().__init__(*args, **kwargs)
258 | assert attention_mode == "Temporal"
259 |
260 | self.attention_mode = attention_mode
261 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None
262 |
263 | self.pos_encoder = PositionalEncoding(
264 | kwargs["query_dim"],
265 | dropout=0.,
266 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None
267 |
268 | def extra_repr(self):
269 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
270 |
271 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
272 | batch_size, sequence_length, _ = hidden_states.shape
273 |
274 | if self.attention_mode == "Temporal":
275 | d = hidden_states.shape[1]
276 | hidden_states = rearrange(
277 | hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278 |
279 | if self.pos_encoder is not None:
280 | hidden_states = self.pos_encoder(hidden_states)
281 |
282 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c",
283 | d=d) if encoder_hidden_states is not None else encoder_hidden_states
284 | else:
285 | raise NotImplementedError
286 |
287 | encoder_hidden_states = encoder_hidden_states
288 |
289 | if self.group_norm is not None:
290 | hidden_states = self.group_norm(
291 | hidden_states.transpose(1, 2)).transpose(1, 2)
292 |
293 | query = self.to_q(hidden_states)
294 | dim = query.shape[-1]
295 | query = self.reshape_heads_to_batch_dim(query)
296 |
297 | if self.added_kv_proj_dim is not None:
298 | raise NotImplementedError
299 |
300 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301 | key = self.to_k(encoder_hidden_states)
302 | value = self.to_v(encoder_hidden_states)
303 |
304 | key = self.reshape_heads_to_batch_dim(key)
305 | value = self.reshape_heads_to_batch_dim(value)
306 |
307 | if attention_mask is not None:
308 | if attention_mask.shape[-1] != query.shape[1]:
309 | target_length = query.shape[1]
310 | attention_mask = F.pad(
311 | attention_mask, (0, target_length), value=0.0)
312 | attention_mask = attention_mask.repeat_interleave(
313 | self.heads, dim=0)
314 |
315 | if self._use_memory_efficient_attention_xformers:
316 | hidden_states = self._memory_efficient_attention_xformers(
317 | query, key, value, attention_mask)
318 | hidden_states = hidden_states.to(query.dtype)
319 | else:
320 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
321 | hidden_states = self._attention(
322 | query, key, value, attention_mask)
323 | else:
324 | hidden_states = self._sliced_attention(
325 | query, key, value, sequence_length, dim, attention_mask)
326 |
327 | # linear proj
328 | hidden_states = self.to_out[0](hidden_states)
329 |
330 | # dropout
331 | hidden_states = self.to_out[1](hidden_states)
332 |
333 | if self.attention_mode == "Temporal":
334 | hidden_states = rearrange(
335 | hidden_states, "(b d) f c -> (b f) d c", d=d)
336 |
337 | return hidden_states
338 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from typing import Optional
7 |
8 | from einops import rearrange
9 |
10 |
11 | class InflatedConv3d(nn.Conv2d):
12 | def forward(self, x):
13 | video_length = x.shape[2]
14 |
15 | x = rearrange(x, "b c f h w -> (b f) c h w")
16 | x = super().forward(x)
17 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
18 |
19 | return x
20 |
21 |
22 | class InflatedGroupNorm(nn.GroupNorm):
23 | def forward(self, x):
24 | video_length = x.shape[2]
25 |
26 | x = rearrange(x, "b c f h w -> (b f) c h w")
27 | x = super().forward(x)
28 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
29 |
30 | return x
31 |
32 |
33 | class Upsample3D(nn.Module):
34 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
35 | super().__init__()
36 | self.channels = channels
37 | self.out_channels = out_channels or channels
38 | self.use_conv = use_conv
39 | self.use_conv_transpose = use_conv_transpose
40 | self.name = name
41 |
42 | conv = None
43 | if use_conv_transpose:
44 | raise NotImplementedError
45 | elif use_conv:
46 | self.conv = InflatedConv3d(
47 | self.channels, self.out_channels, 3, padding=1)
48 |
49 | def forward(self, hidden_states, output_size=None):
50 | assert hidden_states.shape[1] == self.channels
51 |
52 | if self.use_conv_transpose:
53 | raise NotImplementedError
54 |
55 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
56 | dtype = hidden_states.dtype
57 | if dtype == torch.bfloat16:
58 | hidden_states = hidden_states.to(torch.float32)
59 |
60 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
61 | if hidden_states.shape[0] >= 64:
62 | hidden_states = hidden_states.contiguous()
63 |
64 | # if `output_size` is passed we force the interpolation output
65 | # size and do not make use of `scale_factor=2`
66 | if output_size is None:
67 | hidden_states = F.interpolate(hidden_states, scale_factor=[
68 | 1.0, 2.0, 2.0], mode="nearest")
69 | else:
70 | hidden_states = F.interpolate(
71 | hidden_states, size=output_size, mode="nearest")
72 |
73 | # If the input is bfloat16, we cast back to bfloat16
74 | if dtype == torch.bfloat16:
75 | hidden_states = hidden_states.to(dtype)
76 |
77 | # if self.use_conv:
78 | # if self.name == "conv":
79 | # hidden_states = self.conv(hidden_states)
80 | # else:
81 | # hidden_states = self.Conv2d_0(hidden_states)
82 | hidden_states = self.conv(hidden_states)
83 |
84 | return hidden_states
85 |
86 |
87 | class Downsample3D(nn.Module):
88 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89 | super().__init__()
90 | self.channels = channels
91 | self.out_channels = out_channels or channels
92 | self.use_conv = use_conv
93 | self.padding = padding
94 | stride = 2
95 | self.name = name
96 |
97 | if use_conv:
98 | self.conv = InflatedConv3d(
99 | self.channels, self.out_channels, 3, stride=stride, padding=padding)
100 | else:
101 | raise NotImplementedError
102 |
103 | def forward(self, hidden_states):
104 | assert hidden_states.shape[1] == self.channels
105 | if self.use_conv and self.padding == 0:
106 | raise NotImplementedError
107 |
108 | assert hidden_states.shape[1] == self.channels
109 | hidden_states = self.conv(hidden_states)
110 |
111 | return hidden_states
112 |
113 |
114 | class ResnetBlock3D(nn.Module):
115 | def __init__(
116 | self,
117 | *,
118 | in_channels,
119 | out_channels=None,
120 | conv_shortcut=False,
121 | dropout=0.0,
122 | temb_channels=512,
123 | groups=32,
124 | groups_out=None,
125 | pre_norm=True,
126 | eps=1e-6,
127 | non_linearity="swish",
128 | time_embedding_norm="default",
129 | output_scale_factor=1.0,
130 | use_in_shortcut=None,
131 | use_inflated_groupnorm=None,
132 | use_temporal_conv=False,
133 | use_temporal_mixer=False,
134 | ):
135 | super().__init__()
136 | self.pre_norm = pre_norm
137 | self.pre_norm = True
138 | self.in_channels = in_channels
139 | out_channels = in_channels if out_channels is None else out_channels
140 | self.out_channels = out_channels
141 | self.use_conv_shortcut = conv_shortcut
142 | self.time_embedding_norm = time_embedding_norm
143 | self.output_scale_factor = output_scale_factor
144 | self.use_temporal_mixer = use_temporal_mixer
145 | if use_temporal_mixer:
146 | self.temporal_mixer = AlphaBlender(0.3, "learned", None)
147 |
148 | if groups_out is None:
149 | groups_out = groups
150 |
151 | assert use_inflated_groupnorm != None
152 | if use_inflated_groupnorm:
153 | self.norm1 = InflatedGroupNorm(
154 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
155 | else:
156 | self.norm1 = torch.nn.GroupNorm(
157 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
158 |
159 | if use_temporal_conv:
160 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(
161 | 3, 1, 1), stride=1, padding=(1, 0, 0))
162 | else:
163 | self.conv1 = InflatedConv3d(
164 | in_channels, out_channels, kernel_size=3, stride=1, padding=1)
165 |
166 | if temb_channels is not None:
167 | if self.time_embedding_norm == "default":
168 | time_emb_proj_out_channels = out_channels
169 | elif self.time_embedding_norm == "scale_shift":
170 | time_emb_proj_out_channels = out_channels * 2
171 | else:
172 | raise ValueError(
173 | f"unknown time_embedding_norm : {self.time_embedding_norm} ")
174 |
175 | self.time_emb_proj = torch.nn.Linear(
176 | temb_channels, time_emb_proj_out_channels)
177 | else:
178 | self.time_emb_proj = None
179 |
180 | if use_inflated_groupnorm:
181 | self.norm2 = InflatedGroupNorm(
182 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
183 | else:
184 | self.norm2 = torch.nn.GroupNorm(
185 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
186 |
187 | self.dropout = torch.nn.Dropout(dropout)
188 | if use_temporal_conv:
189 | self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(
190 | 3, 1, 1), stride=1, padding=(1, 0, 0))
191 | else:
192 | self.conv2 = InflatedConv3d(
193 | out_channels, out_channels, kernel_size=3, stride=1, padding=1)
194 |
195 | if non_linearity == "swish":
196 | self.nonlinearity = lambda x: F.silu(x)
197 | elif non_linearity == "mish":
198 | self.nonlinearity = Mish()
199 | elif non_linearity == "silu":
200 | self.nonlinearity = nn.SiLU()
201 |
202 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
203 |
204 | self.conv_shortcut = None
205 | if self.use_in_shortcut:
206 | self.conv_shortcut = InflatedConv3d(
207 | in_channels, out_channels, kernel_size=1, stride=1, padding=0)
208 |
209 | def forward(self, input_tensor, temb):
210 | if self.use_temporal_mixer:
211 | residual = input_tensor
212 |
213 | hidden_states = input_tensor
214 |
215 | hidden_states = self.norm1(hidden_states)
216 | hidden_states = self.nonlinearity(hidden_states)
217 |
218 | hidden_states = self.conv1(hidden_states)
219 |
220 | if temb is not None:
221 | temb = self.time_emb_proj(self.nonlinearity(temb))[
222 | :, :, None, None, None]
223 |
224 | if temb is not None and self.time_embedding_norm == "default":
225 | hidden_states = hidden_states + temb
226 |
227 | hidden_states = self.norm2(hidden_states)
228 |
229 | if temb is not None and self.time_embedding_norm == "scale_shift":
230 | scale, shift = torch.chunk(temb, 2, dim=1)
231 | hidden_states = hidden_states * (1 + scale) + shift
232 |
233 | hidden_states = self.nonlinearity(hidden_states)
234 |
235 | hidden_states = self.dropout(hidden_states)
236 | hidden_states = self.conv2(hidden_states)
237 |
238 | if self.conv_shortcut is not None:
239 | input_tensor = self.conv_shortcut(input_tensor)
240 |
241 | output_tensor = (input_tensor + hidden_states) / \
242 | self.output_scale_factor
243 |
244 | if self.use_temporal_mixer:
245 | output_tensor = self.temporal_mixer(residual, output_tensor, None)
246 | # return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None)
247 | return output_tensor
248 |
249 |
250 | class Mish(torch.nn.Module):
251 | def forward(self, hidden_states):
252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
253 |
254 |
255 | class AlphaBlender(nn.Module):
256 | strategies = ["learned", "fixed", "learned_with_images"]
257 |
258 | def __init__(
259 | self,
260 | alpha: float,
261 | merge_strategy: str = "learned_with_images",
262 | rearrange_pattern: str = "b t -> (b t) 1 1",
263 | ):
264 | super().__init__()
265 | self.merge_strategy = merge_strategy
266 | self.rearrange_pattern = rearrange_pattern
267 | self.scaler = 10.
268 |
269 | assert (
270 | merge_strategy in self.strategies
271 | ), f"merge_strategy needs to be in {self.strategies}"
272 |
273 | if self.merge_strategy == "fixed":
274 | self.register_buffer("mix_factor", torch.Tensor([alpha]))
275 | elif (
276 | self.merge_strategy == "learned"
277 | or self.merge_strategy == "learned_with_images"
278 | ):
279 | self.register_parameter(
280 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
281 | )
282 | else:
283 | raise ValueError(f"unknown merge strategy {self.merge_strategy}")
284 |
285 | def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
286 | if self.merge_strategy == "fixed":
287 | alpha = self.mix_factor
288 | elif self.merge_strategy == "learned":
289 | alpha = torch.sigmoid(self.mix_factor*self.scaler)
290 | elif self.merge_strategy == "learned_with_images":
291 | assert image_only_indicator is not None, "need image_only_indicator ..."
292 | alpha = torch.where(
293 | image_only_indicator.bool(),
294 | torch.ones(1, 1, device=image_only_indicator.device),
295 | rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
296 | )
297 | alpha = rearrange(alpha, self.rearrange_pattern)
298 | else:
299 | raise NotImplementedError
300 | return alpha
301 |
302 | def forward(
303 | self,
304 | x_spatial: torch.Tensor,
305 | x_temporal: torch.Tensor,
306 | image_only_indicator: Optional[torch.Tensor] = None,
307 | ) -> torch.Tensor:
308 | alpha = self.get_alpha(image_only_indicator)
309 | x = (
310 | alpha.to(x_spatial.dtype) * x_spatial
311 | + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
312 | )
313 | return x
314 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/utils/convert_lora_safetensor_to_diffusers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ Conversion script for the LoRA's safetensors checkpoints. """
17 |
18 | import argparse
19 |
20 | import torch
21 | from safetensors.torch import load_file
22 |
23 | from diffusers import StableDiffusionPipeline
24 |
25 |
26 | def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
27 | # directly update weight in diffusers model
28 | for key in state_dict:
29 | # only process lora down key
30 | if "up." in key: continue
31 |
32 | up_key = key.replace(".down.", ".up.")
33 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
34 | model_key = model_key.replace("to_out.", "to_out.0.")
35 | layer_infos = model_key.split(".")[:-1]
36 |
37 | curr_layer = pipeline.unet
38 | while len(layer_infos) > 0:
39 | temp_name = layer_infos.pop(0)
40 | curr_layer = curr_layer.__getattr__(temp_name)
41 |
42 | weight_down = state_dict[key]
43 | weight_up = state_dict[up_key]
44 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
45 |
46 | return pipeline
47 |
48 |
49 |
50 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
51 | # load base model
52 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
53 |
54 | # load LoRA weight from .safetensors
55 | # state_dict = load_file(checkpoint_path)
56 |
57 | visited = []
58 |
59 | # directly update weight in diffusers model
60 | for key in state_dict:
61 | # it is suggested to print out the key, it usually will be something like below
62 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
63 |
64 | # as we have set the alpha beforehand, so just skip
65 | if ".alpha" in key or key in visited:
66 | continue
67 |
68 | if "text" in key:
69 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
70 | curr_layer = pipeline.text_encoder
71 | else:
72 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
73 | curr_layer = pipeline.unet
74 |
75 | # find the target layer
76 | temp_name = layer_infos.pop(0)
77 | while len(layer_infos) > -1:
78 | try:
79 | curr_layer = curr_layer.__getattr__(temp_name)
80 | if len(layer_infos) > 0:
81 | temp_name = layer_infos.pop(0)
82 | elif len(layer_infos) == 0:
83 | break
84 | except Exception:
85 | if len(temp_name) > 0:
86 | temp_name += "_" + layer_infos.pop(0)
87 | else:
88 | temp_name = layer_infos.pop(0)
89 |
90 | pair_keys = []
91 | if "lora_down" in key:
92 | pair_keys.append(key.replace("lora_down", "lora_up"))
93 | pair_keys.append(key)
94 | else:
95 | pair_keys.append(key)
96 | pair_keys.append(key.replace("lora_up", "lora_down"))
97 |
98 | # update weight
99 | if len(state_dict[pair_keys[0]].shape) == 4:
100 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
101 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
102 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
103 | else:
104 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
105 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
106 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
107 |
108 | # update visited list
109 | for item in pair_keys:
110 | visited.append(item)
111 |
112 | return pipeline
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 |
118 | parser.add_argument(
119 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
120 | )
121 | parser.add_argument(
122 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
123 | )
124 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
125 | parser.add_argument(
126 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
127 | )
128 | parser.add_argument(
129 | "--lora_prefix_text_encoder",
130 | default="lora_te",
131 | type=str,
132 | help="The prefix of text encoder weight in safetensors",
133 | )
134 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
135 | parser.add_argument(
136 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
137 | )
138 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
139 |
140 | args = parser.parse_args()
141 |
142 | base_model_path = args.base_model_path
143 | checkpoint_path = args.checkpoint_path
144 | dump_path = args.dump_path
145 | lora_prefix_unet = args.lora_prefix_unet
146 | lora_prefix_text_encoder = args.lora_prefix_text_encoder
147 | alpha = args.alpha
148 |
149 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
150 |
151 | pipe = pipe.to(args.device)
152 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
153 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/utils/lcm_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from safetensors import safe_open
4 |
5 |
6 | def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
7 | """
8 | See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
9 |
10 | Args:
11 | timesteps (`torch.Tensor`):
12 | generate embedding vectors at these timesteps
13 | embedding_dim (`int`, *optional*, defaults to 512):
14 | dimension of the embeddings to generate
15 | dtype:
16 | data type of the generated embeddings
17 |
18 | Returns:
19 | `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
20 | """
21 | assert len(w.shape) == 1
22 | w = w * 1000.0
23 |
24 | half_dim = embedding_dim // 2
25 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
26 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
27 | emb = w.to(dtype)[:, None] * emb[None, :]
28 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
29 | if embedding_dim % 2 == 1: # zero pad
30 | emb = torch.nn.functional.pad(emb, (0, 1))
31 | assert emb.shape == (w.shape[0], embedding_dim)
32 | return emb
33 |
34 |
35 | def append_dims(x, target_dims):
36 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
37 | dims_to_append = target_dims - x.ndim
38 | if dims_to_append < 0:
39 | raise ValueError(
40 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
41 | return x[(...,) + (None,) * dims_to_append]
42 |
43 |
44 | # From LCMScheduler.get_scalings_for_boundary_condition_discrete
45 | def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
46 | c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
47 | c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
48 | return c_skip, c_out
49 |
50 |
51 | # Compare LCMScheduler.step, Step 4
52 | def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
53 | if prediction_type == "epsilon":
54 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
55 | alphas = extract_into_tensor(alphas, timesteps, sample.shape)
56 | pred_x_0 = (sample - sigmas * model_output) / alphas
57 | elif prediction_type == "v_prediction":
58 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
59 | alphas = extract_into_tensor(alphas, timesteps, sample.shape)
60 | pred_x_0 = alphas * sample - sigmas * model_output
61 | else:
62 | raise ValueError(
63 | f"Prediction type {prediction_type} currently not supported.")
64 |
65 | return pred_x_0
66 |
67 |
68 | def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
69 | if prediction_type == "epsilon":
70 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
71 | alphas = extract_into_tensor(alphas, timesteps, sample.shape)
72 | sample = sample * alphas / sigmas
73 | else:
74 | raise ValueError(
75 | f"Prediction type {prediction_type} currently not supported.")
76 |
77 | return sample
78 |
79 |
80 | def extract_into_tensor(a, t, x_shape):
81 | b, *_ = t.shape
82 | out = a.gather(-1, t)
83 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
84 |
85 |
86 | class DDIMSolver:
87 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
88 | # DDIM sampling parameters
89 | step_ratio = timesteps // ddim_timesteps
90 | self.ddim_timesteps = (
91 | np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
92 | # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
93 | self.ddim_timesteps_prev = np.asarray(
94 | [0] + self.ddim_timesteps[:-1].tolist()
95 | )
96 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
97 | self.ddim_alpha_cumprods_prev = np.asarray(
98 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
99 | )
100 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
101 | self.ddim_alpha_cumprods_prev = np.asarray(
102 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
103 | )
104 | # convert to torch tensors
105 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
106 | self.ddim_timesteps_prev = torch.from_numpy(
107 | self.ddim_timesteps_prev).long()
108 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
109 | self.ddim_alpha_cumprods_prev = torch.from_numpy(
110 | self.ddim_alpha_cumprods_prev)
111 |
112 | def to(self, device):
113 | self.ddim_timesteps = self.ddim_timesteps.to(device)
114 | self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
115 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
116 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
117 | device)
118 | return self
119 |
120 | def ddim_step(self, pred_x0, pred_noise, timestep_index):
121 | alpha_cumprod_prev = extract_into_tensor(
122 | self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
123 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
124 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
125 | return x_prev
126 |
127 |
128 | @torch.no_grad()
129 | def update_ema(target_params, source_params, rate=0.99):
130 | """
131 | Update target parameters to be closer to those of source parameters using
132 | an exponential moving average.
133 |
134 | :param target_params: the target parameter sequence.
135 | :param source_params: the source parameter sequence.
136 | :param rate: the EMA rate (closer to 1 means slower).
137 | """
138 | for targ, src in zip(target_params, source_params):
139 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
140 |
141 |
142 | def convert_lcm_lora(unet, path, alpha=1.0):
143 |
144 | if path.endswith(("ckpt",)):
145 | state_dict = torch.load(path, map_location="cpu")
146 | else:
147 | state_dict = {}
148 | with safe_open(path, framework="pt", device="cpu") as f:
149 | for key in f.keys():
150 | state_dict[key] = f.get_tensor(key)
151 |
152 | num_alpha = 0
153 | for key in state_dict.keys():
154 | if "alpha" in key:
155 | num_alpha += 1
156 |
157 | lora_keys = [k for k in state_dict.keys(
158 | ) if k.endswith("lora_down.weight")]
159 |
160 | updated_state_dict = {}
161 | for key in lora_keys:
162 | lora_name = key.split(".")[0]
163 |
164 | if lora_name.startswith("lora_unet_"):
165 | diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
166 |
167 | if "input.blocks" in diffusers_name:
168 | diffusers_name = diffusers_name.replace(
169 | "input.blocks", "down_blocks")
170 | else:
171 | diffusers_name = diffusers_name.replace(
172 | "down.blocks", "down_blocks")
173 |
174 | if "middle.block" in diffusers_name:
175 | diffusers_name = diffusers_name.replace(
176 | "middle.block", "mid_block")
177 | else:
178 | diffusers_name = diffusers_name.replace(
179 | "mid.block", "mid_block")
180 | if "output.blocks" in diffusers_name:
181 | diffusers_name = diffusers_name.replace(
182 | "output.blocks", "up_blocks")
183 | else:
184 | diffusers_name = diffusers_name.replace(
185 | "up.blocks", "up_blocks")
186 |
187 | diffusers_name = diffusers_name.replace(
188 | "transformer.blocks", "transformer_blocks")
189 | diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
190 | diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
191 | diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
192 | diffusers_name = diffusers_name.replace(
193 | "to.out.0.lora", "to_out_lora")
194 | diffusers_name = diffusers_name.replace("proj.in", "proj_in")
195 | diffusers_name = diffusers_name.replace("proj.out", "proj_out")
196 | diffusers_name = diffusers_name.replace(
197 | "time.emb.proj", "time_emb_proj")
198 | diffusers_name = diffusers_name.replace(
199 | "conv.shortcut", "conv_shortcut")
200 |
201 | updated_state_dict[diffusers_name] = state_dict[key]
202 | up_diffusers_name = diffusers_name.replace(".down.", ".up.")
203 | up_key = key.replace("lora_down.weight", "lora_up.weight")
204 | updated_state_dict[up_diffusers_name] = state_dict[up_key]
205 |
206 | state_dict = updated_state_dict
207 |
208 | num_lora = 0
209 | for key in state_dict:
210 | if "up." in key:
211 | continue
212 | up_key = key.replace(".down.", ".up.")
213 | model_key = key.replace("processor.", "").replace("_lora", "").replace(
214 | "down.", "").replace("up.", "").replace(".lora", "")
215 | model_key = model_key.replace("to_out.", "to_out.0.")
216 | layer_infos = model_key.split(".")[:-1]
217 |
218 | curr_layer = unet
219 | while len(layer_infos) > 0:
220 | temp_name = layer_infos.pop(0)
221 | curr_layer = curr_layer.__getattr__(temp_name)
222 |
223 | weight_down = state_dict[key].to(
224 | curr_layer.weight.data.device, curr_layer.weight.data.dtype)
225 | weight_up = state_dict[up_key].to(
226 | curr_layer.weight.data.device, curr_layer.weight.data.dtype)
227 |
228 | if weight_up.ndim == 2:
229 | curr_layer.weight.data += 1/8 * alpha * \
230 | torch.mm(weight_up, weight_down)
231 | else:
232 | assert weight_up.ndim == 4
233 | curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
234 | start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
235 | num_lora += 1
236 |
237 | return unet
238 |
--------------------------------------------------------------------------------
/animatelcm_sd15/animatelcm/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 | import torch.distributed as dist
9 | import cv2
10 | from safetensors import safe_open
11 | from tqdm import tqdm
12 | from PIL import Image
13 |
14 | from einops import rearrange
15 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
16 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
17 |
18 |
19 | def resize_and_crop(image_path, target_width, target_height):
20 | image = Image.open(image_path).convert("RGB")
21 |
22 | orig_width, orig_height = image.size
23 |
24 | target_ratio = target_width / target_height
25 | orig_ratio = orig_width / orig_height
26 |
27 | if target_ratio > orig_ratio:
28 | resize_width = target_width
29 | resize_height = round((target_width / orig_width) * orig_height)
30 | else:
31 | resize_height = target_height
32 | resize_width = round((target_height / orig_height) * orig_width)
33 |
34 | resized_image = image.resize((resize_width, resize_height), Image.LANCZOS)
35 |
36 | x0 = (resize_width - target_width) / 2
37 | y0 = (resize_height - target_height) / 2
38 | x1 = x0 + target_width
39 | y1 = y0 + target_height
40 |
41 | cropped_image = resized_image.crop((x0, y0, x1, y1))
42 |
43 | return cropped_image
44 |
45 | def adjust_colors(src_img, target_img):
46 | src_hsv = cv2.cvtColor(src_img, cv2.COLOR_BGR2HSV)
47 | target_hsv = cv2.cvtColor(target_img, cv2.COLOR_BGR2HSV)
48 |
49 | for idx in range(3):
50 | src_mean = np.mean(src_hsv[:, :, idx])
51 | target_mean = np.mean(target_hsv[:, :, idx])
52 |
53 | diff = src_mean - target_mean
54 |
55 | if idx == 0:
56 | target_hsv[:, :, idx] = np.clip(target_hsv[:, :, idx] + 0.03 * diff, 0, 180)
57 | else:
58 | target_hsv[:, :, idx] = np.clip(target_hsv[:, :, idx] + 1.0 * diff, 0, 255)
59 |
60 | adjusted_img = cv2.cvtColor(target_hsv, cv2.COLOR_HSV2BGR)
61 | return adjusted_img
62 |
63 | def zero_rank_print(s):
64 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
65 |
66 |
67 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
68 | videos = rearrange(videos, "b c t h w -> t b c h w")
69 | outputs = []
70 | for x in videos:
71 | x = torchvision.utils.make_grid(x, nrow=n_rows)
72 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
73 | if rescale:
74 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
75 | x = (x * 255).numpy().astype(np.uint8)
76 | outputs.append(x)
77 |
78 | os.makedirs(os.path.dirname(path), exist_ok=True)
79 | imageio.mimsave(path, outputs, fps=fps)
80 |
81 |
82 | # DDIM Inversion
83 | @torch.no_grad()
84 | def init_prompt(prompt, pipeline):
85 | uncond_input = pipeline.tokenizer(
86 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
87 | return_tensors="pt"
88 | )
89 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
90 | text_input = pipeline.tokenizer(
91 | [prompt],
92 | padding="max_length",
93 | max_length=pipeline.tokenizer.model_max_length,
94 | truncation=True,
95 | return_tensors="pt",
96 | )
97 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
98 | context = torch.cat([uncond_embeddings, text_embeddings])
99 |
100 | return context
101 |
102 |
103 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
104 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
105 | timestep, next_timestep = min(
106 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
107 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
108 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
109 | beta_prod_t = 1 - alpha_prod_t
110 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
111 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
112 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
113 | return next_sample
114 |
115 |
116 | def get_noise_pred_single(latents, t, context, unet):
117 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
118 | return noise_pred
119 |
120 |
121 | @torch.no_grad()
122 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
123 | context = init_prompt(prompt, pipeline)
124 | uncond_embeddings, cond_embeddings = context.chunk(2)
125 | all_latent = [latent]
126 | latent = latent.clone().detach()
127 | for i in tqdm(range(num_inv_steps)):
128 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
129 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
130 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
131 | all_latent.append(latent)
132 | return all_latent
133 |
134 |
135 | @torch.no_grad()
136 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
137 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
138 | return ddim_latents
139 |
140 | def load_weights(
141 | animation_pipeline,
142 | motion_module_path = "",
143 | motion_module_lora_configs = [],
144 | dreambooth_model_path = "",
145 | lora_model_path = "",
146 | lora_alpha = 0.8,
147 | ):
148 | unet_state_dict = {}
149 | if motion_module_path != "":
150 | print(f"load motion module from {motion_module_path}")
151 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
152 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
153 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
154 |
155 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
156 | assert len(unexpected) == 0
157 | del unet_state_dict
158 |
159 | if dreambooth_model_path != "":
160 | print(f"load dreambooth model from {dreambooth_model_path}")
161 | if dreambooth_model_path.endswith(".safetensors"):
162 | dreambooth_state_dict = {}
163 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
164 | for key in f.keys():
165 | dreambooth_state_dict[key] = f.get_tensor(key)
166 | elif dreambooth_model_path.endswith(".ckpt"):
167 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
168 |
169 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
170 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
171 |
172 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
173 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
174 |
175 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
176 | del dreambooth_state_dict
177 |
178 | if lora_model_path != "":
179 | print(f"load lora model from {lora_model_path}")
180 | assert lora_model_path.endswith(".safetensors")
181 | lora_state_dict = {}
182 | with safe_open(lora_model_path, framework="pt", device="cpu") as f:
183 | for key in f.keys():
184 | lora_state_dict[key] = f.get_tensor(key)
185 |
186 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
187 | del lora_state_dict
188 |
189 |
190 | for motion_module_lora_config in motion_module_lora_configs:
191 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
192 | print(f"load motion LoRA from {path}")
193 |
194 | motion_lora_state_dict = torch.load(path, map_location="cpu")
195 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
196 |
197 | animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
198 |
199 | return animation_pipeline
200 |
--------------------------------------------------------------------------------
/animatelcm_sd15/app-i2v.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import torch
5 | import random
6 |
7 | import gradio as gr
8 | from glob import glob
9 | from omegaconf import OmegaConf
10 | from datetime import datetime
11 | from safetensors import safe_open
12 |
13 | from diffusers import AutoencoderKL
14 | from diffusers.utils.import_utils import is_xformers_available
15 | from transformers import CLIPTextModel, CLIPTokenizer
16 |
17 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler
18 | from animatelcm.models.unet import UNet3DConditionModel
19 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline
20 | from animatelcm.utils.util import save_videos_grid
21 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
22 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
23 | from animatelcm.utils.lcm_utils import convert_lcm_lora
24 | import copy
25 |
26 | sample_idx = 0
27 | scheduler_dict = {
28 | "LCM": LCMScheduler,
29 | }
30 |
31 | css = """
32 | .toolbutton {
33 | margin-buttom: 0em 0em 0em 0em;
34 | max-width: 2.5em;
35 | min-width: 2.5em !important;
36 | height: 2.5em;
37 | }
38 | """
39 |
40 |
41 | class AnimateController:
42 | def __init__(self):
43 |
44 | # config dirs
45 | self.basedir = os.getcwd()
46 | self.stable_diffusion_dir = os.path.join(
47 | self.basedir, "models", "StableDiffusion")
48 | self.motion_module_dir = os.path.join(
49 | self.basedir, "models", "Motion_Module")
50 | self.personalized_model_dir = os.path.join(
51 | self.basedir, "models", "DreamBooth_LoRA")
52 | self.savedir = os.path.join(
53 | self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
54 | self.savedir_sample = os.path.join(self.savedir, "sample")
55 | self.lcm_lora_path = "models/LCM_LoRA/AnimateLCM_sd15_i2v_lora.safetensors"
56 | os.makedirs(self.savedir, exist_ok=True)
57 |
58 | self.stable_diffusion_list = []
59 | self.motion_module_list = []
60 | self.personalized_model_list = []
61 |
62 | self.refresh_stable_diffusion()
63 | self.refresh_motion_module()
64 | self.refresh_personalized_model()
65 |
66 | # config models
67 | self.tokenizer = None
68 | self.text_encoder = None
69 | self.vae = None
70 | self.unet = None
71 | self.pipeline = None
72 | self.lora_model_state_dict = {}
73 |
74 | self.inference_config = OmegaConf.load("configs/inference-i2v.yaml")
75 |
76 | def refresh_stable_diffusion(self):
77 | self.stable_diffusion_list = glob(
78 | os.path.join(self.stable_diffusion_dir, "*/"))
79 |
80 | def refresh_motion_module(self):
81 | motion_module_list = glob(os.path.join(
82 | self.motion_module_dir, "*.ckpt"))
83 | self.motion_module_list = [
84 | os.path.basename(p) for p in motion_module_list]
85 |
86 | def refresh_personalized_model(self):
87 | personalized_model_list = glob(os.path.join(
88 | self.personalized_model_dir, "*.safetensors"))
89 | self.personalized_model_list = [
90 | os.path.basename(p) for p in personalized_model_list]
91 |
92 | def update_stable_diffusion(self, stable_diffusion_dropdown):
93 | stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
94 | self.tokenizer = CLIPTokenizer.from_pretrained(
95 | stable_diffusion_dropdown, subfolder="tokenizer")
96 | self.text_encoder = CLIPTextModel.from_pretrained(
97 | stable_diffusion_dropdown, subfolder="text_encoder").cuda()
98 | self.vae = AutoencoderKL.from_pretrained(
99 | stable_diffusion_dropdown, subfolder="vae").cuda()
100 | self.unet = UNet3DConditionModel.from_pretrained_2d(
101 | stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
102 | return gr.Dropdown.update()
103 |
104 | def update_motion_module(self, motion_module_dropdown):
105 | if self.unet is None:
106 | gr.Info(f"Please select a pretrained model path.")
107 | return gr.Dropdown.update(value=None)
108 | else:
109 | motion_module_dropdown = os.path.join(
110 | self.motion_module_dir, motion_module_dropdown)
111 | motion_module_state_dict = torch.load(
112 | motion_module_dropdown, map_location="cpu")
113 | missing, unexpected = self.unet.load_state_dict(
114 | motion_module_state_dict, strict=False)
115 | assert len(unexpected) == 0
116 | return gr.Dropdown.update()
117 |
118 | def update_base_model(self, base_model_dropdown):
119 | if self.unet is None:
120 | gr.Info(f"Please select a pretrained model path.")
121 | return gr.Dropdown.update(value=None)
122 | else:
123 | base_model_dropdown = os.path.join(
124 | self.personalized_model_dir, base_model_dropdown)
125 | base_model_state_dict = {}
126 | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
127 | for key in f.keys():
128 | base_model_state_dict[key] = f.get_tensor(key)
129 |
130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(
131 | base_model_state_dict, self.vae.config)
132 | self.vae.load_state_dict(converted_vae_checkpoint)
133 |
134 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(
135 | base_model_state_dict, self.unet.config)
136 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
137 |
138 | self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
139 | return gr.Dropdown.update()
140 |
141 | def update_lora_model(self, lora_model_dropdown):
142 | lora_model_dropdown = os.path.join(
143 | self.personalized_model_dir, lora_model_dropdown)
144 | self.lora_model_state_dict = {}
145 | if lora_model_dropdown == "none":
146 | pass
147 | else:
148 | with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
149 | for key in f.keys():
150 | self.lora_model_state_dict[key] = f.get_tensor(key)
151 | return gr.Dropdown.update()
152 |
153 | def animate(
154 | self,
155 | lora_alpha_slider,
156 | spatial_lora_slider,
157 | prompt_textbox,
158 | negative_prompt_textbox,
159 | sampler_dropdown,
160 | sample_step_slider,
161 | width_slider,
162 | length_slider,
163 | height_slider,
164 | cfg_scale_slider,
165 | seed_textbox,
166 | image_upload,
167 | beta_end_slider,
168 | motion_scale_slider,
169 | ):
170 | print(image_upload)
171 |
172 | if is_xformers_available():
173 | self.unet.enable_xformers_memory_efficient_attention()
174 |
175 | print(self.inference_config.noise_scheduler_kwargs["beta_end"])
176 | self.inference_config.noise_scheduler_kwargs["beta_end"] = beta_end_slider
177 |
178 | self.unet.img_encoder.motion_scale = motion_scale_slider
179 | pipeline = AnimationPipeline(
180 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
181 | scheduler=scheduler_dict[sampler_dropdown](
182 | **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
183 | ).to("cuda")
184 |
185 | if self.lora_model_state_dict != {}:
186 | pipeline = convert_lora(
187 | pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
188 |
189 | pipeline.unet = convert_lcm_lora(copy.deepcopy(
190 | self.unet), self.lcm_lora_path, spatial_lora_slider)
191 |
192 | pipeline.to("cuda")
193 |
194 | if seed_textbox != -1 and seed_textbox != "":
195 | torch.manual_seed(int(seed_textbox))
196 | else:
197 | torch.seed()
198 | seed = torch.initial_seed()
199 |
200 | with torch.autocast("cuda",torch.float16):
201 | sample = pipeline(
202 | prompt_textbox,
203 | negative_prompt=negative_prompt_textbox,
204 | num_inference_steps=sample_step_slider,
205 | guidance_scale=cfg_scale_slider,
206 | width=width_slider,
207 | height=height_slider,
208 | video_length=length_slider,
209 | image_path=image_upload
210 | ).videos
211 |
212 | save_sample_path = os.path.join(
213 | self.savedir_sample, f"{sample_idx}.mp4")
214 | save_videos_grid(sample, save_sample_path)
215 |
216 | sample_config = {
217 | "prompt": prompt_textbox,
218 | "n_prompt": negative_prompt_textbox,
219 | "sampler": sampler_dropdown,
220 | "num_inference_steps": sample_step_slider,
221 | "guidance_scale": cfg_scale_slider,
222 | "width": width_slider,
223 | "height": height_slider,
224 | "video_length": length_slider,
225 | "seed": seed,
226 | }
227 | json_str = json.dumps(sample_config, indent=4)
228 | with open(os.path.join(self.savedir, "logs.json"), "a") as f:
229 | f.write(json_str)
230 | f.write("\n\n")
231 | return gr.Video.update(value=save_sample_path)
232 |
233 |
234 | controller = AnimateController()
235 |
236 | controller.update_stable_diffusion("stable-diffusion-v1-5")
237 | controller.update_motion_module("AnimateLCM_sd15_i2v.ckpt")
238 | controller.update_base_model("realistic1.safetensors")
239 |
240 |
241 | def ui():
242 | with gr.Blocks(css=css) as demo:
243 | gr.Markdown(
244 | """
245 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
246 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
247 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm)
248 | """
249 |
250 | '''
251 | Important Notes:
252 | 1. The generation speed is around few seconds.
253 | 2. Increase the sampling step for better sample quality.
254 | '''
255 | )
256 | with gr.Column(variant="panel"):
257 | with gr.Row():
258 | image_upload = gr.Image(label="Upload Image", tool="select", type="filepath")
259 |
260 | base_model_dropdown = gr.Dropdown(
261 | label="Select base Dreambooth model (required)",
262 | choices=controller.personalized_model_list,
263 | interactive=True,
264 | value="realistic1.safetensors"
265 | )
266 | base_model_dropdown.change(fn=controller.update_base_model, inputs=[
267 | base_model_dropdown], outputs=[base_model_dropdown])
268 |
269 | lora_model_dropdown = gr.Dropdown(
270 | label="Select LoRA model (optional)",
271 | choices=["none"],
272 | value="none",
273 | interactive=True,
274 | )
275 | lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
276 | lora_model_dropdown], outputs=[lora_model_dropdown])
277 |
278 | lora_alpha_slider = gr.Slider(
279 | label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
280 | spatial_lora_slider = gr.Slider(
281 | label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
282 |
283 | personalized_refresh_button = gr.Button(
284 | value="\U0001F503", elem_classes="toolbutton")
285 |
286 | def update_personalized_model():
287 | controller.refresh_personalized_model()
288 | return [
289 | gr.Dropdown.update(
290 | choices=controller.personalized_model_list),
291 | gr.Dropdown.update(
292 | choices=["none"] + controller.personalized_model_list)
293 | ]
294 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
295 | base_model_dropdown, lora_model_dropdown])
296 |
297 | with gr.Column(variant="panel"):
298 | gr.Markdown(
299 | """
300 | ### 2. Configs for AnimateLCM.
301 | """
302 | )
303 |
304 | prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="best quality")
305 | negative_prompt_textbox = gr.Textbox(
306 | label="Negative prompt", lines=2, value="bad quality")
307 |
308 | with gr.Row().style(equal_height=False):
309 | with gr.Column():
310 | with gr.Row():
311 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
312 | scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
313 | sample_step_slider = gr.Slider(
314 | label="Sampling steps", value=4, minimum=1, maximum=25, step=1)
315 |
316 | motion_scale_slider = gr.Slider(
317 | label="motion scale (better identity with smaller scale)", value=0.8, minimum=0.0, maximum=1.5, step=0.05
318 | )
319 | beta_end_slider = gr.Slider(
320 | label="beta end (a tricky way for selecting noisy steps)", value=0.014, minimum=0.012, maximum=0.016, step=0.001)
321 | width_slider = gr.Slider(
322 | label="Width", value=512, minimum=256, maximum=1024, step=64)
323 | height_slider = gr.Slider(
324 | label="Height", value=512, minimum=256, maximum=1024, step=64)
325 | length_slider = gr.Slider(
326 | label="Animation length", value=16, minimum=12, maximum=20, step=1)
327 | cfg_scale_slider = gr.Slider(
328 | label="CFG Scale", value=1, minimum=1, maximum=2)
329 |
330 | with gr.Row():
331 | seed_textbox = gr.Textbox(label="Seed", value=-1)
332 | seed_button = gr.Button(
333 | value="\U0001F3B2", elem_classes="toolbutton")
334 | seed_button.click(fn=lambda: gr.Textbox.update(
335 | value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
336 |
337 | generate_button = gr.Button(
338 | value="Generate", variant='primary')
339 |
340 | result_video = gr.Video(
341 | label="Generated Animation", interactive=False)
342 |
343 |
344 | generate_button.click(
345 | fn=controller.animate,
346 | inputs=[
347 | lora_alpha_slider,
348 | spatial_lora_slider,
349 | prompt_textbox,
350 | negative_prompt_textbox,
351 | sampler_dropdown,
352 | sample_step_slider,
353 | width_slider,
354 | length_slider,
355 | height_slider,
356 | cfg_scale_slider,
357 | seed_textbox,
358 | image_upload,
359 | beta_end_slider,
360 | motion_scale_slider,
361 | ],
362 | outputs=[result_video]
363 | )
364 | examples = [
365 | [0.8, 0.8, "good quality, cloud", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/cloud.jpeg", 0.014, 0.8],
366 | [0.8, 0.8, "good quality, dog", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/dog.jpg", 0.014, 0.8],
367 | [0.8, 0.8, "good quality, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/fire.jpg", 0.014, 0.8],
368 | [0.8, 0.8, "good quality, fox, snow", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/fox.jpg", 0.014, 0.8],
369 | [0.8, 0.8, "good quality, girl, wind, flower", "bad quality", "LCM", 4, 768, 16, 512, 1, 1235, "test_imgs/girl_flower.jpg", 0.014, 0.8],
370 | [0.8, 0.8, "good quality, lighter, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1235, "test_imgs/lighter.jpg", 0.014, 0.8],
371 | [0.8, 0.8, "good quality, snowman, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/snow_man_fire.jpg", 0.014, 0.8],
372 | ]
373 | gr.Examples(
374 | examples = examples,
375 | inputs=[
376 | lora_alpha_slider,
377 | spatial_lora_slider,
378 | prompt_textbox,
379 | negative_prompt_textbox,
380 | sampler_dropdown,
381 | sample_step_slider,
382 | width_slider,
383 | length_slider,
384 | height_slider,
385 | cfg_scale_slider,
386 | seed_textbox,
387 | image_upload,
388 | beta_end_slider,
389 | motion_scale_slider,
390 | ],
391 | outputs=[result_video],
392 | fn=controller.animate,
393 | cache_examples=True,
394 | )
395 | return demo
396 |
397 |
398 | if __name__ == "__main__":
399 | demo = ui()
400 | demo.queue(concurrency_count=3, max_size=20)
401 | demo.launch(share=True, server_name="127.0.0.1")
402 |
--------------------------------------------------------------------------------
/animatelcm_sd15/app.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import json
4 | import torch
5 | import random
6 |
7 | import gradio as gr
8 | from glob import glob
9 | from omegaconf import OmegaConf
10 | from datetime import datetime
11 | from safetensors import safe_open
12 |
13 | from diffusers import AutoencoderKL
14 | from diffusers.utils.import_utils import is_xformers_available
15 | from transformers import CLIPTextModel, CLIPTokenizer
16 |
17 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler
18 | from animatelcm.models.unet import UNet3DConditionModel
19 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline
20 | from animatelcm.utils.util import save_videos_grid
21 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
22 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
23 | from animatelcm.utils.lcm_utils import convert_lcm_lora
24 | import copy
25 |
26 | sample_idx = 0
27 | scheduler_dict = {
28 | "LCM": LCMScheduler,
29 | }
30 |
31 | css = """
32 | .toolbutton {
33 | margin-buttom: 0em 0em 0em 0em;
34 | max-width: 2.5em;
35 | min-width: 2.5em !important;
36 | height: 2.5em;
37 | }
38 | """
39 |
40 |
41 | class AnimateController:
42 | def __init__(self):
43 |
44 | # config dirs
45 | self.basedir = os.getcwd()
46 | self.stable_diffusion_dir = os.path.join(
47 | self.basedir, "models", "StableDiffusion")
48 | self.motion_module_dir = os.path.join(
49 | self.basedir, "models", "Motion_Module")
50 | self.personalized_model_dir = os.path.join(
51 | self.basedir, "models", "Personalized")
52 | self.savedir = os.path.join(
53 | self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
54 | self.savedir_sample = os.path.join(self.savedir, "sample")
55 | self.lcm_lora_path = "models/LCM_LoRA/AnimateLCM_sd15_t2v_lora.safetensors"
56 | os.makedirs(self.savedir, exist_ok=True)
57 |
58 | self.stable_diffusion_list = []
59 | self.motion_module_list = []
60 | self.personalized_model_list = []
61 |
62 | self.refresh_stable_diffusion()
63 | self.refresh_motion_module()
64 | self.refresh_personalized_model()
65 |
66 | # config models
67 | self.tokenizer = None
68 | self.text_encoder = None
69 | self.vae = None
70 | self.unet = None
71 | self.pipeline = None
72 | self.lora_model_state_dict = {}
73 |
74 | self.inference_config = OmegaConf.load("configs/inference-t2v.yaml")
75 |
76 | def refresh_stable_diffusion(self):
77 | self.stable_diffusion_list = glob(
78 | os.path.join(self.stable_diffusion_dir, "*/"))
79 |
80 | def refresh_motion_module(self):
81 | motion_module_list = glob(os.path.join(
82 | self.motion_module_dir, "*.ckpt"))
83 | self.motion_module_list = [
84 | os.path.basename(p) for p in motion_module_list]
85 |
86 | def refresh_personalized_model(self):
87 | personalized_model_list = glob(os.path.join(
88 | self.personalized_model_dir, "*.safetensors"))
89 | self.personalized_model_list = [
90 | os.path.basename(p) for p in personalized_model_list]
91 |
92 | def update_stable_diffusion(self, stable_diffusion_dropdown):
93 | stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
94 | self.tokenizer = CLIPTokenizer.from_pretrained(
95 | stable_diffusion_dropdown, subfolder="tokenizer")
96 | self.text_encoder = CLIPTextModel.from_pretrained(
97 | stable_diffusion_dropdown, subfolder="text_encoder").cuda()
98 | self.vae = AutoencoderKL.from_pretrained(
99 | stable_diffusion_dropdown, subfolder="vae").cuda()
100 | self.unet = UNet3DConditionModel.from_pretrained_2d(
101 | stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
102 | return gr.Dropdown.update()
103 |
104 | def update_motion_module(self, motion_module_dropdown):
105 | if self.unet is None:
106 | gr.Info(f"Please select a pretrained model path.")
107 | return gr.Dropdown.update(value=None)
108 | else:
109 | motion_module_dropdown = os.path.join(
110 | self.motion_module_dir, motion_module_dropdown)
111 | motion_module_state_dict = torch.load(
112 | motion_module_dropdown, map_location="cpu")
113 | missing, unexpected = self.unet.load_state_dict(
114 | motion_module_state_dict, strict=False)
115 | assert len(unexpected) == 0
116 | return gr.Dropdown.update()
117 |
118 | def update_base_model(self, base_model_dropdown):
119 | if self.unet is None:
120 | gr.Info(f"Please select a pretrained model path.")
121 | return gr.Dropdown.update(value=None)
122 | else:
123 | base_model_dropdown = os.path.join(
124 | self.personalized_model_dir, base_model_dropdown)
125 | base_model_state_dict = {}
126 | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
127 | for key in f.keys():
128 | base_model_state_dict[key] = f.get_tensor(key)
129 |
130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(
131 | base_model_state_dict, self.vae.config)
132 | self.vae.load_state_dict(converted_vae_checkpoint)
133 |
134 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(
135 | base_model_state_dict, self.unet.config)
136 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
137 |
138 | # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
139 | return gr.Dropdown.update()
140 |
141 | def update_lora_model(self, lora_model_dropdown):
142 | lora_model_dropdown = os.path.join(
143 | self.personalized_model_dir, lora_model_dropdown)
144 | self.lora_model_state_dict = {}
145 | if lora_model_dropdown == "none":
146 | pass
147 | else:
148 | with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
149 | for key in f.keys():
150 | self.lora_model_state_dict[key] = f.get_tensor(key)
151 | return gr.Dropdown.update()
152 |
153 | def animate(
154 | self,
155 | lora_alpha_slider,
156 | spatial_lora_slider,
157 | prompt_textbox,
158 | negative_prompt_textbox,
159 | sampler_dropdown,
160 | sample_step_slider,
161 | width_slider,
162 | length_slider,
163 | height_slider,
164 | cfg_scale_slider,
165 | seed_textbox
166 | ):
167 |
168 | if is_xformers_available():
169 | self.unet.enable_xformers_memory_efficient_attention()
170 |
171 | pipeline = AnimationPipeline(
172 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
173 | scheduler=scheduler_dict[sampler_dropdown](
174 | **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
175 | ).to("cuda")
176 |
177 | if self.lora_model_state_dict != {}:
178 | pipeline = convert_lora(
179 | pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
180 |
181 | pipeline.unet = convert_lcm_lora(copy.deepcopy(
182 | self.unet), self.lcm_lora_path, spatial_lora_slider)
183 |
184 | pipeline.to("cuda")
185 |
186 | if seed_textbox != -1 and seed_textbox != "":
187 | torch.manual_seed(int(seed_textbox))
188 | else:
189 | torch.seed()
190 | seed = torch.initial_seed()
191 |
192 | sample = pipeline(
193 | prompt_textbox,
194 | negative_prompt=negative_prompt_textbox,
195 | num_inference_steps=sample_step_slider,
196 | guidance_scale=cfg_scale_slider,
197 | width=width_slider,
198 | height=height_slider,
199 | video_length=length_slider,
200 | ).videos
201 |
202 | save_sample_path = os.path.join(
203 | self.savedir_sample, f"{sample_idx}.mp4")
204 | save_videos_grid(sample, save_sample_path)
205 |
206 | sample_config = {
207 | "prompt": prompt_textbox,
208 | "n_prompt": negative_prompt_textbox,
209 | "sampler": sampler_dropdown,
210 | "num_inference_steps": sample_step_slider,
211 | "guidance_scale": cfg_scale_slider,
212 | "width": width_slider,
213 | "height": height_slider,
214 | "video_length": length_slider,
215 | "seed": seed
216 | }
217 | json_str = json.dumps(sample_config, indent=4)
218 | with open(os.path.join(self.savedir, "logs.json"), "a") as f:
219 | f.write(json_str)
220 | f.write("\n\n")
221 | return gr.Video.update(value=save_sample_path)
222 |
223 |
224 | controller = AnimateController()
225 |
226 | controller.update_stable_diffusion("stable-diffusion-v1-5")
227 | controller.update_motion_module("AnimateLCM_sd15_t2v.ckpt")
228 | controller.update_base_model("realistic2.safetensors")
229 |
230 |
231 | def ui():
232 | with gr.Blocks(css=css) as demo:
233 | gr.Markdown(
234 | """
235 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
236 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
237 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm)
238 | """
239 |
240 | '''
241 | Important Notes:
242 | 1. The generation speed is around 1~2 seconds. There is delay in the space.
243 | 2. Increase the sampling step and cfg if you want more fancy videos.
244 | '''
245 | )
246 | with gr.Column(variant="panel"):
247 | with gr.Row():
248 |
249 | base_model_dropdown = gr.Dropdown(
250 | label="Select base Dreambooth model (required)",
251 | choices=controller.personalized_model_list,
252 | interactive=True,
253 | value="realistic2.safetensors"
254 | )
255 |
256 | motion_module_dropdown = gr.Dropdown(
257 | label="Select motion modules",
258 | choices=controller.motion_module_list,
259 | interactive=True,
260 | value="sd15_t2v_beta_motion.ckpt"
261 | )
262 | base_model_dropdown.change(fn=controller.update_base_model, inputs=[
263 | base_model_dropdown], outputs=[base_model_dropdown])
264 |
265 | motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown],outputs=[motion_module_dropdown])
266 |
267 | lora_model_dropdown = gr.Dropdown(
268 | label="Select LoRA model (optional)",
269 | choices=["none"],
270 | value="none",
271 | interactive=True,
272 | )
273 | lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
274 | lora_model_dropdown], outputs=[lora_model_dropdown])
275 |
276 | lora_alpha_slider = gr.Slider(
277 | label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
278 | spatial_lora_slider = gr.Slider(
279 | label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
280 |
281 | personalized_refresh_button = gr.Button(
282 | value="\U0001F503", elem_classes="toolbutton")
283 |
284 | def update_personalized_model():
285 | controller.refresh_personalized_model()
286 | return [
287 | gr.Dropdown.update(
288 | choices=controller.personalized_model_list),
289 | gr.Dropdown.update(
290 | choices=["none"] + controller.personalized_model_list)
291 | ]
292 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
293 | base_model_dropdown, lora_model_dropdown])
294 |
295 | with gr.Column(variant="panel"):
296 | gr.Markdown(
297 | """
298 | ### 2. Configs for AnimateLCM.
299 | """
300 | )
301 |
302 | prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit")
303 | negative_prompt_textbox = gr.Textbox(
304 | label="Negative prompt", lines=2, value="bad quality")
305 |
306 | with gr.Row().style(equal_height=False):
307 | with gr.Column():
308 | with gr.Row():
309 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
310 | scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
311 | sample_step_slider = gr.Slider(
312 | label="Sampling steps", value=6, minimum=1, maximum=25, step=1)
313 |
314 | width_slider = gr.Slider(
315 | label="Width", value=512, minimum=256, maximum=1024, step=64)
316 | height_slider = gr.Slider(
317 | label="Height", value=512, minimum=256, maximum=1024, step=64)
318 | length_slider = gr.Slider(
319 | label="Animation length", value=16, minimum=12, maximum=20, step=1)
320 | cfg_scale_slider = gr.Slider(
321 | label="CFG Scale", value=1.5, minimum=1, maximum=2)
322 |
323 | with gr.Row():
324 | seed_textbox = gr.Textbox(label="Seed", value=-1)
325 | seed_button = gr.Button(
326 | value="\U0001F3B2", elem_classes="toolbutton")
327 | seed_button.click(fn=lambda: gr.Textbox.update(
328 | value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
329 |
330 | generate_button = gr.Button(
331 | value="Generate", variant='primary')
332 |
333 | result_video = gr.Video(
334 | label="Generated Animation", interactive=False)
335 |
336 |
337 | generate_button.click(
338 | fn=controller.animate,
339 | inputs=[
340 | lora_alpha_slider,
341 | spatial_lora_slider,
342 | prompt_textbox,
343 | negative_prompt_textbox,
344 | sampler_dropdown,
345 | sample_step_slider,
346 | width_slider,
347 | length_slider,
348 | height_slider,
349 | cfg_scale_slider,
350 | seed_textbox,
351 | ],
352 | outputs=[result_video]
353 | )
354 | examples = [
355 | [0.8, 0.8, "a boy is holding a rabbit", "bad quality", "LCM", 8, 512, 16, 512, 1.5, 1234],
356 | [0.8, 0.8, "1girl smiling", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1233],
357 | [0.8, 0.8, "1girl,face,white background,", "bad quality", "LCM", 6, 512, 16, 512, 1.5, 1234],
358 | [0.8, 0.8, "clouds in the sky, best quality", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1234],
359 |
360 |
361 | ]
362 | gr.Examples(
363 | examples = examples,
364 | inputs=[
365 | lora_alpha_slider,
366 | spatial_lora_slider,
367 | prompt_textbox,
368 | negative_prompt_textbox,
369 | sampler_dropdown,
370 | sample_step_slider,
371 | width_slider,
372 | length_slider,
373 | height_slider,
374 | cfg_scale_slider,
375 | seed_textbox,
376 | ],
377 | outputs=[result_video],
378 | fn=controller.animate,
379 | cache_examples=True,
380 | )
381 |
382 | return demo
383 |
384 |
385 | if __name__ == "__main__":
386 | demo = ui()
387 | # gr.close_all()
388 | demo.queue(concurrency_count=3, max_size=20)
389 | demo.launch(share=True, server_name="127.0.0.1")
390 |
--------------------------------------------------------------------------------
/animatelcm_sd15/batch_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import inspect
4 | import os
5 | from omegaconf import OmegaConf
6 |
7 | import torch
8 |
9 | from diffusers import AutoencoderKL
10 |
11 | from tqdm.auto import tqdm
12 | from transformers import CLIPTextModel, CLIPTokenizer
13 |
14 | from animatelcm.models.unet import UNet3DConditionModel
15 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline
16 | from animatelcm.utils.util import save_videos_grid
17 | from animatelcm.utils.util import load_weights
18 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler
19 | from animatelcm.utils.lcm_utils import convert_lcm_lora
20 | from diffusers.utils.import_utils import is_xformers_available
21 | from pathlib import Path
22 |
23 |
24 | def main(args):
25 | *_, func_args = inspect.getargvalues(inspect.currentframe())
26 | func_args = dict(func_args)
27 |
28 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
29 | savedir = f"samples/{Path(args.config).stem}-{time_str}"
30 | os.makedirs(savedir)
31 |
32 | config = OmegaConf.load(args.config)
33 | samples = []
34 |
35 | sample_idx = 0
36 | for model_idx, (config_key, model_config) in enumerate(list(config.items())):
37 |
38 | motion_modules = model_config.motion_module
39 | lcm_lora = model_config.lcm_lora_path
40 | motion_modules = (
41 | [motion_modules]
42 | if isinstance(motion_modules, str)
43 | else list(motion_modules)
44 | )
45 | lcm_lora = [lcm_lora] if isinstance(lcm_lora, str) else list(lcm_lora)
46 | lcm_lora = lcm_lora * len(motion_modules) if len(lcm_lora) == 1 else lcm_lora
47 | for motion_module, lcm_lora in zip(motion_modules, lcm_lora):
48 | inference_config = OmegaConf.load(
49 | model_config.get("inference_config", args.inference_config)
50 | )
51 |
52 | tokenizer = CLIPTokenizer.from_pretrained(
53 | args.pretrained_model_path, subfolder="tokenizer"
54 | )
55 | text_encoder = CLIPTextModel.from_pretrained(
56 | args.pretrained_model_path, subfolder="text_encoder"
57 | )
58 | vae = AutoencoderKL.from_pretrained(
59 | args.pretrained_model_path, subfolder="vae"
60 | )
61 | unet = UNet3DConditionModel.from_pretrained_2d(
62 | args.pretrained_model_path,
63 | subfolder="unet",
64 | unet_additional_kwargs=OmegaConf.to_container(
65 | inference_config.unet_additional_kwargs
66 | ),
67 | )
68 |
69 | if is_xformers_available():
70 | unet.enable_xformers_memory_efficient_attention()
71 | else:
72 | assert False
73 |
74 | pipeline = AnimationPipeline(
75 | vae=vae,
76 | text_encoder=text_encoder,
77 | tokenizer=tokenizer,
78 | unet=unet,
79 | scheduler=LCMScheduler(
80 | **OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
81 | ),
82 | ).to("cuda")
83 |
84 | pipeline = load_weights(
85 | pipeline,
86 | motion_module_path=motion_module,
87 | motion_module_lora_configs=model_config.get(
88 | "motion_module_lora_configs", []
89 | ),
90 | dreambooth_model_path=model_config.get("dreambooth_path", ""),
91 | lora_model_path=model_config.get("lora_model_path", ""),
92 | lora_alpha=model_config.get("lora_alpha", 0.8),
93 | ).to("cuda")
94 |
95 | pipeline.unet = convert_lcm_lora(pipeline.unet, lcm_lora, 1.0)
96 | prompts = model_config.prompt
97 | image_paths = (
98 | model_config.image_paths
99 | if hasattr(model_config, "image_paths")
100 | else [None for _ in range(len(prompts))]
101 | )
102 | control_paths = (
103 | model_config.control_paths
104 | if hasattr(model_config, "control_paths")
105 | else [None for _ in range(len(prompts))]
106 | )
107 | n_prompts = (
108 | list(model_config.n_prompt) * len(prompts)
109 | if len(model_config.n_prompt) == 1
110 | else model_config.n_prompt
111 | )
112 |
113 | random_seeds = model_config.get("seed", [-1])
114 | random_seeds = (
115 | [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
116 | )
117 | random_seeds = (
118 | random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
119 | )
120 |
121 | config[config_key].random_seed = []
122 | for prompt_idx, (
123 | prompt,
124 | n_prompt,
125 | random_seed,
126 | image_path,
127 | control_path,
128 | ) in enumerate(
129 | zip(prompts, n_prompts, random_seeds, image_paths, control_paths)
130 | ):
131 |
132 | if random_seed != -1:
133 | torch.manual_seed(random_seed)
134 | else:
135 | torch.seed()
136 | config[config_key].random_seed.append(torch.initial_seed())
137 |
138 | print(f"current seed: {torch.initial_seed()}")
139 | print(f"sampling {prompt} ...")
140 | sample = pipeline(
141 | prompt,
142 | negative_prompt=n_prompt,
143 | control_path=control_path,
144 | image_path=image_path,
145 | num_inference_steps=model_config.steps,
146 | guidance_scale=model_config.guidance_scale,
147 | width=model_config.W,
148 | height=model_config.H,
149 | video_length=model_config.L,
150 | do_classifier_free_guidance=model_config.get(
151 | "do_classifier_free_guidance", False
152 | ),
153 | ).videos
154 | samples.append(sample)
155 |
156 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
157 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
158 | print(f"save to {savedir}/sample/{prompt}.gif")
159 |
160 | sample_idx += 1
161 |
162 | samples = torch.concat(samples)
163 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
164 |
165 | OmegaConf.save(config, f"{savedir}/config.yaml")
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument(
171 | "--pretrained_model_path",
172 | type=str,
173 | default="models/StableDiffusion/stable-diffusion-v1-5",
174 | )
175 | parser.add_argument(
176 | "--inference_config", type=str, default="configs/inference-t2v.yaml"
177 | )
178 | parser.add_argument("--config", type=str, required=True)
179 |
180 | args = parser.parse_args()
181 | main(args)
182 |
--------------------------------------------------------------------------------
/animatelcm_sd15/configs/batch_inference_example.yaml:
--------------------------------------------------------------------------------
1 | T2V-Example:
2 | inference_config: "configs/inference/inference-t2v.yaml"
3 | motion_module:
4 | - "models/Motion_Module/xxx.ckpt"
5 | lcm_lora:
6 | - "models/LCM_LoRA/xxx.safetensors"
7 |
8 | dreambooth_path: "models/Personalized/xxx.safetensors"
9 | lora_model_path: ""
10 |
11 | seed: [2, 0, 2 , 4]
12 | steps: 4
13 | guidance_scale: 1.5
14 | do_classifier_free_guidance: True
15 | H: 512
16 | W: 512
17 | L: 16
18 |
19 | prompt:
20 | - "xxx, specify your prompt here"
21 | - "xxx, specify your prompt here"
22 | - "xxx, specify your prompt here"
23 | - "xxx, specify your prompt here"
24 |
25 | n_prompt:
26 | - "xxx, specify your negative prompt here"
27 | - "xxx, specify your negative prompt here"
28 | - "xxx, specify your negative prompt here"
29 | - "xxx, specify your negative prompt here"
30 |
31 | I2V-Example:
32 |
33 | inference_config: "configs/inference/inference-i2v.yaml"
34 | motion_module:
35 | - "models/Motion_Module/xxx.ckpt"
36 | lcm_lora:
37 | - "models/LCM_LoRA/xxx.safetensors"
38 |
39 | dreambooth_path: "models/Personalized/xxx.safetensors"
40 | lora_model_path: ""
41 |
42 | seed: [2, 0, 2, 4]
43 | steps: 4
44 | guidance_scale: 1 # should be [1,2]
45 | do_classifier_free_guidance: True
46 | H: 512
47 | W: 512
48 | L: 16
49 |
50 | prompt:
51 | - "xxx, specify your prompt here"
52 | - "xxx, specify your prompt here"
53 | - "xxx, specify your prompt here"
54 | - "xxx, specify your prompt here"
55 |
56 | n_prompt:
57 | - "xxx, specify your negative prompt here"
58 | - "xxx, specify your negative prompt here"
59 | - "xxx, specify your negative prompt here"
60 | - "xxx, specify your negative prompt here"
61 |
62 | image_paths:
63 | - "xxx, specify your image paths here"
64 | - "xxx, specify your image paths here"
65 | - "xxx, specify your image paths here"
66 | - "xxx, specify your image paths here"
67 |
68 |
--------------------------------------------------------------------------------
/animatelcm_sd15/configs/inference-i2v.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_pixel_encoder: false
3 | use_img_encoder: true
4 | use_inflated_groupnorm: true
5 | unet_use_cross_frame_attention: false
6 | unet_use_temporal_attention: false
7 | use_motion_module: true
8 | use_motion_resnet: false
9 | motion_module_resolutions:
10 | - 1
11 | - 2
12 | - 4
13 | - 8
14 | motion_module_mid_block: true
15 | motion_module_deco5der_only: false
16 | motion_module_type: Vanilla
17 | motion_module_kwargs:
18 | num_attention_heads: 8
19 | num_transformer_block: 1
20 | attention_block_types:
21 | - Temporal_Self
22 | - Temporal_Self
23 | temporal_position_encoding: true
24 | temporal_attention_dim_div: 1
25 |
26 | noise_scheduler_kwargs:
27 | beta_start: 0.00085
28 | beta_end: 0.014
29 | beta_schedule: "linear"
30 | original_inference_steps: 200
31 | steps_offset: 1
--------------------------------------------------------------------------------
/animatelcm_sd15/configs/inference-t2v.yaml:
--------------------------------------------------------------------------------
1 | unet_additional_kwargs:
2 | use_pixel_encoder: false
3 | use_img_encoder: false
4 | use_inflated_groupnorm: true
5 | unet_use_cross_frame_attention: false
6 | unet_use_temporal_attention: false
7 | use_motion_module: true
8 | use_motion_resnet: false
9 | motion_module_resolutions:
10 | - 1
11 | - 2
12 | - 4
13 | - 8
14 | motion_module_mid_block: true
15 | motion_module_decoder_only: false
16 | motion_module_type: Vanilla
17 | motion_module_kwargs:
18 | num_attention_heads: 8
19 | num_transformer_block: 1
20 | attention_block_types:
21 | - Temporal_Self
22 | - Temporal_Self
23 | temporal_position_encoding: true
24 | temporal_attention_dim_div: 1
25 |
26 | noise_scheduler_kwargs:
27 | beta_start: 0.00085
28 | beta_end: 0.012
29 | beta_schedule: "linear"
30 | original_inference_steps: 100
31 | steps_offset: 1
--------------------------------------------------------------------------------
/animatelcm_sd15/models/LCM_LoRA/put_spatial_lora_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/LCM_LoRA/put_spatial_lora_here.txt
--------------------------------------------------------------------------------
/animatelcm_sd15/models/Motion_Module/put_motion_module_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/Motion_Module/put_motion_module_here.txt
--------------------------------------------------------------------------------
/animatelcm_sd15/models/Personalized/put_personalized_weights_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/Personalized/put_personalized_weights_here.txt
--------------------------------------------------------------------------------
/animatelcm_sd15/models/StableDiffusion/put_stable_diffusion_v15_here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/StableDiffusion/put_stable_diffusion_v15_here.txt
--------------------------------------------------------------------------------
/animatelcm_sd15/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | torchvision==0.14.1
3 | torchaudio==0.13.1
4 | diffusers==0.11.1
5 | transformers==4.25.1
6 | xformers==0.0.16
7 | imageio==2.27.0
8 | gradio==3.48.0
9 | gdown
10 | einops
11 | omegaconf
12 | safetensors
13 | imageio[ffmpeg]
14 | imageio[pyav]
15 | accelerate
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/cloud.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/cloud.jpeg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/dog.jpg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/fire.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/fire.jpg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/fox.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/fox.jpg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/girl.png
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/girl_flower.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/girl_flower.jpg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/lighter.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/lighter.jpg
--------------------------------------------------------------------------------
/animatelcm_sd15/test_imgs/snow_man_fire.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/snow_man_fire.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/README.md:
--------------------------------------------------------------------------------
1 | ## AnimateLCM SVD
2 |
3 | ### Enviroment
4 |
5 | You can directly using the `environment.yaml`
6 | ```
7 | conda env create -f enviroment.yaml
8 | conda activate animatelcm_svd
9 | ```
10 | or through the requirements.txt
11 |
12 | ```
13 | conda create -n animatelcm_svd python=3.9
14 | conda activate animatelcm_svd
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ### Models
19 |
20 |
21 | You can download the models through `wget`
22 |
23 | ```
24 | cd safetensors
25 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt/resolve/main/AnimateLCM-SVD-xt-1.1.safetensors
26 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt/resolve/main/AnimateLCM-SVD-xt.safetensors
27 | cd ..
28 | ```
29 |
30 | ### Runing
31 |
32 | Simply running
33 | ```
34 | python app.py
35 | ```
36 |
--------------------------------------------------------------------------------
/animatelcm_svd/animate_lcm_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from safetensors import safe_open
4 |
5 |
6 |
7 | def calculate_probabilities(sigma_values, Pmean=0.7, Pstd=1.6):
8 |
9 | log_sigma_values = torch.log(sigma_values)
10 |
11 | erf_diff = torch.erf((log_sigma_values[:-1] - Pmean) / (np.sqrt(2) * Pstd)) - \
12 | torch.erf((log_sigma_values[1:] - Pmean) / (np.sqrt(2) * Pstd))
13 |
14 | probabilities = erf_diff / torch.sum(erf_diff)
15 |
16 | return probabilities
17 |
18 | def append_dims(x, target_dims):
19 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
20 | dims_to_append = target_dims - x.ndim
21 | if dims_to_append < 0:
22 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
23 | return x[(...,) + (None,) * dims_to_append]
24 |
25 |
26 | class SVDSolver():
27 | def __init__(self, N, sigma_min, sigma_max, rho, Pmean, Pstd):
28 | self.sigma_min = sigma_min
29 | self.sigma_max = sigma_max
30 | self.rho = rho
31 | self.N = N
32 | self.Pmean = Pmean
33 | self.Pstd = Pstd
34 |
35 |
36 | self.indices = torch.arange(0, N, dtype=torch.float)
37 | self.sigmas = (sigma_max ** (1 / rho) + self.indices / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)))**rho
38 |
39 | self.indices = torch.cat([self.indices, torch.tensor([N])])
40 | self.sigmas = torch.cat([self.sigmas, torch.tensor([0])])
41 |
42 |
43 | self.probs = torch.ones_like(self.sigmas[:-1])*(1/N)
44 |
45 | self.sigmas = self.sigmas[:,None,None,None,None]
46 | self.timesteps = torch.Tensor([0.25 * (sigma + 1e-44).log() for sigma in self.sigmas])
47 |
48 | self.weights = (1/(self.sigmas[:-1] - self.sigmas[1:]))**0.1 # This is not optimal and can influence the training dynamics a lot. Wish someone can make it better.
49 | self.c_out = -self.sigmas / ((self.sigmas**2 + 1)**0.5)
50 | self.c_skip = 1 / (self.sigmas**2 + 1)
51 | self.c_in = 1 /((self.sigmas**2 + 1) ** 0.5)
52 |
53 | def sample_params(self, indices):
54 |
55 | sampled_sigmas = self.sigmas[indices]
56 | sampled_timesteps = self.timesteps[indices]
57 | sampled_weights = self.weights[torch.where(indices>self.weights.shape[0]-1,self.weights.shape[0]-1,indices)]
58 | sampled_c_out = self.c_out[indices]
59 | sampled_c_in = self.c_in[indices]
60 | sampled_c_skip = self.c_skip[indices]
61 |
62 | return indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip
63 |
64 |
65 | def sample_timesteps(self, bsz):
66 |
67 | sampled_indices = torch.multinomial(self.probs, bsz, replacement=True)
68 |
69 | sampled_indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip = self.sample_params(sampled_indices)
70 |
71 | return sampled_indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip
72 |
73 |
74 | def predicted_origin(self, model_output, indices, sample):
75 | return model_output * self.c_out[indices] + sample * self.c_skip[indices]
76 |
77 | @torch.no_grad()
78 | def euler_solver(self, model_output, sample, indices, indices_next):
79 | x = sample
80 | denoiser = self.predicted_origin(model_output, indices, sample)
81 | d = (x - denoiser) / self.sigmas[indices]
82 | sample = x + d * (self.sigmas[indices_next] - self.sigmas[indices])
83 |
84 | return sample
85 |
86 | @torch.no_grad()
87 | def heun_solver(self, model_output, sample, indices, indices_next, model_fn):
88 | pass
89 |
90 | def to(self,device,dtype):
91 | self.indinces = self.indices.to(device,dtype)
92 | self.sigmas = self.sigmas.to(device,dtype)
93 | self.timesteps=self.timesteps.to(device,dtype)
94 | self.probs=self.probs.to(device,dtype)
95 | self.weights=self.weights.to(device,dtype)
96 | self.c_out=self.c_out.to(device,dtype)
97 | self.c_skip=self.c_skip.to(device,dtype)
98 | self.c_in=self.c_in.to(device,dtype)
99 |
100 |
101 | def extract_into_tensor(a, t, x_shape):
102 | b, *_ = t.shape
103 | out = a.gather(-1, t)
104 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
105 |
106 |
107 | @torch.no_grad()
108 | def update_ema(target_params, source_params, rate=0.99):
109 | """
110 | Update target parameters to be closer to those of source parameters using
111 | an exponential moving average.
112 |
113 | :param target_params: the target parameter sequence.
114 | :param source_params: the source parameter sequence.
115 | :param rate: the EMA rate (closer to 1 means slower).
116 | """
117 | for targ, src in zip(target_params, source_params):
118 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
119 |
120 |
121 |
122 |
--------------------------------------------------------------------------------
/animatelcm_svd/animatelcm_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import numpy as np
19 | import torch
20 |
21 | from diffusers.configuration_utils import ConfigMixin, register_to_config
22 | from diffusers.utils import BaseOutput, logging
23 | from diffusers.utils.torch_utils import randn_tensor
24 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
25 |
26 |
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 |
29 |
30 | @dataclass
31 | class AnimateLCMSVDStochasticIterativeSchedulerOutput(BaseOutput):
32 | """
33 | Output class for the scheduler's `step` function.
34 |
35 | Args:
36 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38 | denoising loop.
39 | """
40 |
41 | prev_sample: torch.FloatTensor
42 |
43 |
44 | class AnimateLCMSVDStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
45 | """
46 | Multistep and onestep sampling for consistency models.
47 |
48 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49 | methods the library implements for all schedulers such as loading and saving.
50 |
51 | Args:
52 | num_train_timesteps (`int`, defaults to 40):
53 | The number of diffusion steps to train the model.
54 | sigma_min (`float`, defaults to 0.002):
55 | Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation.
56 | sigma_max (`float`, defaults to 80.0):
57 | Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation.
58 | sigma_data (`float`, defaults to 0.5):
59 | The standard deviation of the data distribution from the EDM
60 | [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation.
61 | s_noise (`float`, defaults to 1.0):
62 | The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
63 | 1.011]. Defaults to 1.0 from the original implementation.
64 | rho (`float`, defaults to 7.0):
65 | The parameter for calculating the Karras sigma schedule from the EDM
66 | [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation.
67 | clip_denoised (`bool`, defaults to `True`):
68 | Whether to clip the denoised outputs to `(-1, 1)`.
69 | timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
70 | An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in
71 | increasing order.
72 | """
73 |
74 | order = 1
75 |
76 | @register_to_config
77 | def __init__(
78 | self,
79 | num_train_timesteps: int = 40,
80 | sigma_min: float = 0.002,
81 | sigma_max: float = 80.0,
82 | sigma_data: float = 0.5,
83 | s_noise: float = 1.0,
84 | rho: float = 7.0,
85 | clip_denoised: bool = True,
86 | ):
87 | # standard deviation of the initial noise distribution
88 | self.init_noise_sigma = (sigma_max**2 + 1) ** 0.5
89 | # self.init_noise_sigma = sigma_max
90 |
91 | ramp = np.linspace(0, 1, num_train_timesteps)
92 | sigmas = self._convert_to_karras(ramp)
93 | sigmas = np.concatenate([sigmas, np.array([0])])
94 | timesteps = self.sigma_to_t(sigmas)
95 |
96 | # setable values
97 | self.num_inference_steps = None
98 | self.sigmas = torch.from_numpy(sigmas)
99 | self.timesteps = torch.from_numpy(timesteps)
100 | self.custom_timesteps = False
101 | self.is_scale_input_called = False
102 | self._step_index = None
103 | self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
104 |
105 | def index_for_timestep(self, timestep, schedule_timesteps=None):
106 | if schedule_timesteps is None:
107 | schedule_timesteps = self.timesteps
108 |
109 | indices = (schedule_timesteps == timestep).nonzero()
110 | return indices.item()
111 |
112 | @property
113 | def step_index(self):
114 | """
115 | The index counter for current timestep. It will increae 1 after each scheduler step.
116 | """
117 | return self._step_index
118 |
119 | def scale_model_input(
120 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
121 | ) -> torch.FloatTensor:
122 | """
123 | Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
124 |
125 | Args:
126 | sample (`torch.FloatTensor`):
127 | The input sample.
128 | timestep (`float` or `torch.FloatTensor`):
129 | The current timestep in the diffusion chain.
130 |
131 | Returns:
132 | `torch.FloatTensor`:
133 | A scaled input sample.
134 | """
135 | # Get sigma corresponding to timestep
136 | if self.step_index is None:
137 | self._init_step_index(timestep)
138 |
139 | sigma = self.sigmas[self.step_index]
140 | sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
141 |
142 | self.is_scale_input_called = True
143 | return sample
144 |
145 | # def _sigma_to_t(self, sigma, log_sigmas):
146 | # # get log sigma
147 | # log_sigma = np.log(np.maximum(sigma, 1e-10))
148 |
149 | # # get distribution
150 | # dists = log_sigma - log_sigmas[:, np.newaxis]
151 |
152 | # # get sigmas range
153 | # low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
154 | # high_idx = low_idx + 1
155 |
156 | # low = log_sigmas[low_idx]
157 | # high = log_sigmas[high_idx]
158 |
159 | # # interpolate sigmas
160 | # w = (low - log_sigma) / (low - high)
161 | # w = np.clip(w, 0, 1)
162 |
163 | # # transform interpolation to time range
164 | # t = (1 - w) * low_idx + w * high_idx
165 | # t = t.reshape(sigma.shape)
166 | # return t
167 |
168 | def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
169 | """
170 | Gets scaled timesteps from the Karras sigmas for input to the consistency model.
171 |
172 | Args:
173 | sigmas (`float` or `np.ndarray`):
174 | A single Karras sigma or an array of Karras sigmas.
175 |
176 | Returns:
177 | `float` or `np.ndarray`:
178 | A scaled input timestep or scaled input timestep array.
179 | """
180 | if not isinstance(sigmas, np.ndarray):
181 | sigmas = np.array(sigmas, dtype=np.float64)
182 |
183 | timesteps = 0.25 * np.log(sigmas + 1e-44)
184 |
185 | return timesteps
186 |
187 | def set_timesteps(
188 | self,
189 | num_inference_steps: Optional[int] = None,
190 | device: Union[str, torch.device] = None,
191 | timesteps: Optional[List[int]] = None,
192 | ):
193 | """
194 | Sets the timesteps used for the diffusion chain (to be run before inference).
195 |
196 | Args:
197 | num_inference_steps (`int`):
198 | The number of diffusion steps used when generating samples with a pre-trained model.
199 | device (`str` or `torch.device`, *optional*):
200 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
201 | timesteps (`List[int]`, *optional*):
202 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
203 | timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
204 | `num_inference_steps` must be `None`.
205 | """
206 | if num_inference_steps is None and timesteps is None:
207 | raise ValueError(
208 | "Exactly one of `num_inference_steps` or `timesteps` must be supplied."
209 | )
210 |
211 | if num_inference_steps is not None and timesteps is not None:
212 | raise ValueError(
213 | "Can only pass one of `num_inference_steps` or `timesteps`."
214 | )
215 |
216 | # Follow DDPMScheduler custom timesteps logic
217 | if timesteps is not None:
218 | for i in range(1, len(timesteps)):
219 | if timesteps[i] >= timesteps[i - 1]:
220 | raise ValueError("`timesteps` must be in descending order.")
221 |
222 | if timesteps[0] >= self.config.num_train_timesteps:
223 | raise ValueError(
224 | f"`timesteps` must start before `self.config.train_timesteps`:"
225 | f" {self.config.num_train_timesteps}."
226 | )
227 |
228 | timesteps = np.array(timesteps, dtype=np.int64)
229 | self.custom_timesteps = True
230 | else:
231 | if num_inference_steps > self.config.num_train_timesteps:
232 | raise ValueError(
233 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
234 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
235 | f" maximal {self.config.num_train_timesteps} timesteps."
236 | )
237 |
238 | self.num_inference_steps = num_inference_steps
239 |
240 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps
241 | timesteps = (
242 | (np.arange(0, num_inference_steps) * step_ratio)
243 | .round()[::-1]
244 | .copy()
245 | .astype(np.int64)
246 | )
247 | self.custom_timesteps = False
248 |
249 | # Map timesteps to Karras sigmas directly for multistep sampling
250 | # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
251 | num_train_timesteps = self.config.num_train_timesteps
252 | ramp = timesteps[::-1].copy()
253 | ramp = ramp / (num_train_timesteps - 1)
254 | sigmas = self._convert_to_karras(ramp)
255 | timesteps = self.sigma_to_t(sigmas)
256 |
257 | sigmas = np.concatenate([sigmas, [0]]).astype(np.float32)
258 | self.sigmas = torch.from_numpy(sigmas).to(device=device)
259 |
260 | if str(device).startswith("mps"):
261 | # mps does not support float64
262 | self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
263 | else:
264 | self.timesteps = torch.from_numpy(timesteps).to(device=device)
265 |
266 | self._step_index = None
267 | self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
268 |
269 | # Modified _convert_to_karras implementation that takes in ramp as argument
270 | def _convert_to_karras(self, ramp):
271 | """Constructs the noise schedule of Karras et al. (2022)."""
272 |
273 | sigma_min: float = self.config.sigma_min
274 | sigma_max: float = self.config.sigma_max
275 |
276 | rho = self.config.rho
277 | min_inv_rho = sigma_min ** (1 / rho)
278 | max_inv_rho = sigma_max ** (1 / rho)
279 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
280 | return sigmas
281 |
282 | def get_scalings(self, sigma):
283 | sigma_data = self.config.sigma_data
284 |
285 | c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
286 | c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
287 | return c_skip, c_out
288 |
289 | def get_scalings_for_boundary_condition(self, sigma):
290 | """
291 | Gets the scalings used in the consistency model parameterization (from Appendix C of the
292 | [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
293 |
294 |
295 |
296 | `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
297 |
298 |
299 |
300 | Args:
301 | sigma (`torch.FloatTensor`):
302 | The current sigma in the Karras sigma schedule.
303 |
304 | Returns:
305 | `tuple`:
306 | A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
307 | (which weights the consistency model output) is the second element.
308 | """
309 | sigma_min = self.config.sigma_min
310 | sigma_data = self.config.sigma_data
311 |
312 | c_skip = sigma_data**2 / ((sigma) ** 2 + sigma_data**2)
313 | c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
314 | return c_skip, c_out
315 |
316 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
317 | def _init_step_index(self, timestep):
318 | if isinstance(timestep, torch.Tensor):
319 | timestep = timestep.to(self.timesteps.device)
320 |
321 | index_candidates = (self.timesteps == timestep).nonzero()
322 |
323 | # The sigma index that is taken for the **very** first `step`
324 | # is always the second index (or the last index if there is only 1)
325 | # This way we can ensure we don't accidentally skip a sigma in
326 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
327 | if len(index_candidates) > 1:
328 | step_index = index_candidates[1]
329 | else:
330 | step_index = index_candidates[0]
331 |
332 | self._step_index = step_index.item()
333 |
334 | def step(
335 | self,
336 | model_output: torch.FloatTensor,
337 | timestep: Union[float, torch.FloatTensor],
338 | sample: torch.FloatTensor,
339 | generator: Optional[torch.Generator] = None,
340 | return_dict: bool = True,
341 | ) -> Union[AnimateLCMSVDStochasticIterativeSchedulerOutput, Tuple]:
342 | """
343 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
344 | process from the learned model outputs (most often the predicted noise).
345 |
346 | Args:
347 | model_output (`torch.FloatTensor`):
348 | The direct output from the learned diffusion model.
349 | timestep (`float`):
350 | The current timestep in the diffusion chain.
351 | sample (`torch.FloatTensor`):
352 | A current instance of a sample created by the diffusion process.
353 | generator (`torch.Generator`, *optional*):
354 | A random number generator.
355 | return_dict (`bool`, *optional*, defaults to `True`):
356 | Whether or not to return a
357 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] or `tuple`.
358 |
359 | Returns:
360 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] or `tuple`:
361 | If return_dict is `True`,
362 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] is returned,
363 | otherwise a tuple is returned where the first element is the sample tensor.
364 | """
365 |
366 | if (
367 | isinstance(timestep, int)
368 | or isinstance(timestep, torch.IntTensor)
369 | or isinstance(timestep, torch.LongTensor)
370 | ):
371 | raise ValueError(
372 | (
373 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
374 | f" `{self.__class__}.step()` is not supported. Make sure to pass"
375 | " one of the `scheduler.timesteps` as a timestep."
376 | ),
377 | )
378 |
379 | if not self.is_scale_input_called:
380 | logger.warning(
381 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
382 | "See `StableDiffusionPipeline` for a usage example."
383 | )
384 |
385 | sigma_min = self.config.sigma_min
386 | sigma_max = self.config.sigma_max
387 |
388 | if self.step_index is None:
389 | self._init_step_index(timestep)
390 |
391 | # sigma_next corresponds to next_t in original implementation
392 | sigma = self.sigmas[self.step_index]
393 | if self.step_index + 1 < self.config.num_train_timesteps:
394 | sigma_next = self.sigmas[self.step_index + 1]
395 | else:
396 | # Set sigma_next to sigma_min
397 | sigma_next = self.sigmas[-1]
398 |
399 | # Get scalings for boundary conditions
400 |
401 | c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
402 |
403 | # 1. Denoise model output using boundary conditions
404 | denoised = c_out * model_output + c_skip * sample
405 | if self.config.clip_denoised:
406 | denoised = denoised.clamp(-1, 1)
407 |
408 | # 2. Sample z ~ N(0, s_noise^2 * I)
409 | # Noise is not used for onestep sampling.
410 | if len(self.timesteps) > 1:
411 | noise = randn_tensor(
412 | model_output.shape,
413 | dtype=model_output.dtype,
414 | device=model_output.device,
415 | generator=generator,
416 | )
417 | else:
418 | noise = torch.zeros_like(model_output)
419 | z = noise * self.config.s_noise
420 |
421 | sigma_hat = sigma_next.clamp(min=0, max=sigma_max)
422 |
423 | print("denoise currently")
424 | print(sigma_hat)
425 |
426 | # origin
427 | prev_sample = denoised + z * sigma_hat
428 |
429 | # upon completion increase step index by one
430 | self._step_index += 1
431 |
432 | if not return_dict:
433 | return (prev_sample,)
434 |
435 | return AnimateLCMSVDStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
436 |
437 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
438 | def add_noise(
439 | self,
440 | original_samples: torch.FloatTensor,
441 | noise: torch.FloatTensor,
442 | timesteps: torch.FloatTensor,
443 | ) -> torch.FloatTensor:
444 | # Make sure sigmas and timesteps have the same device and dtype as original_samples
445 | sigmas = self.sigmas.to(
446 | device=original_samples.device, dtype=original_samples.dtype
447 | )
448 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
449 | # mps does not support float64
450 | schedule_timesteps = self.timesteps.to(
451 | original_samples.device, dtype=torch.float32
452 | )
453 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
454 | else:
455 | schedule_timesteps = self.timesteps.to(original_samples.device)
456 | timesteps = timesteps.to(original_samples.device)
457 |
458 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
459 |
460 | sigma = sigmas[step_indices].flatten()
461 | while len(sigma.shape) < len(original_samples.shape):
462 | sigma = sigma.unsqueeze(-1)
463 |
464 | noisy_samples = original_samples + noise * sigma
465 | return noisy_samples
466 |
467 | def __len__(self):
468 | return self.config.num_train_timesteps
469 |
--------------------------------------------------------------------------------
/animatelcm_svd/app.py:
--------------------------------------------------------------------------------
1 | # import spaces
2 |
3 | import gradio as gr
4 | # import gradio.helpers
5 | import torch
6 | import os
7 | from glob import glob
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | from PIL import Image
12 | from diffusers.utils import load_image, export_to_video
13 | from pipeline import StableVideoDiffusionPipeline
14 |
15 | import random
16 | from safetensors import safe_open
17 | from animatelcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
18 |
19 |
20 | def get_safetensors_files():
21 | models_dir = "./safetensors"
22 | safetensors_files = [
23 | f for f in os.listdir(models_dir) if f.endswith(".safetensors")
24 | ]
25 | return safetensors_files
26 |
27 |
28 | def model_select(selected_file):
29 | print("load model weights", selected_file)
30 | pipe.unet.cpu()
31 | file_path = os.path.join("./safetensors", selected_file)
32 | state_dict = {}
33 | with safe_open(file_path, framework="pt", device="cpu") as f:
34 | for key in f.keys():
35 | state_dict[key] = f.get_tensor(key)
36 | missing, unexpected = pipe.unet.load_state_dict(state_dict, strict=True)
37 | pipe.unet.cuda()
38 | del state_dict
39 | return
40 |
41 |
42 | noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler(
43 | num_train_timesteps=40,
44 | sigma_min=0.002,
45 | sigma_max=700.0,
46 | sigma_data=1.0,
47 | s_noise=1.0,
48 | rho=7,
49 | clip_denoised=False,
50 | )
51 | pipe = StableVideoDiffusionPipeline.from_pretrained(
52 | "stabilityai/stable-video-diffusion-img2vid-xt",
53 | scheduler=noise_scheduler,
54 | torch_dtype=torch.float16,
55 | variant="fp16",
56 | )
57 | pipe.to("cuda")
58 | pipe.enable_model_cpu_offload() # for smaller cost
59 | model_select("AnimateLCM-SVD-xt-1.1.safetensors")
60 | # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) # for faster inference
61 |
62 |
63 | max_64_bit_int = 2**63 - 1
64 |
65 | # @spaces.GPU
66 | def sample(
67 | image: Image,
68 | seed: Optional[int] = 42,
69 | randomize_seed: bool = False,
70 | motion_bucket_id: int = 80,
71 | fps_id: int = 8,
72 | max_guidance_scale: float = 1.2,
73 | min_guidance_scale: float = 1,
74 | width: int = 1024,
75 | height: int = 576,
76 | num_inference_steps: int = 4,
77 | decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
78 | output_folder: str = "outputs_gradio",
79 | ):
80 | if image.mode == "RGBA":
81 | image = image.convert("RGB")
82 |
83 | if randomize_seed:
84 | seed = random.randint(0, max_64_bit_int)
85 | generator = torch.manual_seed(seed)
86 |
87 | os.makedirs(output_folder, exist_ok=True)
88 | base_count = len(glob(os.path.join(output_folder, "*.mp4")))
89 | video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
90 |
91 | with torch.autocast("cuda"):
92 | frames = pipe(
93 | image,
94 | decode_chunk_size=decoding_t,
95 | generator=generator,
96 | motion_bucket_id=motion_bucket_id,
97 | height=height,
98 | width=width,
99 | num_inference_steps=num_inference_steps,
100 | min_guidance_scale=min_guidance_scale,
101 | max_guidance_scale=max_guidance_scale,
102 | ).frames[0]
103 | export_to_video(frames, video_path, fps=fps_id)
104 | torch.manual_seed(seed)
105 |
106 | return video_path, seed
107 |
108 |
109 | def resize_image(image, output_size=(1024, 576)):
110 | # Calculate aspect ratios
111 | target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
112 | image_aspect = image.width / image.height # Aspect ratio of the original image
113 |
114 | # Resize then crop if the original image is larger
115 | if image_aspect > target_aspect:
116 | # Resize the image to match the target height, maintaining aspect ratio
117 | new_height = output_size[1]
118 | new_width = int(new_height * image_aspect)
119 | resized_image = image.resize((new_width, new_height), Image.LANCZOS)
120 | # Calculate coordinates for cropping
121 | left = (new_width - output_size[0]) / 2
122 | top = 0
123 | right = (new_width + output_size[0]) / 2
124 | bottom = output_size[1]
125 | else:
126 | # Resize the image to match the target width, maintaining aspect ratio
127 | new_width = output_size[0]
128 | new_height = int(new_width / image_aspect)
129 | resized_image = image.resize((new_width, new_height), Image.LANCZOS)
130 | # Calculate coordinates for cropping
131 | left = 0
132 | top = (new_height - output_size[1]) / 2
133 | right = output_size[0]
134 | bottom = (new_height + output_size[1]) / 2
135 |
136 | # Crop the image
137 | cropped_image = resized_image.crop((left, top, right, bottom))
138 | return cropped_image
139 |
140 |
141 | with gr.Blocks() as demo:
142 | gr.Markdown(
143 | """
144 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
145 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
146 |
147 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm)
148 |
149 | Related Models:
150 | [AnimateLCM-t2v](https://huggingface.co/wangfuyun/AnimateLCM): Personalized Text-to-Video Generation
151 | [AnimateLCM-SVD-xt](https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt): General Image-to-Video Generation
152 | [AnimateLCM-i2v](https://huggingface.co/wangfuyun/AnimateLCM-I2V): Personalized Image-to-Video Generation
153 | """
154 | )
155 | with gr.Row():
156 | with gr.Column():
157 | image = gr.Image(label="Upload your image", type="pil")
158 | generate_btn = gr.Button("Generate")
159 | video = gr.Video()
160 | with gr.Accordion("Advanced options", open=False):
161 | safetensors_dropdown = gr.Dropdown(
162 | label="Choose Safetensors", choices=get_safetensors_files()
163 | )
164 | seed = gr.Slider(
165 | label="Seed",
166 | value=42,
167 | randomize=False,
168 | minimum=0,
169 | maximum=max_64_bit_int,
170 | step=1,
171 | )
172 | randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
173 | motion_bucket_id = gr.Slider(
174 | label="Motion bucket id",
175 | info="Controls how much motion to add/remove from the image",
176 | value=80,
177 | minimum=1,
178 | maximum=255,
179 | )
180 | fps_id = gr.Slider(
181 | label="Frames per second",
182 | info="The length of your video in seconds will be 25/fps",
183 | value=8,
184 | minimum=5,
185 | maximum=30,
186 | )
187 | width = gr.Slider(
188 | label="Width of input image",
189 | info="It should be divisible by 64",
190 | value=1024,
191 | minimum=576,
192 | maximum=2048,
193 | )
194 | height = gr.Slider(
195 | label="Height of input image",
196 | info="It should be divisible by 64",
197 | value=576,
198 | minimum=320,
199 | maximum=1152,
200 | )
201 | max_guidance_scale = gr.Slider(
202 | label="Max guidance scale",
203 | info="classifier-free guidance strength",
204 | value=1.2,
205 | minimum=1,
206 | maximum=2,
207 | )
208 | min_guidance_scale = gr.Slider(
209 | label="Min guidance scale",
210 | info="classifier-free guidance strength",
211 | value=1,
212 | minimum=1,
213 | maximum=1.5,
214 | )
215 | num_inference_steps = gr.Slider(
216 | label="Num inference steps",
217 | info="steps for inference",
218 | value=4,
219 | minimum=1,
220 | maximum=20,
221 | step=1,
222 | )
223 |
224 | image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
225 | generate_btn.click(
226 | fn=sample,
227 | inputs=[
228 | image,
229 | seed,
230 | randomize_seed,
231 | motion_bucket_id,
232 | fps_id,
233 | max_guidance_scale,
234 | min_guidance_scale,
235 | width,
236 | height,
237 | num_inference_steps,
238 | ],
239 | outputs=[video, seed],
240 | api_name="video",
241 | )
242 | safetensors_dropdown.change(fn=model_select, inputs=safetensors_dropdown)
243 |
244 | gr.Examples(
245 | examples=[
246 | ["test_imgs/ai-generated-8496135_1280.jpg"],
247 | ["test_imgs/dog-7396912_1280.jpg"],
248 | ["test_imgs/ship-7833921_1280.jpg"],
249 | ["test_imgs/girl-4898696_1280.jpg"],
250 | ["test_imgs/power-station-6579092_1280.jpg"]
251 | ],
252 | inputs=[image],
253 | outputs=[video, seed],
254 | fn=sample,
255 | cache_examples=True,
256 | )
257 |
258 | if __name__ == "__main__":
259 | demo.queue(max_size=20, api_open=False)
260 | demo.launch(share=True, show_api=False)
261 |
--------------------------------------------------------------------------------
/animatelcm_svd/batch_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from safetensors import safe_open
3 | from pipeline import StableVideoDiffusionPipeline
4 | from animatelcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler
5 | from diffusers.utils import load_image, export_to_gif
6 | import os
7 |
8 |
9 | def run_inference_once(image_path, height, width, inference_time, min_guidance_scale, max_guidance_scale, noise_scheduler, weight = None,):
10 | if noise_scheduler is not None:
11 | pipe = StableVideoDiffusionPipeline.from_pretrained(
12 | "/mnt/afs/wangfuyun/SVD-xt/stable-video-diffusion-img2vid-xt", scheduler = noise_scheduler, torch_dtype=torch.float16, variant="fp16"
13 | )
14 | else:
15 | pipe = StableVideoDiffusionPipeline.from_pretrained(
16 | "/mnt/afs/wangfuyun/SVD-xt/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
17 | )
18 | pipe.enable_model_cpu_offload()
19 |
20 | if weight is not None:
21 | state_dict = {}
22 | with safe_open(weight, framework="pt", device="cpu") as f:
23 | for key in f.keys():
24 | state_dict[key] = f.get_tensor(key)
25 | m,u = pipe.unet.load_state_dict(state_dict,strict=True)
26 | assert len(u) == 0
27 | del state_dict
28 |
29 | image = load_image(image_path)
30 | image = image.resize((height,width))
31 |
32 | generator = torch.manual_seed(42)
33 | frames = pipe(image, decode_chunk_size=8, generator=generator, num_frames=25, height=height, width=width, num_inference_steps=inference_time, min_guidance_scale = min_guidance_scale, max_guidance_scale = max_guidance_scale).frames[0]
34 | export_to_gif(frames, f"output_gifs/{image_path[-20:-5]}-{height}-{width}-{inference_time}-{min_guidance_scale}-{max_guidance_scale}-{weight is None}.gif")
35 |
36 | if __name__ == "__main__":
37 | path = "test_imgs"
38 | weight_path = None
39 | assert weight_path is not None
40 | noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler(
41 | num_train_timesteps= 40,
42 | sigma_min = 0.002,
43 | sigma_max = 700.0,
44 | sigma_data = 1.0,
45 | s_noise = 1.0,
46 | rho = 7,
47 | clip_denoised = False,
48 | )
49 | # noise_scheduler = None
50 | assert noise_scheduler is not None
51 | for image_path in os.listdir(path)[5:]:
52 | image_path = os.path.join(path, image_path)
53 | for inference_time in [1, 2, 4, 8]:
54 | for height, width in [(576, 1024)]:
55 | # for min_scale, max_scale in [(1, 1.5), (1.2, 1.5), (1, 2)]:
56 | for min_scale, max_scale in [(1.0,1.0)]:
57 | run_inference_once(image_path, height, width, inference_time, min_scale, max_scale, noise_scheduler, weight = weight_path)
58 |
--------------------------------------------------------------------------------
/animatelcm_svd/dataset.py:
--------------------------------------------------------------------------------
1 | import os, io, csv, math, random
2 | import numpy as np
3 | from einops import rearrange
4 | from decord import VideoReader
5 |
6 | import torch
7 | import torchvision.transforms as transforms
8 | from torch.utils.data.dataset import Dataset
9 |
10 | import imageio.v2 as imageio
11 |
12 | class WebVid10M(Dataset):
13 | def __init__(
14 | self,
15 | video_folder,
16 | sample_size=256, sample_stride=6, sample_n_frames=16,
17 | is_image=False,
18 | ):
19 | self.dataset = [os.path.join(video_folder,video_path) for video_path in os.listdir(video_folder) if video_path.endswith(("mp4",))]
20 | random.shuffle(self.dataset)
21 | self.length = len(self.dataset)
22 |
23 | self.video_folder = video_folder
24 |
25 | self.sample_stride = sample_stride
26 | self.sample_n_frames = sample_n_frames
27 | self.is_image = is_image
28 |
29 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
30 | self.pixel_transforms = transforms.Compose([
31 | transforms.RandomHorizontalFlip(),
32 | transforms.Resize(sample_size[0]),
33 | transforms.CenterCrop(sample_size),
34 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
35 | ])
36 |
37 | def get_batch(self, idx):
38 | video_dir = self.dataset[idx]
39 | name = open(video_dir.replace("mp4","txt"),"r").readline().strip()
40 | video_reader = VideoReader(video_dir)
41 | video_length = len(video_reader)
42 |
43 | if not self.is_image:
44 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
45 | start_idx = random.randint(0, video_length - clip_length)
46 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
47 | else:
48 | batch_index = [random.randint(0, video_length - 1)]
49 |
50 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
51 | pixel_values = pixel_values / 255.
52 | del video_reader
53 |
54 | if self.is_image:
55 | pixel_values = pixel_values[0]
56 |
57 | return pixel_values, name
58 |
59 | def __len__(self):
60 | return self.length
61 |
62 | def __getitem__(self, idx):
63 | while True:
64 | try:
65 | pixel_values, name = self.get_batch(idx)
66 | break
67 |
68 | except Exception as e:
69 | print(e)
70 | idx = random.randint(0, self.length-1)
71 |
72 | pixel_values = self.pixel_transforms(pixel_values)
73 | sample = dict(pixel_values=pixel_values, text=name)
74 | return sample
75 |
--------------------------------------------------------------------------------
/animatelcm_svd/enviroment.yaml:
--------------------------------------------------------------------------------
1 | name: animatelcm_svd
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.12.12=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.12=h7f8727e_0
15 | - pip=23.3.1=py39h06a4308_0
16 | - python=3.9.18=h955ad1f_0
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=68.2.2=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.41.2=py39h06a4308_0
22 | - xz=5.4.5=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - accelerate==0.26.1
26 | - albumentations==1.3.1
27 | - antlr4-python3-runtime==4.9.3
28 | - appdirs==1.4.4
29 | - bitsandbytes==0.42.0
30 | - braceexpand==0.1.7
31 | - brotli==1.1.0
32 | - certifi==2023.11.17
33 | - cffi==1.16.0
34 | - charset-normalizer==3.3.2
35 | - click==8.1.7
36 | - dataclasses==0.6
37 | - decord==0.6.0
38 | - diffusers==0.25.1
39 | - docker-pycreds==0.4.0
40 | - docopt==0.6.2
41 | - einops==0.7.0
42 | - exifread-nocycle==3.0.1
43 | - ffmpeg-python==0.2.0
44 | - filelock==3.13.1
45 | - fire==0.5.0
46 | - fsspec==2023.12.2
47 | - future==0.18.3
48 | - gitdb==4.0.11
49 | - gitpython==3.1.41
50 | - huggingface-hub==0.20.3
51 | - idna==3.6
52 | - imageio==2.33.1
53 | - img2dataset==1.45.0
54 | - importlib-metadata==7.0.1
55 | - jinja2==3.1.3
56 | - joblib==1.3.2
57 | - langdetect==1.0.9
58 | - lazy-loader==0.3
59 | - markupsafe==2.1.4
60 | - mpmath==1.3.0
61 | - mutagen==1.47.0
62 | - networkx==3.2.1
63 | - numpy==1.26.3
64 | - nvidia-cublas-cu12==12.1.3.1
65 | - nvidia-cuda-cupti-cu12==12.1.105
66 | - nvidia-cuda-nvrtc-cu12==12.1.105
67 | - nvidia-cuda-runtime-cu12==12.1.105
68 | - nvidia-cudnn-cu12==8.9.2.26
69 | - nvidia-cufft-cu12==11.0.2.54
70 | - nvidia-curand-cu12==10.3.2.106
71 | - nvidia-cusolver-cu12==11.4.5.107
72 | - nvidia-cusparse-cu12==12.1.0.106
73 | - nvidia-nccl-cu12==2.19.3
74 | - nvidia-nvjitlink-cu12==12.3.101
75 | - nvidia-nvtx-cu12==12.1.105
76 | - omegaconf==2.3.0
77 | - opencv-python==4.9.0.80
78 | - opencv-python-headless==4.9.0.80
79 | - packaging==23.2
80 | - pandas==2.2.0
81 | - pillow==10.2.0
82 | - platformdirs==4.1.0
83 | - protobuf==4.25.2
84 | - psutil==5.9.8
85 | - pyarrow==15.0.0
86 | - pycparser==2.21
87 | - pycryptodomex==3.20.0
88 | - python-dateutil==2.8.2
89 | - pytz==2023.3.post1
90 | - pyyaml==6.0.1
91 | - qudida==0.0.4
92 | - regex==2023.12.25
93 | - requests==2.31.0
94 | - safetensors==0.4.2
95 | - scenedetect==0.6.2
96 | - scikit-image==0.22.0
97 | - scikit-learn==1.4.0
98 | - scipy==1.12.0
99 | - sentry-sdk==1.39.2
100 | - setproctitle==1.3.3
101 | - six==1.16.0
102 | - smmap==5.0.1
103 | - soundfile==0.12.1
104 | - sympy==1.12
105 | - termcolor==2.4.0
106 | - threadpoolctl==3.2.0
107 | - tifffile==2023.12.9
108 | - timeout-decorator==0.5.0
109 | - tokenizers==0.15.1
110 | - torch==2.2.0
111 | - torchdata==0.7.1
112 | - torchvision==0.17.0
113 | - tqdm==4.66.1
114 | - transformers==4.37.0
115 | - triton==2.2.0
116 | - typing-extensions==4.9.0
117 | - tzdata==2023.4
118 | - urllib3==2.1.0
119 | - wandb==0.16.2
120 | - webdataset==0.2.86
121 | - websockets==12.0
122 | - webvtt-py==0.4.6
123 | - xformers==0.0.24
124 | - yt-dlp==2023.12.30
125 | - zipp==3.17.0
126 |
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000000.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000000.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000001.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000001.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000002.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000002.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000003.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000003.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000004.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000004.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/outputs_gradio/000005.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000005.mp4
--------------------------------------------------------------------------------
/animatelcm_svd/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | diffusers==0.25.1
3 | gradio==4.19.2
4 | numpy==1.26.4
5 | Pillow==10.2.0
6 | torch==2.2.0
7 | transformers==4.37.0
8 | spaces==0.23.2
9 | opencv-python
10 | xformers
--------------------------------------------------------------------------------
/animatelcm_svd/safetensors/AnimateLCM-SVD-xt-1.1.safetensors:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:94b2da1fcca8d03458ef0b07b94b4c55f94117cb0d90265c6c8452239ecc166e
3 | size 6098682464
4 |
--------------------------------------------------------------------------------
/animatelcm_svd/safetensors/AnimateLCM-SVD-xt.safetensors:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:6ca55e35f29437e8a65e8a1a9ce75262d5bab3d4fe137bdc3f3a94512c54b377
3 | size 6098682464
4 |
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/.DS_Store
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8411866_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8411866_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8463496_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8463496_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8476858_1280.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8476858_1280.png
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8479572_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8479572_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8481641_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8481641_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8496135_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8496135_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8496952_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8496952_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ai-generated-8498844_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8498844_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/bird-7411270_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-7411270_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/bird-7586857_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-7586857_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/bird-8014191_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-8014191_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/couple-8019370_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/couple-8019370_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/cupcakes-380178_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/cupcakes-380178_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/dog-7330712_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/dog-7330712_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/dog-7396912_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/dog-7396912_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/girl-4898696_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/girl-4898696_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/grey-capped-flycatcher-8071233_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/grey-capped-flycatcher-8071233_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/halloween-4585684_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/halloween-4585684_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/leaf-7260246_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/leaf-7260246_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/meerkat-7465819_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/meerkat-7465819_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/mobile-phone-1875813_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/mobile-phone-1875813_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/mother-8097324_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/mother-8097324_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/plane-8145957_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/plane-8145957_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/power-station-6579092_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/power-station-6579092_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/ship-7833921_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ship-7833921_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/sleep-7871915_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/sleep-7871915_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/squirrel-7985502_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/squirrel-7985502_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/squirrel-8211238_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/squirrel-8211238_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/training-8122941_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/training-8122941_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/violin-8405558_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/violin-8405558_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/weight-8246973_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/weight-8246973_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/woman-4549327_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-4549327_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/woman-4757707_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-4757707_1280.jpg
--------------------------------------------------------------------------------
/animatelcm_svd/test_imgs/woman-5667299_1280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-5667299_1280.jpg
--------------------------------------------------------------------------------
/metrics/UCF101_prompts.yaml:
--------------------------------------------------------------------------------
1 | # This file contains the brief captions of UCF101 generated by GPT4. Released by [Fu-Yun Wang](https://github.com/G-U-N).
2 | ApplyEyeMakeup: A girl applying eye makeup.
3 | ApplyLipstick: A person applying bright lipstick.
4 | Archery: A person in sportswear aiming an arrow.
5 | BabyCrawling: A baby crawling with toys around.
6 | BalanceBeam: An athlete on a balance beam.
7 | BandMarching: A band marching with instruments.
8 | BaseballPitch: A pitcher throwing a baseball.
9 | Basketball: A player dribbling a basketball.
10 | BasketballDunk: A player performing a basketball dunk.
11 | BenchPress: An athlete doing a bench press.
12 | Biking: A cyclist riding on a trail.
13 | Billiards: A player aiming in billiards.
14 | BlowDryHair: A person using a blow dryer.
15 | BlowingCandles: Someone blowing out birthday candles.
16 | BodyWeightSquats: A person doing squats.
17 | Bowling: A bowler releasing a bowling ball.
18 | BoxingPunchingBag: A boxer punching a heavy bag.
19 | BoxingSpeedBag: A boxer hitting a speed bag.
20 | BreastStroke: A swimmer doing breaststroke.
21 | BrushingTeeth: A person brushing their teeth.
22 | CleanAndJerk: An athlete lifting a barbell overhead.
23 | CliffDiving: A diver leaping off a cliff.
24 | CricketBowling: A cricket bowler in delivery stride.
25 | CricketShot: A batsman playing a cricket shot.
26 | CuttingInKitchen: A person cutting ingredients.
27 | Diving: A diver performing from a board.
28 | Drumming: A drummer playing a drum set.
29 | Fencing: Two fencers dueling.
30 | FieldHockeyPenalty: A player taking a hockey penalty shot.
31 | FloorGymnastics: A gymnast in a floor routine.
32 | FrisbeeCatch: A person catching a frisbee.
33 | FrontCrawl: A swimmer doing front crawl.
34 | GolfSwing: A golfer mid-swing.
35 | Haircut: A person getting a haircut.
36 | HammerThrow: An athlete throwing a hammer.
37 | Hammering: A person hammering a nail.
38 | HandstandPushups: An athlete doing handstand pushups.
39 | HandstandWalking: A person walking on hands.
40 | HeadMassage: A person receiving a head massage.
41 | HighJump: An athlete doing a high jump.
42 | HorseRace: Jockeys racing on horses.
43 | HorseRiding: A person riding a horse.
44 | HulaHoop: A person using a hula hoop.
45 | IceDancing: Ice dancers performing on a rink.
46 | JavelinThrow: An athlete throwing a javelin.
47 | JugglingBalls: A person juggling balls.
48 | JumpingJack: A person doing jumping jacks.
49 | JumpRope: An athlete skipping rope.
50 | Kayaking: A kayaker navigating rapids.
51 | Knitting: Hands knitting wool.
52 | LongJump: An athlete doing a long jump.
53 | Lunges: A person performing lunges.
54 | MilitaryParade: A military parade.
55 | Mixing: A chef mixing ingredients.
56 | MoppingFloor: Someone mopping a floor.
57 | Nunchucks: A martial artist using nunchucks.
58 | ParallelBars: An athlete on parallel bars.
59 | PizzaTossing: A chef tossing pizza dough.
60 | PlayingCello: A musician playing the cello.
61 | PlayingDaf: A musician playing a Daf.
62 | PlayingDhol: A person playing a Dhol.
63 | PlayingFlute: A musician playing the flute.
64 | PlayingGuitar: A guitarist strumming chords.
65 | PlayingPiano: Hands on piano keys.
66 | PlayingSitar: A musician playing the Sitar.
67 | PlayingTabla: A musician playing the Tabla.
68 | PlayingViolin: A violinist playing.
69 | PoleVault: An athlete doing pole vault.
70 | PommelHorse: A gymnast on the pommel horse.
71 | PullUps: A person doing pull-ups.
72 | Punch: A boxer throwing a punch.
73 | PushUps: A person doing push-ups.
74 | Rafting: A team rafting in rapids.
75 | RockClimbingIndoor: A climber reaching for a hold.
76 | RopeClimbing: An athlete climbing a rope.
77 | Rowing: A team rowing in unison.
78 | SalsaSpin: A couple doing a salsa spin.
79 | ShavingBeard: A person shaving their beard.
80 | Shotput: An athlete throwing shot put.
81 | SkateBoarding: A skateboarder doing a trick.
82 | Skiing: A skier descending a slope.
83 | Skijet: A person on a jet ski.
84 | SkyDiving: A skydiver in free fall.
85 | SoccerJuggling: A player juggling a soccer ball.
86 | SoccerPenalty: A player taking a soccer penalty.
87 | StillRings: A gymnast on still rings.
88 | SumoWrestling: Sumo wrestlers in a match.
89 | Surfing: A surfer on a wave.
90 | Swing: A person on a swing.
91 | TableTennisShot: A player making a table tennis shot.
92 | TaiChi: A person practicing Tai Chi.
93 | TennisSwing: A tennis player mid-swing.
94 | ThrowDiscus: An athlete throwing a discus.
95 | TrampolineJumping: A person on a trampoline.
96 | Typing: Hands typing on a keyboard.
97 | UnevenBars: A gymnast on uneven bars.
98 | VolleyballSpiking: A player spiking a volleyball.
99 | WalkingWithDog: A person walking with a dog.
100 | WallPushups: A person doing wall pushups.
101 | WritingOnBoard: A person writing on a board.
102 | YoYo: A person performing yo-yo tricks.
103 |
--------------------------------------------------------------------------------
/metrics/clip_score.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPTextModel, CLIPTokenizer
2 |
3 | def load_clip():
4 | clip_model = CLIPModel.from_pretrained("CLIP-ViT-H-14-laion2B-s32B-b79K")
5 | clip_processor = CLIPProcessor.from_pretrained("CLIP-ViT-H-14-laion2B-s32B-b79K")
6 | return clip_model, clip_processor
7 |
8 | def get_clip_score(image_pil,text, clip_model, clip_processor):
9 | inputs = clip_processor(text=text, images=image_pil, return_tensors="pt", padding=True)
10 | if torch.cuda.is_available():
11 | inputs = {key: value.to("cuda") for key, value in inputs.items()}
12 | outputs = clip_model(**inputs)
13 | logits_per_image = outputs.logits_per_image
14 | return logits_per_image
15 |
--------------------------------------------------------------------------------
/metrics/fvd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.utils.data as data
7 | from tqdm import tqdm
8 |
9 | from scripts.fvd.pytorch_i3d import InceptionI3d
10 | import os
11 |
12 | from sklearn.metrics.pairwise import polynomial_kernel
13 |
14 | MAX_BATCH = 8
15 | FVD_SAMPLE_SIZE = 2048
16 | TARGET_RESOLUTION = (224, 224)
17 |
18 | def preprocess(videos, target_resolution):
19 | # videos in {0, ..., 255} as np.uint8 array
20 | b, t, h, w, c = videos.shape
21 | all_frames = torch.FloatTensor(videos).flatten(end_dim=1) # (b * t, h, w, c)
22 | all_frames = all_frames.permute(0, 3, 1, 2).contiguous() # (b * t, c, h, w)
23 | resized_videos = F.interpolate(all_frames, size=target_resolution,
24 | mode='bilinear', align_corners=False)
25 | resized_videos = resized_videos.view(b, t, c, *target_resolution)
26 | output_videos = resized_videos.transpose(1, 2).contiguous() # (b, c, t, *)
27 | scaled_videos = 2. * output_videos / 255. - 1 # [-1, 1]
28 | return scaled_videos
29 |
30 | def get_fvd_logits(videos, i3d, device):
31 | videos = preprocess(videos, TARGET_RESOLUTION)
32 | embeddings = get_logits(i3d, videos, device)
33 | return embeddings
34 |
35 | def load_fvd_model(device):
36 | i3d = InceptionI3d(400, in_channels=3).to(device)
37 | current_dir = os.path.dirname(os.path.abspath(__file__))
38 | i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt')
39 | i3d.load_state_dict(torch.load(i3d_path, map_location=device))
40 | i3d.eval()
41 | return i3d
42 |
43 |
44 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
45 | def _symmetric_matrix_square_root(mat, eps=1e-10):
46 | u, s, v = torch.svd(mat)
47 | si = torch.where(s < eps, s, torch.sqrt(s))
48 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
49 |
50 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
51 | def trace_sqrt_product(sigma, sigma_v):
52 | sqrt_sigma = _symmetric_matrix_square_root(sigma)
53 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
54 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
55 |
56 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
57 | def cov(m, rowvar=False):
58 | '''Estimate a covariance matrix given data.
59 |
60 | Covariance indicates the level to which two variables vary together.
61 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
62 | then the covariance matrix element `C_{ij}` is the covariance of
63 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
64 |
65 | Args:
66 | m: A 1-D or 2-D array containing multiple variables and observations.
67 | Each row of `m` represents a variable, and each column a single
68 | observation of all those variables.
69 | rowvar: If `rowvar` is True, then each row represents a
70 | variable, with observations in the columns. Otherwise, the
71 | relationship is transposed: each column represents a variable,
72 | while the rows contain observations.
73 |
74 | Returns:
75 | The covariance matrix of the variables.
76 | '''
77 | if m.dim() > 2:
78 | raise ValueError('m has more than 2 dimensions')
79 | if m.dim() < 2:
80 | m = m.view(1, -1)
81 | if not rowvar and m.size(0) != 1:
82 | m = m.t()
83 |
84 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate
85 | m_center = m - torch.mean(m, dim=1, keepdim=True)
86 | mt = m_center.t() # if complex: mt = m.t().conj()
87 | return fact * m_center.matmul(mt).squeeze()
88 |
89 |
90 | def frechet_distance(x1, x2):
91 | x1 = x1.flatten(start_dim=1)
92 | x2 = x2.flatten(start_dim=1)
93 | m, m_w = x1.mean(dim=0), x2.mean(dim=0)
94 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
95 |
96 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
97 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
98 |
99 | mean = torch.sum((m - m_w) ** 2)
100 | fd = trace + mean
101 | return fd
102 |
103 |
104 | def polynomial_mmd(X, Y):
105 | m = X.shape[0]
106 | n = Y.shape[0]
107 | # compute kernels
108 | K_XX = polynomial_kernel(X)
109 | K_YY = polynomial_kernel(Y)
110 | K_XY = polynomial_kernel(X, Y)
111 | # compute mmd distance
112 | K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1))
113 | K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1))
114 | K_XY_sum = K_XY.sum() / (m * n)
115 | mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum
116 | return mmd
117 |
118 |
119 |
120 | def get_logits(i3d, videos, device):
121 | MAX_BATCH_INPUT = min(MAX_BATCH, videos.shape[0])
122 | # assert videos.shape[0] % MAX_BATCH_INPUT == 0
123 | with torch.no_grad():
124 | logits = []
125 | for i in tqdm(range(0, videos.shape[0], MAX_BATCH_INPUT)):
126 | batch = videos[i:i + MAX_BATCH_INPUT].to(device)
127 | logits.append(i3d(batch))
128 | logits = torch.cat(logits, dim=0)
129 | return logits
130 |
131 |
132 | def compute_logits(samples, i3d, device=torch.device("cpu")):
133 | samples = preprocess(samples, (224, 224))
134 | logits = get_logits(i3d, samples, device)
135 |
136 | return logits
137 |
138 |
139 | def compute_fvd(real, samples, i3d, device=torch.device('cpu')):
140 | # real, samples are (N, T, H, W, C) numpy arrays in np.uint8
141 | real, samples = preprocess(real, (224, 224)), preprocess(samples, (224, 224))
142 | first_embed = get_logits(i3d, real, device)
143 | second_embed = get_logits(i3d, samples, device)
144 |
145 | return frechet_distance(first_embed, second_embed)
146 |
147 |
148 |
--------------------------------------------------------------------------------
/metrics/i3d_pretrained_400.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/metrics/i3d_pretrained_400.pt
--------------------------------------------------------------------------------
/metrics/pytorch_i3d.py:
--------------------------------------------------------------------------------
1 | # https://github.com/piergiaj/pytorch-i3d
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import numpy as np
8 |
9 | import os
10 | import sys
11 | from collections import OrderedDict
12 |
13 |
14 | class MaxPool3dSamePadding(nn.MaxPool3d):
15 |
16 | def compute_pad(self, dim, s):
17 | if s % self.stride[dim] == 0:
18 | return max(self.kernel_size[dim] - self.stride[dim], 0)
19 | else:
20 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
21 |
22 | def forward(self, x):
23 | # compute 'same' padding
24 | (batch, channel, t, h, w) = x.size()
25 | #print t,h,w
26 | out_t = np.ceil(float(t) / float(self.stride[0]))
27 | out_h = np.ceil(float(h) / float(self.stride[1]))
28 | out_w = np.ceil(float(w) / float(self.stride[2]))
29 | #print out_t, out_h, out_w
30 | pad_t = self.compute_pad(0, t)
31 | pad_h = self.compute_pad(1, h)
32 | pad_w = self.compute_pad(2, w)
33 | #print pad_t, pad_h, pad_w
34 |
35 | pad_t_f = pad_t // 2
36 | pad_t_b = pad_t - pad_t_f
37 | pad_h_f = pad_h // 2
38 | pad_h_b = pad_h - pad_h_f
39 | pad_w_f = pad_w // 2
40 | pad_w_b = pad_w - pad_w_f
41 |
42 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
43 | #print x.size()
44 | #print pad
45 | x = F.pad(x, pad)
46 | return super(MaxPool3dSamePadding, self).forward(x)
47 |
48 |
49 | class Unit3D(nn.Module):
50 |
51 | def __init__(self, in_channels,
52 | output_channels,
53 | kernel_shape=(1, 1, 1),
54 | stride=(1, 1, 1),
55 | padding=0,
56 | activation_fn=F.relu,
57 | use_batch_norm=True,
58 | use_bias=False,
59 | name='unit_3d'):
60 |
61 | """Initializes Unit3D module."""
62 | super(Unit3D, self).__init__()
63 |
64 | self._output_channels = output_channels
65 | self._kernel_shape = kernel_shape
66 | self._stride = stride
67 | self._use_batch_norm = use_batch_norm
68 | self._activation_fn = activation_fn
69 | self._use_bias = use_bias
70 | self.name = name
71 | self.padding = padding
72 |
73 | self.conv3d = nn.Conv3d(in_channels=in_channels,
74 | out_channels=self._output_channels,
75 | kernel_size=self._kernel_shape,
76 | stride=self._stride,
77 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
78 | bias=self._use_bias)
79 |
80 | if self._use_batch_norm:
81 | self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
82 |
83 | def compute_pad(self, dim, s):
84 | if s % self._stride[dim] == 0:
85 | return max(self._kernel_shape[dim] - self._stride[dim], 0)
86 | else:
87 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
88 |
89 |
90 | def forward(self, x):
91 | # compute 'same' padding
92 | (batch, channel, t, h, w) = x.size()
93 | #print t,h,w
94 | out_t = np.ceil(float(t) / float(self._stride[0]))
95 | out_h = np.ceil(float(h) / float(self._stride[1]))
96 | out_w = np.ceil(float(w) / float(self._stride[2]))
97 | #print out_t, out_h, out_w
98 | pad_t = self.compute_pad(0, t)
99 | pad_h = self.compute_pad(1, h)
100 | pad_w = self.compute_pad(2, w)
101 | #print pad_t, pad_h, pad_w
102 |
103 | pad_t_f = pad_t // 2
104 | pad_t_b = pad_t - pad_t_f
105 | pad_h_f = pad_h // 2
106 | pad_h_b = pad_h - pad_h_f
107 | pad_w_f = pad_w // 2
108 | pad_w_b = pad_w - pad_w_f
109 |
110 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
111 | #print x.size()
112 | #print pad
113 | x = F.pad(x, pad)
114 | #print x.size()
115 |
116 | x = self.conv3d(x)
117 | if self._use_batch_norm:
118 | x = self.bn(x)
119 | if self._activation_fn is not None:
120 | x = self._activation_fn(x)
121 | return x
122 |
123 |
124 |
125 | class InceptionModule(nn.Module):
126 | def __init__(self, in_channels, out_channels, name):
127 | super(InceptionModule, self).__init__()
128 |
129 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
130 | name=name+'/Branch_0/Conv3d_0a_1x1')
131 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
132 | name=name+'/Branch_1/Conv3d_0a_1x1')
133 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
134 | name=name+'/Branch_1/Conv3d_0b_3x3')
135 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
136 | name=name+'/Branch_2/Conv3d_0a_1x1')
137 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
138 | name=name+'/Branch_2/Conv3d_0b_3x3')
139 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
140 | stride=(1, 1, 1), padding=0)
141 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
142 | name=name+'/Branch_3/Conv3d_0b_1x1')
143 | self.name = name
144 |
145 | def forward(self, x):
146 | b0 = self.b0(x)
147 | b1 = self.b1b(self.b1a(x))
148 | b2 = self.b2b(self.b2a(x))
149 | b3 = self.b3b(self.b3a(x))
150 | return torch.cat([b0,b1,b2,b3], dim=1)
151 |
152 |
153 | class InceptionI3d(nn.Module):
154 | """Inception-v1 I3D architecture.
155 | The model is introduced in:
156 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
157 | Joao Carreira, Andrew Zisserman
158 | https://arxiv.org/pdf/1705.07750v1.pdf.
159 | See also the Inception architecture, introduced in:
160 | Going deeper with convolutions
161 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
162 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
163 | http://arxiv.org/pdf/1409.4842v1.pdf.
164 | """
165 |
166 | # Endpoints of the model in order. During construction, all the endpoints up
167 | # to a designated `final_endpoint` are returned in a dictionary as the
168 | # second return value.
169 | VALID_ENDPOINTS = (
170 | 'Conv3d_1a_7x7',
171 | 'MaxPool3d_2a_3x3',
172 | 'Conv3d_2b_1x1',
173 | 'Conv3d_2c_3x3',
174 | 'MaxPool3d_3a_3x3',
175 | 'Mixed_3b',
176 | 'Mixed_3c',
177 | 'MaxPool3d_4a_3x3',
178 | 'Mixed_4b',
179 | 'Mixed_4c',
180 | 'Mixed_4d',
181 | 'Mixed_4e',
182 | 'Mixed_4f',
183 | 'MaxPool3d_5a_2x2',
184 | 'Mixed_5b',
185 | 'Mixed_5c',
186 | 'Logits',
187 | 'Predictions',
188 | )
189 |
190 | FEAT_ENDPOINTS = (
191 | 'Conv3d_1a_7x7',
192 | 'Conv3d_2c_3x3',
193 | 'Mixed_3c',
194 | 'Mixed_4f',
195 | 'Mixed_5c',
196 | )
197 | def __init__(self,
198 | num_classes=400,
199 | spatial_squeeze=True,
200 | final_endpoint='Logits',
201 | name='inception_i3d',
202 | in_channels=3,
203 | dropout_keep_prob=0.5,
204 | is_coinrun=False,
205 | ):
206 | """Initializes I3D model instance.
207 | Args:
208 | num_classes: The number of outputs in the logit layer (default 400, which
209 | matches the Kinetics dataset).
210 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
211 | before returning (default True).
212 | final_endpoint: The model contains many possible endpoints.
213 | `final_endpoint` specifies the last endpoint for the model to be built
214 | up to. In addition to the output at `final_endpoint`, all the outputs
215 | at endpoints up to `final_endpoint` will also be returned, in a
216 | dictionary. `final_endpoint` must be one of
217 | InceptionI3d.VALID_ENDPOINTS (default 'Logits').
218 | name: A string (optional). The name of this module.
219 | Raises:
220 | ValueError: if `final_endpoint` is not recognized.
221 | """
222 |
223 | if final_endpoint not in self.VALID_ENDPOINTS:
224 | raise ValueError('Unknown final endpoint %s' % final_endpoint)
225 |
226 | super(InceptionI3d, self).__init__()
227 | self._num_classes = num_classes
228 | self._spatial_squeeze = spatial_squeeze
229 | self._final_endpoint = final_endpoint
230 | self.logits = None
231 | self.is_coinrun = is_coinrun
232 |
233 | if self._final_endpoint not in self.VALID_ENDPOINTS:
234 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
235 |
236 | self.end_points = {}
237 | end_point = 'Conv3d_1a_7x7'
238 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
239 | stride=(1 if is_coinrun else 2, 2, 2), padding=(3,3,3), name=name+end_point)
240 | if self._final_endpoint == end_point: return
241 |
242 | end_point = 'MaxPool3d_2a_3x3'
243 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
244 | padding=0)
245 | if self._final_endpoint == end_point: return
246 |
247 | end_point = 'Conv3d_2b_1x1'
248 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
249 | name=name+end_point)
250 | if self._final_endpoint == end_point: return
251 |
252 | end_point = 'Conv3d_2c_3x3'
253 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
254 | name=name+end_point)
255 | if self._final_endpoint == end_point: return
256 |
257 | end_point = 'MaxPool3d_3a_3x3'
258 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
259 | padding=0)
260 | if self._final_endpoint == end_point: return
261 |
262 | end_point = 'Mixed_3b'
263 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
264 | if self._final_endpoint == end_point: return
265 |
266 | end_point = 'Mixed_3c'
267 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
268 | if self._final_endpoint == end_point: return
269 |
270 | end_point = 'MaxPool3d_4a_3x3'
271 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1 if is_coinrun else 3, 3, 3], stride=(1 if is_coinrun else 2, 2, 2),
272 | padding=0)
273 | if self._final_endpoint == end_point: return
274 |
275 | end_point = 'Mixed_4b'
276 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
277 | if self._final_endpoint == end_point: return
278 |
279 | end_point = 'Mixed_4c'
280 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
281 | if self._final_endpoint == end_point: return
282 |
283 | end_point = 'Mixed_4d'
284 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
285 | if self._final_endpoint == end_point: return
286 |
287 | end_point = 'Mixed_4e'
288 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
289 | if self._final_endpoint == end_point: return
290 |
291 | end_point = 'Mixed_4f'
292 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
293 | if self._final_endpoint == end_point: return
294 |
295 | end_point = 'MaxPool3d_5a_2x2'
296 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(1 if is_coinrun else 2, 2, 2),
297 | padding=0)
298 | if self._final_endpoint == end_point: return
299 |
300 | end_point = 'Mixed_5b'
301 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
302 | if self._final_endpoint == end_point: return
303 |
304 | end_point = 'Mixed_5c'
305 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
306 | if self._final_endpoint == end_point: return
307 |
308 | end_point = 'Logits'
309 | self.avg_pool = nn.AvgPool3d(kernel_size=[1, 8, 8] if is_coinrun else [2, 7, 7],
310 | stride=(1, 1, 1))
311 | self.dropout = nn.Dropout(dropout_keep_prob)
312 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
313 | kernel_shape=[1, 1, 1],
314 | padding=0,
315 | activation_fn=None,
316 | use_batch_norm=False,
317 | use_bias=True,
318 | name='logits')
319 |
320 | self.build()
321 |
322 |
323 | def replace_logits(self, num_classes):
324 | self._num_classes = num_classes
325 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
326 | kernel_shape=[1, 1, 1],
327 | padding=0,
328 | activation_fn=None,
329 | use_batch_norm=False,
330 | use_bias=True,
331 | name='logits')
332 |
333 |
334 | def build(self):
335 | for k in self.end_points.keys():
336 | self.add_module(k, self.end_points[k])
337 |
338 | def forward(self, x):
339 | for end_point in self.VALID_ENDPOINTS:
340 | if end_point in self.end_points:
341 | x = self._modules[end_point](x) # use _modules to work with dataparallel
342 |
343 | x = self.logits(self.dropout(self.avg_pool(x)))
344 | if self._spatial_squeeze:
345 | logits = x.squeeze(3).squeeze(3)
346 | logits = logits.mean(dim=2)
347 | # logits is batch X time X classes, which is what we want to work with
348 | return logits
349 |
350 |
351 | def extract_features(self, x):
352 | for end_point in self.VALID_ENDPOINTS:
353 | if end_point in self.end_points:
354 | x = self._modules[end_point](x)
355 | return self.avg_pool(x)
356 |
357 |
358 | def extract_pre_pool_features(self, x):
359 | for end_point in self.VALID_ENDPOINTS:
360 | if end_point in self.end_points:
361 | x = self._modules[end_point](x)
362 | return x
363 |
364 |
365 | def extract_features_multiscale(self, x):
366 | xs = []
367 | for end_point in self.VALID_ENDPOINTS:
368 | if end_point in self.end_points:
369 | x = self._modules[end_point](x)
370 | if end_point in self.FEAT_ENDPOINTS:
371 | xs.append(x)
372 | return xs
373 |
374 |
--------------------------------------------------------------------------------