├── LICENSE.txt ├── README.md ├── __assets__ ├── gifs │ ├── dog.gif │ ├── girl.gif │ └── swan.gif └── imgs │ ├── comparison.png │ ├── demo_figure.png │ └── examples.png ├── animatelcm_sd15 ├── README.md ├── animatelcm │ ├── .DS_Store │ ├── models │ │ ├── adapter.py │ │ ├── attention.py │ │ ├── embeddings.py │ │ ├── motion_module.py │ │ ├── resnet.py │ │ ├── unet.py │ │ └── unet_blocks.py │ ├── pipelines │ │ └── pipeline_animation.py │ ├── scheduler │ │ └── lcm_scheduler.py │ └── utils │ │ ├── convert_from_ckpt.py │ │ ├── convert_lora_safetensor_to_diffusers.py │ │ ├── lcm_utils.py │ │ └── util.py ├── app-i2v.py ├── app.py ├── batch_inference.py ├── configs │ ├── batch_inference_example.yaml │ ├── inference-i2v.yaml │ └── inference-t2v.yaml ├── models │ ├── LCM_LoRA │ │ └── put_spatial_lora_here.txt │ ├── Motion_Module │ │ └── put_motion_module_here.txt │ ├── Personalized │ │ └── put_personalized_weights_here.txt │ └── StableDiffusion │ │ └── put_stable_diffusion_v15_here.txt ├── requirements.txt └── test_imgs │ ├── cloud.jpeg │ ├── dog.jpg │ ├── fire.jpg │ ├── fox.jpg │ ├── girl.png │ ├── girl_flower.jpg │ ├── lighter.jpg │ └── snow_man_fire.jpg ├── animatelcm_svd ├── README.md ├── animate_lcm_utils.py ├── animatelcm_scheduler.py ├── app.py ├── batch_inference.py ├── dataset.py ├── enviroment.yaml ├── outputs_gradio │ ├── 000000.mp4 │ ├── 000001.mp4 │ ├── 000002.mp4 │ ├── 000003.mp4 │ ├── 000004.mp4 │ └── 000005.mp4 ├── pipeline.py ├── requirements.txt ├── safetensors │ ├── AnimateLCM-SVD-xt-1.1.safetensors │ └── AnimateLCM-SVD-xt.safetensors ├── test_imgs │ ├── .DS_Store │ ├── ai-generated-8411866_1280.jpg │ ├── ai-generated-8463496_1280.jpg │ ├── ai-generated-8476858_1280.png │ ├── ai-generated-8479572_1280.jpg │ ├── ai-generated-8481641_1280.jpg │ ├── ai-generated-8496135_1280.jpg │ ├── ai-generated-8496952_1280.jpg │ ├── ai-generated-8498844_1280.jpg │ ├── bird-7411270_1280.jpg │ ├── bird-7586857_1280.jpg │ ├── bird-8014191_1280.jpg │ ├── couple-8019370_1280.jpg │ ├── cupcakes-380178_1280.jpg │ ├── dog-7330712_1280.jpg │ ├── dog-7396912_1280.jpg │ ├── girl-4898696_1280.jpg │ ├── grey-capped-flycatcher-8071233_1280.jpg │ ├── halloween-4585684_1280.jpg │ ├── leaf-7260246_1280.jpg │ ├── meerkat-7465819_1280.jpg │ ├── mobile-phone-1875813_1280.jpg │ ├── mother-8097324_1280.jpg │ ├── plane-8145957_1280.jpg │ ├── power-station-6579092_1280.jpg │ ├── ship-7833921_1280.jpg │ ├── sleep-7871915_1280.jpg │ ├── squirrel-7985502_1280.jpg │ ├── squirrel-8211238_1280.jpg │ ├── training-8122941_1280.jpg │ ├── violin-8405558_1280.jpg │ ├── weight-8246973_1280.jpg │ ├── woman-4549327_1280.jpg │ ├── woman-4757707_1280.jpg │ └── woman-5667299_1280.jpg └── train_svd_lcm.py └── metrics ├── UCF101_prompts.yaml ├── clip_score.py ├── fvd.py ├── i3d_pretrained_400.pt └── pytorch_i3d.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fu-Yun Wang. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ## ⚡️AnimateLCM: Computation-Efficient Personalized Style Video Generation without Personalized Video Data 4 | 5 | [[Paper]](https://arxiv.org/abs/2402.00769) [[Project Page ✨]](https://animatelcm.github.io/) [[Demo in 🤗Hugging Face]](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) [[Pre-trained Models]](https://huggingface.co/wangfuyun/AnimateLCM) [[Civitai]](https://civitai.com/models/290375/animatelcm-fast-video-generation) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=G-U-N.AnimateLCM) 6 | 7 | 8 | by *[Fu-Yun Wang](https://g-u-n.github.io), Zhaoyang Huang📮, Weikang Bian, Xiaoyu Shi, Keqiang Sun, Guanglu Song, Yu Liu, Hongsheng Li📮* 9 | 10 |
11 | 12 | | Example 1 | Example 2 | Example 3 | 13 | |-----------------|-----------------|-----------------| 14 | | ![GIF 1](__assets__/gifs/dog.gif) | ![GIF 2](__assets__/gifs/girl.gif) | ![GIF 3](__assets__/gifs/swan.gif) | 15 | 16 | If you use any components of our work, please cite it. 17 | 18 | ``` 19 | @article{wang2024animatelcm, 20 | title={AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning}, 21 | author={Wang, Fu-Yun and Huang, Zhaoyang and Shi, Xiaoyu and Bian, Weikang and Song, Guanglu and Liu, Yu and Li, Hongsheng}, 22 | journal={arXiv preprint arXiv:2402.00769}, 23 | year={2024} 24 | } 25 | 26 | ``` 27 | ### News 28 | 29 | - [2024.05]: 🔥🔥🔥 We release the [training script](https://github.com/G-U-N/AnimateLCM/blob/master/animatelcm_svd/train_svd_lcm.py) for accelerating Stable Video Diffusion. 30 | - [2024.03]: 😆😆😆 We release the AnimateLCM-I2V and AnimateLCM-SVD for fast image animation. 31 | - [2024.02]: 🤗🤗🤗 Release pretrained model weights and Huggingface Demo. 32 | - [2024.02]: 💡💡💡 Technical report is available on arXiv. 33 | 34 | 35 | Here is a screen recording of usage. Prompt:"river reflecting mountain" 36 | 37 | ![case1x2](https://github.com/G-U-N/AnimateLCM/assets/60997859/98f6cefe-b5f8-4bcc-966e-bbca93638e8d) 38 | 39 | 40 | ### Introduction 41 | 42 | 43 | Animate-LCM is **a pioneer work** and exploratory on fast animation generation following the consistency models, being able to generate animations in good quality with 4 inference steps. 44 | 45 | It relies on the **decoupled** learning paradigm, firstly learning image generation prior and then learning the temporal generation prior for fast sampling, greatly boosting the training efficiency. 46 | 47 | The High-level workflow of AnimateLCM can be 48 | 49 | 50 |
51 | comparison 52 |
53 | 54 | 55 | ### Demos 56 | 57 | We have **launched lots of demo videos generated by Animate-LCM on the [Project Page](https://animatelcm.github.io/)**. Generally speaking, AnimateLCM works for fast, text-to-video, control-to-video, image-to-video, video-to-video stylization, and longer video generation. 58 | 59 | 60 |
61 | comparison 62 |
63 | 64 | 65 | 66 | 67 | 68 | ### Models 69 | 70 | So far, we have released three models for usage 71 | 72 | - [Animate-LCM-T2V](https://huggingface.co/wangfuyun/AnimateLCM): A spatial LoRA weight and a motion module for personalized video generation. Some trying from the community point out that the motion module is also compatible with many personalized models tuned for LCM, for example [Dreamshaper-LCM](https://civitai.com/models/4384?modelVersionId=252914). 73 | 74 | - [AnimateLCM-SVD-xt](https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt). I provide AnimateLCM-SVD-xt and AnimateLCM-SVD-xt 1.1, which are tuned from [SVD-xt](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) and [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) respectively. They work for high-resolution image animation with 25 frames with 1~8 steps. You can try it with the Hugging Face [Demo](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD). Thanks to the Hugging Face team for providing the GPU grants. 75 | 76 | - [AnimateLCM-I2V](https://huggingface.co/wangfuyun/AnimateLCM-I2V). A spatial LoRA weight and a motion module with an additional image encoder for personalized image animation. It is our trying to directly train an image animation model for fast sampling without any teacher models. It can generate animations with a personalized image with 2~4 steps. Yet due to the training resources is very limited, it is not as stable as I would like (Just like most I2V models built on Stable-Diffusion-v1-5, they generally not very stable for generation). 77 | 78 | ### Install & Usage Instruction 79 | 80 | We split the animatelcm_sd15 and animatelcm_svd into two folders. They are based on different environments. Please refer to [README_animatelcm_sd15](./animatelcm_sd15/README.md) and [README_animatelcm_svd](./animatelcm_svd/README.md) for instructions. 81 | 82 | ### Usage Tips 83 | 84 | - **AnimateLCM-T2V**: 85 | - 4 steps can generally work well. For better quality, apply 6~8 inference steps to improve the generation quality. 86 | - CFG scale should be set between 1~2. Set CFG=1 can reduce the sampling cost by half. However, generally, I would prefer using CFG 1.5 and setting proper negative prompts for sampling to achieve better quality. 87 | - Set the video length to 16 frames for sampling. This is the length that the model trained with. 88 | - The models should work with IP-Adapter, ControlNet, and lots of adapters tuned for Stable Diffusion in a zero-shot manner. If you hope for better results of combination, you can try to tune them together by applying the teacher-free adaptation script I provide. It will not corrupt the sampling speed. 89 | 90 | - **AnimateLCM-I2V**: 91 | - 2-4 steps should work for personalized image animation. 92 | - In most cases, the model does not need CFG values. Just set the CFG=1 to reduce inference cost. 93 | - I additionally set a `motion scale` hyper-parameter. Set it to 0.8 as the default choice. If you set it to 0.0, you should always obtain static animations. You can increase the motion scale for larger motions, but that will sometimes cause generation failure. 94 | 95 | - The typical workflow can be: 96 | - Using your personalized image models to generate an image with good quality. 97 | - Applying the generated image as input and reusing the same prompt for image animation. 98 | - You can even further apply AnimateLCM-T2V to refine the final motion quality. 99 | 100 | - **AnimateLCM-SVD**: 101 | - 1-4 steps should work. 102 | - SVD requires two CFG values. `CFG_min` and `CFG_max`. By default, `CFG_min` is set to 1. Slightly adjusting `CFG_max` between [1, 1.5] will obtain good results. Again, just setting it to 1 to reduce the inference cost. 103 | - For other hyper-parameters of AnimateLCM-SVD-xt, please just follow the original SVD design. 104 | 105 | ### Related Notes 106 | 107 | - 🎉 Tutorial video of AnimateLCM on ComfyUI: [Tutorial Video](https://www.youtube.com/watch?v=HxlZHsd6xAk&feature=youtu.be) 108 | - 🎉 ComfyUI for AnimateLCM: [AnimateLCM-ComfyUI](https://github.com/dezi-ai/ComfyUI-AnimateLCM) & [ComfyUI-Reddit](https://www.reddit.com/r/comfyui/comments/1ajjp9v/animatelcm_support_just_dropped/) 109 | 110 | 111 | ### Comparison 112 | 113 | Screen recording of AnimateLCM-T2V. Prompt: "dog with sunglasses". 114 | 115 | ![case2x2](https://github.com/G-U-N/AnimateLCM/assets/60997859/b23f3946-f3e2-4800-8662-dff9457a60ac) 116 | 117 | 118 |
119 | comparison 120 |
121 | 122 | 123 | ### Contact & Collaboration 124 | 125 | I am open to collaboration, but not to a full-time intern. If you find some of my work interesting and hope for collaboration/discussion in any format, please do not hesitate to contact me. 126 | 127 | 📧 Email: fywang@link.cuhk.edu.hk 128 | 129 | ### Acknowledge 130 | 131 | I would thank **[AK](https://twitter.com/_akhaliq)** for broadcasting our work and the hugging face team for providing help in building the gradio demo and storing the models. Would thank the [Dhruv Nair](https://twitter.com/_DhruvNair_) for providing help in diffusers. 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /__assets__/gifs/dog.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/dog.gif -------------------------------------------------------------------------------- /__assets__/gifs/girl.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/girl.gif -------------------------------------------------------------------------------- /__assets__/gifs/swan.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/gifs/swan.gif -------------------------------------------------------------------------------- /__assets__/imgs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/comparison.png -------------------------------------------------------------------------------- /__assets__/imgs/demo_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/demo_figure.png -------------------------------------------------------------------------------- /__assets__/imgs/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/__assets__/imgs/examples.png -------------------------------------------------------------------------------- /animatelcm_sd15/README.md: -------------------------------------------------------------------------------- 1 | ## AnimateLCM SD15 2 | 3 | 4 | ### Enviroment 5 | 6 | ``` 7 | conda create -n animatelcm python=3.9 8 | conda activate animatelcm 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ### Models 13 | 14 | 1. stable diffusion 15 | ``` 16 | cd models/StableDiffusion/ 17 | git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 18 | ``` 19 | 2. motion_module 20 | ``` 21 | cd .. 22 | cd Motion_Module 23 | wget -c https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v.ckpt 24 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-I2V/resolve/main/AnimateLCM_sd15_i2v.ckpt 25 | ``` 26 | 27 | 3. spatial_lora 28 | ``` 29 | cd .. 30 | cd LCM_LoRA 31 | wget -c https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v_lora.safetensors 32 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-I2V/resolve/main/AnimateLCM_sd15_i2v_lora.safetensors 33 | ``` 34 | 35 | 4. personalized models 36 | 37 | You can either download from the civitai page or apply this [civitai downloader](https://github.com/ashleykleynhans/civitai-downloader). Then put your downloaded models on the Personalized folder 38 | 39 | 40 | ### Inference 41 | 42 | ``` 43 | python app.py 44 | 45 | python app-i2v.py 46 | ``` 47 | 48 | ### Batch Inference 49 | 50 | ``` 51 | python batch_inference.py --config=./configs/batch_inference_example.yaml 52 | ``` -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/animatelcm/.DS_Store -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/models/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import math 6 | import numpy as np 7 | 8 | 9 | def zero_module(module): 10 | # Zero out the parameters of a module and return it. 11 | for p in module.parameters(): 12 | p.detach().zero_() 13 | return module 14 | 15 | 16 | def conv_nd(dims, in_channels, out_channels, kernel_size, **kwargs): 17 | """ 18 | Create a 1D, 2D, or 3D convolution module. 19 | """ 20 | if dims == 1: 21 | return nn.Conv1d(in_channels, out_channels, kernel_size, **kwargs) 22 | elif dims == 2: 23 | return nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs) 24 | elif dims == 3: 25 | if isinstance(kernel_size, int): 26 | kernel_size = (1, *((kernel_size,) * 2)) 27 | if 'stride' in kwargs.keys(): 28 | if isinstance(kwargs['stride'], int): 29 | kwargs['stride'] = (1, *((kwargs['stride'],) * 2)) 30 | if 'padding' in kwargs.keys(): 31 | if isinstance(kwargs['padding'], int): 32 | kwargs['padding'] = (0, *((kwargs['padding'],) * 2)) 33 | return nn.Conv3d(in_channels, out_channels, kernel_size, **kwargs) 34 | raise ValueError(f"unsupported dimensions: {dims}") 35 | 36 | 37 | def avg_pool_nd(dims, *args, **kwargs): 38 | """ 39 | Create a 1D, 2D, or 3D average pooling module. 40 | """ 41 | if dims == 1: 42 | return nn.AvgPool1d(*args, **kwargs) 43 | elif dims == 2: 44 | return nn.AvgPool2d(*args, **kwargs) 45 | elif dims == 3: 46 | return nn.AvgPool3d(*args, **kwargs) 47 | raise ValueError(f"unsupported dimensions: {dims}") 48 | 49 | 50 | def fixed_positional_embedding(t, d_model): 51 | position = torch.arange(0, t, dtype=torch.float).unsqueeze(1) 52 | div_term = torch.exp(torch.arange(0, d_model, 2).float() 53 | * (-np.log(10000.0) / d_model)) 54 | pos_embedding = torch.zeros(t, d_model) 55 | pos_embedding[:, 0::2] = torch.sin(position * div_term) 56 | pos_embedding[:, 1::2] = torch.cos(position * div_term) 57 | return pos_embedding 58 | 59 | 60 | class Adapter(nn.Module): 61 | def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True): 62 | super(Adapter, self).__init__() 63 | self.channels = channels 64 | self.nums_rb = nums_rb 65 | self.body = [] 66 | for i in range(len(channels)): 67 | for j in range(nums_rb): 68 | if (i != 0) and (j == 0): 69 | self.body.append(ResnetBlock( 70 | channels[i-1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv)) 71 | else: 72 | self.body.append(ResnetBlock( 73 | channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv)) 74 | self.body = nn.ModuleList(self.body) 75 | self.conv_in = zero_module(nn.Conv2d(cin, channels[0], 3, 1, 1)) 76 | self.motion_scale = 0.8 77 | self.insertion_weights = [1., 1., 1., 1.] 78 | 79 | self.d_model = channels[0] 80 | 81 | def forward(self, x): 82 | b, c, t, h, w = x.shape 83 | x = rearrange(x, 'b c t h w -> (b t) c h w') 84 | 85 | features = [] 86 | x = self.conv_in(x) 87 | 88 | pos_embedding = fixed_positional_embedding( 89 | t, self.d_model).to(x.device) 90 | pos_embedding = pos_embedding.unsqueeze(-1).unsqueeze(-1) 91 | pos_embedding = pos_embedding.expand(-1, -1, h, w) 92 | 93 | x_pos = pos_embedding.repeat(b, 1, 1, 1) 94 | 95 | x = self.motion_scale*x_pos + x 96 | 97 | for i in range(len(self.channels)): 98 | for j in range(self.nums_rb): 99 | idx = i*self.nums_rb + j 100 | x = self.body[idx](x) 101 | features.append(x) 102 | features = [weight*rearrange(fn, '(b t) c h w -> b c t h w', b=b, t=t) 103 | for fn, weight in zip(features, self.insertion_weights)] 104 | return features 105 | 106 | 107 | class ResnetBlock(nn.Module): 108 | def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): 109 | super().__init__() 110 | ps = ksize//2 111 | if in_c != out_c or sk == False: 112 | self.in_conv = zero_module(nn.Conv2d(in_c, out_c, ksize, 1, ps)) 113 | else: 114 | self.in_conv = None 115 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) 116 | self.act = nn.ReLU() 117 | self.block2 = zero_module(nn.Conv2d(out_c, out_c, ksize, 1, ps)) 118 | if sk == False: 119 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 120 | else: 121 | self.skep = None 122 | 123 | self.down = down 124 | if self.down == True: 125 | self.down_opt = Downsample(in_c, use_conv=use_conv) 126 | 127 | def forward(self, x): 128 | if self.down == True: 129 | x = self.down_opt(x) 130 | 131 | if self.in_conv is not None: 132 | x = self.in_conv(x) 133 | 134 | h = self.block1(x) 135 | h = self.act(h) 136 | h = self.block2(h) 137 | 138 | if self.skep is not None: 139 | return h + self.skep(x) 140 | else: 141 | return h + x 142 | 143 | 144 | class Downsample(nn.Module): 145 | """ 146 | A downsampling layer with an optional convolution. 147 | :param channels: channels in the inputs and outputs. 148 | :param use_conv: a bool determining if a convolution is applied. 149 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 150 | downsampling occurs in the inner-two dimensions. 151 | """ 152 | 153 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 154 | super().__init__() 155 | self.channels = channels 156 | self.out_channels = out_channels or channels 157 | self.use_conv = use_conv 158 | self.dims = dims 159 | stride = 2 if dims != 3 else (1, 2, 2) 160 | if use_conv: 161 | self.op = conv_nd( 162 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 163 | ) 164 | else: 165 | assert self.channels == self.out_channels 166 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 167 | 168 | def forward(self, x): 169 | assert x.shape[1] == self.channels 170 | 171 | kernel_size = (2, 2) 172 | 173 | input_height, input_width = x.size(2), x.size(3) 174 | 175 | padding_height = ( 176 | math.ceil(input_height / kernel_size[0]) * kernel_size[0]) - input_height 177 | padding_width = ( 178 | math.ceil(input_width / kernel_size[1]) * kernel_size[1]) - input_width 179 | 180 | x = F.pad(x, (0, padding_width, 0, padding_height), mode='replicate') 181 | 182 | return self.op(x) 183 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/models/attention.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.modeling_utils import ModelMixin 10 | from diffusers.utils import BaseOutput 11 | from diffusers.utils.import_utils import is_xformers_available 12 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 13 | 14 | from einops import rearrange, repeat 15 | 16 | @dataclass 17 | class Transformer3DModelOutput(BaseOutput): 18 | sample: torch.FloatTensor 19 | 20 | 21 | if is_xformers_available(): 22 | import xformers 23 | import xformers.ops 24 | else: 25 | xformers = None 26 | 27 | 28 | class Transformer3DModel(ModelMixin, ConfigMixin): 29 | @register_to_config 30 | def __init__( 31 | self, 32 | num_attention_heads: int = 16, 33 | attention_head_dim: int = 88, 34 | in_channels: Optional[int] = None, 35 | num_layers: int = 1, 36 | dropout: float = 0.0, 37 | norm_num_groups: int = 32, 38 | cross_attention_dim: Optional[int] = None, 39 | attention_bias: bool = False, 40 | activation_fn: str = "geglu", 41 | num_embeds_ada_norm: Optional[int] = None, 42 | use_linear_projection: bool = False, 43 | only_cross_attention: bool = False, 44 | upcast_attention: bool = False, 45 | 46 | unet_use_cross_frame_attention=None, 47 | unet_use_temporal_attention=None, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 59 | if use_linear_projection: 60 | self.proj_in = nn.Linear(in_channels, inner_dim) 61 | else: 62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 63 | 64 | # Define transformers blocks 65 | self.transformer_blocks = nn.ModuleList( 66 | [ 67 | BasicTransformerBlock( 68 | inner_dim, 69 | num_attention_heads, 70 | attention_head_dim, 71 | dropout=dropout, 72 | cross_attention_dim=cross_attention_dim, 73 | activation_fn=activation_fn, 74 | num_embeds_ada_norm=num_embeds_ada_norm, 75 | attention_bias=attention_bias, 76 | only_cross_attention=only_cross_attention, 77 | upcast_attention=upcast_attention, 78 | 79 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 80 | unet_use_temporal_attention=unet_use_temporal_attention, 81 | ) 82 | for d in range(num_layers) 83 | ] 84 | ) 85 | 86 | # 4. Define output layers 87 | if use_linear_projection: 88 | self.proj_out = nn.Linear(in_channels, inner_dim) 89 | else: 90 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 91 | 92 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 93 | # Input 94 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 95 | video_length = hidden_states.shape[2] 96 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 97 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 98 | 99 | batch, channel, height, weight = hidden_states.shape 100 | residual = hidden_states 101 | 102 | hidden_states = self.norm(hidden_states) 103 | if not self.use_linear_projection: 104 | hidden_states = self.proj_in(hidden_states) 105 | inner_dim = hidden_states.shape[1] 106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 107 | else: 108 | inner_dim = hidden_states.shape[1] 109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 110 | hidden_states = self.proj_in(hidden_states) 111 | 112 | # Blocks 113 | for block in self.transformer_blocks: 114 | hidden_states = block( 115 | hidden_states, 116 | encoder_hidden_states=encoder_hidden_states, 117 | timestep=timestep, 118 | video_length=video_length 119 | ) 120 | 121 | # Output 122 | if not self.use_linear_projection: 123 | hidden_states = ( 124 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 125 | ) 126 | hidden_states = self.proj_out(hidden_states) 127 | else: 128 | hidden_states = self.proj_out(hidden_states) 129 | hidden_states = ( 130 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 131 | ) 132 | 133 | output = hidden_states + residual 134 | 135 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 136 | if not return_dict: 137 | return (output,) 138 | 139 | return Transformer3DModelOutput(sample=output) 140 | 141 | 142 | class BasicTransformerBlock(nn.Module): 143 | def __init__( 144 | self, 145 | dim: int, 146 | num_attention_heads: int, 147 | attention_head_dim: int, 148 | dropout=0.0, 149 | cross_attention_dim: Optional[int] = None, 150 | activation_fn: str = "geglu", 151 | num_embeds_ada_norm: Optional[int] = None, 152 | attention_bias: bool = False, 153 | only_cross_attention: bool = False, 154 | upcast_attention: bool = False, 155 | 156 | unet_use_cross_frame_attention = None, 157 | unet_use_temporal_attention = None, 158 | ): 159 | super().__init__() 160 | self.only_cross_attention = only_cross_attention 161 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 162 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 163 | self.unet_use_temporal_attention = unet_use_temporal_attention 164 | 165 | # SC-Attn 166 | assert unet_use_cross_frame_attention is not None 167 | if unet_use_cross_frame_attention: 168 | self.attn1 = SparseCausalAttention2D( 169 | query_dim=dim, 170 | heads=num_attention_heads, 171 | dim_head=attention_head_dim, 172 | dropout=dropout, 173 | bias=attention_bias, 174 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 175 | upcast_attention=upcast_attention, 176 | ) 177 | else: 178 | self.attn1 = CrossAttention( 179 | query_dim=dim, 180 | heads=num_attention_heads, 181 | dim_head=attention_head_dim, 182 | dropout=dropout, 183 | bias=attention_bias, 184 | upcast_attention=upcast_attention, 185 | ) 186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 187 | 188 | # Cross-Attn 189 | if cross_attention_dim is not None: 190 | self.attn2 = CrossAttention( 191 | query_dim=dim, 192 | cross_attention_dim=cross_attention_dim, 193 | heads=num_attention_heads, 194 | dim_head=attention_head_dim, 195 | dropout=dropout, 196 | bias=attention_bias, 197 | upcast_attention=upcast_attention, 198 | ) 199 | else: 200 | self.attn2 = None 201 | 202 | if cross_attention_dim is not None: 203 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 204 | else: 205 | self.norm2 = None 206 | 207 | # Feed-forward 208 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 209 | self.norm3 = nn.LayerNorm(dim) 210 | 211 | # Temp-Attn 212 | assert unet_use_temporal_attention is not None 213 | if unet_use_temporal_attention: 214 | self.attn_temp = CrossAttention( 215 | query_dim=dim, 216 | heads=num_attention_heads, 217 | dim_head=attention_head_dim, 218 | dropout=dropout, 219 | bias=attention_bias, 220 | upcast_attention=upcast_attention, 221 | ) 222 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 223 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 224 | 225 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 226 | if not is_xformers_available(): 227 | raise ModuleNotFoundError( 228 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 229 | " xformers", 230 | name="xformers", 231 | ) 232 | elif not torch.cuda.is_available(): 233 | raise ValueError( 234 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 235 | " available for GPU " 236 | ) 237 | else: 238 | try: 239 | # Make sure we can run the memory efficient attention 240 | _ = xformers.ops.memory_efficient_attention( 241 | torch.randn((1, 2, 40), device="cuda"), 242 | torch.randn((1, 2, 40), device="cuda"), 243 | torch.randn((1, 2, 40), device="cuda"), 244 | ) 245 | except Exception as e: 246 | raise e 247 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 248 | if self.attn2 is not None: 249 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 250 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 251 | 252 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 253 | # SparseCausal-Attention 254 | norm_hidden_states = ( 255 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 256 | ) 257 | 258 | # if self.only_cross_attention: 259 | # hidden_states = ( 260 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 261 | # ) 262 | # else: 263 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 264 | 265 | # pdb.set_trace() 266 | if self.unet_use_cross_frame_attention: 267 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 268 | else: 269 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 270 | 271 | if self.attn2 is not None: 272 | # Cross-Attention 273 | norm_hidden_states = ( 274 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 275 | ) 276 | hidden_states = ( 277 | self.attn2( 278 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 279 | ) 280 | + hidden_states 281 | ) 282 | 283 | # Feed-forward 284 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 285 | 286 | # Temporal-Attention 287 | if self.unet_use_temporal_attention: 288 | d = hidden_states.shape[1] 289 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 290 | norm_hidden_states = ( 291 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 292 | ) 293 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 294 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 295 | 296 | return hidden_states 297 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/models/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | 20 | 21 | def get_timestep_embedding( 22 | timesteps: torch.Tensor, 23 | embedding_dim: int, 24 | flip_sin_to_cos: bool = False, 25 | downscale_freq_shift: float = 1, 26 | scale: float = 1, 27 | max_period: int = 10000, 28 | ): 29 | """ 30 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 31 | 32 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 33 | These may be fractional. 34 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 35 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 36 | """ 37 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 38 | 39 | half_dim = embedding_dim // 2 40 | exponent = -math.log(max_period) * torch.arange( 41 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 42 | ) 43 | exponent = exponent / (half_dim - downscale_freq_shift) 44 | 45 | emb = torch.exp(exponent) 46 | emb = timesteps[:, None].float() * emb[None, :] 47 | 48 | # scale embeddings 49 | emb = scale * emb 50 | 51 | # concat sine and cosine embeddings 52 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 53 | 54 | # flip sine and cosine embeddings 55 | if flip_sin_to_cos: 56 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 57 | 58 | # zero pad 59 | if embedding_dim % 2 == 1: 60 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 61 | return emb 62 | 63 | def zero_module(module): 64 | # Zero out the parameters of a module and return it. 65 | for p in module.parameters(): 66 | p.detach().zero_() 67 | return module 68 | 69 | class TimestepEmbedding(nn.Module): 70 | def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None): 71 | super().__init__() 72 | 73 | self.linear_1 = nn.Linear(in_channels, time_embed_dim) 74 | self.act = None 75 | if act_fn == "silu": 76 | self.act = nn.SiLU() 77 | elif act_fn == "mish": 78 | self.act = nn.Mish() 79 | 80 | if time_cond_proj_dim is not None: 81 | self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False)) 82 | else: 83 | self.cond_proj = None 84 | 85 | 86 | if out_dim is not None: 87 | time_embed_dim_out = out_dim 88 | else: 89 | time_embed_dim_out = time_embed_dim 90 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) 91 | 92 | def forward(self, sample, condition=None): 93 | if condition is not None: 94 | sample = sample + self.cond_proj(condition) 95 | sample = self.linear_1(sample) 96 | 97 | if self.act is not None: 98 | sample = self.act(sample) 99 | 100 | sample = self.linear_2(sample) 101 | return sample 102 | 103 | 104 | class Timesteps(nn.Module): 105 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 106 | super().__init__() 107 | self.num_channels = num_channels 108 | self.flip_sin_to_cos = flip_sin_to_cos 109 | self.downscale_freq_shift = downscale_freq_shift 110 | 111 | def forward(self, timesteps): 112 | t_emb = get_timestep_embedding( 113 | timesteps, 114 | self.num_channels, 115 | flip_sin_to_cos=self.flip_sin_to_cos, 116 | downscale_freq_shift=self.downscale_freq_shift, 117 | ) 118 | return t_emb 119 | 120 | 121 | class GaussianFourierProjection(nn.Module): 122 | """Gaussian Fourier embeddings for noise levels.""" 123 | 124 | def __init__( 125 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False 126 | ): 127 | super().__init__() 128 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 129 | self.log = log 130 | self.flip_sin_to_cos = flip_sin_to_cos 131 | 132 | if set_W_to_weight: 133 | # to delete later 134 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 135 | 136 | self.weight = self.W 137 | 138 | def forward(self, x): 139 | if self.log: 140 | x = torch.log(x) 141 | 142 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi 143 | 144 | if self.flip_sin_to_cos: 145 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) 146 | else: 147 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 148 | return out 149 | 150 | 151 | class ImagePositionalEmbeddings(nn.Module): 152 | """ 153 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the 154 | height and width of the latent space. 155 | 156 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 157 | 158 | For VQ-diffusion: 159 | 160 | Output vector embeddings are used as input for the transformer. 161 | 162 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. 163 | 164 | Args: 165 | num_embed (`int`): 166 | Number of embeddings for the latent pixels embeddings. 167 | height (`int`): 168 | Height of the latent image i.e. the number of height embeddings. 169 | width (`int`): 170 | Width of the latent image i.e. the number of width embeddings. 171 | embed_dim (`int`): 172 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. 173 | """ 174 | 175 | def __init__( 176 | self, 177 | num_embed: int, 178 | height: int, 179 | width: int, 180 | embed_dim: int, 181 | ): 182 | super().__init__() 183 | 184 | self.height = height 185 | self.width = width 186 | self.num_embed = num_embed 187 | self.embed_dim = embed_dim 188 | 189 | self.emb = nn.Embedding(self.num_embed, embed_dim) 190 | self.height_emb = nn.Embedding(self.height, embed_dim) 191 | self.width_emb = nn.Embedding(self.width, embed_dim) 192 | 193 | def forward(self, index): 194 | emb = self.emb(index) 195 | 196 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) 197 | 198 | # 1 x H x D -> 1 x H x 1 x D 199 | height_emb = height_emb.unsqueeze(2) 200 | 201 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) 202 | 203 | # 1 x W x D -> 1 x 1 x W x D 204 | width_emb = width_emb.unsqueeze(1) 205 | 206 | pos_emb = height_emb + width_emb 207 | 208 | # 1 x H x W x D -> 1 x L xD 209 | pos_emb = pos_emb.view(1, self.height * self.width, -1) 210 | 211 | emb = emb + pos_emb[:, : emb.shape[1], :] 212 | 213 | return emb 214 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.modeling_utils import ModelMixin 10 | from diffusers.utils import BaseOutput 11 | from diffusers.utils.import_utils import is_xformers_available 12 | from diffusers.models.attention import CrossAttention, FeedForward 13 | 14 | from einops import rearrange, repeat 15 | import math 16 | 17 | 18 | def zero_module(module): 19 | for p in module.parameters(): 20 | p.detach().zero_() 21 | return module 22 | 23 | 24 | @dataclass 25 | class TemporalTransformer3DModelOutput(BaseOutput): 26 | sample: torch.FloatTensor 27 | 28 | 29 | if is_xformers_available(): 30 | import xformers 31 | import xformers.ops 32 | else: 33 | xformers = None 34 | 35 | 36 | def get_motion_module( 37 | in_channels, 38 | motion_module_type: str, 39 | motion_module_kwargs: dict 40 | ): 41 | if motion_module_type == "Vanilla": 42 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 43 | else: 44 | raise ValueError 45 | 46 | 47 | class VanillaTemporalModule(nn.Module): 48 | def __init__( 49 | self, 50 | in_channels, 51 | num_attention_heads=8, 52 | num_transformer_block=2, 53 | attention_block_types=("Temporal_Self", "Temporal_Self"), 54 | cross_frame_attention_mode=None, 55 | temporal_position_encoding=False, 56 | temporal_attention_dim_div=1, 57 | zero_initialize=True, 58 | ): 59 | super().__init__() 60 | 61 | self.temporal_transformer = TemporalTransformer3DModel( 62 | in_channels=in_channels, 63 | num_attention_heads=num_attention_heads, 64 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 65 | num_layers=num_transformer_block, 66 | attention_block_types=attention_block_types, 67 | cross_frame_attention_mode=cross_frame_attention_mode, 68 | temporal_position_encoding=temporal_position_encoding, 69 | ) 70 | 71 | if zero_initialize: 72 | self.temporal_transformer.proj_out = zero_module( 73 | self.temporal_transformer.proj_out) 74 | 75 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 76 | hidden_states = input_tensor 77 | hidden_states = self.temporal_transformer( 78 | hidden_states, encoder_hidden_states, attention_mask) 79 | 80 | output = hidden_states 81 | return output 82 | 83 | 84 | class TemporalTransformer3DModel(nn.Module): 85 | def __init__( 86 | self, 87 | in_channels, 88 | num_attention_heads, 89 | attention_head_dim, 90 | 91 | num_layers, 92 | attention_block_types=("Temporal_Self", "Temporal_Self", ), 93 | dropout=0.0, 94 | norm_num_groups=32, 95 | cross_attention_dim=768, 96 | activation_fn="geglu", 97 | attention_bias=False, 98 | upcast_attention=False, 99 | 100 | cross_frame_attention_mode=None, 101 | temporal_position_encoding=False, 102 | ): 103 | super().__init__() 104 | 105 | inner_dim = num_attention_heads * attention_head_dim 106 | 107 | self.norm = torch.nn.GroupNorm( 108 | num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 109 | self.proj_in = nn.Linear(in_channels, inner_dim) 110 | 111 | self.transformer_blocks = nn.ModuleList( 112 | [ 113 | TemporalTransformerBlock( 114 | dim=inner_dim, 115 | num_attention_heads=num_attention_heads, 116 | attention_head_dim=attention_head_dim, 117 | attention_block_types=attention_block_types, 118 | dropout=dropout, 119 | norm_num_groups=norm_num_groups, 120 | cross_attention_dim=cross_attention_dim, 121 | activation_fn=activation_fn, 122 | attention_bias=attention_bias, 123 | upcast_attention=upcast_attention, 124 | cross_frame_attention_mode=cross_frame_attention_mode, 125 | temporal_position_encoding=temporal_position_encoding, 126 | ) 127 | for d in range(num_layers) 128 | ] 129 | ) 130 | self.proj_out = nn.Linear(inner_dim, in_channels) 131 | 132 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 133 | assert hidden_states.dim( 134 | ) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 135 | video_length = hidden_states.shape[2] 136 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 137 | 138 | batch, channel, height, weight = hidden_states.shape 139 | residual = hidden_states 140 | 141 | hidden_states = self.norm(hidden_states) 142 | inner_dim = hidden_states.shape[1] 143 | hidden_states = hidden_states.permute( 144 | 0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 145 | hidden_states = self.proj_in(hidden_states) 146 | 147 | # Transformer Blocks 148 | for block in self.transformer_blocks: 149 | hidden_states = block( 150 | hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 151 | 152 | # output 153 | hidden_states = self.proj_out(hidden_states) 154 | hidden_states = hidden_states.reshape( 155 | batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 156 | 157 | output = hidden_states + residual 158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 159 | 160 | return output 161 | 162 | 163 | class TemporalTransformerBlock(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | num_attention_heads, 168 | attention_head_dim, 169 | attention_block_types=("Temporal_Self", "Temporal_Self", ), 170 | dropout=0.0, 171 | norm_num_groups=32, 172 | cross_attention_dim=768, 173 | activation_fn="geglu", 174 | attention_bias=False, 175 | upcast_attention=False, 176 | cross_frame_attention_mode=None, 177 | temporal_position_encoding=False, 178 | ): 179 | super().__init__() 180 | 181 | attention_blocks = [] 182 | norms = [] 183 | 184 | for block_name in attention_block_types: 185 | attention_blocks.append( 186 | VersatileAttention( 187 | attention_mode=block_name.split("_")[0], 188 | cross_attention_dim=cross_attention_dim if block_name.endswith( 189 | "_Cross") else None, 190 | 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | upcast_attention=upcast_attention, 197 | 198 | cross_frame_attention_mode=cross_frame_attention_mode, 199 | temporal_position_encoding=temporal_position_encoding, 200 | ) 201 | ) 202 | norms.append(nn.LayerNorm(dim)) 203 | 204 | self.attention_blocks = nn.ModuleList(attention_blocks) 205 | self.norms = nn.ModuleList(norms) 206 | 207 | self.ff = FeedForward(dim, dropout=dropout, 208 | activation_fn=activation_fn) 209 | self.ff_norm = nn.LayerNorm(dim) 210 | 211 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 212 | for attention_block, norm in zip(self.attention_blocks, self.norms): 213 | norm_hidden_states = norm(hidden_states) 214 | hidden_states = attention_block( 215 | norm_hidden_states, 216 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 217 | video_length=video_length, 218 | ) + hidden_states 219 | 220 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 221 | 222 | output = hidden_states 223 | return output 224 | 225 | 226 | class PositionalEncoding(nn.Module): 227 | def __init__( 228 | self, 229 | d_model, 230 | dropout=0., 231 | ): 232 | super().__init__() 233 | 234 | max_length = 64 235 | self.dropout = nn.Dropout(p=dropout) 236 | position = torch.arange(max_length).unsqueeze(1) 237 | div_term = torch.exp(torch.arange(0, d_model, 2) 238 | * (-math.log(10000.0) / d_model)) 239 | pe = torch.zeros(1, max_length, d_model) 240 | pe[0, :, 0::2] = torch.sin(position * div_term) 241 | pe[0, :, 1::2] = torch.cos(position * div_term) 242 | self.register_buffer('pos_encoding', pe) 243 | 244 | def forward(self, x): 245 | x = x + self.pos_encoding[:, :x.size(1)] 246 | return self.dropout(x) 247 | 248 | 249 | class VersatileAttention(CrossAttention): 250 | def __init__( 251 | self, 252 | attention_mode=None, 253 | cross_frame_attention_mode=None, 254 | temporal_position_encoding=False, 255 | *args, **kwargs 256 | ): 257 | super().__init__(*args, **kwargs) 258 | assert attention_mode == "Temporal" 259 | 260 | self.attention_mode = attention_mode 261 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 262 | 263 | self.pos_encoder = PositionalEncoding( 264 | kwargs["query_dim"], 265 | dropout=0., 266 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 267 | 268 | def extra_repr(self): 269 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 270 | 271 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 272 | batch_size, sequence_length, _ = hidden_states.shape 273 | 274 | if self.attention_mode == "Temporal": 275 | d = hidden_states.shape[1] 276 | hidden_states = rearrange( 277 | hidden_states, "(b f) d c -> (b d) f c", f=video_length) 278 | 279 | if self.pos_encoder is not None: 280 | hidden_states = self.pos_encoder(hidden_states) 281 | 282 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", 283 | d=d) if encoder_hidden_states is not None else encoder_hidden_states 284 | else: 285 | raise NotImplementedError 286 | 287 | encoder_hidden_states = encoder_hidden_states 288 | 289 | if self.group_norm is not None: 290 | hidden_states = self.group_norm( 291 | hidden_states.transpose(1, 2)).transpose(1, 2) 292 | 293 | query = self.to_q(hidden_states) 294 | dim = query.shape[-1] 295 | query = self.reshape_heads_to_batch_dim(query) 296 | 297 | if self.added_kv_proj_dim is not None: 298 | raise NotImplementedError 299 | 300 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 301 | key = self.to_k(encoder_hidden_states) 302 | value = self.to_v(encoder_hidden_states) 303 | 304 | key = self.reshape_heads_to_batch_dim(key) 305 | value = self.reshape_heads_to_batch_dim(value) 306 | 307 | if attention_mask is not None: 308 | if attention_mask.shape[-1] != query.shape[1]: 309 | target_length = query.shape[1] 310 | attention_mask = F.pad( 311 | attention_mask, (0, target_length), value=0.0) 312 | attention_mask = attention_mask.repeat_interleave( 313 | self.heads, dim=0) 314 | 315 | if self._use_memory_efficient_attention_xformers: 316 | hidden_states = self._memory_efficient_attention_xformers( 317 | query, key, value, attention_mask) 318 | hidden_states = hidden_states.to(query.dtype) 319 | else: 320 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 321 | hidden_states = self._attention( 322 | query, key, value, attention_mask) 323 | else: 324 | hidden_states = self._sliced_attention( 325 | query, key, value, sequence_length, dim, attention_mask) 326 | 327 | # linear proj 328 | hidden_states = self.to_out[0](hidden_states) 329 | 330 | # dropout 331 | hidden_states = self.to_out[1](hidden_states) 332 | 333 | if self.attention_mode == "Temporal": 334 | hidden_states = rearrange( 335 | hidden_states, "(b d) f c -> (b f) d c", d=d) 336 | 337 | return hidden_states 338 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from typing import Optional 7 | 8 | from einops import rearrange 9 | 10 | 11 | class InflatedConv3d(nn.Conv2d): 12 | def forward(self, x): 13 | video_length = x.shape[2] 14 | 15 | x = rearrange(x, "b c f h w -> (b f) c h w") 16 | x = super().forward(x) 17 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 18 | 19 | return x 20 | 21 | 22 | class InflatedGroupNorm(nn.GroupNorm): 23 | def forward(self, x): 24 | video_length = x.shape[2] 25 | 26 | x = rearrange(x, "b c f h w -> (b f) c h w") 27 | x = super().forward(x) 28 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 29 | 30 | return x 31 | 32 | 33 | class Upsample3D(nn.Module): 34 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 35 | super().__init__() 36 | self.channels = channels 37 | self.out_channels = out_channels or channels 38 | self.use_conv = use_conv 39 | self.use_conv_transpose = use_conv_transpose 40 | self.name = name 41 | 42 | conv = None 43 | if use_conv_transpose: 44 | raise NotImplementedError 45 | elif use_conv: 46 | self.conv = InflatedConv3d( 47 | self.channels, self.out_channels, 3, padding=1) 48 | 49 | def forward(self, hidden_states, output_size=None): 50 | assert hidden_states.shape[1] == self.channels 51 | 52 | if self.use_conv_transpose: 53 | raise NotImplementedError 54 | 55 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 56 | dtype = hidden_states.dtype 57 | if dtype == torch.bfloat16: 58 | hidden_states = hidden_states.to(torch.float32) 59 | 60 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 61 | if hidden_states.shape[0] >= 64: 62 | hidden_states = hidden_states.contiguous() 63 | 64 | # if `output_size` is passed we force the interpolation output 65 | # size and do not make use of `scale_factor=2` 66 | if output_size is None: 67 | hidden_states = F.interpolate(hidden_states, scale_factor=[ 68 | 1.0, 2.0, 2.0], mode="nearest") 69 | else: 70 | hidden_states = F.interpolate( 71 | hidden_states, size=output_size, mode="nearest") 72 | 73 | # If the input is bfloat16, we cast back to bfloat16 74 | if dtype == torch.bfloat16: 75 | hidden_states = hidden_states.to(dtype) 76 | 77 | # if self.use_conv: 78 | # if self.name == "conv": 79 | # hidden_states = self.conv(hidden_states) 80 | # else: 81 | # hidden_states = self.Conv2d_0(hidden_states) 82 | hidden_states = self.conv(hidden_states) 83 | 84 | return hidden_states 85 | 86 | 87 | class Downsample3D(nn.Module): 88 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 89 | super().__init__() 90 | self.channels = channels 91 | self.out_channels = out_channels or channels 92 | self.use_conv = use_conv 93 | self.padding = padding 94 | stride = 2 95 | self.name = name 96 | 97 | if use_conv: 98 | self.conv = InflatedConv3d( 99 | self.channels, self.out_channels, 3, stride=stride, padding=padding) 100 | else: 101 | raise NotImplementedError 102 | 103 | def forward(self, hidden_states): 104 | assert hidden_states.shape[1] == self.channels 105 | if self.use_conv and self.padding == 0: 106 | raise NotImplementedError 107 | 108 | assert hidden_states.shape[1] == self.channels 109 | hidden_states = self.conv(hidden_states) 110 | 111 | return hidden_states 112 | 113 | 114 | class ResnetBlock3D(nn.Module): 115 | def __init__( 116 | self, 117 | *, 118 | in_channels, 119 | out_channels=None, 120 | conv_shortcut=False, 121 | dropout=0.0, 122 | temb_channels=512, 123 | groups=32, 124 | groups_out=None, 125 | pre_norm=True, 126 | eps=1e-6, 127 | non_linearity="swish", 128 | time_embedding_norm="default", 129 | output_scale_factor=1.0, 130 | use_in_shortcut=None, 131 | use_inflated_groupnorm=None, 132 | use_temporal_conv=False, 133 | use_temporal_mixer=False, 134 | ): 135 | super().__init__() 136 | self.pre_norm = pre_norm 137 | self.pre_norm = True 138 | self.in_channels = in_channels 139 | out_channels = in_channels if out_channels is None else out_channels 140 | self.out_channels = out_channels 141 | self.use_conv_shortcut = conv_shortcut 142 | self.time_embedding_norm = time_embedding_norm 143 | self.output_scale_factor = output_scale_factor 144 | self.use_temporal_mixer = use_temporal_mixer 145 | if use_temporal_mixer: 146 | self.temporal_mixer = AlphaBlender(0.3, "learned", None) 147 | 148 | if groups_out is None: 149 | groups_out = groups 150 | 151 | assert use_inflated_groupnorm != None 152 | if use_inflated_groupnorm: 153 | self.norm1 = InflatedGroupNorm( 154 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 155 | else: 156 | self.norm1 = torch.nn.GroupNorm( 157 | num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 158 | 159 | if use_temporal_conv: 160 | self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=( 161 | 3, 1, 1), stride=1, padding=(1, 0, 0)) 162 | else: 163 | self.conv1 = InflatedConv3d( 164 | in_channels, out_channels, kernel_size=3, stride=1, padding=1) 165 | 166 | if temb_channels is not None: 167 | if self.time_embedding_norm == "default": 168 | time_emb_proj_out_channels = out_channels 169 | elif self.time_embedding_norm == "scale_shift": 170 | time_emb_proj_out_channels = out_channels * 2 171 | else: 172 | raise ValueError( 173 | f"unknown time_embedding_norm : {self.time_embedding_norm} ") 174 | 175 | self.time_emb_proj = torch.nn.Linear( 176 | temb_channels, time_emb_proj_out_channels) 177 | else: 178 | self.time_emb_proj = None 179 | 180 | if use_inflated_groupnorm: 181 | self.norm2 = InflatedGroupNorm( 182 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 183 | else: 184 | self.norm2 = torch.nn.GroupNorm( 185 | num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 186 | 187 | self.dropout = torch.nn.Dropout(dropout) 188 | if use_temporal_conv: 189 | self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=( 190 | 3, 1, 1), stride=1, padding=(1, 0, 0)) 191 | else: 192 | self.conv2 = InflatedConv3d( 193 | out_channels, out_channels, kernel_size=3, stride=1, padding=1) 194 | 195 | if non_linearity == "swish": 196 | self.nonlinearity = lambda x: F.silu(x) 197 | elif non_linearity == "mish": 198 | self.nonlinearity = Mish() 199 | elif non_linearity == "silu": 200 | self.nonlinearity = nn.SiLU() 201 | 202 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 203 | 204 | self.conv_shortcut = None 205 | if self.use_in_shortcut: 206 | self.conv_shortcut = InflatedConv3d( 207 | in_channels, out_channels, kernel_size=1, stride=1, padding=0) 208 | 209 | def forward(self, input_tensor, temb): 210 | if self.use_temporal_mixer: 211 | residual = input_tensor 212 | 213 | hidden_states = input_tensor 214 | 215 | hidden_states = self.norm1(hidden_states) 216 | hidden_states = self.nonlinearity(hidden_states) 217 | 218 | hidden_states = self.conv1(hidden_states) 219 | 220 | if temb is not None: 221 | temb = self.time_emb_proj(self.nonlinearity(temb))[ 222 | :, :, None, None, None] 223 | 224 | if temb is not None and self.time_embedding_norm == "default": 225 | hidden_states = hidden_states + temb 226 | 227 | hidden_states = self.norm2(hidden_states) 228 | 229 | if temb is not None and self.time_embedding_norm == "scale_shift": 230 | scale, shift = torch.chunk(temb, 2, dim=1) 231 | hidden_states = hidden_states * (1 + scale) + shift 232 | 233 | hidden_states = self.nonlinearity(hidden_states) 234 | 235 | hidden_states = self.dropout(hidden_states) 236 | hidden_states = self.conv2(hidden_states) 237 | 238 | if self.conv_shortcut is not None: 239 | input_tensor = self.conv_shortcut(input_tensor) 240 | 241 | output_tensor = (input_tensor + hidden_states) / \ 242 | self.output_scale_factor 243 | 244 | if self.use_temporal_mixer: 245 | output_tensor = self.temporal_mixer(residual, output_tensor, None) 246 | # return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None) 247 | return output_tensor 248 | 249 | 250 | class Mish(torch.nn.Module): 251 | def forward(self, hidden_states): 252 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 253 | 254 | 255 | class AlphaBlender(nn.Module): 256 | strategies = ["learned", "fixed", "learned_with_images"] 257 | 258 | def __init__( 259 | self, 260 | alpha: float, 261 | merge_strategy: str = "learned_with_images", 262 | rearrange_pattern: str = "b t -> (b t) 1 1", 263 | ): 264 | super().__init__() 265 | self.merge_strategy = merge_strategy 266 | self.rearrange_pattern = rearrange_pattern 267 | self.scaler = 10. 268 | 269 | assert ( 270 | merge_strategy in self.strategies 271 | ), f"merge_strategy needs to be in {self.strategies}" 272 | 273 | if self.merge_strategy == "fixed": 274 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 275 | elif ( 276 | self.merge_strategy == "learned" 277 | or self.merge_strategy == "learned_with_images" 278 | ): 279 | self.register_parameter( 280 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 281 | ) 282 | else: 283 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 284 | 285 | def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: 286 | if self.merge_strategy == "fixed": 287 | alpha = self.mix_factor 288 | elif self.merge_strategy == "learned": 289 | alpha = torch.sigmoid(self.mix_factor*self.scaler) 290 | elif self.merge_strategy == "learned_with_images": 291 | assert image_only_indicator is not None, "need image_only_indicator ..." 292 | alpha = torch.where( 293 | image_only_indicator.bool(), 294 | torch.ones(1, 1, device=image_only_indicator.device), 295 | rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), 296 | ) 297 | alpha = rearrange(alpha, self.rearrange_pattern) 298 | else: 299 | raise NotImplementedError 300 | return alpha 301 | 302 | def forward( 303 | self, 304 | x_spatial: torch.Tensor, 305 | x_temporal: torch.Tensor, 306 | image_only_indicator: Optional[torch.Tensor] = None, 307 | ) -> torch.Tensor: 308 | alpha = self.get_alpha(image_only_indicator) 309 | x = ( 310 | alpha.to(x_spatial.dtype) * x_spatial 311 | + (1.0 - alpha).to(x_spatial.dtype) * x_temporal 312 | ) 313 | return x 314 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | 18 | import argparse 19 | 20 | import torch 21 | from safetensors.torch import load_file 22 | 23 | from diffusers import StableDiffusionPipeline 24 | 25 | 26 | def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): 27 | # directly update weight in diffusers model 28 | for key in state_dict: 29 | # only process lora down key 30 | if "up." in key: continue 31 | 32 | up_key = key.replace(".down.", ".up.") 33 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 34 | model_key = model_key.replace("to_out.", "to_out.0.") 35 | layer_infos = model_key.split(".")[:-1] 36 | 37 | curr_layer = pipeline.unet 38 | while len(layer_infos) > 0: 39 | temp_name = layer_infos.pop(0) 40 | curr_layer = curr_layer.__getattr__(temp_name) 41 | 42 | weight_down = state_dict[key] 43 | weight_up = state_dict[up_key] 44 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 45 | 46 | return pipeline 47 | 48 | 49 | 50 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 51 | # load base model 52 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 53 | 54 | # load LoRA weight from .safetensors 55 | # state_dict = load_file(checkpoint_path) 56 | 57 | visited = [] 58 | 59 | # directly update weight in diffusers model 60 | for key in state_dict: 61 | # it is suggested to print out the key, it usually will be something like below 62 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 63 | 64 | # as we have set the alpha beforehand, so just skip 65 | if ".alpha" in key or key in visited: 66 | continue 67 | 68 | if "text" in key: 69 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 70 | curr_layer = pipeline.text_encoder 71 | else: 72 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 73 | curr_layer = pipeline.unet 74 | 75 | # find the target layer 76 | temp_name = layer_infos.pop(0) 77 | while len(layer_infos) > -1: 78 | try: 79 | curr_layer = curr_layer.__getattr__(temp_name) 80 | if len(layer_infos) > 0: 81 | temp_name = layer_infos.pop(0) 82 | elif len(layer_infos) == 0: 83 | break 84 | except Exception: 85 | if len(temp_name) > 0: 86 | temp_name += "_" + layer_infos.pop(0) 87 | else: 88 | temp_name = layer_infos.pop(0) 89 | 90 | pair_keys = [] 91 | if "lora_down" in key: 92 | pair_keys.append(key.replace("lora_down", "lora_up")) 93 | pair_keys.append(key) 94 | else: 95 | pair_keys.append(key) 96 | pair_keys.append(key.replace("lora_up", "lora_down")) 97 | 98 | # update weight 99 | if len(state_dict[pair_keys[0]].shape) == 4: 100 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 101 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 102 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 103 | else: 104 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 105 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 106 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 107 | 108 | # update visited list 109 | for item in pair_keys: 110 | visited.append(item) 111 | 112 | return pipeline 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | 118 | parser.add_argument( 119 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 120 | ) 121 | parser.add_argument( 122 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 123 | ) 124 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 125 | parser.add_argument( 126 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 127 | ) 128 | parser.add_argument( 129 | "--lora_prefix_text_encoder", 130 | default="lora_te", 131 | type=str, 132 | help="The prefix of text encoder weight in safetensors", 133 | ) 134 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 135 | parser.add_argument( 136 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 137 | ) 138 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 139 | 140 | args = parser.parse_args() 141 | 142 | base_model_path = args.base_model_path 143 | checkpoint_path = args.checkpoint_path 144 | dump_path = args.dump_path 145 | lora_prefix_unet = args.lora_prefix_unet 146 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 147 | alpha = args.alpha 148 | 149 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 150 | 151 | pipe = pipe.to(args.device) 152 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 153 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/utils/lcm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from safetensors import safe_open 4 | 5 | 6 | def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): 7 | """ 8 | See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 9 | 10 | Args: 11 | timesteps (`torch.Tensor`): 12 | generate embedding vectors at these timesteps 13 | embedding_dim (`int`, *optional*, defaults to 512): 14 | dimension of the embeddings to generate 15 | dtype: 16 | data type of the generated embeddings 17 | 18 | Returns: 19 | `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` 20 | """ 21 | assert len(w.shape) == 1 22 | w = w * 1000.0 23 | 24 | half_dim = embedding_dim // 2 25 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) 26 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) 27 | emb = w.to(dtype)[:, None] * emb[None, :] 28 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 29 | if embedding_dim % 2 == 1: # zero pad 30 | emb = torch.nn.functional.pad(emb, (0, 1)) 31 | assert emb.shape == (w.shape[0], embedding_dim) 32 | return emb 33 | 34 | 35 | def append_dims(x, target_dims): 36 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 37 | dims_to_append = target_dims - x.ndim 38 | if dims_to_append < 0: 39 | raise ValueError( 40 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 41 | return x[(...,) + (None,) * dims_to_append] 42 | 43 | 44 | # From LCMScheduler.get_scalings_for_boundary_condition_discrete 45 | def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): 46 | c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) 47 | c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 48 | return c_skip, c_out 49 | 50 | 51 | # Compare LCMScheduler.step, Step 4 52 | def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): 53 | if prediction_type == "epsilon": 54 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) 55 | alphas = extract_into_tensor(alphas, timesteps, sample.shape) 56 | pred_x_0 = (sample - sigmas * model_output) / alphas 57 | elif prediction_type == "v_prediction": 58 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) 59 | alphas = extract_into_tensor(alphas, timesteps, sample.shape) 60 | pred_x_0 = alphas * sample - sigmas * model_output 61 | else: 62 | raise ValueError( 63 | f"Prediction type {prediction_type} currently not supported.") 64 | 65 | return pred_x_0 66 | 67 | 68 | def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas): 69 | if prediction_type == "epsilon": 70 | sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) 71 | alphas = extract_into_tensor(alphas, timesteps, sample.shape) 72 | sample = sample * alphas / sigmas 73 | else: 74 | raise ValueError( 75 | f"Prediction type {prediction_type} currently not supported.") 76 | 77 | return sample 78 | 79 | 80 | def extract_into_tensor(a, t, x_shape): 81 | b, *_ = t.shape 82 | out = a.gather(-1, t) 83 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 84 | 85 | 86 | class DDIMSolver: 87 | def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): 88 | # DDIM sampling parameters 89 | step_ratio = timesteps // ddim_timesteps 90 | self.ddim_timesteps = ( 91 | np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 92 | # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1 93 | self.ddim_timesteps_prev = np.asarray( 94 | [0] + self.ddim_timesteps[:-1].tolist() 95 | ) 96 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] 97 | self.ddim_alpha_cumprods_prev = np.asarray( 98 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() 99 | ) 100 | self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] 101 | self.ddim_alpha_cumprods_prev = np.asarray( 102 | [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() 103 | ) 104 | # convert to torch tensors 105 | self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() 106 | self.ddim_timesteps_prev = torch.from_numpy( 107 | self.ddim_timesteps_prev).long() 108 | self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) 109 | self.ddim_alpha_cumprods_prev = torch.from_numpy( 110 | self.ddim_alpha_cumprods_prev) 111 | 112 | def to(self, device): 113 | self.ddim_timesteps = self.ddim_timesteps.to(device) 114 | self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) 115 | self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) 116 | self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to( 117 | device) 118 | return self 119 | 120 | def ddim_step(self, pred_x0, pred_noise, timestep_index): 121 | alpha_cumprod_prev = extract_into_tensor( 122 | self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) 123 | dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise 124 | x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt 125 | return x_prev 126 | 127 | 128 | @torch.no_grad() 129 | def update_ema(target_params, source_params, rate=0.99): 130 | """ 131 | Update target parameters to be closer to those of source parameters using 132 | an exponential moving average. 133 | 134 | :param target_params: the target parameter sequence. 135 | :param source_params: the source parameter sequence. 136 | :param rate: the EMA rate (closer to 1 means slower). 137 | """ 138 | for targ, src in zip(target_params, source_params): 139 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 140 | 141 | 142 | def convert_lcm_lora(unet, path, alpha=1.0): 143 | 144 | if path.endswith(("ckpt",)): 145 | state_dict = torch.load(path, map_location="cpu") 146 | else: 147 | state_dict = {} 148 | with safe_open(path, framework="pt", device="cpu") as f: 149 | for key in f.keys(): 150 | state_dict[key] = f.get_tensor(key) 151 | 152 | num_alpha = 0 153 | for key in state_dict.keys(): 154 | if "alpha" in key: 155 | num_alpha += 1 156 | 157 | lora_keys = [k for k in state_dict.keys( 158 | ) if k.endswith("lora_down.weight")] 159 | 160 | updated_state_dict = {} 161 | for key in lora_keys: 162 | lora_name = key.split(".")[0] 163 | 164 | if lora_name.startswith("lora_unet_"): 165 | diffusers_name = key.replace("lora_unet_", "").replace("_", ".") 166 | 167 | if "input.blocks" in diffusers_name: 168 | diffusers_name = diffusers_name.replace( 169 | "input.blocks", "down_blocks") 170 | else: 171 | diffusers_name = diffusers_name.replace( 172 | "down.blocks", "down_blocks") 173 | 174 | if "middle.block" in diffusers_name: 175 | diffusers_name = diffusers_name.replace( 176 | "middle.block", "mid_block") 177 | else: 178 | diffusers_name = diffusers_name.replace( 179 | "mid.block", "mid_block") 180 | if "output.blocks" in diffusers_name: 181 | diffusers_name = diffusers_name.replace( 182 | "output.blocks", "up_blocks") 183 | else: 184 | diffusers_name = diffusers_name.replace( 185 | "up.blocks", "up_blocks") 186 | 187 | diffusers_name = diffusers_name.replace( 188 | "transformer.blocks", "transformer_blocks") 189 | diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") 190 | diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") 191 | diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") 192 | diffusers_name = diffusers_name.replace( 193 | "to.out.0.lora", "to_out_lora") 194 | diffusers_name = diffusers_name.replace("proj.in", "proj_in") 195 | diffusers_name = diffusers_name.replace("proj.out", "proj_out") 196 | diffusers_name = diffusers_name.replace( 197 | "time.emb.proj", "time_emb_proj") 198 | diffusers_name = diffusers_name.replace( 199 | "conv.shortcut", "conv_shortcut") 200 | 201 | updated_state_dict[diffusers_name] = state_dict[key] 202 | up_diffusers_name = diffusers_name.replace(".down.", ".up.") 203 | up_key = key.replace("lora_down.weight", "lora_up.weight") 204 | updated_state_dict[up_diffusers_name] = state_dict[up_key] 205 | 206 | state_dict = updated_state_dict 207 | 208 | num_lora = 0 209 | for key in state_dict: 210 | if "up." in key: 211 | continue 212 | up_key = key.replace(".down.", ".up.") 213 | model_key = key.replace("processor.", "").replace("_lora", "").replace( 214 | "down.", "").replace("up.", "").replace(".lora", "") 215 | model_key = model_key.replace("to_out.", "to_out.0.") 216 | layer_infos = model_key.split(".")[:-1] 217 | 218 | curr_layer = unet 219 | while len(layer_infos) > 0: 220 | temp_name = layer_infos.pop(0) 221 | curr_layer = curr_layer.__getattr__(temp_name) 222 | 223 | weight_down = state_dict[key].to( 224 | curr_layer.weight.data.device, curr_layer.weight.data.dtype) 225 | weight_up = state_dict[up_key].to( 226 | curr_layer.weight.data.device, curr_layer.weight.data.dtype) 227 | 228 | if weight_up.ndim == 2: 229 | curr_layer.weight.data += 1/8 * alpha * \ 230 | torch.mm(weight_up, weight_down) 231 | else: 232 | assert weight_up.ndim == 4 233 | curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten( 234 | start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape) 235 | num_lora += 1 236 | 237 | return unet 238 | -------------------------------------------------------------------------------- /animatelcm_sd15/animatelcm/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | import cv2 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from PIL import Image 13 | 14 | from einops import rearrange 15 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 16 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers 17 | 18 | 19 | def resize_and_crop(image_path, target_width, target_height): 20 | image = Image.open(image_path).convert("RGB") 21 | 22 | orig_width, orig_height = image.size 23 | 24 | target_ratio = target_width / target_height 25 | orig_ratio = orig_width / orig_height 26 | 27 | if target_ratio > orig_ratio: 28 | resize_width = target_width 29 | resize_height = round((target_width / orig_width) * orig_height) 30 | else: 31 | resize_height = target_height 32 | resize_width = round((target_height / orig_height) * orig_width) 33 | 34 | resized_image = image.resize((resize_width, resize_height), Image.LANCZOS) 35 | 36 | x0 = (resize_width - target_width) / 2 37 | y0 = (resize_height - target_height) / 2 38 | x1 = x0 + target_width 39 | y1 = y0 + target_height 40 | 41 | cropped_image = resized_image.crop((x0, y0, x1, y1)) 42 | 43 | return cropped_image 44 | 45 | def adjust_colors(src_img, target_img): 46 | src_hsv = cv2.cvtColor(src_img, cv2.COLOR_BGR2HSV) 47 | target_hsv = cv2.cvtColor(target_img, cv2.COLOR_BGR2HSV) 48 | 49 | for idx in range(3): 50 | src_mean = np.mean(src_hsv[:, :, idx]) 51 | target_mean = np.mean(target_hsv[:, :, idx]) 52 | 53 | diff = src_mean - target_mean 54 | 55 | if idx == 0: 56 | target_hsv[:, :, idx] = np.clip(target_hsv[:, :, idx] + 0.03 * diff, 0, 180) 57 | else: 58 | target_hsv[:, :, idx] = np.clip(target_hsv[:, :, idx] + 1.0 * diff, 0, 255) 59 | 60 | adjusted_img = cv2.cvtColor(target_hsv, cv2.COLOR_HSV2BGR) 61 | return adjusted_img 62 | 63 | def zero_rank_print(s): 64 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 65 | 66 | 67 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 68 | videos = rearrange(videos, "b c t h w -> t b c h w") 69 | outputs = [] 70 | for x in videos: 71 | x = torchvision.utils.make_grid(x, nrow=n_rows) 72 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 73 | if rescale: 74 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 75 | x = (x * 255).numpy().astype(np.uint8) 76 | outputs.append(x) 77 | 78 | os.makedirs(os.path.dirname(path), exist_ok=True) 79 | imageio.mimsave(path, outputs, fps=fps) 80 | 81 | 82 | # DDIM Inversion 83 | @torch.no_grad() 84 | def init_prompt(prompt, pipeline): 85 | uncond_input = pipeline.tokenizer( 86 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 87 | return_tensors="pt" 88 | ) 89 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 90 | text_input = pipeline.tokenizer( 91 | [prompt], 92 | padding="max_length", 93 | max_length=pipeline.tokenizer.model_max_length, 94 | truncation=True, 95 | return_tensors="pt", 96 | ) 97 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 98 | context = torch.cat([uncond_embeddings, text_embeddings]) 99 | 100 | return context 101 | 102 | 103 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 104 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 105 | timestep, next_timestep = min( 106 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 107 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 108 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 109 | beta_prod_t = 1 - alpha_prod_t 110 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 111 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 112 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 113 | return next_sample 114 | 115 | 116 | def get_noise_pred_single(latents, t, context, unet): 117 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 118 | return noise_pred 119 | 120 | 121 | @torch.no_grad() 122 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 123 | context = init_prompt(prompt, pipeline) 124 | uncond_embeddings, cond_embeddings = context.chunk(2) 125 | all_latent = [latent] 126 | latent = latent.clone().detach() 127 | for i in tqdm(range(num_inv_steps)): 128 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 129 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 130 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 131 | all_latent.append(latent) 132 | return all_latent 133 | 134 | 135 | @torch.no_grad() 136 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 137 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 138 | return ddim_latents 139 | 140 | def load_weights( 141 | animation_pipeline, 142 | motion_module_path = "", 143 | motion_module_lora_configs = [], 144 | dreambooth_model_path = "", 145 | lora_model_path = "", 146 | lora_alpha = 0.8, 147 | ): 148 | unet_state_dict = {} 149 | if motion_module_path != "": 150 | print(f"load motion module from {motion_module_path}") 151 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 152 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 153 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 154 | 155 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 156 | assert len(unexpected) == 0 157 | del unet_state_dict 158 | 159 | if dreambooth_model_path != "": 160 | print(f"load dreambooth model from {dreambooth_model_path}") 161 | if dreambooth_model_path.endswith(".safetensors"): 162 | dreambooth_state_dict = {} 163 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 164 | for key in f.keys(): 165 | dreambooth_state_dict[key] = f.get_tensor(key) 166 | elif dreambooth_model_path.endswith(".ckpt"): 167 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 168 | 169 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 170 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 171 | 172 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 173 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 174 | 175 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 176 | del dreambooth_state_dict 177 | 178 | if lora_model_path != "": 179 | print(f"load lora model from {lora_model_path}") 180 | assert lora_model_path.endswith(".safetensors") 181 | lora_state_dict = {} 182 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 183 | for key in f.keys(): 184 | lora_state_dict[key] = f.get_tensor(key) 185 | 186 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 187 | del lora_state_dict 188 | 189 | 190 | for motion_module_lora_config in motion_module_lora_configs: 191 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 192 | print(f"load motion LoRA from {path}") 193 | 194 | motion_lora_state_dict = torch.load(path, map_location="cpu") 195 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 196 | 197 | animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) 198 | 199 | return animation_pipeline 200 | -------------------------------------------------------------------------------- /animatelcm_sd15/app-i2v.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | import random 6 | 7 | import gradio as gr 8 | from glob import glob 9 | from omegaconf import OmegaConf 10 | from datetime import datetime 11 | from safetensors import safe_open 12 | 13 | from diffusers import AutoencoderKL 14 | from diffusers.utils.import_utils import is_xformers_available 15 | from transformers import CLIPTextModel, CLIPTokenizer 16 | 17 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler 18 | from animatelcm.models.unet import UNet3DConditionModel 19 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline 20 | from animatelcm.utils.util import save_videos_grid 21 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 22 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora 23 | from animatelcm.utils.lcm_utils import convert_lcm_lora 24 | import copy 25 | 26 | sample_idx = 0 27 | scheduler_dict = { 28 | "LCM": LCMScheduler, 29 | } 30 | 31 | css = """ 32 | .toolbutton { 33 | margin-buttom: 0em 0em 0em 0em; 34 | max-width: 2.5em; 35 | min-width: 2.5em !important; 36 | height: 2.5em; 37 | } 38 | """ 39 | 40 | 41 | class AnimateController: 42 | def __init__(self): 43 | 44 | # config dirs 45 | self.basedir = os.getcwd() 46 | self.stable_diffusion_dir = os.path.join( 47 | self.basedir, "models", "StableDiffusion") 48 | self.motion_module_dir = os.path.join( 49 | self.basedir, "models", "Motion_Module") 50 | self.personalized_model_dir = os.path.join( 51 | self.basedir, "models", "DreamBooth_LoRA") 52 | self.savedir = os.path.join( 53 | self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) 54 | self.savedir_sample = os.path.join(self.savedir, "sample") 55 | self.lcm_lora_path = "models/LCM_LoRA/AnimateLCM_sd15_i2v_lora.safetensors" 56 | os.makedirs(self.savedir, exist_ok=True) 57 | 58 | self.stable_diffusion_list = [] 59 | self.motion_module_list = [] 60 | self.personalized_model_list = [] 61 | 62 | self.refresh_stable_diffusion() 63 | self.refresh_motion_module() 64 | self.refresh_personalized_model() 65 | 66 | # config models 67 | self.tokenizer = None 68 | self.text_encoder = None 69 | self.vae = None 70 | self.unet = None 71 | self.pipeline = None 72 | self.lora_model_state_dict = {} 73 | 74 | self.inference_config = OmegaConf.load("configs/inference-i2v.yaml") 75 | 76 | def refresh_stable_diffusion(self): 77 | self.stable_diffusion_list = glob( 78 | os.path.join(self.stable_diffusion_dir, "*/")) 79 | 80 | def refresh_motion_module(self): 81 | motion_module_list = glob(os.path.join( 82 | self.motion_module_dir, "*.ckpt")) 83 | self.motion_module_list = [ 84 | os.path.basename(p) for p in motion_module_list] 85 | 86 | def refresh_personalized_model(self): 87 | personalized_model_list = glob(os.path.join( 88 | self.personalized_model_dir, "*.safetensors")) 89 | self.personalized_model_list = [ 90 | os.path.basename(p) for p in personalized_model_list] 91 | 92 | def update_stable_diffusion(self, stable_diffusion_dropdown): 93 | stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) 94 | self.tokenizer = CLIPTokenizer.from_pretrained( 95 | stable_diffusion_dropdown, subfolder="tokenizer") 96 | self.text_encoder = CLIPTextModel.from_pretrained( 97 | stable_diffusion_dropdown, subfolder="text_encoder").cuda() 98 | self.vae = AutoencoderKL.from_pretrained( 99 | stable_diffusion_dropdown, subfolder="vae").cuda() 100 | self.unet = UNet3DConditionModel.from_pretrained_2d( 101 | stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() 102 | return gr.Dropdown.update() 103 | 104 | def update_motion_module(self, motion_module_dropdown): 105 | if self.unet is None: 106 | gr.Info(f"Please select a pretrained model path.") 107 | return gr.Dropdown.update(value=None) 108 | else: 109 | motion_module_dropdown = os.path.join( 110 | self.motion_module_dir, motion_module_dropdown) 111 | motion_module_state_dict = torch.load( 112 | motion_module_dropdown, map_location="cpu") 113 | missing, unexpected = self.unet.load_state_dict( 114 | motion_module_state_dict, strict=False) 115 | assert len(unexpected) == 0 116 | return gr.Dropdown.update() 117 | 118 | def update_base_model(self, base_model_dropdown): 119 | if self.unet is None: 120 | gr.Info(f"Please select a pretrained model path.") 121 | return gr.Dropdown.update(value=None) 122 | else: 123 | base_model_dropdown = os.path.join( 124 | self.personalized_model_dir, base_model_dropdown) 125 | base_model_state_dict = {} 126 | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: 127 | for key in f.keys(): 128 | base_model_state_dict[key] = f.get_tensor(key) 129 | 130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint( 131 | base_model_state_dict, self.vae.config) 132 | self.vae.load_state_dict(converted_vae_checkpoint) 133 | 134 | converted_unet_checkpoint = convert_ldm_unet_checkpoint( 135 | base_model_state_dict, self.unet.config) 136 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False) 137 | 138 | self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) 139 | return gr.Dropdown.update() 140 | 141 | def update_lora_model(self, lora_model_dropdown): 142 | lora_model_dropdown = os.path.join( 143 | self.personalized_model_dir, lora_model_dropdown) 144 | self.lora_model_state_dict = {} 145 | if lora_model_dropdown == "none": 146 | pass 147 | else: 148 | with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: 149 | for key in f.keys(): 150 | self.lora_model_state_dict[key] = f.get_tensor(key) 151 | return gr.Dropdown.update() 152 | 153 | def animate( 154 | self, 155 | lora_alpha_slider, 156 | spatial_lora_slider, 157 | prompt_textbox, 158 | negative_prompt_textbox, 159 | sampler_dropdown, 160 | sample_step_slider, 161 | width_slider, 162 | length_slider, 163 | height_slider, 164 | cfg_scale_slider, 165 | seed_textbox, 166 | image_upload, 167 | beta_end_slider, 168 | motion_scale_slider, 169 | ): 170 | print(image_upload) 171 | 172 | if is_xformers_available(): 173 | self.unet.enable_xformers_memory_efficient_attention() 174 | 175 | print(self.inference_config.noise_scheduler_kwargs["beta_end"]) 176 | self.inference_config.noise_scheduler_kwargs["beta_end"] = beta_end_slider 177 | 178 | self.unet.img_encoder.motion_scale = motion_scale_slider 179 | pipeline = AnimationPipeline( 180 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, 181 | scheduler=scheduler_dict[sampler_dropdown]( 182 | **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) 183 | ).to("cuda") 184 | 185 | if self.lora_model_state_dict != {}: 186 | pipeline = convert_lora( 187 | pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) 188 | 189 | pipeline.unet = convert_lcm_lora(copy.deepcopy( 190 | self.unet), self.lcm_lora_path, spatial_lora_slider) 191 | 192 | pipeline.to("cuda") 193 | 194 | if seed_textbox != -1 and seed_textbox != "": 195 | torch.manual_seed(int(seed_textbox)) 196 | else: 197 | torch.seed() 198 | seed = torch.initial_seed() 199 | 200 | with torch.autocast("cuda",torch.float16): 201 | sample = pipeline( 202 | prompt_textbox, 203 | negative_prompt=negative_prompt_textbox, 204 | num_inference_steps=sample_step_slider, 205 | guidance_scale=cfg_scale_slider, 206 | width=width_slider, 207 | height=height_slider, 208 | video_length=length_slider, 209 | image_path=image_upload 210 | ).videos 211 | 212 | save_sample_path = os.path.join( 213 | self.savedir_sample, f"{sample_idx}.mp4") 214 | save_videos_grid(sample, save_sample_path) 215 | 216 | sample_config = { 217 | "prompt": prompt_textbox, 218 | "n_prompt": negative_prompt_textbox, 219 | "sampler": sampler_dropdown, 220 | "num_inference_steps": sample_step_slider, 221 | "guidance_scale": cfg_scale_slider, 222 | "width": width_slider, 223 | "height": height_slider, 224 | "video_length": length_slider, 225 | "seed": seed, 226 | } 227 | json_str = json.dumps(sample_config, indent=4) 228 | with open(os.path.join(self.savedir, "logs.json"), "a") as f: 229 | f.write(json_str) 230 | f.write("\n\n") 231 | return gr.Video.update(value=save_sample_path) 232 | 233 | 234 | controller = AnimateController() 235 | 236 | controller.update_stable_diffusion("stable-diffusion-v1-5") 237 | controller.update_motion_module("AnimateLCM_sd15_i2v.ckpt") 238 | controller.update_base_model("realistic1.safetensors") 239 | 240 | 241 | def ui(): 242 | with gr.Blocks(css=css) as demo: 243 | gr.Markdown( 244 | """ 245 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) 246 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
247 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) 248 | """ 249 | 250 | ''' 251 | Important Notes: 252 | 1. The generation speed is around few seconds. 253 | 2. Increase the sampling step for better sample quality. 254 | ''' 255 | ) 256 | with gr.Column(variant="panel"): 257 | with gr.Row(): 258 | image_upload = gr.Image(label="Upload Image", tool="select", type="filepath") 259 | 260 | base_model_dropdown = gr.Dropdown( 261 | label="Select base Dreambooth model (required)", 262 | choices=controller.personalized_model_list, 263 | interactive=True, 264 | value="realistic1.safetensors" 265 | ) 266 | base_model_dropdown.change(fn=controller.update_base_model, inputs=[ 267 | base_model_dropdown], outputs=[base_model_dropdown]) 268 | 269 | lora_model_dropdown = gr.Dropdown( 270 | label="Select LoRA model (optional)", 271 | choices=["none"], 272 | value="none", 273 | interactive=True, 274 | ) 275 | lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ 276 | lora_model_dropdown], outputs=[lora_model_dropdown]) 277 | 278 | lora_alpha_slider = gr.Slider( 279 | label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) 280 | spatial_lora_slider = gr.Slider( 281 | label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) 282 | 283 | personalized_refresh_button = gr.Button( 284 | value="\U0001F503", elem_classes="toolbutton") 285 | 286 | def update_personalized_model(): 287 | controller.refresh_personalized_model() 288 | return [ 289 | gr.Dropdown.update( 290 | choices=controller.personalized_model_list), 291 | gr.Dropdown.update( 292 | choices=["none"] + controller.personalized_model_list) 293 | ] 294 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ 295 | base_model_dropdown, lora_model_dropdown]) 296 | 297 | with gr.Column(variant="panel"): 298 | gr.Markdown( 299 | """ 300 | ### 2. Configs for AnimateLCM. 301 | """ 302 | ) 303 | 304 | prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="best quality") 305 | negative_prompt_textbox = gr.Textbox( 306 | label="Negative prompt", lines=2, value="bad quality") 307 | 308 | with gr.Row().style(equal_height=False): 309 | with gr.Column(): 310 | with gr.Row(): 311 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( 312 | scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) 313 | sample_step_slider = gr.Slider( 314 | label="Sampling steps", value=4, minimum=1, maximum=25, step=1) 315 | 316 | motion_scale_slider = gr.Slider( 317 | label="motion scale (better identity with smaller scale)", value=0.8, minimum=0.0, maximum=1.5, step=0.05 318 | ) 319 | beta_end_slider = gr.Slider( 320 | label="beta end (a tricky way for selecting noisy steps)", value=0.014, minimum=0.012, maximum=0.016, step=0.001) 321 | width_slider = gr.Slider( 322 | label="Width", value=512, minimum=256, maximum=1024, step=64) 323 | height_slider = gr.Slider( 324 | label="Height", value=512, minimum=256, maximum=1024, step=64) 325 | length_slider = gr.Slider( 326 | label="Animation length", value=16, minimum=12, maximum=20, step=1) 327 | cfg_scale_slider = gr.Slider( 328 | label="CFG Scale", value=1, minimum=1, maximum=2) 329 | 330 | with gr.Row(): 331 | seed_textbox = gr.Textbox(label="Seed", value=-1) 332 | seed_button = gr.Button( 333 | value="\U0001F3B2", elem_classes="toolbutton") 334 | seed_button.click(fn=lambda: gr.Textbox.update( 335 | value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) 336 | 337 | generate_button = gr.Button( 338 | value="Generate", variant='primary') 339 | 340 | result_video = gr.Video( 341 | label="Generated Animation", interactive=False) 342 | 343 | 344 | generate_button.click( 345 | fn=controller.animate, 346 | inputs=[ 347 | lora_alpha_slider, 348 | spatial_lora_slider, 349 | prompt_textbox, 350 | negative_prompt_textbox, 351 | sampler_dropdown, 352 | sample_step_slider, 353 | width_slider, 354 | length_slider, 355 | height_slider, 356 | cfg_scale_slider, 357 | seed_textbox, 358 | image_upload, 359 | beta_end_slider, 360 | motion_scale_slider, 361 | ], 362 | outputs=[result_video] 363 | ) 364 | examples = [ 365 | [0.8, 0.8, "good quality, cloud", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/cloud.jpeg", 0.014, 0.8], 366 | [0.8, 0.8, "good quality, dog", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/dog.jpg", 0.014, 0.8], 367 | [0.8, 0.8, "good quality, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/fire.jpg", 0.014, 0.8], 368 | [0.8, 0.8, "good quality, fox, snow", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/fox.jpg", 0.014, 0.8], 369 | [0.8, 0.8, "good quality, girl, wind, flower", "bad quality", "LCM", 4, 768, 16, 512, 1, 1235, "test_imgs/girl_flower.jpg", 0.014, 0.8], 370 | [0.8, 0.8, "good quality, lighter, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1235, "test_imgs/lighter.jpg", 0.014, 0.8], 371 | [0.8, 0.8, "good quality, snowman, fire", "bad quality", "LCM", 4, 768, 16, 512, 1, 1234, "test_imgs/snow_man_fire.jpg", 0.014, 0.8], 372 | ] 373 | gr.Examples( 374 | examples = examples, 375 | inputs=[ 376 | lora_alpha_slider, 377 | spatial_lora_slider, 378 | prompt_textbox, 379 | negative_prompt_textbox, 380 | sampler_dropdown, 381 | sample_step_slider, 382 | width_slider, 383 | length_slider, 384 | height_slider, 385 | cfg_scale_slider, 386 | seed_textbox, 387 | image_upload, 388 | beta_end_slider, 389 | motion_scale_slider, 390 | ], 391 | outputs=[result_video], 392 | fn=controller.animate, 393 | cache_examples=True, 394 | ) 395 | return demo 396 | 397 | 398 | if __name__ == "__main__": 399 | demo = ui() 400 | demo.queue(concurrency_count=3, max_size=20) 401 | demo.launch(share=True, server_name="127.0.0.1") 402 | -------------------------------------------------------------------------------- /animatelcm_sd15/app.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | import random 6 | 7 | import gradio as gr 8 | from glob import glob 9 | from omegaconf import OmegaConf 10 | from datetime import datetime 11 | from safetensors import safe_open 12 | 13 | from diffusers import AutoencoderKL 14 | from diffusers.utils.import_utils import is_xformers_available 15 | from transformers import CLIPTextModel, CLIPTokenizer 16 | 17 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler 18 | from animatelcm.models.unet import UNet3DConditionModel 19 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline 20 | from animatelcm.utils.util import save_videos_grid 21 | from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 22 | from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora 23 | from animatelcm.utils.lcm_utils import convert_lcm_lora 24 | import copy 25 | 26 | sample_idx = 0 27 | scheduler_dict = { 28 | "LCM": LCMScheduler, 29 | } 30 | 31 | css = """ 32 | .toolbutton { 33 | margin-buttom: 0em 0em 0em 0em; 34 | max-width: 2.5em; 35 | min-width: 2.5em !important; 36 | height: 2.5em; 37 | } 38 | """ 39 | 40 | 41 | class AnimateController: 42 | def __init__(self): 43 | 44 | # config dirs 45 | self.basedir = os.getcwd() 46 | self.stable_diffusion_dir = os.path.join( 47 | self.basedir, "models", "StableDiffusion") 48 | self.motion_module_dir = os.path.join( 49 | self.basedir, "models", "Motion_Module") 50 | self.personalized_model_dir = os.path.join( 51 | self.basedir, "models", "Personalized") 52 | self.savedir = os.path.join( 53 | self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) 54 | self.savedir_sample = os.path.join(self.savedir, "sample") 55 | self.lcm_lora_path = "models/LCM_LoRA/AnimateLCM_sd15_t2v_lora.safetensors" 56 | os.makedirs(self.savedir, exist_ok=True) 57 | 58 | self.stable_diffusion_list = [] 59 | self.motion_module_list = [] 60 | self.personalized_model_list = [] 61 | 62 | self.refresh_stable_diffusion() 63 | self.refresh_motion_module() 64 | self.refresh_personalized_model() 65 | 66 | # config models 67 | self.tokenizer = None 68 | self.text_encoder = None 69 | self.vae = None 70 | self.unet = None 71 | self.pipeline = None 72 | self.lora_model_state_dict = {} 73 | 74 | self.inference_config = OmegaConf.load("configs/inference-t2v.yaml") 75 | 76 | def refresh_stable_diffusion(self): 77 | self.stable_diffusion_list = glob( 78 | os.path.join(self.stable_diffusion_dir, "*/")) 79 | 80 | def refresh_motion_module(self): 81 | motion_module_list = glob(os.path.join( 82 | self.motion_module_dir, "*.ckpt")) 83 | self.motion_module_list = [ 84 | os.path.basename(p) for p in motion_module_list] 85 | 86 | def refresh_personalized_model(self): 87 | personalized_model_list = glob(os.path.join( 88 | self.personalized_model_dir, "*.safetensors")) 89 | self.personalized_model_list = [ 90 | os.path.basename(p) for p in personalized_model_list] 91 | 92 | def update_stable_diffusion(self, stable_diffusion_dropdown): 93 | stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) 94 | self.tokenizer = CLIPTokenizer.from_pretrained( 95 | stable_diffusion_dropdown, subfolder="tokenizer") 96 | self.text_encoder = CLIPTextModel.from_pretrained( 97 | stable_diffusion_dropdown, subfolder="text_encoder").cuda() 98 | self.vae = AutoencoderKL.from_pretrained( 99 | stable_diffusion_dropdown, subfolder="vae").cuda() 100 | self.unet = UNet3DConditionModel.from_pretrained_2d( 101 | stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() 102 | return gr.Dropdown.update() 103 | 104 | def update_motion_module(self, motion_module_dropdown): 105 | if self.unet is None: 106 | gr.Info(f"Please select a pretrained model path.") 107 | return gr.Dropdown.update(value=None) 108 | else: 109 | motion_module_dropdown = os.path.join( 110 | self.motion_module_dir, motion_module_dropdown) 111 | motion_module_state_dict = torch.load( 112 | motion_module_dropdown, map_location="cpu") 113 | missing, unexpected = self.unet.load_state_dict( 114 | motion_module_state_dict, strict=False) 115 | assert len(unexpected) == 0 116 | return gr.Dropdown.update() 117 | 118 | def update_base_model(self, base_model_dropdown): 119 | if self.unet is None: 120 | gr.Info(f"Please select a pretrained model path.") 121 | return gr.Dropdown.update(value=None) 122 | else: 123 | base_model_dropdown = os.path.join( 124 | self.personalized_model_dir, base_model_dropdown) 125 | base_model_state_dict = {} 126 | with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: 127 | for key in f.keys(): 128 | base_model_state_dict[key] = f.get_tensor(key) 129 | 130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint( 131 | base_model_state_dict, self.vae.config) 132 | self.vae.load_state_dict(converted_vae_checkpoint) 133 | 134 | converted_unet_checkpoint = convert_ldm_unet_checkpoint( 135 | base_model_state_dict, self.unet.config) 136 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False) 137 | 138 | # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) 139 | return gr.Dropdown.update() 140 | 141 | def update_lora_model(self, lora_model_dropdown): 142 | lora_model_dropdown = os.path.join( 143 | self.personalized_model_dir, lora_model_dropdown) 144 | self.lora_model_state_dict = {} 145 | if lora_model_dropdown == "none": 146 | pass 147 | else: 148 | with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: 149 | for key in f.keys(): 150 | self.lora_model_state_dict[key] = f.get_tensor(key) 151 | return gr.Dropdown.update() 152 | 153 | def animate( 154 | self, 155 | lora_alpha_slider, 156 | spatial_lora_slider, 157 | prompt_textbox, 158 | negative_prompt_textbox, 159 | sampler_dropdown, 160 | sample_step_slider, 161 | width_slider, 162 | length_slider, 163 | height_slider, 164 | cfg_scale_slider, 165 | seed_textbox 166 | ): 167 | 168 | if is_xformers_available(): 169 | self.unet.enable_xformers_memory_efficient_attention() 170 | 171 | pipeline = AnimationPipeline( 172 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, 173 | scheduler=scheduler_dict[sampler_dropdown]( 174 | **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) 175 | ).to("cuda") 176 | 177 | if self.lora_model_state_dict != {}: 178 | pipeline = convert_lora( 179 | pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) 180 | 181 | pipeline.unet = convert_lcm_lora(copy.deepcopy( 182 | self.unet), self.lcm_lora_path, spatial_lora_slider) 183 | 184 | pipeline.to("cuda") 185 | 186 | if seed_textbox != -1 and seed_textbox != "": 187 | torch.manual_seed(int(seed_textbox)) 188 | else: 189 | torch.seed() 190 | seed = torch.initial_seed() 191 | 192 | sample = pipeline( 193 | prompt_textbox, 194 | negative_prompt=negative_prompt_textbox, 195 | num_inference_steps=sample_step_slider, 196 | guidance_scale=cfg_scale_slider, 197 | width=width_slider, 198 | height=height_slider, 199 | video_length=length_slider, 200 | ).videos 201 | 202 | save_sample_path = os.path.join( 203 | self.savedir_sample, f"{sample_idx}.mp4") 204 | save_videos_grid(sample, save_sample_path) 205 | 206 | sample_config = { 207 | "prompt": prompt_textbox, 208 | "n_prompt": negative_prompt_textbox, 209 | "sampler": sampler_dropdown, 210 | "num_inference_steps": sample_step_slider, 211 | "guidance_scale": cfg_scale_slider, 212 | "width": width_slider, 213 | "height": height_slider, 214 | "video_length": length_slider, 215 | "seed": seed 216 | } 217 | json_str = json.dumps(sample_config, indent=4) 218 | with open(os.path.join(self.savedir, "logs.json"), "a") as f: 219 | f.write(json_str) 220 | f.write("\n\n") 221 | return gr.Video.update(value=save_sample_path) 222 | 223 | 224 | controller = AnimateController() 225 | 226 | controller.update_stable_diffusion("stable-diffusion-v1-5") 227 | controller.update_motion_module("AnimateLCM_sd15_t2v.ckpt") 228 | controller.update_base_model("realistic2.safetensors") 229 | 230 | 231 | def ui(): 232 | with gr.Blocks(css=css) as demo: 233 | gr.Markdown( 234 | """ 235 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) 236 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
237 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) 238 | """ 239 | 240 | ''' 241 | Important Notes: 242 | 1. The generation speed is around 1~2 seconds. There is delay in the space. 243 | 2. Increase the sampling step and cfg if you want more fancy videos. 244 | ''' 245 | ) 246 | with gr.Column(variant="panel"): 247 | with gr.Row(): 248 | 249 | base_model_dropdown = gr.Dropdown( 250 | label="Select base Dreambooth model (required)", 251 | choices=controller.personalized_model_list, 252 | interactive=True, 253 | value="realistic2.safetensors" 254 | ) 255 | 256 | motion_module_dropdown = gr.Dropdown( 257 | label="Select motion modules", 258 | choices=controller.motion_module_list, 259 | interactive=True, 260 | value="sd15_t2v_beta_motion.ckpt" 261 | ) 262 | base_model_dropdown.change(fn=controller.update_base_model, inputs=[ 263 | base_model_dropdown], outputs=[base_model_dropdown]) 264 | 265 | motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown],outputs=[motion_module_dropdown]) 266 | 267 | lora_model_dropdown = gr.Dropdown( 268 | label="Select LoRA model (optional)", 269 | choices=["none"], 270 | value="none", 271 | interactive=True, 272 | ) 273 | lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ 274 | lora_model_dropdown], outputs=[lora_model_dropdown]) 275 | 276 | lora_alpha_slider = gr.Slider( 277 | label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) 278 | spatial_lora_slider = gr.Slider( 279 | label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) 280 | 281 | personalized_refresh_button = gr.Button( 282 | value="\U0001F503", elem_classes="toolbutton") 283 | 284 | def update_personalized_model(): 285 | controller.refresh_personalized_model() 286 | return [ 287 | gr.Dropdown.update( 288 | choices=controller.personalized_model_list), 289 | gr.Dropdown.update( 290 | choices=["none"] + controller.personalized_model_list) 291 | ] 292 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ 293 | base_model_dropdown, lora_model_dropdown]) 294 | 295 | with gr.Column(variant="panel"): 296 | gr.Markdown( 297 | """ 298 | ### 2. Configs for AnimateLCM. 299 | """ 300 | ) 301 | 302 | prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit") 303 | negative_prompt_textbox = gr.Textbox( 304 | label="Negative prompt", lines=2, value="bad quality") 305 | 306 | with gr.Row().style(equal_height=False): 307 | with gr.Column(): 308 | with gr.Row(): 309 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( 310 | scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) 311 | sample_step_slider = gr.Slider( 312 | label="Sampling steps", value=6, minimum=1, maximum=25, step=1) 313 | 314 | width_slider = gr.Slider( 315 | label="Width", value=512, minimum=256, maximum=1024, step=64) 316 | height_slider = gr.Slider( 317 | label="Height", value=512, minimum=256, maximum=1024, step=64) 318 | length_slider = gr.Slider( 319 | label="Animation length", value=16, minimum=12, maximum=20, step=1) 320 | cfg_scale_slider = gr.Slider( 321 | label="CFG Scale", value=1.5, minimum=1, maximum=2) 322 | 323 | with gr.Row(): 324 | seed_textbox = gr.Textbox(label="Seed", value=-1) 325 | seed_button = gr.Button( 326 | value="\U0001F3B2", elem_classes="toolbutton") 327 | seed_button.click(fn=lambda: gr.Textbox.update( 328 | value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) 329 | 330 | generate_button = gr.Button( 331 | value="Generate", variant='primary') 332 | 333 | result_video = gr.Video( 334 | label="Generated Animation", interactive=False) 335 | 336 | 337 | generate_button.click( 338 | fn=controller.animate, 339 | inputs=[ 340 | lora_alpha_slider, 341 | spatial_lora_slider, 342 | prompt_textbox, 343 | negative_prompt_textbox, 344 | sampler_dropdown, 345 | sample_step_slider, 346 | width_slider, 347 | length_slider, 348 | height_slider, 349 | cfg_scale_slider, 350 | seed_textbox, 351 | ], 352 | outputs=[result_video] 353 | ) 354 | examples = [ 355 | [0.8, 0.8, "a boy is holding a rabbit", "bad quality", "LCM", 8, 512, 16, 512, 1.5, 1234], 356 | [0.8, 0.8, "1girl smiling", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1233], 357 | [0.8, 0.8, "1girl,face,white background,", "bad quality", "LCM", 6, 512, 16, 512, 1.5, 1234], 358 | [0.8, 0.8, "clouds in the sky, best quality", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1234], 359 | 360 | 361 | ] 362 | gr.Examples( 363 | examples = examples, 364 | inputs=[ 365 | lora_alpha_slider, 366 | spatial_lora_slider, 367 | prompt_textbox, 368 | negative_prompt_textbox, 369 | sampler_dropdown, 370 | sample_step_slider, 371 | width_slider, 372 | length_slider, 373 | height_slider, 374 | cfg_scale_slider, 375 | seed_textbox, 376 | ], 377 | outputs=[result_video], 378 | fn=controller.animate, 379 | cache_examples=True, 380 | ) 381 | 382 | return demo 383 | 384 | 385 | if __name__ == "__main__": 386 | demo = ui() 387 | # gr.close_all() 388 | demo.queue(concurrency_count=3, max_size=20) 389 | demo.launch(share=True, server_name="127.0.0.1") 390 | -------------------------------------------------------------------------------- /animatelcm_sd15/batch_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import os 5 | from omegaconf import OmegaConf 6 | 7 | import torch 8 | 9 | from diffusers import AutoencoderKL 10 | 11 | from tqdm.auto import tqdm 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from animatelcm.models.unet import UNet3DConditionModel 15 | from animatelcm.pipelines.pipeline_animation import AnimationPipeline 16 | from animatelcm.utils.util import save_videos_grid 17 | from animatelcm.utils.util import load_weights 18 | from animatelcm.scheduler.lcm_scheduler import LCMScheduler 19 | from animatelcm.utils.lcm_utils import convert_lcm_lora 20 | from diffusers.utils.import_utils import is_xformers_available 21 | from pathlib import Path 22 | 23 | 24 | def main(args): 25 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 26 | func_args = dict(func_args) 27 | 28 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 29 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 30 | os.makedirs(savedir) 31 | 32 | config = OmegaConf.load(args.config) 33 | samples = [] 34 | 35 | sample_idx = 0 36 | for model_idx, (config_key, model_config) in enumerate(list(config.items())): 37 | 38 | motion_modules = model_config.motion_module 39 | lcm_lora = model_config.lcm_lora_path 40 | motion_modules = ( 41 | [motion_modules] 42 | if isinstance(motion_modules, str) 43 | else list(motion_modules) 44 | ) 45 | lcm_lora = [lcm_lora] if isinstance(lcm_lora, str) else list(lcm_lora) 46 | lcm_lora = lcm_lora * len(motion_modules) if len(lcm_lora) == 1 else lcm_lora 47 | for motion_module, lcm_lora in zip(motion_modules, lcm_lora): 48 | inference_config = OmegaConf.load( 49 | model_config.get("inference_config", args.inference_config) 50 | ) 51 | 52 | tokenizer = CLIPTokenizer.from_pretrained( 53 | args.pretrained_model_path, subfolder="tokenizer" 54 | ) 55 | text_encoder = CLIPTextModel.from_pretrained( 56 | args.pretrained_model_path, subfolder="text_encoder" 57 | ) 58 | vae = AutoencoderKL.from_pretrained( 59 | args.pretrained_model_path, subfolder="vae" 60 | ) 61 | unet = UNet3DConditionModel.from_pretrained_2d( 62 | args.pretrained_model_path, 63 | subfolder="unet", 64 | unet_additional_kwargs=OmegaConf.to_container( 65 | inference_config.unet_additional_kwargs 66 | ), 67 | ) 68 | 69 | if is_xformers_available(): 70 | unet.enable_xformers_memory_efficient_attention() 71 | else: 72 | assert False 73 | 74 | pipeline = AnimationPipeline( 75 | vae=vae, 76 | text_encoder=text_encoder, 77 | tokenizer=tokenizer, 78 | unet=unet, 79 | scheduler=LCMScheduler( 80 | **OmegaConf.to_container(inference_config.noise_scheduler_kwargs) 81 | ), 82 | ).to("cuda") 83 | 84 | pipeline = load_weights( 85 | pipeline, 86 | motion_module_path=motion_module, 87 | motion_module_lora_configs=model_config.get( 88 | "motion_module_lora_configs", [] 89 | ), 90 | dreambooth_model_path=model_config.get("dreambooth_path", ""), 91 | lora_model_path=model_config.get("lora_model_path", ""), 92 | lora_alpha=model_config.get("lora_alpha", 0.8), 93 | ).to("cuda") 94 | 95 | pipeline.unet = convert_lcm_lora(pipeline.unet, lcm_lora, 1.0) 96 | prompts = model_config.prompt 97 | image_paths = ( 98 | model_config.image_paths 99 | if hasattr(model_config, "image_paths") 100 | else [None for _ in range(len(prompts))] 101 | ) 102 | control_paths = ( 103 | model_config.control_paths 104 | if hasattr(model_config, "control_paths") 105 | else [None for _ in range(len(prompts))] 106 | ) 107 | n_prompts = ( 108 | list(model_config.n_prompt) * len(prompts) 109 | if len(model_config.n_prompt) == 1 110 | else model_config.n_prompt 111 | ) 112 | 113 | random_seeds = model_config.get("seed", [-1]) 114 | random_seeds = ( 115 | [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 116 | ) 117 | random_seeds = ( 118 | random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 119 | ) 120 | 121 | config[config_key].random_seed = [] 122 | for prompt_idx, ( 123 | prompt, 124 | n_prompt, 125 | random_seed, 126 | image_path, 127 | control_path, 128 | ) in enumerate( 129 | zip(prompts, n_prompts, random_seeds, image_paths, control_paths) 130 | ): 131 | 132 | if random_seed != -1: 133 | torch.manual_seed(random_seed) 134 | else: 135 | torch.seed() 136 | config[config_key].random_seed.append(torch.initial_seed()) 137 | 138 | print(f"current seed: {torch.initial_seed()}") 139 | print(f"sampling {prompt} ...") 140 | sample = pipeline( 141 | prompt, 142 | negative_prompt=n_prompt, 143 | control_path=control_path, 144 | image_path=image_path, 145 | num_inference_steps=model_config.steps, 146 | guidance_scale=model_config.guidance_scale, 147 | width=model_config.W, 148 | height=model_config.H, 149 | video_length=model_config.L, 150 | do_classifier_free_guidance=model_config.get( 151 | "do_classifier_free_guidance", False 152 | ), 153 | ).videos 154 | samples.append(sample) 155 | 156 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) 157 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") 158 | print(f"save to {savedir}/sample/{prompt}.gif") 159 | 160 | sample_idx += 1 161 | 162 | samples = torch.concat(samples) 163 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) 164 | 165 | OmegaConf.save(config, f"{savedir}/config.yaml") 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument( 171 | "--pretrained_model_path", 172 | type=str, 173 | default="models/StableDiffusion/stable-diffusion-v1-5", 174 | ) 175 | parser.add_argument( 176 | "--inference_config", type=str, default="configs/inference-t2v.yaml" 177 | ) 178 | parser.add_argument("--config", type=str, required=True) 179 | 180 | args = parser.parse_args() 181 | main(args) 182 | -------------------------------------------------------------------------------- /animatelcm_sd15/configs/batch_inference_example.yaml: -------------------------------------------------------------------------------- 1 | T2V-Example: 2 | inference_config: "configs/inference/inference-t2v.yaml" 3 | motion_module: 4 | - "models/Motion_Module/xxx.ckpt" 5 | lcm_lora: 6 | - "models/LCM_LoRA/xxx.safetensors" 7 | 8 | dreambooth_path: "models/Personalized/xxx.safetensors" 9 | lora_model_path: "" 10 | 11 | seed: [2, 0, 2 , 4] 12 | steps: 4 13 | guidance_scale: 1.5 14 | do_classifier_free_guidance: True 15 | H: 512 16 | W: 512 17 | L: 16 18 | 19 | prompt: 20 | - "xxx, specify your prompt here" 21 | - "xxx, specify your prompt here" 22 | - "xxx, specify your prompt here" 23 | - "xxx, specify your prompt here" 24 | 25 | n_prompt: 26 | - "xxx, specify your negative prompt here" 27 | - "xxx, specify your negative prompt here" 28 | - "xxx, specify your negative prompt here" 29 | - "xxx, specify your negative prompt here" 30 | 31 | I2V-Example: 32 | 33 | inference_config: "configs/inference/inference-i2v.yaml" 34 | motion_module: 35 | - "models/Motion_Module/xxx.ckpt" 36 | lcm_lora: 37 | - "models/LCM_LoRA/xxx.safetensors" 38 | 39 | dreambooth_path: "models/Personalized/xxx.safetensors" 40 | lora_model_path: "" 41 | 42 | seed: [2, 0, 2, 4] 43 | steps: 4 44 | guidance_scale: 1 # should be [1,2] 45 | do_classifier_free_guidance: True 46 | H: 512 47 | W: 512 48 | L: 16 49 | 50 | prompt: 51 | - "xxx, specify your prompt here" 52 | - "xxx, specify your prompt here" 53 | - "xxx, specify your prompt here" 54 | - "xxx, specify your prompt here" 55 | 56 | n_prompt: 57 | - "xxx, specify your negative prompt here" 58 | - "xxx, specify your negative prompt here" 59 | - "xxx, specify your negative prompt here" 60 | - "xxx, specify your negative prompt here" 61 | 62 | image_paths: 63 | - "xxx, specify your image paths here" 64 | - "xxx, specify your image paths here" 65 | - "xxx, specify your image paths here" 66 | - "xxx, specify your image paths here" 67 | 68 | -------------------------------------------------------------------------------- /animatelcm_sd15/configs/inference-i2v.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_pixel_encoder: false 3 | use_img_encoder: true 4 | use_inflated_groupnorm: true 5 | unet_use_cross_frame_attention: false 6 | unet_use_temporal_attention: false 7 | use_motion_module: true 8 | use_motion_resnet: false 9 | motion_module_resolutions: 10 | - 1 11 | - 2 12 | - 4 13 | - 8 14 | motion_module_mid_block: true 15 | motion_module_deco5der_only: false 16 | motion_module_type: Vanilla 17 | motion_module_kwargs: 18 | num_attention_heads: 8 19 | num_transformer_block: 1 20 | attention_block_types: 21 | - Temporal_Self 22 | - Temporal_Self 23 | temporal_position_encoding: true 24 | temporal_attention_dim_div: 1 25 | 26 | noise_scheduler_kwargs: 27 | beta_start: 0.00085 28 | beta_end: 0.014 29 | beta_schedule: "linear" 30 | original_inference_steps: 200 31 | steps_offset: 1 -------------------------------------------------------------------------------- /animatelcm_sd15/configs/inference-t2v.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_pixel_encoder: false 3 | use_img_encoder: false 4 | use_inflated_groupnorm: true 5 | unet_use_cross_frame_attention: false 6 | unet_use_temporal_attention: false 7 | use_motion_module: true 8 | use_motion_resnet: false 9 | motion_module_resolutions: 10 | - 1 11 | - 2 12 | - 4 13 | - 8 14 | motion_module_mid_block: true 15 | motion_module_decoder_only: false 16 | motion_module_type: Vanilla 17 | motion_module_kwargs: 18 | num_attention_heads: 8 19 | num_transformer_block: 1 20 | attention_block_types: 21 | - Temporal_Self 22 | - Temporal_Self 23 | temporal_position_encoding: true 24 | temporal_attention_dim_div: 1 25 | 26 | noise_scheduler_kwargs: 27 | beta_start: 0.00085 28 | beta_end: 0.012 29 | beta_schedule: "linear" 30 | original_inference_steps: 100 31 | steps_offset: 1 -------------------------------------------------------------------------------- /animatelcm_sd15/models/LCM_LoRA/put_spatial_lora_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/LCM_LoRA/put_spatial_lora_here.txt -------------------------------------------------------------------------------- /animatelcm_sd15/models/Motion_Module/put_motion_module_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/Motion_Module/put_motion_module_here.txt -------------------------------------------------------------------------------- /animatelcm_sd15/models/Personalized/put_personalized_weights_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/Personalized/put_personalized_weights_here.txt -------------------------------------------------------------------------------- /animatelcm_sd15/models/StableDiffusion/put_stable_diffusion_v15_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/models/StableDiffusion/put_stable_diffusion_v15_here.txt -------------------------------------------------------------------------------- /animatelcm_sd15/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | torchaudio==0.13.1 4 | diffusers==0.11.1 5 | transformers==4.25.1 6 | xformers==0.0.16 7 | imageio==2.27.0 8 | gradio==3.48.0 9 | gdown 10 | einops 11 | omegaconf 12 | safetensors 13 | imageio[ffmpeg] 14 | imageio[pyav] 15 | accelerate -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/cloud.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/cloud.jpeg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/dog.jpg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/fire.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/fire.jpg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/fox.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/fox.jpg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/girl.png -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/girl_flower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/girl_flower.jpg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/lighter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/lighter.jpg -------------------------------------------------------------------------------- /animatelcm_sd15/test_imgs/snow_man_fire.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_sd15/test_imgs/snow_man_fire.jpg -------------------------------------------------------------------------------- /animatelcm_svd/README.md: -------------------------------------------------------------------------------- 1 | ## AnimateLCM SVD 2 | 3 | ### Enviroment 4 | 5 | You can directly using the `environment.yaml` 6 | ``` 7 | conda env create -f enviroment.yaml 8 | conda activate animatelcm_svd 9 | ``` 10 | or through the requirements.txt 11 | 12 | ``` 13 | conda create -n animatelcm_svd python=3.9 14 | conda activate animatelcm_svd 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Models 19 | 20 | 21 | You can download the models through `wget` 22 | 23 | ``` 24 | cd safetensors 25 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt/resolve/main/AnimateLCM-SVD-xt-1.1.safetensors 26 | wget -c https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt/resolve/main/AnimateLCM-SVD-xt.safetensors 27 | cd .. 28 | ``` 29 | 30 | ### Runing 31 | 32 | Simply running 33 | ``` 34 | python app.py 35 | ``` 36 | -------------------------------------------------------------------------------- /animatelcm_svd/animate_lcm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from safetensors import safe_open 4 | 5 | 6 | 7 | def calculate_probabilities(sigma_values, Pmean=0.7, Pstd=1.6): 8 | 9 | log_sigma_values = torch.log(sigma_values) 10 | 11 | erf_diff = torch.erf((log_sigma_values[:-1] - Pmean) / (np.sqrt(2) * Pstd)) - \ 12 | torch.erf((log_sigma_values[1:] - Pmean) / (np.sqrt(2) * Pstd)) 13 | 14 | probabilities = erf_diff / torch.sum(erf_diff) 15 | 16 | return probabilities 17 | 18 | def append_dims(x, target_dims): 19 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 20 | dims_to_append = target_dims - x.ndim 21 | if dims_to_append < 0: 22 | raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") 23 | return x[(...,) + (None,) * dims_to_append] 24 | 25 | 26 | class SVDSolver(): 27 | def __init__(self, N, sigma_min, sigma_max, rho, Pmean, Pstd): 28 | self.sigma_min = sigma_min 29 | self.sigma_max = sigma_max 30 | self.rho = rho 31 | self.N = N 32 | self.Pmean = Pmean 33 | self.Pstd = Pstd 34 | 35 | 36 | self.indices = torch.arange(0, N, dtype=torch.float) 37 | self.sigmas = (sigma_max ** (1 / rho) + self.indices / (N - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)))**rho 38 | 39 | self.indices = torch.cat([self.indices, torch.tensor([N])]) 40 | self.sigmas = torch.cat([self.sigmas, torch.tensor([0])]) 41 | 42 | 43 | self.probs = torch.ones_like(self.sigmas[:-1])*(1/N) 44 | 45 | self.sigmas = self.sigmas[:,None,None,None,None] 46 | self.timesteps = torch.Tensor([0.25 * (sigma + 1e-44).log() for sigma in self.sigmas]) 47 | 48 | self.weights = (1/(self.sigmas[:-1] - self.sigmas[1:]))**0.1 # This is not optimal and can influence the training dynamics a lot. Wish someone can make it better. 49 | self.c_out = -self.sigmas / ((self.sigmas**2 + 1)**0.5) 50 | self.c_skip = 1 / (self.sigmas**2 + 1) 51 | self.c_in = 1 /((self.sigmas**2 + 1) ** 0.5) 52 | 53 | def sample_params(self, indices): 54 | 55 | sampled_sigmas = self.sigmas[indices] 56 | sampled_timesteps = self.timesteps[indices] 57 | sampled_weights = self.weights[torch.where(indices>self.weights.shape[0]-1,self.weights.shape[0]-1,indices)] 58 | sampled_c_out = self.c_out[indices] 59 | sampled_c_in = self.c_in[indices] 60 | sampled_c_skip = self.c_skip[indices] 61 | 62 | return indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip 63 | 64 | 65 | def sample_timesteps(self, bsz): 66 | 67 | sampled_indices = torch.multinomial(self.probs, bsz, replacement=True) 68 | 69 | sampled_indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip = self.sample_params(sampled_indices) 70 | 71 | return sampled_indices, sampled_sigmas, sampled_timesteps, sampled_weights, sampled_c_in, sampled_c_out, sampled_c_skip 72 | 73 | 74 | def predicted_origin(self, model_output, indices, sample): 75 | return model_output * self.c_out[indices] + sample * self.c_skip[indices] 76 | 77 | @torch.no_grad() 78 | def euler_solver(self, model_output, sample, indices, indices_next): 79 | x = sample 80 | denoiser = self.predicted_origin(model_output, indices, sample) 81 | d = (x - denoiser) / self.sigmas[indices] 82 | sample = x + d * (self.sigmas[indices_next] - self.sigmas[indices]) 83 | 84 | return sample 85 | 86 | @torch.no_grad() 87 | def heun_solver(self, model_output, sample, indices, indices_next, model_fn): 88 | pass 89 | 90 | def to(self,device,dtype): 91 | self.indinces = self.indices.to(device,dtype) 92 | self.sigmas = self.sigmas.to(device,dtype) 93 | self.timesteps=self.timesteps.to(device,dtype) 94 | self.probs=self.probs.to(device,dtype) 95 | self.weights=self.weights.to(device,dtype) 96 | self.c_out=self.c_out.to(device,dtype) 97 | self.c_skip=self.c_skip.to(device,dtype) 98 | self.c_in=self.c_in.to(device,dtype) 99 | 100 | 101 | def extract_into_tensor(a, t, x_shape): 102 | b, *_ = t.shape 103 | out = a.gather(-1, t) 104 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 105 | 106 | 107 | @torch.no_grad() 108 | def update_ema(target_params, source_params, rate=0.99): 109 | """ 110 | Update target parameters to be closer to those of source parameters using 111 | an exponential moving average. 112 | 113 | :param target_params: the target parameter sequence. 114 | :param source_params: the source parameter sequence. 115 | :param rate: the EMA rate (closer to 1 means slower). 116 | """ 117 | for targ, src in zip(target_params, source_params): 118 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /animatelcm_svd/animatelcm_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import numpy as np 19 | import torch 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.utils import BaseOutput, logging 23 | from diffusers.utils.torch_utils import randn_tensor 24 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 25 | 26 | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 28 | 29 | 30 | @dataclass 31 | class AnimateLCMSVDStochasticIterativeSchedulerOutput(BaseOutput): 32 | """ 33 | Output class for the scheduler's `step` function. 34 | 35 | Args: 36 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 37 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 38 | denoising loop. 39 | """ 40 | 41 | prev_sample: torch.FloatTensor 42 | 43 | 44 | class AnimateLCMSVDStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): 45 | """ 46 | Multistep and onestep sampling for consistency models. 47 | 48 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 49 | methods the library implements for all schedulers such as loading and saving. 50 | 51 | Args: 52 | num_train_timesteps (`int`, defaults to 40): 53 | The number of diffusion steps to train the model. 54 | sigma_min (`float`, defaults to 0.002): 55 | Minimum noise magnitude in the sigma schedule. Defaults to 0.002 from the original implementation. 56 | sigma_max (`float`, defaults to 80.0): 57 | Maximum noise magnitude in the sigma schedule. Defaults to 80.0 from the original implementation. 58 | sigma_data (`float`, defaults to 0.5): 59 | The standard deviation of the data distribution from the EDM 60 | [paper](https://huggingface.co/papers/2206.00364). Defaults to 0.5 from the original implementation. 61 | s_noise (`float`, defaults to 1.0): 62 | The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 63 | 1.011]. Defaults to 1.0 from the original implementation. 64 | rho (`float`, defaults to 7.0): 65 | The parameter for calculating the Karras sigma schedule from the EDM 66 | [paper](https://huggingface.co/papers/2206.00364). Defaults to 7.0 from the original implementation. 67 | clip_denoised (`bool`, defaults to `True`): 68 | Whether to clip the denoised outputs to `(-1, 1)`. 69 | timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): 70 | An explicit timestep schedule that can be optionally specified. The timesteps are expected to be in 71 | increasing order. 72 | """ 73 | 74 | order = 1 75 | 76 | @register_to_config 77 | def __init__( 78 | self, 79 | num_train_timesteps: int = 40, 80 | sigma_min: float = 0.002, 81 | sigma_max: float = 80.0, 82 | sigma_data: float = 0.5, 83 | s_noise: float = 1.0, 84 | rho: float = 7.0, 85 | clip_denoised: bool = True, 86 | ): 87 | # standard deviation of the initial noise distribution 88 | self.init_noise_sigma = (sigma_max**2 + 1) ** 0.5 89 | # self.init_noise_sigma = sigma_max 90 | 91 | ramp = np.linspace(0, 1, num_train_timesteps) 92 | sigmas = self._convert_to_karras(ramp) 93 | sigmas = np.concatenate([sigmas, np.array([0])]) 94 | timesteps = self.sigma_to_t(sigmas) 95 | 96 | # setable values 97 | self.num_inference_steps = None 98 | self.sigmas = torch.from_numpy(sigmas) 99 | self.timesteps = torch.from_numpy(timesteps) 100 | self.custom_timesteps = False 101 | self.is_scale_input_called = False 102 | self._step_index = None 103 | self.sigmas.to("cpu") # to avoid too much CPU/GPU communication 104 | 105 | def index_for_timestep(self, timestep, schedule_timesteps=None): 106 | if schedule_timesteps is None: 107 | schedule_timesteps = self.timesteps 108 | 109 | indices = (schedule_timesteps == timestep).nonzero() 110 | return indices.item() 111 | 112 | @property 113 | def step_index(self): 114 | """ 115 | The index counter for current timestep. It will increae 1 after each scheduler step. 116 | """ 117 | return self._step_index 118 | 119 | def scale_model_input( 120 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] 121 | ) -> torch.FloatTensor: 122 | """ 123 | Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`. 124 | 125 | Args: 126 | sample (`torch.FloatTensor`): 127 | The input sample. 128 | timestep (`float` or `torch.FloatTensor`): 129 | The current timestep in the diffusion chain. 130 | 131 | Returns: 132 | `torch.FloatTensor`: 133 | A scaled input sample. 134 | """ 135 | # Get sigma corresponding to timestep 136 | if self.step_index is None: 137 | self._init_step_index(timestep) 138 | 139 | sigma = self.sigmas[self.step_index] 140 | sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) 141 | 142 | self.is_scale_input_called = True 143 | return sample 144 | 145 | # def _sigma_to_t(self, sigma, log_sigmas): 146 | # # get log sigma 147 | # log_sigma = np.log(np.maximum(sigma, 1e-10)) 148 | 149 | # # get distribution 150 | # dists = log_sigma - log_sigmas[:, np.newaxis] 151 | 152 | # # get sigmas range 153 | # low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) 154 | # high_idx = low_idx + 1 155 | 156 | # low = log_sigmas[low_idx] 157 | # high = log_sigmas[high_idx] 158 | 159 | # # interpolate sigmas 160 | # w = (low - log_sigma) / (low - high) 161 | # w = np.clip(w, 0, 1) 162 | 163 | # # transform interpolation to time range 164 | # t = (1 - w) * low_idx + w * high_idx 165 | # t = t.reshape(sigma.shape) 166 | # return t 167 | 168 | def sigma_to_t(self, sigmas: Union[float, np.ndarray]): 169 | """ 170 | Gets scaled timesteps from the Karras sigmas for input to the consistency model. 171 | 172 | Args: 173 | sigmas (`float` or `np.ndarray`): 174 | A single Karras sigma or an array of Karras sigmas. 175 | 176 | Returns: 177 | `float` or `np.ndarray`: 178 | A scaled input timestep or scaled input timestep array. 179 | """ 180 | if not isinstance(sigmas, np.ndarray): 181 | sigmas = np.array(sigmas, dtype=np.float64) 182 | 183 | timesteps = 0.25 * np.log(sigmas + 1e-44) 184 | 185 | return timesteps 186 | 187 | def set_timesteps( 188 | self, 189 | num_inference_steps: Optional[int] = None, 190 | device: Union[str, torch.device] = None, 191 | timesteps: Optional[List[int]] = None, 192 | ): 193 | """ 194 | Sets the timesteps used for the diffusion chain (to be run before inference). 195 | 196 | Args: 197 | num_inference_steps (`int`): 198 | The number of diffusion steps used when generating samples with a pre-trained model. 199 | device (`str` or `torch.device`, *optional*): 200 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 201 | timesteps (`List[int]`, *optional*): 202 | Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default 203 | timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed, 204 | `num_inference_steps` must be `None`. 205 | """ 206 | if num_inference_steps is None and timesteps is None: 207 | raise ValueError( 208 | "Exactly one of `num_inference_steps` or `timesteps` must be supplied." 209 | ) 210 | 211 | if num_inference_steps is not None and timesteps is not None: 212 | raise ValueError( 213 | "Can only pass one of `num_inference_steps` or `timesteps`." 214 | ) 215 | 216 | # Follow DDPMScheduler custom timesteps logic 217 | if timesteps is not None: 218 | for i in range(1, len(timesteps)): 219 | if timesteps[i] >= timesteps[i - 1]: 220 | raise ValueError("`timesteps` must be in descending order.") 221 | 222 | if timesteps[0] >= self.config.num_train_timesteps: 223 | raise ValueError( 224 | f"`timesteps` must start before `self.config.train_timesteps`:" 225 | f" {self.config.num_train_timesteps}." 226 | ) 227 | 228 | timesteps = np.array(timesteps, dtype=np.int64) 229 | self.custom_timesteps = True 230 | else: 231 | if num_inference_steps > self.config.num_train_timesteps: 232 | raise ValueError( 233 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 234 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 235 | f" maximal {self.config.num_train_timesteps} timesteps." 236 | ) 237 | 238 | self.num_inference_steps = num_inference_steps 239 | 240 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 241 | timesteps = ( 242 | (np.arange(0, num_inference_steps) * step_ratio) 243 | .round()[::-1] 244 | .copy() 245 | .astype(np.int64) 246 | ) 247 | self.custom_timesteps = False 248 | 249 | # Map timesteps to Karras sigmas directly for multistep sampling 250 | # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 251 | num_train_timesteps = self.config.num_train_timesteps 252 | ramp = timesteps[::-1].copy() 253 | ramp = ramp / (num_train_timesteps - 1) 254 | sigmas = self._convert_to_karras(ramp) 255 | timesteps = self.sigma_to_t(sigmas) 256 | 257 | sigmas = np.concatenate([sigmas, [0]]).astype(np.float32) 258 | self.sigmas = torch.from_numpy(sigmas).to(device=device) 259 | 260 | if str(device).startswith("mps"): 261 | # mps does not support float64 262 | self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) 263 | else: 264 | self.timesteps = torch.from_numpy(timesteps).to(device=device) 265 | 266 | self._step_index = None 267 | self.sigmas.to("cpu") # to avoid too much CPU/GPU communication 268 | 269 | # Modified _convert_to_karras implementation that takes in ramp as argument 270 | def _convert_to_karras(self, ramp): 271 | """Constructs the noise schedule of Karras et al. (2022).""" 272 | 273 | sigma_min: float = self.config.sigma_min 274 | sigma_max: float = self.config.sigma_max 275 | 276 | rho = self.config.rho 277 | min_inv_rho = sigma_min ** (1 / rho) 278 | max_inv_rho = sigma_max ** (1 / rho) 279 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 280 | return sigmas 281 | 282 | def get_scalings(self, sigma): 283 | sigma_data = self.config.sigma_data 284 | 285 | c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) 286 | c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 287 | return c_skip, c_out 288 | 289 | def get_scalings_for_boundary_condition(self, sigma): 290 | """ 291 | Gets the scalings used in the consistency model parameterization (from Appendix C of the 292 | [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition. 293 | 294 | 295 | 296 | `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`. 297 | 298 | 299 | 300 | Args: 301 | sigma (`torch.FloatTensor`): 302 | The current sigma in the Karras sigma schedule. 303 | 304 | Returns: 305 | `tuple`: 306 | A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out` 307 | (which weights the consistency model output) is the second element. 308 | """ 309 | sigma_min = self.config.sigma_min 310 | sigma_data = self.config.sigma_data 311 | 312 | c_skip = sigma_data**2 / ((sigma) ** 2 + sigma_data**2) 313 | c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 314 | return c_skip, c_out 315 | 316 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index 317 | def _init_step_index(self, timestep): 318 | if isinstance(timestep, torch.Tensor): 319 | timestep = timestep.to(self.timesteps.device) 320 | 321 | index_candidates = (self.timesteps == timestep).nonzero() 322 | 323 | # The sigma index that is taken for the **very** first `step` 324 | # is always the second index (or the last index if there is only 1) 325 | # This way we can ensure we don't accidentally skip a sigma in 326 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 327 | if len(index_candidates) > 1: 328 | step_index = index_candidates[1] 329 | else: 330 | step_index = index_candidates[0] 331 | 332 | self._step_index = step_index.item() 333 | 334 | def step( 335 | self, 336 | model_output: torch.FloatTensor, 337 | timestep: Union[float, torch.FloatTensor], 338 | sample: torch.FloatTensor, 339 | generator: Optional[torch.Generator] = None, 340 | return_dict: bool = True, 341 | ) -> Union[AnimateLCMSVDStochasticIterativeSchedulerOutput, Tuple]: 342 | """ 343 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 344 | process from the learned model outputs (most often the predicted noise). 345 | 346 | Args: 347 | model_output (`torch.FloatTensor`): 348 | The direct output from the learned diffusion model. 349 | timestep (`float`): 350 | The current timestep in the diffusion chain. 351 | sample (`torch.FloatTensor`): 352 | A current instance of a sample created by the diffusion process. 353 | generator (`torch.Generator`, *optional*): 354 | A random number generator. 355 | return_dict (`bool`, *optional*, defaults to `True`): 356 | Whether or not to return a 357 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] or `tuple`. 358 | 359 | Returns: 360 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] or `tuple`: 361 | If return_dict is `True`, 362 | [`~schedulers.scheduling_consistency_models.AnimateLCMSVDStochasticIterativeSchedulerOutput`] is returned, 363 | otherwise a tuple is returned where the first element is the sample tensor. 364 | """ 365 | 366 | if ( 367 | isinstance(timestep, int) 368 | or isinstance(timestep, torch.IntTensor) 369 | or isinstance(timestep, torch.LongTensor) 370 | ): 371 | raise ValueError( 372 | ( 373 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 374 | f" `{self.__class__}.step()` is not supported. Make sure to pass" 375 | " one of the `scheduler.timesteps` as a timestep." 376 | ), 377 | ) 378 | 379 | if not self.is_scale_input_called: 380 | logger.warning( 381 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " 382 | "See `StableDiffusionPipeline` for a usage example." 383 | ) 384 | 385 | sigma_min = self.config.sigma_min 386 | sigma_max = self.config.sigma_max 387 | 388 | if self.step_index is None: 389 | self._init_step_index(timestep) 390 | 391 | # sigma_next corresponds to next_t in original implementation 392 | sigma = self.sigmas[self.step_index] 393 | if self.step_index + 1 < self.config.num_train_timesteps: 394 | sigma_next = self.sigmas[self.step_index + 1] 395 | else: 396 | # Set sigma_next to sigma_min 397 | sigma_next = self.sigmas[-1] 398 | 399 | # Get scalings for boundary conditions 400 | 401 | c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) 402 | 403 | # 1. Denoise model output using boundary conditions 404 | denoised = c_out * model_output + c_skip * sample 405 | if self.config.clip_denoised: 406 | denoised = denoised.clamp(-1, 1) 407 | 408 | # 2. Sample z ~ N(0, s_noise^2 * I) 409 | # Noise is not used for onestep sampling. 410 | if len(self.timesteps) > 1: 411 | noise = randn_tensor( 412 | model_output.shape, 413 | dtype=model_output.dtype, 414 | device=model_output.device, 415 | generator=generator, 416 | ) 417 | else: 418 | noise = torch.zeros_like(model_output) 419 | z = noise * self.config.s_noise 420 | 421 | sigma_hat = sigma_next.clamp(min=0, max=sigma_max) 422 | 423 | print("denoise currently") 424 | print(sigma_hat) 425 | 426 | # origin 427 | prev_sample = denoised + z * sigma_hat 428 | 429 | # upon completion increase step index by one 430 | self._step_index += 1 431 | 432 | if not return_dict: 433 | return (prev_sample,) 434 | 435 | return AnimateLCMSVDStochasticIterativeSchedulerOutput(prev_sample=prev_sample) 436 | 437 | # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise 438 | def add_noise( 439 | self, 440 | original_samples: torch.FloatTensor, 441 | noise: torch.FloatTensor, 442 | timesteps: torch.FloatTensor, 443 | ) -> torch.FloatTensor: 444 | # Make sure sigmas and timesteps have the same device and dtype as original_samples 445 | sigmas = self.sigmas.to( 446 | device=original_samples.device, dtype=original_samples.dtype 447 | ) 448 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): 449 | # mps does not support float64 450 | schedule_timesteps = self.timesteps.to( 451 | original_samples.device, dtype=torch.float32 452 | ) 453 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) 454 | else: 455 | schedule_timesteps = self.timesteps.to(original_samples.device) 456 | timesteps = timesteps.to(original_samples.device) 457 | 458 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 459 | 460 | sigma = sigmas[step_indices].flatten() 461 | while len(sigma.shape) < len(original_samples.shape): 462 | sigma = sigma.unsqueeze(-1) 463 | 464 | noisy_samples = original_samples + noise * sigma 465 | return noisy_samples 466 | 467 | def __len__(self): 468 | return self.config.num_train_timesteps 469 | -------------------------------------------------------------------------------- /animatelcm_svd/app.py: -------------------------------------------------------------------------------- 1 | # import spaces 2 | 3 | import gradio as gr 4 | # import gradio.helpers 5 | import torch 6 | import os 7 | from glob import glob 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | from PIL import Image 12 | from diffusers.utils import load_image, export_to_video 13 | from pipeline import StableVideoDiffusionPipeline 14 | 15 | import random 16 | from safetensors import safe_open 17 | from animatelcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler 18 | 19 | 20 | def get_safetensors_files(): 21 | models_dir = "./safetensors" 22 | safetensors_files = [ 23 | f for f in os.listdir(models_dir) if f.endswith(".safetensors") 24 | ] 25 | return safetensors_files 26 | 27 | 28 | def model_select(selected_file): 29 | print("load model weights", selected_file) 30 | pipe.unet.cpu() 31 | file_path = os.path.join("./safetensors", selected_file) 32 | state_dict = {} 33 | with safe_open(file_path, framework="pt", device="cpu") as f: 34 | for key in f.keys(): 35 | state_dict[key] = f.get_tensor(key) 36 | missing, unexpected = pipe.unet.load_state_dict(state_dict, strict=True) 37 | pipe.unet.cuda() 38 | del state_dict 39 | return 40 | 41 | 42 | noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler( 43 | num_train_timesteps=40, 44 | sigma_min=0.002, 45 | sigma_max=700.0, 46 | sigma_data=1.0, 47 | s_noise=1.0, 48 | rho=7, 49 | clip_denoised=False, 50 | ) 51 | pipe = StableVideoDiffusionPipeline.from_pretrained( 52 | "stabilityai/stable-video-diffusion-img2vid-xt", 53 | scheduler=noise_scheduler, 54 | torch_dtype=torch.float16, 55 | variant="fp16", 56 | ) 57 | pipe.to("cuda") 58 | pipe.enable_model_cpu_offload() # for smaller cost 59 | model_select("AnimateLCM-SVD-xt-1.1.safetensors") 60 | # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) # for faster inference 61 | 62 | 63 | max_64_bit_int = 2**63 - 1 64 | 65 | # @spaces.GPU 66 | def sample( 67 | image: Image, 68 | seed: Optional[int] = 42, 69 | randomize_seed: bool = False, 70 | motion_bucket_id: int = 80, 71 | fps_id: int = 8, 72 | max_guidance_scale: float = 1.2, 73 | min_guidance_scale: float = 1, 74 | width: int = 1024, 75 | height: int = 576, 76 | num_inference_steps: int = 4, 77 | decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. 78 | output_folder: str = "outputs_gradio", 79 | ): 80 | if image.mode == "RGBA": 81 | image = image.convert("RGB") 82 | 83 | if randomize_seed: 84 | seed = random.randint(0, max_64_bit_int) 85 | generator = torch.manual_seed(seed) 86 | 87 | os.makedirs(output_folder, exist_ok=True) 88 | base_count = len(glob(os.path.join(output_folder, "*.mp4"))) 89 | video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") 90 | 91 | with torch.autocast("cuda"): 92 | frames = pipe( 93 | image, 94 | decode_chunk_size=decoding_t, 95 | generator=generator, 96 | motion_bucket_id=motion_bucket_id, 97 | height=height, 98 | width=width, 99 | num_inference_steps=num_inference_steps, 100 | min_guidance_scale=min_guidance_scale, 101 | max_guidance_scale=max_guidance_scale, 102 | ).frames[0] 103 | export_to_video(frames, video_path, fps=fps_id) 104 | torch.manual_seed(seed) 105 | 106 | return video_path, seed 107 | 108 | 109 | def resize_image(image, output_size=(1024, 576)): 110 | # Calculate aspect ratios 111 | target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size 112 | image_aspect = image.width / image.height # Aspect ratio of the original image 113 | 114 | # Resize then crop if the original image is larger 115 | if image_aspect > target_aspect: 116 | # Resize the image to match the target height, maintaining aspect ratio 117 | new_height = output_size[1] 118 | new_width = int(new_height * image_aspect) 119 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 120 | # Calculate coordinates for cropping 121 | left = (new_width - output_size[0]) / 2 122 | top = 0 123 | right = (new_width + output_size[0]) / 2 124 | bottom = output_size[1] 125 | else: 126 | # Resize the image to match the target width, maintaining aspect ratio 127 | new_width = output_size[0] 128 | new_height = int(new_width / image_aspect) 129 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 130 | # Calculate coordinates for cropping 131 | left = 0 132 | top = (new_height - output_size[1]) / 2 133 | right = output_size[0] 134 | bottom = (new_height + output_size[1]) / 2 135 | 136 | # Crop the image 137 | cropped_image = resized_image.crop((left, top, right, bottom)) 138 | return cropped_image 139 | 140 | 141 | with gr.Blocks() as demo: 142 | gr.Markdown( 143 | """ 144 | # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) 145 | Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
146 | 147 | [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) 148 | 149 | Related Models: 150 | [AnimateLCM-t2v](https://huggingface.co/wangfuyun/AnimateLCM): Personalized Text-to-Video Generation 151 | [AnimateLCM-SVD-xt](https://huggingface.co/wangfuyun/AnimateLCM-SVD-xt): General Image-to-Video Generation 152 | [AnimateLCM-i2v](https://huggingface.co/wangfuyun/AnimateLCM-I2V): Personalized Image-to-Video Generation 153 | """ 154 | ) 155 | with gr.Row(): 156 | with gr.Column(): 157 | image = gr.Image(label="Upload your image", type="pil") 158 | generate_btn = gr.Button("Generate") 159 | video = gr.Video() 160 | with gr.Accordion("Advanced options", open=False): 161 | safetensors_dropdown = gr.Dropdown( 162 | label="Choose Safetensors", choices=get_safetensors_files() 163 | ) 164 | seed = gr.Slider( 165 | label="Seed", 166 | value=42, 167 | randomize=False, 168 | minimum=0, 169 | maximum=max_64_bit_int, 170 | step=1, 171 | ) 172 | randomize_seed = gr.Checkbox(label="Randomize seed", value=False) 173 | motion_bucket_id = gr.Slider( 174 | label="Motion bucket id", 175 | info="Controls how much motion to add/remove from the image", 176 | value=80, 177 | minimum=1, 178 | maximum=255, 179 | ) 180 | fps_id = gr.Slider( 181 | label="Frames per second", 182 | info="The length of your video in seconds will be 25/fps", 183 | value=8, 184 | minimum=5, 185 | maximum=30, 186 | ) 187 | width = gr.Slider( 188 | label="Width of input image", 189 | info="It should be divisible by 64", 190 | value=1024, 191 | minimum=576, 192 | maximum=2048, 193 | ) 194 | height = gr.Slider( 195 | label="Height of input image", 196 | info="It should be divisible by 64", 197 | value=576, 198 | minimum=320, 199 | maximum=1152, 200 | ) 201 | max_guidance_scale = gr.Slider( 202 | label="Max guidance scale", 203 | info="classifier-free guidance strength", 204 | value=1.2, 205 | minimum=1, 206 | maximum=2, 207 | ) 208 | min_guidance_scale = gr.Slider( 209 | label="Min guidance scale", 210 | info="classifier-free guidance strength", 211 | value=1, 212 | minimum=1, 213 | maximum=1.5, 214 | ) 215 | num_inference_steps = gr.Slider( 216 | label="Num inference steps", 217 | info="steps for inference", 218 | value=4, 219 | minimum=1, 220 | maximum=20, 221 | step=1, 222 | ) 223 | 224 | image.upload(fn=resize_image, inputs=image, outputs=image, queue=False) 225 | generate_btn.click( 226 | fn=sample, 227 | inputs=[ 228 | image, 229 | seed, 230 | randomize_seed, 231 | motion_bucket_id, 232 | fps_id, 233 | max_guidance_scale, 234 | min_guidance_scale, 235 | width, 236 | height, 237 | num_inference_steps, 238 | ], 239 | outputs=[video, seed], 240 | api_name="video", 241 | ) 242 | safetensors_dropdown.change(fn=model_select, inputs=safetensors_dropdown) 243 | 244 | gr.Examples( 245 | examples=[ 246 | ["test_imgs/ai-generated-8496135_1280.jpg"], 247 | ["test_imgs/dog-7396912_1280.jpg"], 248 | ["test_imgs/ship-7833921_1280.jpg"], 249 | ["test_imgs/girl-4898696_1280.jpg"], 250 | ["test_imgs/power-station-6579092_1280.jpg"] 251 | ], 252 | inputs=[image], 253 | outputs=[video, seed], 254 | fn=sample, 255 | cache_examples=True, 256 | ) 257 | 258 | if __name__ == "__main__": 259 | demo.queue(max_size=20, api_open=False) 260 | demo.launch(share=True, show_api=False) 261 | -------------------------------------------------------------------------------- /animatelcm_svd/batch_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors import safe_open 3 | from pipeline import StableVideoDiffusionPipeline 4 | from animatelcm_scheduler import AnimateLCMSVDStochasticIterativeScheduler 5 | from diffusers.utils import load_image, export_to_gif 6 | import os 7 | 8 | 9 | def run_inference_once(image_path, height, width, inference_time, min_guidance_scale, max_guidance_scale, noise_scheduler, weight = None,): 10 | if noise_scheduler is not None: 11 | pipe = StableVideoDiffusionPipeline.from_pretrained( 12 | "/mnt/afs/wangfuyun/SVD-xt/stable-video-diffusion-img2vid-xt", scheduler = noise_scheduler, torch_dtype=torch.float16, variant="fp16" 13 | ) 14 | else: 15 | pipe = StableVideoDiffusionPipeline.from_pretrained( 16 | "/mnt/afs/wangfuyun/SVD-xt/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" 17 | ) 18 | pipe.enable_model_cpu_offload() 19 | 20 | if weight is not None: 21 | state_dict = {} 22 | with safe_open(weight, framework="pt", device="cpu") as f: 23 | for key in f.keys(): 24 | state_dict[key] = f.get_tensor(key) 25 | m,u = pipe.unet.load_state_dict(state_dict,strict=True) 26 | assert len(u) == 0 27 | del state_dict 28 | 29 | image = load_image(image_path) 30 | image = image.resize((height,width)) 31 | 32 | generator = torch.manual_seed(42) 33 | frames = pipe(image, decode_chunk_size=8, generator=generator, num_frames=25, height=height, width=width, num_inference_steps=inference_time, min_guidance_scale = min_guidance_scale, max_guidance_scale = max_guidance_scale).frames[0] 34 | export_to_gif(frames, f"output_gifs/{image_path[-20:-5]}-{height}-{width}-{inference_time}-{min_guidance_scale}-{max_guidance_scale}-{weight is None}.gif") 35 | 36 | if __name__ == "__main__": 37 | path = "test_imgs" 38 | weight_path = None 39 | assert weight_path is not None 40 | noise_scheduler = AnimateLCMSVDStochasticIterativeScheduler( 41 | num_train_timesteps= 40, 42 | sigma_min = 0.002, 43 | sigma_max = 700.0, 44 | sigma_data = 1.0, 45 | s_noise = 1.0, 46 | rho = 7, 47 | clip_denoised = False, 48 | ) 49 | # noise_scheduler = None 50 | assert noise_scheduler is not None 51 | for image_path in os.listdir(path)[5:]: 52 | image_path = os.path.join(path, image_path) 53 | for inference_time in [1, 2, 4, 8]: 54 | for height, width in [(576, 1024)]: 55 | # for min_scale, max_scale in [(1, 1.5), (1.2, 1.5), (1, 2)]: 56 | for min_scale, max_scale in [(1.0,1.0)]: 57 | run_inference_once(image_path, height, width, inference_time, min_scale, max_scale, noise_scheduler, weight = weight_path) 58 | -------------------------------------------------------------------------------- /animatelcm_svd/dataset.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random 2 | import numpy as np 3 | from einops import rearrange 4 | from decord import VideoReader 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torch.utils.data.dataset import Dataset 9 | 10 | import imageio.v2 as imageio 11 | 12 | class WebVid10M(Dataset): 13 | def __init__( 14 | self, 15 | video_folder, 16 | sample_size=256, sample_stride=6, sample_n_frames=16, 17 | is_image=False, 18 | ): 19 | self.dataset = [os.path.join(video_folder,video_path) for video_path in os.listdir(video_folder) if video_path.endswith(("mp4",))] 20 | random.shuffle(self.dataset) 21 | self.length = len(self.dataset) 22 | 23 | self.video_folder = video_folder 24 | 25 | self.sample_stride = sample_stride 26 | self.sample_n_frames = sample_n_frames 27 | self.is_image = is_image 28 | 29 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 30 | self.pixel_transforms = transforms.Compose([ 31 | transforms.RandomHorizontalFlip(), 32 | transforms.Resize(sample_size[0]), 33 | transforms.CenterCrop(sample_size), 34 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 35 | ]) 36 | 37 | def get_batch(self, idx): 38 | video_dir = self.dataset[idx] 39 | name = open(video_dir.replace("mp4","txt"),"r").readline().strip() 40 | video_reader = VideoReader(video_dir) 41 | video_length = len(video_reader) 42 | 43 | if not self.is_image: 44 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 45 | start_idx = random.randint(0, video_length - clip_length) 46 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 47 | else: 48 | batch_index = [random.randint(0, video_length - 1)] 49 | 50 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 51 | pixel_values = pixel_values / 255. 52 | del video_reader 53 | 54 | if self.is_image: 55 | pixel_values = pixel_values[0] 56 | 57 | return pixel_values, name 58 | 59 | def __len__(self): 60 | return self.length 61 | 62 | def __getitem__(self, idx): 63 | while True: 64 | try: 65 | pixel_values, name = self.get_batch(idx) 66 | break 67 | 68 | except Exception as e: 69 | print(e) 70 | idx = random.randint(0, self.length-1) 71 | 72 | pixel_values = self.pixel_transforms(pixel_values) 73 | sample = dict(pixel_values=pixel_values, text=name) 74 | return sample 75 | -------------------------------------------------------------------------------- /animatelcm_svd/enviroment.yaml: -------------------------------------------------------------------------------- 1 | name: animatelcm_svd 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.12.12=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.12=h7f8727e_0 15 | - pip=23.3.1=py39h06a4308_0 16 | - python=3.9.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.2.2=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py39h06a4308_0 22 | - xz=5.4.5=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - accelerate==0.26.1 26 | - albumentations==1.3.1 27 | - antlr4-python3-runtime==4.9.3 28 | - appdirs==1.4.4 29 | - bitsandbytes==0.42.0 30 | - braceexpand==0.1.7 31 | - brotli==1.1.0 32 | - certifi==2023.11.17 33 | - cffi==1.16.0 34 | - charset-normalizer==3.3.2 35 | - click==8.1.7 36 | - dataclasses==0.6 37 | - decord==0.6.0 38 | - diffusers==0.25.1 39 | - docker-pycreds==0.4.0 40 | - docopt==0.6.2 41 | - einops==0.7.0 42 | - exifread-nocycle==3.0.1 43 | - ffmpeg-python==0.2.0 44 | - filelock==3.13.1 45 | - fire==0.5.0 46 | - fsspec==2023.12.2 47 | - future==0.18.3 48 | - gitdb==4.0.11 49 | - gitpython==3.1.41 50 | - huggingface-hub==0.20.3 51 | - idna==3.6 52 | - imageio==2.33.1 53 | - img2dataset==1.45.0 54 | - importlib-metadata==7.0.1 55 | - jinja2==3.1.3 56 | - joblib==1.3.2 57 | - langdetect==1.0.9 58 | - lazy-loader==0.3 59 | - markupsafe==2.1.4 60 | - mpmath==1.3.0 61 | - mutagen==1.47.0 62 | - networkx==3.2.1 63 | - numpy==1.26.3 64 | - nvidia-cublas-cu12==12.1.3.1 65 | - nvidia-cuda-cupti-cu12==12.1.105 66 | - nvidia-cuda-nvrtc-cu12==12.1.105 67 | - nvidia-cuda-runtime-cu12==12.1.105 68 | - nvidia-cudnn-cu12==8.9.2.26 69 | - nvidia-cufft-cu12==11.0.2.54 70 | - nvidia-curand-cu12==10.3.2.106 71 | - nvidia-cusolver-cu12==11.4.5.107 72 | - nvidia-cusparse-cu12==12.1.0.106 73 | - nvidia-nccl-cu12==2.19.3 74 | - nvidia-nvjitlink-cu12==12.3.101 75 | - nvidia-nvtx-cu12==12.1.105 76 | - omegaconf==2.3.0 77 | - opencv-python==4.9.0.80 78 | - opencv-python-headless==4.9.0.80 79 | - packaging==23.2 80 | - pandas==2.2.0 81 | - pillow==10.2.0 82 | - platformdirs==4.1.0 83 | - protobuf==4.25.2 84 | - psutil==5.9.8 85 | - pyarrow==15.0.0 86 | - pycparser==2.21 87 | - pycryptodomex==3.20.0 88 | - python-dateutil==2.8.2 89 | - pytz==2023.3.post1 90 | - pyyaml==6.0.1 91 | - qudida==0.0.4 92 | - regex==2023.12.25 93 | - requests==2.31.0 94 | - safetensors==0.4.2 95 | - scenedetect==0.6.2 96 | - scikit-image==0.22.0 97 | - scikit-learn==1.4.0 98 | - scipy==1.12.0 99 | - sentry-sdk==1.39.2 100 | - setproctitle==1.3.3 101 | - six==1.16.0 102 | - smmap==5.0.1 103 | - soundfile==0.12.1 104 | - sympy==1.12 105 | - termcolor==2.4.0 106 | - threadpoolctl==3.2.0 107 | - tifffile==2023.12.9 108 | - timeout-decorator==0.5.0 109 | - tokenizers==0.15.1 110 | - torch==2.2.0 111 | - torchdata==0.7.1 112 | - torchvision==0.17.0 113 | - tqdm==4.66.1 114 | - transformers==4.37.0 115 | - triton==2.2.0 116 | - typing-extensions==4.9.0 117 | - tzdata==2023.4 118 | - urllib3==2.1.0 119 | - wandb==0.16.2 120 | - webdataset==0.2.86 121 | - websockets==12.0 122 | - webvtt-py==0.4.6 123 | - xformers==0.0.24 124 | - yt-dlp==2023.12.30 125 | - zipp==3.17.0 126 | -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000000.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000001.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000001.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000002.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000002.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000003.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000003.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000004.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000004.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/outputs_gradio/000005.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/outputs_gradio/000005.mp4 -------------------------------------------------------------------------------- /animatelcm_svd/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers==0.25.1 3 | gradio==4.19.2 4 | numpy==1.26.4 5 | Pillow==10.2.0 6 | torch==2.2.0 7 | transformers==4.37.0 8 | spaces==0.23.2 9 | opencv-python 10 | xformers -------------------------------------------------------------------------------- /animatelcm_svd/safetensors/AnimateLCM-SVD-xt-1.1.safetensors: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:94b2da1fcca8d03458ef0b07b94b4c55f94117cb0d90265c6c8452239ecc166e 3 | size 6098682464 4 | -------------------------------------------------------------------------------- /animatelcm_svd/safetensors/AnimateLCM-SVD-xt.safetensors: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6ca55e35f29437e8a65e8a1a9ce75262d5bab3d4fe137bdc3f3a94512c54b377 3 | size 6098682464 4 | -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/.DS_Store -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8411866_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8411866_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8463496_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8463496_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8476858_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8476858_1280.png -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8479572_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8479572_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8481641_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8481641_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8496135_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8496135_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8496952_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8496952_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ai-generated-8498844_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ai-generated-8498844_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/bird-7411270_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-7411270_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/bird-7586857_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-7586857_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/bird-8014191_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/bird-8014191_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/couple-8019370_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/couple-8019370_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/cupcakes-380178_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/cupcakes-380178_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/dog-7330712_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/dog-7330712_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/dog-7396912_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/dog-7396912_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/girl-4898696_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/girl-4898696_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/grey-capped-flycatcher-8071233_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/grey-capped-flycatcher-8071233_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/halloween-4585684_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/halloween-4585684_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/leaf-7260246_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/leaf-7260246_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/meerkat-7465819_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/meerkat-7465819_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/mobile-phone-1875813_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/mobile-phone-1875813_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/mother-8097324_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/mother-8097324_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/plane-8145957_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/plane-8145957_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/power-station-6579092_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/power-station-6579092_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/ship-7833921_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/ship-7833921_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/sleep-7871915_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/sleep-7871915_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/squirrel-7985502_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/squirrel-7985502_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/squirrel-8211238_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/squirrel-8211238_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/training-8122941_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/training-8122941_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/violin-8405558_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/violin-8405558_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/weight-8246973_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/weight-8246973_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/woman-4549327_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-4549327_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/woman-4757707_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-4757707_1280.jpg -------------------------------------------------------------------------------- /animatelcm_svd/test_imgs/woman-5667299_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/animatelcm_svd/test_imgs/woman-5667299_1280.jpg -------------------------------------------------------------------------------- /metrics/UCF101_prompts.yaml: -------------------------------------------------------------------------------- 1 | # This file contains the brief captions of UCF101 generated by GPT4. Released by [Fu-Yun Wang](https://github.com/G-U-N). 2 | ApplyEyeMakeup: A girl applying eye makeup. 3 | ApplyLipstick: A person applying bright lipstick. 4 | Archery: A person in sportswear aiming an arrow. 5 | BabyCrawling: A baby crawling with toys around. 6 | BalanceBeam: An athlete on a balance beam. 7 | BandMarching: A band marching with instruments. 8 | BaseballPitch: A pitcher throwing a baseball. 9 | Basketball: A player dribbling a basketball. 10 | BasketballDunk: A player performing a basketball dunk. 11 | BenchPress: An athlete doing a bench press. 12 | Biking: A cyclist riding on a trail. 13 | Billiards: A player aiming in billiards. 14 | BlowDryHair: A person using a blow dryer. 15 | BlowingCandles: Someone blowing out birthday candles. 16 | BodyWeightSquats: A person doing squats. 17 | Bowling: A bowler releasing a bowling ball. 18 | BoxingPunchingBag: A boxer punching a heavy bag. 19 | BoxingSpeedBag: A boxer hitting a speed bag. 20 | BreastStroke: A swimmer doing breaststroke. 21 | BrushingTeeth: A person brushing their teeth. 22 | CleanAndJerk: An athlete lifting a barbell overhead. 23 | CliffDiving: A diver leaping off a cliff. 24 | CricketBowling: A cricket bowler in delivery stride. 25 | CricketShot: A batsman playing a cricket shot. 26 | CuttingInKitchen: A person cutting ingredients. 27 | Diving: A diver performing from a board. 28 | Drumming: A drummer playing a drum set. 29 | Fencing: Two fencers dueling. 30 | FieldHockeyPenalty: A player taking a hockey penalty shot. 31 | FloorGymnastics: A gymnast in a floor routine. 32 | FrisbeeCatch: A person catching a frisbee. 33 | FrontCrawl: A swimmer doing front crawl. 34 | GolfSwing: A golfer mid-swing. 35 | Haircut: A person getting a haircut. 36 | HammerThrow: An athlete throwing a hammer. 37 | Hammering: A person hammering a nail. 38 | HandstandPushups: An athlete doing handstand pushups. 39 | HandstandWalking: A person walking on hands. 40 | HeadMassage: A person receiving a head massage. 41 | HighJump: An athlete doing a high jump. 42 | HorseRace: Jockeys racing on horses. 43 | HorseRiding: A person riding a horse. 44 | HulaHoop: A person using a hula hoop. 45 | IceDancing: Ice dancers performing on a rink. 46 | JavelinThrow: An athlete throwing a javelin. 47 | JugglingBalls: A person juggling balls. 48 | JumpingJack: A person doing jumping jacks. 49 | JumpRope: An athlete skipping rope. 50 | Kayaking: A kayaker navigating rapids. 51 | Knitting: Hands knitting wool. 52 | LongJump: An athlete doing a long jump. 53 | Lunges: A person performing lunges. 54 | MilitaryParade: A military parade. 55 | Mixing: A chef mixing ingredients. 56 | MoppingFloor: Someone mopping a floor. 57 | Nunchucks: A martial artist using nunchucks. 58 | ParallelBars: An athlete on parallel bars. 59 | PizzaTossing: A chef tossing pizza dough. 60 | PlayingCello: A musician playing the cello. 61 | PlayingDaf: A musician playing a Daf. 62 | PlayingDhol: A person playing a Dhol. 63 | PlayingFlute: A musician playing the flute. 64 | PlayingGuitar: A guitarist strumming chords. 65 | PlayingPiano: Hands on piano keys. 66 | PlayingSitar: A musician playing the Sitar. 67 | PlayingTabla: A musician playing the Tabla. 68 | PlayingViolin: A violinist playing. 69 | PoleVault: An athlete doing pole vault. 70 | PommelHorse: A gymnast on the pommel horse. 71 | PullUps: A person doing pull-ups. 72 | Punch: A boxer throwing a punch. 73 | PushUps: A person doing push-ups. 74 | Rafting: A team rafting in rapids. 75 | RockClimbingIndoor: A climber reaching for a hold. 76 | RopeClimbing: An athlete climbing a rope. 77 | Rowing: A team rowing in unison. 78 | SalsaSpin: A couple doing a salsa spin. 79 | ShavingBeard: A person shaving their beard. 80 | Shotput: An athlete throwing shot put. 81 | SkateBoarding: A skateboarder doing a trick. 82 | Skiing: A skier descending a slope. 83 | Skijet: A person on a jet ski. 84 | SkyDiving: A skydiver in free fall. 85 | SoccerJuggling: A player juggling a soccer ball. 86 | SoccerPenalty: A player taking a soccer penalty. 87 | StillRings: A gymnast on still rings. 88 | SumoWrestling: Sumo wrestlers in a match. 89 | Surfing: A surfer on a wave. 90 | Swing: A person on a swing. 91 | TableTennisShot: A player making a table tennis shot. 92 | TaiChi: A person practicing Tai Chi. 93 | TennisSwing: A tennis player mid-swing. 94 | ThrowDiscus: An athlete throwing a discus. 95 | TrampolineJumping: A person on a trampoline. 96 | Typing: Hands typing on a keyboard. 97 | UnevenBars: A gymnast on uneven bars. 98 | VolleyballSpiking: A player spiking a volleyball. 99 | WalkingWithDog: A person walking with a dog. 100 | WallPushups: A person doing wall pushups. 101 | WritingOnBoard: A person writing on a board. 102 | YoYo: A person performing yo-yo tricks. 103 | -------------------------------------------------------------------------------- /metrics/clip_score.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTokenizer 2 | 3 | def load_clip(): 4 | clip_model = CLIPModel.from_pretrained("CLIP-ViT-H-14-laion2B-s32B-b79K") 5 | clip_processor = CLIPProcessor.from_pretrained("CLIP-ViT-H-14-laion2B-s32B-b79K") 6 | return clip_model, clip_processor 7 | 8 | def get_clip_score(image_pil,text, clip_model, clip_processor): 9 | inputs = clip_processor(text=text, images=image_pil, return_tensors="pt", padding=True) 10 | if torch.cuda.is_available(): 11 | inputs = {key: value.to("cuda") for key, value in inputs.items()} 12 | outputs = clip_model(**inputs) 13 | logits_per_image = outputs.logits_per_image 14 | return logits_per_image 15 | -------------------------------------------------------------------------------- /metrics/fvd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.utils.data as data 7 | from tqdm import tqdm 8 | 9 | from scripts.fvd.pytorch_i3d import InceptionI3d 10 | import os 11 | 12 | from sklearn.metrics.pairwise import polynomial_kernel 13 | 14 | MAX_BATCH = 8 15 | FVD_SAMPLE_SIZE = 2048 16 | TARGET_RESOLUTION = (224, 224) 17 | 18 | def preprocess(videos, target_resolution): 19 | # videos in {0, ..., 255} as np.uint8 array 20 | b, t, h, w, c = videos.shape 21 | all_frames = torch.FloatTensor(videos).flatten(end_dim=1) # (b * t, h, w, c) 22 | all_frames = all_frames.permute(0, 3, 1, 2).contiguous() # (b * t, c, h, w) 23 | resized_videos = F.interpolate(all_frames, size=target_resolution, 24 | mode='bilinear', align_corners=False) 25 | resized_videos = resized_videos.view(b, t, c, *target_resolution) 26 | output_videos = resized_videos.transpose(1, 2).contiguous() # (b, c, t, *) 27 | scaled_videos = 2. * output_videos / 255. - 1 # [-1, 1] 28 | return scaled_videos 29 | 30 | def get_fvd_logits(videos, i3d, device): 31 | videos = preprocess(videos, TARGET_RESOLUTION) 32 | embeddings = get_logits(i3d, videos, device) 33 | return embeddings 34 | 35 | def load_fvd_model(device): 36 | i3d = InceptionI3d(400, in_channels=3).to(device) 37 | current_dir = os.path.dirname(os.path.abspath(__file__)) 38 | i3d_path = os.path.join(current_dir, 'i3d_pretrained_400.pt') 39 | i3d.load_state_dict(torch.load(i3d_path, map_location=device)) 40 | i3d.eval() 41 | return i3d 42 | 43 | 44 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 45 | def _symmetric_matrix_square_root(mat, eps=1e-10): 46 | u, s, v = torch.svd(mat) 47 | si = torch.where(s < eps, s, torch.sqrt(s)) 48 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 49 | 50 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 51 | def trace_sqrt_product(sigma, sigma_v): 52 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 53 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 54 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 55 | 56 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 57 | def cov(m, rowvar=False): 58 | '''Estimate a covariance matrix given data. 59 | 60 | Covariance indicates the level to which two variables vary together. 61 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 62 | then the covariance matrix element `C_{ij}` is the covariance of 63 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 64 | 65 | Args: 66 | m: A 1-D or 2-D array containing multiple variables and observations. 67 | Each row of `m` represents a variable, and each column a single 68 | observation of all those variables. 69 | rowvar: If `rowvar` is True, then each row represents a 70 | variable, with observations in the columns. Otherwise, the 71 | relationship is transposed: each column represents a variable, 72 | while the rows contain observations. 73 | 74 | Returns: 75 | The covariance matrix of the variables. 76 | ''' 77 | if m.dim() > 2: 78 | raise ValueError('m has more than 2 dimensions') 79 | if m.dim() < 2: 80 | m = m.view(1, -1) 81 | if not rowvar and m.size(0) != 1: 82 | m = m.t() 83 | 84 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 85 | m_center = m - torch.mean(m, dim=1, keepdim=True) 86 | mt = m_center.t() # if complex: mt = m.t().conj() 87 | return fact * m_center.matmul(mt).squeeze() 88 | 89 | 90 | def frechet_distance(x1, x2): 91 | x1 = x1.flatten(start_dim=1) 92 | x2 = x2.flatten(start_dim=1) 93 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 94 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 95 | 96 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 97 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 98 | 99 | mean = torch.sum((m - m_w) ** 2) 100 | fd = trace + mean 101 | return fd 102 | 103 | 104 | def polynomial_mmd(X, Y): 105 | m = X.shape[0] 106 | n = Y.shape[0] 107 | # compute kernels 108 | K_XX = polynomial_kernel(X) 109 | K_YY = polynomial_kernel(Y) 110 | K_XY = polynomial_kernel(X, Y) 111 | # compute mmd distance 112 | K_XX_sum = (K_XX.sum() - np.diagonal(K_XX).sum()) / (m * (m - 1)) 113 | K_YY_sum = (K_YY.sum() - np.diagonal(K_YY).sum()) / (n * (n - 1)) 114 | K_XY_sum = K_XY.sum() / (m * n) 115 | mmd = K_XX_sum + K_YY_sum - 2 * K_XY_sum 116 | return mmd 117 | 118 | 119 | 120 | def get_logits(i3d, videos, device): 121 | MAX_BATCH_INPUT = min(MAX_BATCH, videos.shape[0]) 122 | # assert videos.shape[0] % MAX_BATCH_INPUT == 0 123 | with torch.no_grad(): 124 | logits = [] 125 | for i in tqdm(range(0, videos.shape[0], MAX_BATCH_INPUT)): 126 | batch = videos[i:i + MAX_BATCH_INPUT].to(device) 127 | logits.append(i3d(batch)) 128 | logits = torch.cat(logits, dim=0) 129 | return logits 130 | 131 | 132 | def compute_logits(samples, i3d, device=torch.device("cpu")): 133 | samples = preprocess(samples, (224, 224)) 134 | logits = get_logits(i3d, samples, device) 135 | 136 | return logits 137 | 138 | 139 | def compute_fvd(real, samples, i3d, device=torch.device('cpu')): 140 | # real, samples are (N, T, H, W, C) numpy arrays in np.uint8 141 | real, samples = preprocess(real, (224, 224)), preprocess(samples, (224, 224)) 142 | first_embed = get_logits(i3d, real, device) 143 | second_embed = get_logits(i3d, samples, device) 144 | 145 | return frechet_distance(first_embed, second_embed) 146 | 147 | 148 | -------------------------------------------------------------------------------- /metrics/i3d_pretrained_400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/G-U-N/AnimateLCM/9a5a314f4294340c76f4b49638d78f98d1c6e763/metrics/i3d_pretrained_400.pt -------------------------------------------------------------------------------- /metrics/pytorch_i3d.py: -------------------------------------------------------------------------------- 1 | # https://github.com/piergiaj/pytorch-i3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import numpy as np 8 | 9 | import os 10 | import sys 11 | from collections import OrderedDict 12 | 13 | 14 | class MaxPool3dSamePadding(nn.MaxPool3d): 15 | 16 | def compute_pad(self, dim, s): 17 | if s % self.stride[dim] == 0: 18 | return max(self.kernel_size[dim] - self.stride[dim], 0) 19 | else: 20 | return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) 21 | 22 | def forward(self, x): 23 | # compute 'same' padding 24 | (batch, channel, t, h, w) = x.size() 25 | #print t,h,w 26 | out_t = np.ceil(float(t) / float(self.stride[0])) 27 | out_h = np.ceil(float(h) / float(self.stride[1])) 28 | out_w = np.ceil(float(w) / float(self.stride[2])) 29 | #print out_t, out_h, out_w 30 | pad_t = self.compute_pad(0, t) 31 | pad_h = self.compute_pad(1, h) 32 | pad_w = self.compute_pad(2, w) 33 | #print pad_t, pad_h, pad_w 34 | 35 | pad_t_f = pad_t // 2 36 | pad_t_b = pad_t - pad_t_f 37 | pad_h_f = pad_h // 2 38 | pad_h_b = pad_h - pad_h_f 39 | pad_w_f = pad_w // 2 40 | pad_w_b = pad_w - pad_w_f 41 | 42 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 43 | #print x.size() 44 | #print pad 45 | x = F.pad(x, pad) 46 | return super(MaxPool3dSamePadding, self).forward(x) 47 | 48 | 49 | class Unit3D(nn.Module): 50 | 51 | def __init__(self, in_channels, 52 | output_channels, 53 | kernel_shape=(1, 1, 1), 54 | stride=(1, 1, 1), 55 | padding=0, 56 | activation_fn=F.relu, 57 | use_batch_norm=True, 58 | use_bias=False, 59 | name='unit_3d'): 60 | 61 | """Initializes Unit3D module.""" 62 | super(Unit3D, self).__init__() 63 | 64 | self._output_channels = output_channels 65 | self._kernel_shape = kernel_shape 66 | self._stride = stride 67 | self._use_batch_norm = use_batch_norm 68 | self._activation_fn = activation_fn 69 | self._use_bias = use_bias 70 | self.name = name 71 | self.padding = padding 72 | 73 | self.conv3d = nn.Conv3d(in_channels=in_channels, 74 | out_channels=self._output_channels, 75 | kernel_size=self._kernel_shape, 76 | stride=self._stride, 77 | padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function 78 | bias=self._use_bias) 79 | 80 | if self._use_batch_norm: 81 | self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001) 82 | 83 | def compute_pad(self, dim, s): 84 | if s % self._stride[dim] == 0: 85 | return max(self._kernel_shape[dim] - self._stride[dim], 0) 86 | else: 87 | return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) 88 | 89 | 90 | def forward(self, x): 91 | # compute 'same' padding 92 | (batch, channel, t, h, w) = x.size() 93 | #print t,h,w 94 | out_t = np.ceil(float(t) / float(self._stride[0])) 95 | out_h = np.ceil(float(h) / float(self._stride[1])) 96 | out_w = np.ceil(float(w) / float(self._stride[2])) 97 | #print out_t, out_h, out_w 98 | pad_t = self.compute_pad(0, t) 99 | pad_h = self.compute_pad(1, h) 100 | pad_w = self.compute_pad(2, w) 101 | #print pad_t, pad_h, pad_w 102 | 103 | pad_t_f = pad_t // 2 104 | pad_t_b = pad_t - pad_t_f 105 | pad_h_f = pad_h // 2 106 | pad_h_b = pad_h - pad_h_f 107 | pad_w_f = pad_w // 2 108 | pad_w_b = pad_w - pad_w_f 109 | 110 | pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) 111 | #print x.size() 112 | #print pad 113 | x = F.pad(x, pad) 114 | #print x.size() 115 | 116 | x = self.conv3d(x) 117 | if self._use_batch_norm: 118 | x = self.bn(x) 119 | if self._activation_fn is not None: 120 | x = self._activation_fn(x) 121 | return x 122 | 123 | 124 | 125 | class InceptionModule(nn.Module): 126 | def __init__(self, in_channels, out_channels, name): 127 | super(InceptionModule, self).__init__() 128 | 129 | self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0, 130 | name=name+'/Branch_0/Conv3d_0a_1x1') 131 | self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0, 132 | name=name+'/Branch_1/Conv3d_0a_1x1') 133 | self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3], 134 | name=name+'/Branch_1/Conv3d_0b_3x3') 135 | self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0, 136 | name=name+'/Branch_2/Conv3d_0a_1x1') 137 | self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3], 138 | name=name+'/Branch_2/Conv3d_0b_3x3') 139 | self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], 140 | stride=(1, 1, 1), padding=0) 141 | self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0, 142 | name=name+'/Branch_3/Conv3d_0b_1x1') 143 | self.name = name 144 | 145 | def forward(self, x): 146 | b0 = self.b0(x) 147 | b1 = self.b1b(self.b1a(x)) 148 | b2 = self.b2b(self.b2a(x)) 149 | b3 = self.b3b(self.b3a(x)) 150 | return torch.cat([b0,b1,b2,b3], dim=1) 151 | 152 | 153 | class InceptionI3d(nn.Module): 154 | """Inception-v1 I3D architecture. 155 | The model is introduced in: 156 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset 157 | Joao Carreira, Andrew Zisserman 158 | https://arxiv.org/pdf/1705.07750v1.pdf. 159 | See also the Inception architecture, introduced in: 160 | Going deeper with convolutions 161 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 162 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 163 | http://arxiv.org/pdf/1409.4842v1.pdf. 164 | """ 165 | 166 | # Endpoints of the model in order. During construction, all the endpoints up 167 | # to a designated `final_endpoint` are returned in a dictionary as the 168 | # second return value. 169 | VALID_ENDPOINTS = ( 170 | 'Conv3d_1a_7x7', 171 | 'MaxPool3d_2a_3x3', 172 | 'Conv3d_2b_1x1', 173 | 'Conv3d_2c_3x3', 174 | 'MaxPool3d_3a_3x3', 175 | 'Mixed_3b', 176 | 'Mixed_3c', 177 | 'MaxPool3d_4a_3x3', 178 | 'Mixed_4b', 179 | 'Mixed_4c', 180 | 'Mixed_4d', 181 | 'Mixed_4e', 182 | 'Mixed_4f', 183 | 'MaxPool3d_5a_2x2', 184 | 'Mixed_5b', 185 | 'Mixed_5c', 186 | 'Logits', 187 | 'Predictions', 188 | ) 189 | 190 | FEAT_ENDPOINTS = ( 191 | 'Conv3d_1a_7x7', 192 | 'Conv3d_2c_3x3', 193 | 'Mixed_3c', 194 | 'Mixed_4f', 195 | 'Mixed_5c', 196 | ) 197 | def __init__(self, 198 | num_classes=400, 199 | spatial_squeeze=True, 200 | final_endpoint='Logits', 201 | name='inception_i3d', 202 | in_channels=3, 203 | dropout_keep_prob=0.5, 204 | is_coinrun=False, 205 | ): 206 | """Initializes I3D model instance. 207 | Args: 208 | num_classes: The number of outputs in the logit layer (default 400, which 209 | matches the Kinetics dataset). 210 | spatial_squeeze: Whether to squeeze the spatial dimensions for the logits 211 | before returning (default True). 212 | final_endpoint: The model contains many possible endpoints. 213 | `final_endpoint` specifies the last endpoint for the model to be built 214 | up to. In addition to the output at `final_endpoint`, all the outputs 215 | at endpoints up to `final_endpoint` will also be returned, in a 216 | dictionary. `final_endpoint` must be one of 217 | InceptionI3d.VALID_ENDPOINTS (default 'Logits'). 218 | name: A string (optional). The name of this module. 219 | Raises: 220 | ValueError: if `final_endpoint` is not recognized. 221 | """ 222 | 223 | if final_endpoint not in self.VALID_ENDPOINTS: 224 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 225 | 226 | super(InceptionI3d, self).__init__() 227 | self._num_classes = num_classes 228 | self._spatial_squeeze = spatial_squeeze 229 | self._final_endpoint = final_endpoint 230 | self.logits = None 231 | self.is_coinrun = is_coinrun 232 | 233 | if self._final_endpoint not in self.VALID_ENDPOINTS: 234 | raise ValueError('Unknown final endpoint %s' % self._final_endpoint) 235 | 236 | self.end_points = {} 237 | end_point = 'Conv3d_1a_7x7' 238 | self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7], 239 | stride=(1 if is_coinrun else 2, 2, 2), padding=(3,3,3), name=name+end_point) 240 | if self._final_endpoint == end_point: return 241 | 242 | end_point = 'MaxPool3d_2a_3x3' 243 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 244 | padding=0) 245 | if self._final_endpoint == end_point: return 246 | 247 | end_point = 'Conv3d_2b_1x1' 248 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0, 249 | name=name+end_point) 250 | if self._final_endpoint == end_point: return 251 | 252 | end_point = 'Conv3d_2c_3x3' 253 | self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1, 254 | name=name+end_point) 255 | if self._final_endpoint == end_point: return 256 | 257 | end_point = 'MaxPool3d_3a_3x3' 258 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2), 259 | padding=0) 260 | if self._final_endpoint == end_point: return 261 | 262 | end_point = 'Mixed_3b' 263 | self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point) 264 | if self._final_endpoint == end_point: return 265 | 266 | end_point = 'Mixed_3c' 267 | self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point) 268 | if self._final_endpoint == end_point: return 269 | 270 | end_point = 'MaxPool3d_4a_3x3' 271 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1 if is_coinrun else 3, 3, 3], stride=(1 if is_coinrun else 2, 2, 2), 272 | padding=0) 273 | if self._final_endpoint == end_point: return 274 | 275 | end_point = 'Mixed_4b' 276 | self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point) 277 | if self._final_endpoint == end_point: return 278 | 279 | end_point = 'Mixed_4c' 280 | self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point) 281 | if self._final_endpoint == end_point: return 282 | 283 | end_point = 'Mixed_4d' 284 | self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point) 285 | if self._final_endpoint == end_point: return 286 | 287 | end_point = 'Mixed_4e' 288 | self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point) 289 | if self._final_endpoint == end_point: return 290 | 291 | end_point = 'Mixed_4f' 292 | self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point) 293 | if self._final_endpoint == end_point: return 294 | 295 | end_point = 'MaxPool3d_5a_2x2' 296 | self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(1 if is_coinrun else 2, 2, 2), 297 | padding=0) 298 | if self._final_endpoint == end_point: return 299 | 300 | end_point = 'Mixed_5b' 301 | self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point) 302 | if self._final_endpoint == end_point: return 303 | 304 | end_point = 'Mixed_5c' 305 | self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point) 306 | if self._final_endpoint == end_point: return 307 | 308 | end_point = 'Logits' 309 | self.avg_pool = nn.AvgPool3d(kernel_size=[1, 8, 8] if is_coinrun else [2, 7, 7], 310 | stride=(1, 1, 1)) 311 | self.dropout = nn.Dropout(dropout_keep_prob) 312 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 313 | kernel_shape=[1, 1, 1], 314 | padding=0, 315 | activation_fn=None, 316 | use_batch_norm=False, 317 | use_bias=True, 318 | name='logits') 319 | 320 | self.build() 321 | 322 | 323 | def replace_logits(self, num_classes): 324 | self._num_classes = num_classes 325 | self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes, 326 | kernel_shape=[1, 1, 1], 327 | padding=0, 328 | activation_fn=None, 329 | use_batch_norm=False, 330 | use_bias=True, 331 | name='logits') 332 | 333 | 334 | def build(self): 335 | for k in self.end_points.keys(): 336 | self.add_module(k, self.end_points[k]) 337 | 338 | def forward(self, x): 339 | for end_point in self.VALID_ENDPOINTS: 340 | if end_point in self.end_points: 341 | x = self._modules[end_point](x) # use _modules to work with dataparallel 342 | 343 | x = self.logits(self.dropout(self.avg_pool(x))) 344 | if self._spatial_squeeze: 345 | logits = x.squeeze(3).squeeze(3) 346 | logits = logits.mean(dim=2) 347 | # logits is batch X time X classes, which is what we want to work with 348 | return logits 349 | 350 | 351 | def extract_features(self, x): 352 | for end_point in self.VALID_ENDPOINTS: 353 | if end_point in self.end_points: 354 | x = self._modules[end_point](x) 355 | return self.avg_pool(x) 356 | 357 | 358 | def extract_pre_pool_features(self, x): 359 | for end_point in self.VALID_ENDPOINTS: 360 | if end_point in self.end_points: 361 | x = self._modules[end_point](x) 362 | return x 363 | 364 | 365 | def extract_features_multiscale(self, x): 366 | xs = [] 367 | for end_point in self.VALID_ENDPOINTS: 368 | if end_point in self.end_points: 369 | x = self._modules[end_point](x) 370 | if end_point in self.FEAT_ENDPOINTS: 371 | xs.append(x) 372 | return xs 373 | 374 | --------------------------------------------------------------------------------