├── .gitignore ├── LICENSE.txt ├── README.md ├── __assets__ ├── animations │ ├── compare │ │ ├── ffmpeg │ │ ├── new_0.gif │ │ ├── new_1.gif │ │ ├── new_2.gif │ │ ├── new_3.gif │ │ ├── old_0.gif │ │ ├── old_1.gif │ │ ├── old_2.gif │ │ └── old_3.gif │ ├── model_01 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_02 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_03 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_04 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_05 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_06 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── model_07 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ ├── 04.gif │ │ └── init.jpg │ ├── model_08 │ │ ├── 01.gif │ │ ├── 02.gif │ │ ├── 03.gif │ │ └── 04.gif │ ├── motion_lora │ │ ├── model_01 │ │ │ ├── 01.gif │ │ │ ├── 02.gif │ │ │ ├── 03.gif │ │ │ ├── 04.gif │ │ │ ├── 05.gif │ │ │ ├── 06.gif │ │ │ ├── 07.gif │ │ │ └── 08.gif │ │ └── model_02 │ │ │ ├── 01.gif │ │ │ ├── 02.gif │ │ │ ├── 03.gif │ │ │ ├── 04.gif │ │ │ ├── 05.gif │ │ │ ├── 06.gif │ │ │ ├── 07.gif │ │ │ └── 08.gif │ ├── motion_xl │ │ ├── 01.gif │ │ ├── 02.gif │ │ └── 03.gif │ └── v3 │ │ ├── animation_fireworks.gif │ │ ├── animation_sunset.gif │ │ ├── sketch_boy.gif │ │ └── sketch_city.gif ├── demos │ ├── image │ │ ├── RealisticVision_firework.png │ │ ├── RealisticVision_sunset.png │ │ ├── interpolation_1.png │ │ ├── interpolation_2.png │ │ ├── low_fps_1.png │ │ ├── low_fps_2.png │ │ ├── low_fps_3.png │ │ ├── low_fps_4.png │ │ ├── painting.png │ │ ├── prediction_1.png │ │ ├── prediction_2.png │ │ ├── prediction_3.png │ │ └── prediction_4.png │ └── scribble │ │ ├── scribble_1.png │ │ ├── scribble_2_1.png │ │ ├── scribble_2_2.png │ │ ├── scribble_2_3.png │ │ └── scribble_2_readme.png ├── docs │ ├── animatediff.md │ └── gallery.md └── figs │ ├── adapter_explain.png │ └── gradio.jpg ├── animatediff ├── data │ └── dataset.py ├── models │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── sparse_controlnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ └── pipeline_animation.py └── utils │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── app.py ├── configs ├── inference │ ├── inference-v1.yaml │ ├── inference-v2.yaml │ ├── inference-v3.yaml │ └── sparsectrl │ │ ├── image_condition.yaml │ │ └── latent_condition.yaml ├── prompts │ ├── 1_animate │ │ ├── 1_1_animate_RealisticVision.yaml │ │ ├── 1_2_animate_FilmVelvia.yaml │ │ ├── 1_3_animate_ToonYou.yaml │ │ ├── 1_4_animate_MajicMix.yaml │ │ ├── 1_5_animate_RcnzCartoon.yaml │ │ ├── 1_6_animate_Lyriel.yaml │ │ └── 1_7_animate_Tusun.yaml │ ├── 2_motionlora │ │ └── 2_motionlora_RealisticVision.yaml │ └── 3_sparsectrl │ │ ├── 3_1_sparsectrl_i2v.yaml │ │ ├── 3_2_sparsectrl_rgb_RealisticVision.yaml │ │ └── 3_3_sparsectrl_sketch_RealisticVision.yaml └── training │ └── v1 │ ├── image_finetune.yaml │ └── training.yaml ├── models ├── DreamBooth_LoRA │ └── Put personalized T2I checkpoints here.txt ├── MotionLoRA │ └── Put MotionLoRA checkpoints here.txt ├── Motion_Module │ └── Put motion module checkpoints here.txt └── StableDiffusion │ └── Put diffusers stable-diffusion-v1-5 repo here.txt ├── requirements.txt ├── scripts └── animate.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | *debug* 3 | debugs/ 4 | outputs/ 5 | samples/ 6 | __pycache__/ 7 | ossutil_output/ 8 | .ossutil_checkpoint/ 9 | 10 | scripts/* 11 | !scripts/animate.py 12 | 13 | *.ipynb 14 | *.safetensors 15 | *.ckpt 16 | 17 | models/* 18 | !models/StableDiffusion/ 19 | models/StableDiffusion/* 20 | !models/StableDiffusion/*.txt 21 | !models/Motion_Module/ 22 | !models/Motion_Module/*.txt 23 | !models/DreamBooth_LoRA/ 24 | !models/DreamBooth_LoRA/*.txt 25 | !models/MotionLoRA/ 26 | !models/MotionLoRA/*.txt 27 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnimateDiff 2 | 3 | This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725) [ICLR2024 Spotlight]. 4 | It is a plug-and-play module turning most community text-to-image models into animation generators, without the need of additional training. 5 | 6 | **[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)** 7 |
8 | [Yuwei Guo](https://guoyww.github.io/), 9 | [Ceyuan Yang✝](https://ceyuan.me/), 10 | [Anyi Rao](https://anyirao.com/), 11 | [Zhengyang Liang](https://maxleung99.github.io/), 12 | [Yaohui Wang](https://wyhsirius.github.io/), 13 | [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), 14 | [Maneesh Agrawala](https://graphics.stanford.edu/~maneesh/), 15 | [Dahua Lin](http://dahua.site), 16 | [Bo Dai](https://daibo.info) 17 | (✝Corresponding Author) 18 | [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2307.04725) 19 | [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://animatediff.github.io/) 20 | [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff) 21 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/guoyww/AnimateDiff) 22 | 23 | ***Note:*** The `main` branch is for [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5); for [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), please refer `sdxl-beta` branch. 24 | 25 | 26 | ## Quick Demos 27 | More results can be found in the [Gallery](__assets__/docs/gallery.md). 28 | Some of them are contributed by the community. 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 |

Model:ToonYou

39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
48 |

Model:Realistic Vision V2.0

49 | 50 | 51 | ## Quick Start 52 | ***Note:*** AnimateDiff is also offically supported by Diffusers. 53 | Visit [AnimateDiff Diffusers Tutorial](https://huggingface.co/docs/diffusers/api/pipelines/animatediff) for more details. 54 | *Following instructions is for working with this repository*. 55 | 56 | ***Note:*** For all scripts, checkpoint downloading will be *automatically* handled, so the script running may take longer time when first executed. 57 | 58 | ### 1. Setup repository and environment 59 | 60 | ``` 61 | git clone https://github.com/guoyww/AnimateDiff.git 62 | cd AnimateDiff 63 | 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | ### 2. Launch the sampling script! 68 | The generated samples can be found in `samples/` folder. 69 | 70 | #### 2.1 Generate animations with comunity models 71 | ``` 72 | python -m scripts.animate --config configs/prompts/1_animate/1_1_animate_RealisticVision.yaml 73 | python -m scripts.animate --config configs/prompts/1_animate/1_2_animate_FilmVelvia.yaml 74 | python -m scripts.animate --config configs/prompts/1_animate/1_3_animate_ToonYou.yaml 75 | python -m scripts.animate --config configs/prompts/1_animate/1_4_animate_MajicMix.yaml 76 | python -m scripts.animate --config configs/prompts/1_animate/1_5_animate_RcnzCartoon.yaml 77 | python -m scripts.animate --config configs/prompts/1_animate/1_6_animate_Lyriel.yaml 78 | python -m scripts.animate --config configs/prompts/1_animate/1_7_animate_Tusun.yaml 79 | ``` 80 | 81 | #### 2.2 Generate animation with MotionLoRA control 82 | ``` 83 | python -m scripts.animate --config configs/prompts/2_motionlora/2_motionlora_RealisticVision.yaml 84 | ``` 85 | 86 | #### 2.3 More control with SparseCtrl RGB and sketch 87 | ``` 88 | python -m scripts.animate --config configs/prompts/3_sparsectrl/3_1_sparsectrl_i2v.yaml 89 | python -m scripts.animate --config configs/prompts/3_sparsectrl/3_2_sparsectrl_rgb_RealisticVision.yaml 90 | python -m scripts.animate --config configs/prompts/3_sparsectrl/3_3_sparsectrl_sketch_RealisticVision.yaml 91 | ``` 92 | 93 | #### 2.4 Gradio app 94 | We created a Gradio demo to make AnimateDiff easier to use. 95 | By default, the demo will run at `localhost:7860`. 96 | ``` 97 | python -u app.py 98 | ``` 99 | 100 | 101 | 102 | ## Technical Explanation 103 |
104 | Technical Explanation 105 | 106 | ### AnimateDiff 107 | 108 | **AnimateDiff aims to learn transferable motion priors that can be applied to other variants of Stable Diffusion family.** 109 | To this end, we design the following training pipeline consisting of three stages. 110 | 111 | 112 | 113 | - In **1. Alleviate Negative Effects** stage, we train the **domain adapter**, e.g., `v3_sd15_adapter.ckpt`, to fit defective visual aritfacts (e.g., watermarks) in the training dataset. 114 | This can also benefit the distangled learning of motion and spatial appearance. 115 | By default, the adapter can be removed at inference. It can also be integrated into the model and its effects can be adjusted by a lora scaler. 116 | 117 | - In **2. Learn Motion Priors** stage, we train the **motion module**, e.g., `v3_sd15_mm.ckpt`, to learn the real-world motion patterns from videos. 118 | 119 | - In **3. (optional) Adapt to New Patterns** stage, we train **MotionLoRA**, e.g., `v2_lora_ZoomIn.ckpt`, to efficiently adapt motion module for specific motion patterns (camera zooming, rolling, etc.). 120 | 121 | ### SparseCtrl 122 | 123 | **SparseCtrl aims to add more control to text-to-video models by adopting some sparse inputs (e.g., few RGB images or sketch inputs).** 124 | Its technicall details can be found in the following paper: 125 | 126 | **[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933)** 127 | [Yuwei Guo](https://guoyww.github.io/), 128 | [Ceyuan Yang✝](https://ceyuan.me/), 129 | [Anyi Rao](https://anyirao.com/), 130 | [Maneesh Agrawala](https://graphics.stanford.edu/~maneesh/), 131 | [Dahua Lin](http://dahua.site), 132 | [Bo Dai](https://daibo.info) 133 | (✝Corresponding Author) 134 | [![arXiv](https://img.shields.io/badge/arXiv-2311.16933-b31b1b.svg)](https://arxiv.org/abs/2311.16933) 135 | [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://guoyww.github.io/projects/SparseCtrl/) 136 | 137 |
138 | 139 | 140 | ## Model Versions 141 |
142 | Model Versions 143 | 144 | ### AnimateDiff v3 and SparseCtrl (2023.12) 145 | 146 | In this version, we use **Domain Adapter LoRA** for image model finetuning, which provides more flexiblity at inference. 147 | We also implement two (RGB image/scribble) [SparseCtrl](https://arxiv.org/abs/2311.16933) encoders, which can take abitary number of condition maps to control the animation contents. 148 | 149 |
150 | AnimateDiff v3 Model Zoo 151 | 152 | | Name | HuggingFace | Type | Storage | Description | 153 | | - | - | - | - | - | 154 | | `v3_adapter_sd_v15.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt) | Domain Adapter | 97.4 MB | | 155 | | `v3_sd15_mm.ckpt.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt) | Motion Module | 1.56 GB | | 156 | | `v3_sd15_sparsectrl_scribble.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_sparsectrl_scribble.ckpt) | SparseCtrl Encoder | 1.86 GB | scribble condition | 157 | | `v3_sd15_sparsectrl_rgb.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_sparsectrl_rgb.ckpt) | SparseCtrl Encoder | 1.85 GB | RGB image condition | 158 |
159 | 160 | #### Limitations 161 | 1. Small fickering is noticable; 162 | 2. To stay compatible with comunity models, there is no specific optimizations for general T2V, leading to limited visual quality under this setting; 163 | 3. **(Style Alignment) For usage such as image animation/interpolation, it's recommanded to use images generated by the same community model.** 164 | 165 | #### Demos 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 |
Input (by RealisticVision)AnimationInputAnimation
180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 |
Input ScribbleOutputInput ScribblesOutput
195 | 196 | 197 | ### AnimateDiff SDXL-Beta (2023.11) 198 | 199 | Release the Motion Module (beta version) on SDXL, available at [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link 200 | ) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt 201 | ) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). High resolution videos (i.e., 1024x1024x16 frames with various aspect ratios) could be produced **with/without** personalized models. Inference usually requires ~13GB VRAM and tuned hyperparameters (e.g., sampling steps), depending on the chosen personalized models. 202 | Checkout to the branch [sdxl](https://github.com/guoyww/AnimateDiff/tree/sdxl) for more details of the inference. 203 | 204 |
205 | AnimateDiff SDXL-Beta Model Zoo 206 | 207 | | Name | HuggingFace | Type | Storage Space | 208 | | - | - | - | - | 209 | | `mm_sdxl_v10_beta.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt) | Motion Module | 950 MB | 210 |
211 | 212 | #### Demos 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 |
Original SDXLCommunity SDXLCommunity SDXL
225 | 226 | 227 | ### AnimateDiff v2 (2023.09) 228 | 229 | In this version, the motion module `mm_sd_v15_v2.ckpt` ([Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules)) is trained upon larger resolution and batch size. 230 | We found that the scale-up training significantly helps improve the motion quality and diversity. 231 | We also support **MotionLoRA** of eight basic camera movements. 232 | MotionLoRA checkpoints take up only **77 MB storage per model**, and are available at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). 233 | 234 |
235 | AnimateDiff v2 Model Zoo 236 | 237 | | Name | HuggingFace | Type | Parameter | Storage | 238 | | - | - | - | - | - | 239 | | `mm_sd_v15_v2.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt) | Motion Module | 453 M | 1.7 GB | 240 | | `v2_lora_ZoomIn.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_ZoomIn.ckpt) | MotionLoRA | 19 M | 74 MB | 241 | | `v2_lora_ZoomOut.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_ZoomOut.ckpt) | MotionLoRA | 19 M | 74 MB | 242 | | `v2_lora_PanLeft.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_PanLeft.ckpt) | MotionLoRA | 19 M | 74 MB | 243 | | `v2_lora_PanRight.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_PanRight.ckpt) | MotionLoRA | 19 M | 74 MB | 244 | | `v2_lora_TiltUp.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_TiltUp.ckpt) | MotionLoRA | 19 M | 74 MB | 245 | | `v2_lora_TiltDown.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_TiltDown.ckpt) | MotionLoRA | 19 M | 74 MB | 246 | | `v2_lora_RollingClockwise.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_RollingClockwise.ckpt) | MotionLoRA | 19 M | 74 MB | 247 | | `v2_lora_RollingAnticlockwise.ckpt` | [Link](https://huggingface.co/guoyww/animatediff/blob/main/v2_lora_RollingAnticlockwise.ckpt) | MotionLoRA | 19 M | 74 MB | 248 |
249 | 250 | 251 | #### Demos (MotionLoRA) 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 |
Zoom InZoom OutZoom Pan LeftZoom Pan Right
Tilt UpTilt DownRolling Anti-ClockwiseRolling Clockwise
286 | 287 | 288 | #### Demos (Improved Motions) 289 | Here's a comparison between `mm_sd_v15.ckpt` (left) and improved `mm_sd_v15_v2.ckpt` (right). 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 |
303 | 304 | 305 | ### AnimateDiff v1 (2023.07) 306 | 307 | The first version of AnimateDiff! 308 | 309 |
310 | AnimateDiff v1 Model Zoo 311 | 312 | | Name | HuggingFace | Parameter | Storage Space | 313 | | - | - | - | - | 314 | | mm_sd_v14.ckpt | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v14.ckpt) | 417 M | 1.6 GB | 315 | | mm_sd_v15.ckpt | [Link](https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15.ckpt) | 417 M | 1.6 GB | 316 |
317 | 318 |
319 | 320 | 321 | ## Training 322 | Please check [Steps for Training](__assets__/docs/animatediff.md) for details. 323 | 324 | 325 | ## Related Resources 326 | 327 | AnimateDiff for Stable Diffusion WebUI: [sd-webui-animatediff](https://github.com/continue-revolution/sd-webui-animatediff) (by [@continue-revolution](https://github.com/continue-revolution)) 328 | AnimateDiff for ComfyUI: [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) (by [@Kosinkadink](https://github.com/Kosinkadink)) 329 | Google Colab: [Colab](https://colab.research.google.com/github/camenduru/AnimateDiff-colab/blob/main/AnimateDiff_colab.ipynb) (by [@camenduru](https://github.com/camenduru)) 330 | 331 | 332 | ## Disclaimer 333 | This project is released for academic use. 334 | We disclaim responsibility for user-generated content. 335 | Also, please be advised that our only official website are https://github.com/guoyww/AnimateDiff and https://animatediff.github.io, and all the other websites are NOT associated with us at AnimateDiff. 336 | 337 | 338 | ## Contact Us 339 | Yuwei Guo: [guoyw@ie.cuhk.edu.hk](mailto:guoyw@ie.cuhk.edu.hk) 340 | Ceyuan Yang: [limbo0066@gmail.com](mailto:limbo0066@gmail.com) 341 | Bo Dai: [doubledaibo@gmail.com](mailto:doubledaibo@gmail.com) 342 | 343 | 344 | ## BibTeX 345 | ``` 346 | @article{guo2023animatediff, 347 | title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, 348 | author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Liang, Zhengyang and Wang, Yaohui and Qiao, Yu and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, 349 | journal={International Conference on Learning Representations}, 350 | year={2024} 351 | } 352 | 353 | @article{guo2023sparsectrl, 354 | title={SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models}, 355 | author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Agrawala, Maneesh and Lin, Dahua and Dai, Bo}, 356 | journal={arXiv preprint arXiv:2311.16933}, 357 | year={2023} 358 | } 359 | ``` 360 | 361 | 362 | ## Acknowledgements 363 | Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video). 364 | -------------------------------------------------------------------------------- /__assets__/animations/compare/ffmpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/ffmpeg -------------------------------------------------------------------------------- /__assets__/animations/compare/new_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/new_0.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/new_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/new_1.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/new_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/new_2.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/new_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/new_3.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/old_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/old_0.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/old_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/old_1.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/old_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/old_2.gif -------------------------------------------------------------------------------- /__assets__/animations/compare/old_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/compare/old_3.gif -------------------------------------------------------------------------------- /__assets__/animations/model_01/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_01/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_01/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_01/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_01/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_01/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_01/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_01/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_02/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_02/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_02/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_02/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_02/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_02/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_02/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_02/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_03/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_03/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_03/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_03/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_03/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_03/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_03/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_03/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_04/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_04/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_04/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_04/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_04/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_04/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_04/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_04/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_05/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_05/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_05/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_05/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_05/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_05/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_05/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_05/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_06/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_06/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_06/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_06/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_06/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_06/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_06/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_06/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_07/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_07/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_07/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_07/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_07/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_07/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_07/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_07/04.gif -------------------------------------------------------------------------------- /__assets__/animations/model_07/init.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_07/init.jpg -------------------------------------------------------------------------------- /__assets__/animations/model_08/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_08/01.gif -------------------------------------------------------------------------------- /__assets__/animations/model_08/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_08/02.gif -------------------------------------------------------------------------------- /__assets__/animations/model_08/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_08/03.gif -------------------------------------------------------------------------------- /__assets__/animations/model_08/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/model_08/04.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/01.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/02.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/03.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/04.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/05.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/05.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/06.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/06.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/07.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/07.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_01/08.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_01/08.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/01.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/02.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/03.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/04.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/05.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/05.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/06.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/06.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/07.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/07.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_lora/model_02/08.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_lora/model_02/08.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_xl/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_xl/01.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_xl/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_xl/02.gif -------------------------------------------------------------------------------- /__assets__/animations/motion_xl/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/motion_xl/03.gif -------------------------------------------------------------------------------- /__assets__/animations/v3/animation_fireworks.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/v3/animation_fireworks.gif -------------------------------------------------------------------------------- /__assets__/animations/v3/animation_sunset.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/v3/animation_sunset.gif -------------------------------------------------------------------------------- /__assets__/animations/v3/sketch_boy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/v3/sketch_boy.gif -------------------------------------------------------------------------------- /__assets__/animations/v3/sketch_city.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/animations/v3/sketch_city.gif -------------------------------------------------------------------------------- /__assets__/demos/image/RealisticVision_firework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/RealisticVision_firework.png -------------------------------------------------------------------------------- /__assets__/demos/image/RealisticVision_sunset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/RealisticVision_sunset.png -------------------------------------------------------------------------------- /__assets__/demos/image/interpolation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/interpolation_1.png -------------------------------------------------------------------------------- /__assets__/demos/image/interpolation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/interpolation_2.png -------------------------------------------------------------------------------- /__assets__/demos/image/low_fps_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/low_fps_1.png -------------------------------------------------------------------------------- /__assets__/demos/image/low_fps_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/low_fps_2.png -------------------------------------------------------------------------------- /__assets__/demos/image/low_fps_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/low_fps_3.png -------------------------------------------------------------------------------- /__assets__/demos/image/low_fps_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/low_fps_4.png -------------------------------------------------------------------------------- /__assets__/demos/image/painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/painting.png -------------------------------------------------------------------------------- /__assets__/demos/image/prediction_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/prediction_1.png -------------------------------------------------------------------------------- /__assets__/demos/image/prediction_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/prediction_2.png -------------------------------------------------------------------------------- /__assets__/demos/image/prediction_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/prediction_3.png -------------------------------------------------------------------------------- /__assets__/demos/image/prediction_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/image/prediction_4.png -------------------------------------------------------------------------------- /__assets__/demos/scribble/scribble_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/scribble/scribble_1.png -------------------------------------------------------------------------------- /__assets__/demos/scribble/scribble_2_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/scribble/scribble_2_1.png -------------------------------------------------------------------------------- /__assets__/demos/scribble/scribble_2_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/scribble/scribble_2_2.png -------------------------------------------------------------------------------- /__assets__/demos/scribble/scribble_2_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/scribble/scribble_2_3.png -------------------------------------------------------------------------------- /__assets__/demos/scribble/scribble_2_readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/demos/scribble/scribble_2_readme.png -------------------------------------------------------------------------------- /__assets__/docs/animatediff.md: -------------------------------------------------------------------------------- 1 | ## Steps for Training 2 | 3 | ### Dataset 4 | Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine. 5 | Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`. 6 | 7 | ### Configuration 8 | After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder: 9 | ``` 10 | train_data: 11 | csv_path: [Replace with .csv Annotation File Path] 12 | video_folder: [Replace with Video Folder Path] 13 | sample_size: 256 14 | ``` 15 | Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files. 16 | 17 | ### Training 18 | To finetune the unet's image layers 19 | ``` 20 | torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/image_finetune.yaml 21 | ``` 22 | 23 | To train motion modules 24 | ``` 25 | torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/training.yaml 26 | ``` 27 | -------------------------------------------------------------------------------- /__assets__/docs/gallery.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | Here we demonstrate several best results we found in our experiments. 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 |

Model:ToonYou

13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 |

Model:Counterfeit V3.0

23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
32 |

Model:Realistic Vision V2.0

33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 |
42 |

Model: majicMIX Realistic

43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 |
52 |

Model:RCNZ Cartoon

53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 |
62 |

Model:FilmVelvia

63 | 64 | #### Community Cases 65 | Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚. 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 |
76 |

77 | Character Model:Yoimiya 78 | (with an initial reference image, see WIP fork for the extended implementation.) 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 |
89 |

90 | Character Model:Paimon; 91 | Pose Model:Hold Sign

92 | 93 | 94 | -------------------------------------------------------------------------------- /__assets__/figs/adapter_explain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/figs/adapter_explain.png -------------------------------------------------------------------------------- /__assets__/figs/gradio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/__assets__/figs/gradio.jpg -------------------------------------------------------------------------------- /animatediff/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os, io, csv, math, random 2 | import numpy as np 3 | from einops import rearrange 4 | from decord import VideoReader 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torch.utils.data.dataset import Dataset 9 | from animatediff.utils.util import zero_rank_print 10 | 11 | 12 | 13 | class WebVid10M(Dataset): 14 | def __init__( 15 | self, 16 | csv_path, video_folder, 17 | sample_size=256, sample_stride=4, sample_n_frames=16, 18 | is_image=False, 19 | ): 20 | zero_rank_print(f"loading annotations from {csv_path} ...") 21 | with open(csv_path, 'r') as csvfile: 22 | self.dataset = list(csv.DictReader(csvfile)) 23 | self.length = len(self.dataset) 24 | zero_rank_print(f"data scale: {self.length}") 25 | 26 | self.video_folder = video_folder 27 | self.sample_stride = sample_stride 28 | self.sample_n_frames = sample_n_frames 29 | self.is_image = is_image 30 | 31 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 32 | self.pixel_transforms = transforms.Compose([ 33 | transforms.RandomHorizontalFlip(), 34 | transforms.Resize(sample_size[0]), 35 | transforms.CenterCrop(sample_size), 36 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 37 | ]) 38 | 39 | def get_batch(self, idx): 40 | video_dict = self.dataset[idx] 41 | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] 42 | 43 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 44 | video_reader = VideoReader(video_dir) 45 | video_length = len(video_reader) 46 | 47 | if not self.is_image: 48 | clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) 49 | start_idx = random.randint(0, video_length - clip_length) 50 | batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) 51 | else: 52 | batch_index = [random.randint(0, video_length - 1)] 53 | 54 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() 55 | pixel_values = pixel_values / 255. 56 | del video_reader 57 | 58 | if self.is_image: 59 | pixel_values = pixel_values[0] 60 | 61 | return pixel_values, name 62 | 63 | def __len__(self): 64 | return self.length 65 | 66 | def __getitem__(self, idx): 67 | while True: 68 | try: 69 | pixel_values, name = self.get_batch(idx) 70 | break 71 | 72 | except Exception as e: 73 | idx = random.randint(0, self.length-1) 74 | 75 | pixel_values = self.pixel_transforms(pixel_values) 76 | sample = dict(pixel_values=pixel_values, text=name) 77 | return sample 78 | 79 | 80 | 81 | if __name__ == "__main__": 82 | from animatediff.utils.util import save_videos_grid 83 | 84 | dataset = WebVid10M( 85 | csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", 86 | video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", 87 | sample_size=256, 88 | sample_stride=4, sample_n_frames=16, 89 | is_image=True, 90 | ) 91 | import pdb 92 | pdb.set_trace() 93 | 94 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) 95 | for idx, batch in enumerate(dataloader): 96 | print(batch["pixel_values"].shape, len(batch["text"])) 97 | # for i in range(batch["pixel_values"].shape[0]): 98 | # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) 99 | -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | import pdb 18 | 19 | @dataclass 20 | class Transformer3DModelOutput(BaseOutput): 21 | sample: torch.FloatTensor 22 | 23 | 24 | if is_xformers_available(): 25 | import xformers 26 | import xformers.ops 27 | else: 28 | xformers = None 29 | 30 | 31 | class Transformer3DModel(ModelMixin, ConfigMixin): 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | activation_fn: str = "geglu", 44 | num_embeds_ada_norm: Optional[int] = None, 45 | use_linear_projection: bool = False, 46 | only_cross_attention: bool = False, 47 | upcast_attention: bool = False, 48 | 49 | unet_use_cross_frame_attention=None, 50 | unet_use_temporal_attention=None, 51 | ): 52 | super().__init__() 53 | self.use_linear_projection = use_linear_projection 54 | self.num_attention_heads = num_attention_heads 55 | self.attention_head_dim = attention_head_dim 56 | inner_dim = num_attention_heads * attention_head_dim 57 | 58 | # Define input layers 59 | self.in_channels = in_channels 60 | 61 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 62 | if use_linear_projection: 63 | self.proj_in = nn.Linear(in_channels, inner_dim) 64 | else: 65 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 66 | 67 | # Define transformers blocks 68 | self.transformer_blocks = nn.ModuleList( 69 | [ 70 | BasicTransformerBlock( 71 | inner_dim, 72 | num_attention_heads, 73 | attention_head_dim, 74 | dropout=dropout, 75 | cross_attention_dim=cross_attention_dim, 76 | activation_fn=activation_fn, 77 | num_embeds_ada_norm=num_embeds_ada_norm, 78 | attention_bias=attention_bias, 79 | only_cross_attention=only_cross_attention, 80 | upcast_attention=upcast_attention, 81 | 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 94 | 95 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 96 | # Input 97 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 98 | video_length = hidden_states.shape[2] 99 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 100 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 101 | 102 | batch, channel, height, weight = hidden_states.shape 103 | residual = hidden_states 104 | 105 | hidden_states = self.norm(hidden_states) 106 | if not self.use_linear_projection: 107 | hidden_states = self.proj_in(hidden_states) 108 | inner_dim = hidden_states.shape[1] 109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 110 | else: 111 | inner_dim = hidden_states.shape[1] 112 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 113 | hidden_states = self.proj_in(hidden_states) 114 | 115 | # Blocks 116 | for block in self.transformer_blocks: 117 | hidden_states = block( 118 | hidden_states, 119 | encoder_hidden_states=encoder_hidden_states, 120 | timestep=timestep, 121 | video_length=video_length 122 | ) 123 | 124 | # Output 125 | if not self.use_linear_projection: 126 | hidden_states = ( 127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | ) 129 | hidden_states = self.proj_out(hidden_states) 130 | else: 131 | hidden_states = self.proj_out(hidden_states) 132 | hidden_states = ( 133 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 134 | ) 135 | 136 | output = hidden_states + residual 137 | 138 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 139 | if not return_dict: 140 | return (output,) 141 | 142 | return Transformer3DModelOutput(sample=output) 143 | 144 | 145 | class BasicTransformerBlock(nn.Module): 146 | def __init__( 147 | self, 148 | dim: int, 149 | num_attention_heads: int, 150 | attention_head_dim: int, 151 | dropout=0.0, 152 | cross_attention_dim: Optional[int] = None, 153 | activation_fn: str = "geglu", 154 | num_embeds_ada_norm: Optional[int] = None, 155 | attention_bias: bool = False, 156 | only_cross_attention: bool = False, 157 | upcast_attention: bool = False, 158 | 159 | unet_use_cross_frame_attention = None, 160 | unet_use_temporal_attention = None, 161 | ): 162 | super().__init__() 163 | self.only_cross_attention = only_cross_attention 164 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 165 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 166 | self.unet_use_temporal_attention = unet_use_temporal_attention 167 | 168 | # SC-Attn 169 | assert unet_use_cross_frame_attention is not None 170 | if unet_use_cross_frame_attention: 171 | self.attn1 = SparseCausalAttention2D( 172 | query_dim=dim, 173 | heads=num_attention_heads, 174 | dim_head=attention_head_dim, 175 | dropout=dropout, 176 | bias=attention_bias, 177 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 178 | upcast_attention=upcast_attention, 179 | ) 180 | else: 181 | self.attn1 = CrossAttention( 182 | query_dim=dim, 183 | heads=num_attention_heads, 184 | dim_head=attention_head_dim, 185 | dropout=dropout, 186 | bias=attention_bias, 187 | upcast_attention=upcast_attention, 188 | ) 189 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 190 | 191 | # Cross-Attn 192 | if cross_attention_dim is not None: 193 | self.attn2 = CrossAttention( 194 | query_dim=dim, 195 | cross_attention_dim=cross_attention_dim, 196 | heads=num_attention_heads, 197 | dim_head=attention_head_dim, 198 | dropout=dropout, 199 | bias=attention_bias, 200 | upcast_attention=upcast_attention, 201 | ) 202 | else: 203 | self.attn2 = None 204 | 205 | if cross_attention_dim is not None: 206 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 207 | else: 208 | self.norm2 = None 209 | 210 | # Feed-forward 211 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 212 | self.norm3 = nn.LayerNorm(dim) 213 | 214 | # Temp-Attn 215 | assert unet_use_temporal_attention is not None 216 | if unet_use_temporal_attention: 217 | self.attn_temp = CrossAttention( 218 | query_dim=dim, 219 | heads=num_attention_heads, 220 | dim_head=attention_head_dim, 221 | dropout=dropout, 222 | bias=attention_bias, 223 | upcast_attention=upcast_attention, 224 | ) 225 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 226 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 227 | 228 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 229 | if not is_xformers_available(): 230 | print("Here is how to install it") 231 | raise ModuleNotFoundError( 232 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 233 | " xformers", 234 | name="xformers", 235 | ) 236 | elif not torch.cuda.is_available(): 237 | raise ValueError( 238 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 239 | " available for GPU " 240 | ) 241 | else: 242 | try: 243 | # Make sure we can run the memory efficient attention 244 | _ = xformers.ops.memory_efficient_attention( 245 | torch.randn((1, 2, 40), device="cuda"), 246 | torch.randn((1, 2, 40), device="cuda"), 247 | torch.randn((1, 2, 40), device="cuda"), 248 | ) 249 | except Exception as e: 250 | raise e 251 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 252 | if self.attn2 is not None: 253 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 254 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 255 | 256 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 257 | # SparseCausal-Attention 258 | norm_hidden_states = ( 259 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 260 | ) 261 | 262 | # if self.only_cross_attention: 263 | # hidden_states = ( 264 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 265 | # ) 266 | # else: 267 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 268 | 269 | # pdb.set_trace() 270 | if self.unet_use_cross_frame_attention: 271 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 272 | else: 273 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 274 | 275 | if self.attn2 is not None: 276 | # Cross-Attention 277 | norm_hidden_states = ( 278 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 279 | ) 280 | hidden_states = ( 281 | self.attn2( 282 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 283 | ) 284 | + hidden_states 285 | ) 286 | 287 | # Feed-forward 288 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 289 | 290 | # Temporal-Attention 291 | if self.unet_use_temporal_attention: 292 | d = hidden_states.shape[1] 293 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 294 | norm_hidden_states = ( 295 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 296 | ) 297 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 298 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 299 | 300 | return hidden_states 301 | -------------------------------------------------------------------------------- /animatediff/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward 15 | 16 | from einops import rearrange, repeat 17 | import math 18 | 19 | 20 | def zero_module(module): 21 | # Zero out the parameters of a module and return it. 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | 26 | 27 | @dataclass 28 | class TemporalTransformer3DModelOutput(BaseOutput): 29 | sample: torch.FloatTensor 30 | 31 | 32 | if is_xformers_available(): 33 | import xformers 34 | import xformers.ops 35 | else: 36 | xformers = None 37 | 38 | 39 | def get_motion_module( 40 | in_channels, 41 | motion_module_type: str, 42 | motion_module_kwargs: dict 43 | ): 44 | if motion_module_type == "Vanilla": 45 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 46 | else: 47 | raise ValueError 48 | 49 | 50 | class VanillaTemporalModule(nn.Module): 51 | def __init__( 52 | self, 53 | in_channels, 54 | num_attention_heads = 8, 55 | num_transformer_block = 2, 56 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 57 | cross_frame_attention_mode = None, 58 | temporal_position_encoding = False, 59 | temporal_position_encoding_max_len = 24, 60 | temporal_attention_dim_div = 1, 61 | zero_initialize = True, 62 | ): 63 | super().__init__() 64 | 65 | self.temporal_transformer = TemporalTransformer3DModel( 66 | in_channels=in_channels, 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 69 | num_layers=num_transformer_block, 70 | attention_block_types=attention_block_types, 71 | cross_frame_attention_mode=cross_frame_attention_mode, 72 | temporal_position_encoding=temporal_position_encoding, 73 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 74 | ) 75 | 76 | if zero_initialize: 77 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 78 | 79 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 80 | hidden_states = input_tensor 81 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 82 | 83 | output = hidden_states 84 | return output 85 | 86 | 87 | class TemporalTransformer3DModel(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | num_attention_heads, 92 | attention_head_dim, 93 | 94 | num_layers, 95 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 96 | dropout = 0.0, 97 | norm_num_groups = 32, 98 | cross_attention_dim = 768, 99 | activation_fn = "geglu", 100 | attention_bias = False, 101 | upcast_attention = False, 102 | 103 | cross_frame_attention_mode = None, 104 | temporal_position_encoding = False, 105 | temporal_position_encoding_max_len = 24, 106 | ): 107 | super().__init__() 108 | 109 | inner_dim = num_attention_heads * attention_head_dim 110 | 111 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 112 | self.proj_in = nn.Linear(in_channels, inner_dim) 113 | 114 | self.transformer_blocks = nn.ModuleList( 115 | [ 116 | TemporalTransformerBlock( 117 | dim=inner_dim, 118 | num_attention_heads=num_attention_heads, 119 | attention_head_dim=attention_head_dim, 120 | attention_block_types=attention_block_types, 121 | dropout=dropout, 122 | norm_num_groups=norm_num_groups, 123 | cross_attention_dim=cross_attention_dim, 124 | activation_fn=activation_fn, 125 | attention_bias=attention_bias, 126 | upcast_attention=upcast_attention, 127 | cross_frame_attention_mode=cross_frame_attention_mode, 128 | temporal_position_encoding=temporal_position_encoding, 129 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 130 | ) 131 | for d in range(num_layers) 132 | ] 133 | ) 134 | self.proj_out = nn.Linear(inner_dim, in_channels) 135 | 136 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 137 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 138 | video_length = hidden_states.shape[2] 139 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 140 | 141 | batch, channel, height, weight = hidden_states.shape 142 | residual = hidden_states 143 | 144 | hidden_states = self.norm(hidden_states) 145 | inner_dim = hidden_states.shape[1] 146 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 147 | hidden_states = self.proj_in(hidden_states) 148 | 149 | # Transformer Blocks 150 | for block in self.transformer_blocks: 151 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 152 | 153 | # output 154 | hidden_states = self.proj_out(hidden_states) 155 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 156 | 157 | output = hidden_states + residual 158 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 159 | 160 | return output 161 | 162 | 163 | class TemporalTransformerBlock(nn.Module): 164 | def __init__( 165 | self, 166 | dim, 167 | num_attention_heads, 168 | attention_head_dim, 169 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 170 | dropout = 0.0, 171 | norm_num_groups = 32, 172 | cross_attention_dim = 768, 173 | activation_fn = "geglu", 174 | attention_bias = False, 175 | upcast_attention = False, 176 | cross_frame_attention_mode = None, 177 | temporal_position_encoding = False, 178 | temporal_position_encoding_max_len = 24, 179 | ): 180 | super().__init__() 181 | 182 | attention_blocks = [] 183 | norms = [] 184 | 185 | for block_name in attention_block_types: 186 | attention_blocks.append( 187 | VersatileAttention( 188 | attention_mode=block_name.split("_")[0], 189 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 190 | 191 | query_dim=dim, 192 | heads=num_attention_heads, 193 | dim_head=attention_head_dim, 194 | dropout=dropout, 195 | bias=attention_bias, 196 | upcast_attention=upcast_attention, 197 | 198 | cross_frame_attention_mode=cross_frame_attention_mode, 199 | temporal_position_encoding=temporal_position_encoding, 200 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 201 | ) 202 | ) 203 | norms.append(nn.LayerNorm(dim)) 204 | 205 | self.attention_blocks = nn.ModuleList(attention_blocks) 206 | self.norms = nn.ModuleList(norms) 207 | 208 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 209 | self.ff_norm = nn.LayerNorm(dim) 210 | 211 | 212 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 213 | for attention_block, norm in zip(self.attention_blocks, self.norms): 214 | norm_hidden_states = norm(hidden_states) 215 | hidden_states = attention_block( 216 | norm_hidden_states, 217 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 218 | video_length=video_length, 219 | ) + hidden_states 220 | 221 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 222 | 223 | output = hidden_states 224 | return output 225 | 226 | 227 | class PositionalEncoding(nn.Module): 228 | def __init__( 229 | self, 230 | d_model, 231 | dropout = 0., 232 | max_len = 24 233 | ): 234 | super().__init__() 235 | self.dropout = nn.Dropout(p=dropout) 236 | position = torch.arange(max_len).unsqueeze(1) 237 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 238 | pe = torch.zeros(1, max_len, d_model) 239 | pe[0, :, 0::2] = torch.sin(position * div_term) 240 | pe[0, :, 1::2] = torch.cos(position * div_term) 241 | self.register_buffer('pe', pe, persistent=False) 242 | 243 | def forward(self, x): 244 | x = x + self.pe[:, :x.size(1)] 245 | return self.dropout(x) 246 | 247 | 248 | class VersatileAttention(CrossAttention): 249 | def __init__( 250 | self, 251 | attention_mode = None, 252 | cross_frame_attention_mode = None, 253 | temporal_position_encoding = False, 254 | temporal_position_encoding_max_len = 32, 255 | *args, **kwargs 256 | ): 257 | super().__init__(*args, **kwargs) 258 | assert attention_mode == "Temporal" 259 | 260 | self.attention_mode = attention_mode 261 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 262 | 263 | self.pos_encoder = PositionalEncoding( 264 | kwargs["query_dim"], 265 | dropout=0., 266 | max_len=temporal_position_encoding_max_len 267 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 268 | 269 | def extra_repr(self): 270 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 271 | 272 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 273 | batch_size, sequence_length, _ = hidden_states.shape 274 | 275 | if self.attention_mode == "Temporal": 276 | d = hidden_states.shape[1] 277 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 278 | 279 | if self.pos_encoder is not None: 280 | hidden_states = self.pos_encoder(hidden_states) 281 | 282 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 283 | else: 284 | raise NotImplementedError 285 | 286 | encoder_hidden_states = encoder_hidden_states 287 | 288 | if self.group_norm is not None: 289 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 290 | 291 | query = self.to_q(hidden_states) 292 | dim = query.shape[-1] 293 | query = self.reshape_heads_to_batch_dim(query) 294 | 295 | if self.added_kv_proj_dim is not None: 296 | raise NotImplementedError 297 | 298 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 299 | key = self.to_k(encoder_hidden_states) 300 | value = self.to_v(encoder_hidden_states) 301 | 302 | key = self.reshape_heads_to_batch_dim(key) 303 | value = self.reshape_heads_to_batch_dim(value) 304 | 305 | if attention_mask is not None: 306 | if attention_mask.shape[-1] != query.shape[1]: 307 | target_length = query.shape[1] 308 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 309 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 310 | 311 | # attention, what we cannot get enough of 312 | if self._use_memory_efficient_attention_xformers: 313 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 314 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 315 | hidden_states = hidden_states.to(query.dtype) 316 | else: 317 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 318 | hidden_states = self._attention(query, key, value, attention_mask) 319 | else: 320 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 321 | 322 | # linear proj 323 | hidden_states = self.to_out[0](hidden_states) 324 | 325 | # dropout 326 | hidden_states = self.to_out[1](hidden_states) 327 | 328 | if self.attention_mode == "Temporal": 329 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 330 | 331 | return hidden_states 332 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 34 | super().__init__() 35 | self.channels = channels 36 | self.out_channels = out_channels or channels 37 | self.use_conv = use_conv 38 | self.use_conv_transpose = use_conv_transpose 39 | self.name = name 40 | 41 | conv = None 42 | if use_conv_transpose: 43 | raise NotImplementedError 44 | elif use_conv: 45 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 66 | else: 67 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 68 | 69 | # If the input is bfloat16, we cast back to bfloat16 70 | if dtype == torch.bfloat16: 71 | hidden_states = hidden_states.to(dtype) 72 | 73 | # if self.use_conv: 74 | # if self.name == "conv": 75 | # hidden_states = self.conv(hidden_states) 76 | # else: 77 | # hidden_states = self.Conv2d_0(hidden_states) 78 | hidden_states = self.conv(hidden_states) 79 | 80 | return hidden_states 81 | 82 | 83 | class Downsample3D(nn.Module): 84 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 85 | super().__init__() 86 | self.channels = channels 87 | self.out_channels = out_channels or channels 88 | self.use_conv = use_conv 89 | self.padding = padding 90 | stride = 2 91 | self.name = name 92 | 93 | if use_conv: 94 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 95 | else: 96 | raise NotImplementedError 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | use_inflated_groupnorm=False, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | assert use_inflated_groupnorm != None 142 | if use_inflated_groupnorm: 143 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | else: 145 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 146 | 147 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 148 | 149 | if temb_channels is not None: 150 | if self.time_embedding_norm == "default": 151 | time_emb_proj_out_channels = out_channels 152 | elif self.time_embedding_norm == "scale_shift": 153 | time_emb_proj_out_channels = out_channels * 2 154 | else: 155 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 156 | 157 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 158 | else: 159 | self.time_emb_proj = None 160 | 161 | if use_inflated_groupnorm: 162 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | else: 164 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 165 | 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | def forward(self, input_tensor, temb): 183 | hidden_states = input_tensor 184 | 185 | hidden_states = self.norm1(hidden_states) 186 | hidden_states = self.nonlinearity(hidden_states) 187 | 188 | hidden_states = self.conv1(hidden_states) 189 | 190 | if temb is not None: 191 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 192 | 193 | if temb is not None and self.time_embedding_norm == "default": 194 | hidden_states = hidden_states + temb 195 | 196 | hidden_states = self.norm2(hidden_states) 197 | 198 | if temb is not None and self.time_embedding_norm == "scale_shift": 199 | scale, shift = torch.chunk(temb, 2, dim=1) 200 | hidden_states = hidden_states * (1 + scale) + shift 201 | 202 | hidden_states = self.nonlinearity(hidden_states) 203 | 204 | hidden_states = self.dropout(hidden_states) 205 | hidden_states = self.conv2(hidden_states) 206 | 207 | if self.conv_shortcut is not None: 208 | input_tensor = self.conv_shortcut(input_tensor) 209 | 210 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 211 | 212 | return output_tensor 213 | 214 | 215 | class Mish(torch.nn.Module): 216 | def forward(self, hidden_states): 217 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /animatediff/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Changes were made to this source code by Yuwei Guo. 17 | """ Conversion script for the LoRA's safetensors checkpoints. """ 18 | 19 | import argparse 20 | 21 | import torch 22 | from safetensors.torch import load_file 23 | 24 | from diffusers import StableDiffusionPipeline 25 | 26 | 27 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 28 | # directly update weight in diffusers model 29 | for key in state_dict: 30 | # only process lora down key 31 | if "up." in key: continue 32 | 33 | up_key = key.replace(".down.", ".up.") 34 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 35 | model_key = model_key.replace("to_out.", "to_out.0.") 36 | layer_infos = model_key.split(".")[:-1] 37 | 38 | curr_layer = pipeline.unet 39 | while len(layer_infos) > 0: 40 | temp_name = layer_infos.pop(0) 41 | curr_layer = curr_layer.__getattr__(temp_name) 42 | 43 | weight_down = state_dict[key] 44 | weight_up = state_dict[up_key] 45 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 46 | 47 | return pipeline 48 | 49 | 50 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 51 | # load base model 52 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 53 | 54 | # load LoRA weight from .safetensors 55 | # state_dict = load_file(checkpoint_path) 56 | 57 | visited = [] 58 | 59 | # directly update weight in diffusers model 60 | for key in state_dict: 61 | # it is suggested to print out the key, it usually will be something like below 62 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 63 | 64 | # as we have set the alpha beforehand, so just skip 65 | if ".alpha" in key or key in visited: 66 | continue 67 | 68 | if "text" in key: 69 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 70 | curr_layer = pipeline.text_encoder 71 | else: 72 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 73 | curr_layer = pipeline.unet 74 | 75 | # find the target layer 76 | temp_name = layer_infos.pop(0) 77 | while len(layer_infos) > -1: 78 | try: 79 | curr_layer = curr_layer.__getattr__(temp_name) 80 | if len(layer_infos) > 0: 81 | temp_name = layer_infos.pop(0) 82 | elif len(layer_infos) == 0: 83 | break 84 | except Exception: 85 | if len(temp_name) > 0: 86 | temp_name += "_" + layer_infos.pop(0) 87 | else: 88 | temp_name = layer_infos.pop(0) 89 | 90 | pair_keys = [] 91 | if "lora_down" in key: 92 | pair_keys.append(key.replace("lora_down", "lora_up")) 93 | pair_keys.append(key) 94 | else: 95 | pair_keys.append(key) 96 | pair_keys.append(key.replace("lora_up", "lora_down")) 97 | 98 | # update weight 99 | if len(state_dict[pair_keys[0]].shape) == 4: 100 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 101 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 102 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 103 | else: 104 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 105 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 106 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 107 | 108 | # update visited list 109 | for item in pair_keys: 110 | visited.append(item) 111 | 112 | return pipeline 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | 118 | parser.add_argument( 119 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 120 | ) 121 | parser.add_argument( 122 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 123 | ) 124 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 125 | parser.add_argument( 126 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 127 | ) 128 | parser.add_argument( 129 | "--lora_prefix_text_encoder", 130 | default="lora_te", 131 | type=str, 132 | help="The prefix of text encoder weight in safetensors", 133 | ) 134 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 135 | parser.add_argument( 136 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 137 | ) 138 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 139 | 140 | args = parser.parse_args() 141 | 142 | base_model_path = args.base_model_path 143 | checkpoint_path = args.checkpoint_path 144 | dump_path = args.dump_path 145 | lora_prefix_unet = args.lora_prefix_unet 146 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 147 | alpha = args.alpha 148 | 149 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 150 | 151 | pipe = pipe.to(args.device) 152 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 153 | -------------------------------------------------------------------------------- /animatediff/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from huggingface_hub import snapshot_download 11 | from safetensors import safe_open 12 | from tqdm import tqdm 13 | from einops import rearrange 14 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 15 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora 16 | 17 | 18 | MOTION_MODULES = [ 19 | "mm_sd_v14.ckpt", 20 | "mm_sd_v15.ckpt", 21 | "mm_sd_v15_v2.ckpt", 22 | "v3_sd15_mm.ckpt", 23 | ] 24 | 25 | ADAPTERS = [ 26 | # "mm_sd_v14.ckpt", 27 | # "mm_sd_v15.ckpt", 28 | # "mm_sd_v15_v2.ckpt", 29 | # "mm_sdxl_v10_beta.ckpt", 30 | "v2_lora_PanLeft.ckpt", 31 | "v2_lora_PanRight.ckpt", 32 | "v2_lora_RollingAnticlockwise.ckpt", 33 | "v2_lora_RollingClockwise.ckpt", 34 | "v2_lora_TiltDown.ckpt", 35 | "v2_lora_TiltUp.ckpt", 36 | "v2_lora_ZoomIn.ckpt", 37 | "v2_lora_ZoomOut.ckpt", 38 | "v3_sd15_adapter.ckpt", 39 | # "v3_sd15_mm.ckpt", 40 | "v3_sd15_sparsectrl_rgb.ckpt", 41 | "v3_sd15_sparsectrl_scribble.ckpt", 42 | ] 43 | 44 | BACKUP_DREAMBOOTH_MODELS = [ 45 | "realisticVisionV60B1_v51VAE.safetensors", 46 | "majicmixRealistic_v4.safetensors", 47 | "leosamsFilmgirlUltra_velvia20Lora.safetensors", 48 | "toonyou_beta3.safetensors", 49 | "majicmixRealistic_v5Preview.safetensors", 50 | "rcnzCartoon3d_v10.safetensors", 51 | "lyriel_v16.safetensors", 52 | "leosamsHelloworldXL_filmGrain20.safetensors", 53 | "TUSUN.safetensors", 54 | ] 55 | 56 | 57 | def zero_rank_print(s): 58 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 59 | 60 | 61 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 62 | videos = rearrange(videos, "b c t h w -> t b c h w") 63 | outputs = [] 64 | for x in videos: 65 | x = torchvision.utils.make_grid(x, nrow=n_rows) 66 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 67 | if rescale: 68 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 69 | x = (x * 255).numpy().astype(np.uint8) 70 | outputs.append(x) 71 | 72 | os.makedirs(os.path.dirname(path), exist_ok=True) 73 | imageio.mimsave(path, outputs, fps=fps) 74 | 75 | 76 | def auto_download(local_path, is_dreambooth_lora=False): 77 | hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" 78 | folder, filename = os.path.split(local_path) 79 | 80 | if not os.path.exists(local_path): 81 | print(f"local file {local_path} does not exist. trying to download from {hf_repo}") 82 | 83 | if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" 84 | else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" 85 | 86 | folder = "." if folder == "" else folder 87 | os.makedirs(folder, exist_ok=True) 88 | snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) 89 | 90 | 91 | def load_weights( 92 | animation_pipeline, 93 | # motion module 94 | motion_module_path = "", 95 | motion_module_lora_configs = [], 96 | # domain adapter 97 | adapter_lora_path = "", 98 | adapter_lora_scale = 1.0, 99 | # image layers 100 | dreambooth_model_path = "", 101 | lora_model_path = "", 102 | lora_alpha = 0.8, 103 | ): 104 | # motion module 105 | unet_state_dict = {} 106 | if motion_module_path != "": 107 | auto_download(motion_module_path, is_dreambooth_lora=False) 108 | 109 | print(f"load motion module from {motion_module_path}") 110 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 111 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 112 | # filter parameters 113 | for name, param in motion_module_state_dict.items(): 114 | if not "motion_modules." in name: continue 115 | if "pos_encoder.pe" in name: continue 116 | unet_state_dict.update({name: param}) 117 | unet_state_dict.pop("animatediff_config", "") 118 | 119 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 120 | assert len(unexpected) == 0 121 | del unet_state_dict 122 | 123 | # base model 124 | if dreambooth_model_path != "": 125 | auto_download(dreambooth_model_path, is_dreambooth_lora=True) 126 | 127 | print(f"load dreambooth model from {dreambooth_model_path}") 128 | if dreambooth_model_path.endswith(".safetensors"): 129 | dreambooth_state_dict = {} 130 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 131 | for key in f.keys(): 132 | dreambooth_state_dict[key] = f.get_tensor(key) 133 | elif dreambooth_model_path.endswith(".ckpt"): 134 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 135 | 136 | # 1. vae 137 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 138 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 139 | # 2. unet 140 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 141 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 142 | # 3. text_model 143 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 144 | del dreambooth_state_dict 145 | 146 | # lora layers 147 | if lora_model_path != "": 148 | auto_download(lora_model_path, is_dreambooth_lora=True) 149 | 150 | print(f"load lora model from {lora_model_path}") 151 | assert lora_model_path.endswith(".safetensors") 152 | lora_state_dict = {} 153 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 154 | for key in f.keys(): 155 | lora_state_dict[key] = f.get_tensor(key) 156 | 157 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 158 | del lora_state_dict 159 | 160 | # domain adapter lora 161 | if adapter_lora_path != "": 162 | auto_download(adapter_lora_path, is_dreambooth_lora=False) 163 | 164 | print(f"load domain lora from {adapter_lora_path}") 165 | domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 166 | domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 167 | domain_lora_state_dict.pop("animatediff_config", "") 168 | 169 | animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) 170 | 171 | # motion module lora 172 | for motion_module_lora_config in motion_module_lora_configs: 173 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 174 | 175 | auto_download(path, is_dreambooth_lora=False) 176 | 177 | print(f"load motion LoRA from {path}") 178 | motion_lora_state_dict = torch.load(path, map_location="cpu") 179 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 180 | motion_lora_state_dict.pop("animatediff_config", "") 181 | 182 | animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) 183 | 184 | return animation_pipeline 185 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import torch 5 | import random 6 | 7 | import gradio as gr 8 | from glob import glob 9 | from omegaconf import OmegaConf 10 | from datetime import datetime 11 | from safetensors import safe_open 12 | 13 | from diffusers import AutoencoderKL 14 | from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler 15 | from diffusers.utils.import_utils import is_xformers_available 16 | from transformers import CLIPTextModel, CLIPTokenizer 17 | 18 | from animatediff.models.unet import UNet3DConditionModel 19 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 20 | from animatediff.utils.util import save_videos_grid, load_weights, auto_download, MOTION_MODULES, BACKUP_DREAMBOOTH_MODELS 21 | from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 22 | from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora 23 | import pdb 24 | 25 | 26 | sample_idx = 0 27 | scheduler_dict = { 28 | "DDIM": DDIMScheduler, 29 | "Euler": EulerDiscreteScheduler, 30 | "PNDM": PNDMScheduler, 31 | } 32 | 33 | css = """ 34 | .toolbutton { 35 | margin-buttom: 0em 0em 0em 0em; 36 | max-width: 2.5em; 37 | min-width: 2.5em !important; 38 | height: 2.5em; 39 | } 40 | """ 41 | 42 | PRETRAINED_SD = "runwayml/stable-diffusion-v1-5" 43 | 44 | default_motion_module = "v3_sd15_mm.ckpt" 45 | default_inference_config = "configs/inference/inference-v3.yaml" 46 | default_dreambooth_model = "realisticVisionV60B1_v51VAE.safetensors" 47 | default_prompt = "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 48 | default_n_prompt = "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 49 | default_seed = 8893659352891878017 50 | 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | 53 | 54 | class AnimateController: 55 | def __init__(self): 56 | # config dirs 57 | self.basedir = os.getcwd() 58 | self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") 59 | self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") 60 | self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") 61 | self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) 62 | self.savedir_sample = os.path.join(self.savedir, "sample") 63 | os.makedirs(self.savedir, exist_ok=True) 64 | 65 | self.stable_diffusion_list = [PRETRAINED_SD] 66 | self.motion_module_list = MOTION_MODULES 67 | self.personalized_model_list = BACKUP_DREAMBOOTH_MODELS 68 | 69 | # config models 70 | self.pipeline = None 71 | # self.lora_model_state_dict = {} 72 | 73 | self.refresh_stable_diffusion() 74 | self.refresh_personalized_model() 75 | 76 | # default setting 77 | self.update_pipeline( 78 | stable_diffusion_dropdown=PRETRAINED_SD, 79 | motion_module_dropdown=default_motion_module, 80 | base_model_dropdown=default_dreambooth_model, 81 | sampler_dropdown="DDIM", 82 | ) 83 | 84 | def refresh_stable_diffusion(self): 85 | self.stable_diffusion_list = [PRETRAINED_SD] + glob(os.path.join(self.stable_diffusion_dir, "*/")) 86 | 87 | def refresh_personalized_model(self): 88 | personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) 89 | self.personalized_model_list = BACKUP_DREAMBOOTH_MODELS + [os.path.basename(p) for p in personalized_model_list if os.path.basename(p) not in BACKUP_DREAMBOOTH_MODELS] 90 | 91 | # for dropdown update 92 | def update_pipeline( 93 | self, 94 | stable_diffusion_dropdown, 95 | motion_module_dropdown, 96 | base_model_dropdown="", 97 | lora_model_dropdown="none", 98 | lora_alpha_dropdown="0.6", 99 | sampler_dropdown="DDIM", 100 | ): 101 | if "v2" in motion_module_dropdown: 102 | inference_config = "configs/inference/inference-v2.yaml" 103 | elif "v3" in motion_module_dropdown: 104 | inference_config = "configs/inference/inference-v3.yaml" 105 | else: 106 | inference_config = "configs/inference/inference-v1.yaml" 107 | 108 | unet = UNet3DConditionModel.from_pretrained_2d( 109 | stable_diffusion_dropdown, subfolder="unet", 110 | unet_additional_kwargs=OmegaConf.load(inference_config).unet_additional_kwargs 111 | ) 112 | if is_xformers_available() and torch.cuda.is_available(): 113 | unet.enable_xformers_memory_efficient_attention() 114 | 115 | noise_scheduler_cls = scheduler_dict[sampler_dropdown] 116 | noise_scheduler_kwargs = OmegaConf.load(inference_config).noise_scheduler_kwargs 117 | if noise_scheduler_cls == EulerDiscreteScheduler: 118 | noise_scheduler_kwargs.pop("steps_offset") 119 | noise_scheduler_kwargs.pop("clip_sample") 120 | elif noise_scheduler_cls == PNDMScheduler: 121 | noise_scheduler_kwargs.pop("clip_sample") 122 | 123 | pipeline = AnimationPipeline( 124 | unet=unet, 125 | vae=AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae"), 126 | text_encoder=CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder"), 127 | tokenizer=CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer"), 128 | scheduler=noise_scheduler_cls(**noise_scheduler_kwargs), 129 | ) 130 | 131 | pipeline = load_weights( 132 | pipeline, 133 | motion_module_path=os.path.join(self.motion_module_dir, motion_module_dropdown), 134 | dreambooth_model_path=os.path.join(self.personalized_model_dir, base_model_dropdown) if base_model_dropdown != "" else "", 135 | lora_model_path=os.path.join(self.personalized_model_dir, lora_model_dropdown) if lora_model_dropdown != "none" else "", 136 | lora_alpha=float(lora_alpha_dropdown), 137 | ) 138 | 139 | pipeline.to(device) 140 | self.pipeline = pipeline 141 | print("done.") 142 | 143 | return gr.Dropdown.update() 144 | 145 | def update_pipeline_alpha( 146 | self, 147 | stable_diffusion_dropdown, 148 | motion_module_dropdown, 149 | base_model_dropdown="", 150 | lora_model_dropdown="none", 151 | lora_alpha_dropdown="0.6", 152 | sampler_dropdown="DDIM", 153 | ): 154 | if lora_model_dropdown == "none": 155 | return gr.Slider.update() 156 | 157 | self.update_pipeline( 158 | stable_diffusion_dropdown=stable_diffusion_dropdown, 159 | motion_module_dropdown=motion_module_dropdown, 160 | base_model_dropdown=base_model_dropdown, 161 | lora_model_dropdown=lora_model_dropdown, 162 | lora_alpha_dropdown=lora_alpha_dropdown, 163 | sampler_dropdown=sampler_dropdown, 164 | ) 165 | 166 | return gr.Slider.update() 167 | 168 | 169 | @torch.no_grad() 170 | def animate( 171 | self, 172 | prompt_textbox, 173 | negative_prompt_textbox, 174 | sampler_dropdown, 175 | sample_step_slider, 176 | width_slider, 177 | length_slider, 178 | height_slider, 179 | cfg_scale_slider, 180 | seed_textbox, 181 | ): 182 | if int(seed_textbox) != -1: 183 | torch.manual_seed(int(seed_textbox)) 184 | else: 185 | torch.seed() 186 | seed = torch.initial_seed() 187 | 188 | sample = self.pipeline( 189 | prompt_textbox, 190 | negative_prompt = negative_prompt_textbox, 191 | num_inference_steps = sample_step_slider, 192 | guidance_scale = cfg_scale_slider, 193 | width = width_slider, 194 | height = height_slider, 195 | video_length = length_slider, 196 | ).videos 197 | 198 | save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") 199 | save_videos_grid(sample, save_sample_path) 200 | 201 | sample_config = { 202 | "prompt": prompt_textbox, 203 | "n_prompt": negative_prompt_textbox, 204 | "sampler": sampler_dropdown, 205 | "num_inference_steps": sample_step_slider, 206 | "guidance_scale": cfg_scale_slider, 207 | "width": width_slider, 208 | "height": height_slider, 209 | "video_length": length_slider, 210 | "seed": seed 211 | } 212 | 213 | json_str = json.dumps(sample_config, indent=4) 214 | with open(os.path.join(self.savedir, "logs.json"), "a") as f: 215 | f.write(json_str) 216 | f.write("\n\n") 217 | 218 | return gr.Video.update(value=save_sample_path) 219 | 220 | 221 | controller = AnimateController() 222 | 223 | 224 | def ui(): 225 | with gr.Blocks(css=css) as demo: 226 | gr.Markdown( 227 | """ 228 | # AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning 229 | Yuwei Guo, Ceyuan Yang✝, Anyi Rao, Zhengyang Liang, Yaohui Wang, Yu Qiao, Maneesh Agrawala, Dahua Lin, Bo Dai (✝Corresponding Author)
230 | [Paper](https://arxiv.org/abs/2307.04725) | [Webpage](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) 231 | """ 232 | ) 233 | with gr.Column(variant="panel"): 234 | gr.Markdown( 235 | """ 236 | ### 1. Model Checkpoints 237 | """ 238 | ) 239 | with gr.Row(): 240 | stable_diffusion_dropdown = gr.Dropdown( 241 | label="Pretrained Model Path", 242 | choices=controller.stable_diffusion_list, 243 | value=PRETRAINED_SD, 244 | interactive=True, 245 | ) 246 | 247 | with gr.Row(): 248 | motion_module_dropdown = gr.Dropdown( 249 | label="Select motion module", 250 | choices=controller.motion_module_list, 251 | value=default_motion_module, 252 | interactive=True, 253 | ) 254 | 255 | base_model_dropdown = gr.Dropdown( 256 | label="Select base Dreambooth model (required)", 257 | choices=controller.personalized_model_list, 258 | value=default_dreambooth_model, 259 | interactive=True, 260 | ) 261 | 262 | lora_model_dropdown = gr.Dropdown( 263 | label="Select LoRA model (optional)", 264 | choices=["none"] + controller.personalized_model_list, 265 | value="none", 266 | interactive=True, 267 | ) 268 | 269 | lora_alpha_dropdown = gr.Dropdown( 270 | label="LoRA alpha", 271 | choices=["0.", "0.2", "0.4", "0.6", "0.8", "1.0"], 272 | value="0.6", 273 | interactive=True, 274 | ) 275 | 276 | personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") 277 | def update_personalized_model(): 278 | controller.refresh_stable_diffusion() 279 | controller.refresh_personalized_model() 280 | return [ 281 | gr.Dropdown.update(choices=controller.stable_diffusion_list), 282 | gr.Dropdown.update(choices=controller.personalized_model_list), 283 | gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) 284 | ] 285 | personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[stable_diffusion_dropdown, base_model_dropdown, lora_model_dropdown]) 286 | 287 | with gr.Column(variant="panel"): 288 | gr.Markdown( 289 | """ 290 | ### 2. Configs for AnimateDiff. 291 | """ 292 | ) 293 | prompt_textbox = gr.Textbox(label="Prompt", lines=2, value=default_prompt) 294 | negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=default_n_prompt) 295 | 296 | with gr.Row().style(equal_height=False): 297 | with gr.Column(): 298 | with gr.Row(): 299 | sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) 300 | sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) 301 | 302 | width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) 303 | height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) 304 | length_slider = gr.Slider(label="Animation length (default: 16)", value=16, minimum=8, maximum=24, step=1) 305 | cfg_scale_slider = gr.Slider(label="CFG Scale", value=8.0, minimum=0, maximum=20) 306 | 307 | with gr.Row(): 308 | seed_textbox = gr.Textbox(label="Seed (-1 for random seed)", value=default_seed) 309 | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") 310 | seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) 311 | 312 | generate_button = gr.Button(value="Generate", variant='primary') 313 | 314 | result_video = gr.Video(label="Generated Animation", interactive=False) 315 | 316 | # update method 317 | stable_diffusion_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[stable_diffusion_dropdown]) 318 | motion_module_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[motion_module_dropdown]) 319 | base_model_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[base_model_dropdown]) 320 | lora_model_dropdown.change(fn=controller.update_pipeline, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[lora_model_dropdown]) 321 | lora_alpha_dropdown.change(fn=controller.update_pipeline_alpha, inputs=[stable_diffusion_dropdown, motion_module_dropdown, base_model_dropdown, lora_model_dropdown, lora_alpha_dropdown, sampler_dropdown], outputs=[lora_alpha_dropdown]) 322 | 323 | generate_button.click( 324 | fn=controller.animate, 325 | inputs=[ 326 | prompt_textbox, 327 | negative_prompt_textbox, 328 | sampler_dropdown, 329 | sample_step_slider, 330 | width_slider, 331 | length_slider, 332 | height_slider, 333 | cfg_scale_slider, 334 | seed_textbox, 335 | ], 336 | outputs=[result_video] 337 | ) 338 | 339 | return demo 340 | 341 | 342 | if __name__ == "__main__": 343 | demo = ui() 344 | demo.launch(share=True) 345 | -------------------------------------------------------------------------------- /configs/inference/inference-v1.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: false 3 | use_motion_module: true 4 | motion_module_resolutions: [1,2,4,8] 5 | motion_module_mid_block: false 6 | motion_module_type: "Vanilla" 7 | 8 | motion_module_kwargs: 9 | num_attention_heads: 8 10 | num_transformer_block: 1 11 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 12 | temporal_position_encoding: true 13 | temporal_attention_dim_div: 1 14 | zero_initialize: true 15 | 16 | noise_scheduler_kwargs: 17 | beta_start: 0.00085 18 | beta_end: 0.012 19 | beta_schedule: "linear" 20 | steps_offset: 1 21 | clip_sample: false 22 | -------------------------------------------------------------------------------- /configs/inference/inference-v2.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | use_motion_module: true 4 | motion_module_resolutions: [1,2,4,8] 5 | motion_module_mid_block: true 6 | motion_module_type: "Vanilla" 7 | 8 | motion_module_kwargs: 9 | num_attention_heads: 8 10 | num_transformer_block: 1 11 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 12 | temporal_position_encoding: true 13 | temporal_attention_dim_div: 1 14 | zero_initialize: true 15 | 16 | noise_scheduler_kwargs: 17 | beta_start: 0.00085 18 | beta_end: 0.012 19 | beta_schedule: "linear" 20 | steps_offset: 1 21 | clip_sample: false 22 | -------------------------------------------------------------------------------- /configs/inference/inference-v3.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | use_motion_module: true 4 | motion_module_resolutions: [1,2,4,8] 5 | motion_module_mid_block: false 6 | motion_module_type: "Vanilla" 7 | 8 | motion_module_kwargs: 9 | num_attention_heads: 8 10 | num_transformer_block: 1 11 | attention_block_types: [ "Temporal_Self", "Temporal_Self" ] 12 | temporal_position_encoding: true 13 | temporal_attention_dim_div: 1 14 | zero_initialize: true 15 | 16 | noise_scheduler_kwargs: 17 | beta_start: 0.00085 18 | beta_end: 0.012 19 | beta_schedule: "linear" 20 | steps_offset: 1 21 | clip_sample: false 22 | -------------------------------------------------------------------------------- /configs/inference/sparsectrl/image_condition.yaml: -------------------------------------------------------------------------------- 1 | controlnet_additional_kwargs: 2 | set_noisy_sample_input_to_zero: true 3 | use_simplified_condition_embedding: false 4 | conditioning_channels: 3 5 | 6 | use_motion_module: true 7 | motion_module_resolutions: [1,2,4,8] 8 | motion_module_mid_block: false 9 | motion_module_type: "Vanilla" 10 | 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 32 17 | temporal_attention_dim_div: 1 18 | -------------------------------------------------------------------------------- /configs/inference/sparsectrl/latent_condition.yaml: -------------------------------------------------------------------------------- 1 | controlnet_additional_kwargs: 2 | set_noisy_sample_input_to_zero: true 3 | use_simplified_condition_embedding: true 4 | conditioning_channels: 4 5 | 6 | use_motion_module: true 7 | motion_module_resolutions: [1,2,4,8] 8 | motion_module_mid_block: false 9 | motion_module_type: "Vanilla" 10 | 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 32 17 | temporal_attention_dim_div: 1 18 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_1_animate_RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | # motion module v3 2 | - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 3 | lora_model_path: "" 4 | 5 | inference_config: "configs/inference/inference-v3.yaml" 6 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 7 | 8 | seed: [8893659352891878017, 9317678091797131699, 43242532350557906, 4162228652802886667] 9 | steps: 25 10 | guidance_scale: 8 11 | 12 | prompt: 13 | - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 14 | - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" 15 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 16 | - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" 17 | 18 | n_prompt: 19 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 20 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 21 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 22 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 23 | 24 | 25 | # motion module v2 26 | - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 27 | lora_model_path: "" 28 | 29 | inference_config: "configs/inference/inference-v2.yaml" 30 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 31 | 32 | seed: [8964153601421814582, 10589116295929063558, 13214918285578813247, 3460258020075528001] 33 | steps: 25 34 | guidance_scale: 8 35 | 36 | prompt: 37 | - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 38 | - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" 39 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 40 | - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" 41 | 42 | n_prompt: 43 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 44 | - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 45 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 46 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 47 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_2_animate_FilmVelvia.yaml: -------------------------------------------------------------------------------- 1 | # motion module v1_14 2 | - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors" 3 | lora_model_path: "models/DreamBooth_LoRA/leosamsFilmgirlUltra_velvia20Lora.safetensors" 4 | lora_alpha: 0.6 5 | 6 | inference_config: "configs/inference/inference-v3.yaml" 7 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 8 | 9 | seed: [5726977427157971918, 18368660165286593270, 9350384325017735240, 2097615141377450078] 10 | steps: 25 11 | guidance_scale: 8 12 | 13 | prompt: 14 | - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name" 15 | - "dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir" 16 | - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark" 17 | - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " 18 | 19 | n_prompt: 20 | - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 21 | - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 22 | - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 23 | - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 24 | 25 | 26 | # motion module v1_15 27 | - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors" 28 | lora_model_path: "models/DreamBooth_LoRA/leosamsFilmgirlUltra_velvia20Lora.safetensors" 29 | lora_alpha: 0.6 30 | 31 | inference_config: "configs/inference/inference-v2.yaml" 32 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 33 | 34 | seed: [2802659149552239028, 12507673598434739425, 1350017671114249824, 2813556755112853775] 35 | steps: 25 36 | guidance_scale: 8 37 | 38 | prompt: 39 | - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name" 40 | - ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir" 41 | - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark" 42 | - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " 43 | 44 | n_prompt: 45 | - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 46 | - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 47 | - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 48 | - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" 49 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_3_animate_ToonYou.yaml: -------------------------------------------------------------------------------- 1 | # motion module v3 2 | - dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" 3 | lora_model_path: "" 4 | 5 | inference_config: "configs/inference/inference-v3.yaml" 6 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 7 | 8 | seed: [12192490710448890259, 12238800062118732365, 13226337751639812613, 16431231374396590344] 9 | steps: 25 10 | guidance_scale: 7.5 11 | 12 | prompt: 13 | - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress" 14 | - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes," 15 | - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern" 16 | - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle," 17 | 18 | n_prompt: 19 | - "worst quality, low quality, letterboxed" 20 | 21 | 22 | # motion module v2 23 | - dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" 24 | lora_model_path: "" 25 | 26 | inference_config: "configs/inference/inference-v2.yaml" 27 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 28 | 29 | seed: [2362336635702964940, 8149279371559927917, 1487371078234460867, 17554906328875363976] 30 | steps: 25 31 | guidance_scale: 7.5 32 | 33 | prompt: 34 | - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress" 35 | - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes," 36 | - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern" 37 | - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle," 38 | 39 | n_prompt: 40 | - "worst quality, low quality, letterboxed" 41 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_4_animate_MajicMix.yaml: -------------------------------------------------------------------------------- 1 | # motion module v1_14 2 | - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors" 3 | lora_model_path: "" 4 | 5 | inference_config: "configs/inference/inference-v3.yaml" 6 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 7 | 8 | seed: [11413213594134208212, 11357183503136546592, 7315638361411279346, 10191753182015596097] 9 | steps: 25 10 | guidance_scale: 8 11 | 12 | prompt: 13 | - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic" 14 | - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting" 15 | - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below" 16 | - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" 17 | 18 | n_prompt: 19 | - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles" 20 | - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" 21 | - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" 22 | - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" 23 | 24 | 25 | # motion module v1_15 26 | - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors" 27 | lora_model_path: "" 28 | 29 | inference_config: "configs/inference/inference-v2.yaml" 30 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 31 | 32 | seed: [3364626746360550707, 10635741750919791646, 3130334860012077860, 1530101570151479035] 33 | steps: 25 34 | guidance_scale: 8 35 | 36 | prompt: 37 | - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic" 38 | - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting" 39 | - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below" 40 | - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" 41 | 42 | n_prompt: 43 | - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles" 44 | - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" 45 | - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" 46 | - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" 47 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_5_animate_RcnzCartoon.yaml: -------------------------------------------------------------------------------- 1 | # motion module v1_14 2 | - dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 3 | lora_model_path: "" 4 | 5 | inference_config: "configs/inference/inference-v3.yaml" 6 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 7 | 8 | seed: [8085079222822100088, 15493278891844617620, 17384760730172253253, 3896292336733512420] 9 | steps: 25 10 | guidance_scale: 8 11 | 12 | prompt: 13 | - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded" 14 | - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face" 15 | - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes" 16 | - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" 17 | 18 | n_prompt: 19 | - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" 20 | - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular" 21 | - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face," 22 | - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" 23 | 24 | 25 | # motion module v1_15 26 | - dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" 27 | lora_model_path: "" 28 | 29 | inference_config: "configs/inference/inference-v2.yaml" 30 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 31 | 32 | seed: [10087939632512181573, 6440765888009826001, 4292543217695451092, 14003068315619866795] 33 | steps: 25 34 | guidance_scale: 8 35 | 36 | prompt: 37 | - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded" 38 | - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face" 39 | - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes" 40 | - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" 41 | 42 | n_prompt: 43 | - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" 44 | - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular" 45 | - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face," 46 | - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" 47 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_6_animate_Lyriel.yaml: -------------------------------------------------------------------------------- 1 | # motion module v1_14 2 | - dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors" 3 | lora_model_path: "" 4 | 5 | inference_config: "configs/inference/inference-v3.yaml" 6 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 7 | 8 | seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551] 9 | steps: 25 10 | guidance_scale: 8 11 | 12 | prompt: 13 | - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange" 14 | - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal" 15 | - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray" 16 | - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." 17 | 18 | n_prompt: 19 | - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration" 20 | - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular" 21 | - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome" 22 | - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" 23 | 24 | 25 | # motion module v1_15 26 | - dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors" 27 | lora_model_path: "" 28 | 29 | inference_config: "configs/inference/inference-v2.yaml" 30 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 31 | 32 | seed: [9217823730598265840, 10815047877257294769, 15033600051075248739, 3730216622332453211] 33 | steps: 25 34 | guidance_scale: 8 35 | 36 | prompt: 37 | - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange" 38 | - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal" 39 | - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray" 40 | - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." 41 | 42 | n_prompt: 43 | - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration" 44 | - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular" 45 | - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome" 46 | - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" 47 | -------------------------------------------------------------------------------- /configs/prompts/1_animate/1_7_animate_Tusun.yaml: -------------------------------------------------------------------------------- 1 | # motion module v1_14 2 | - dreambooth_path: "models/DreamBooth_LoRA/leosamsHelloworldXL_filmGrain20.safetensors" 3 | lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors" 4 | lora_alpha: 0.6 5 | 6 | inference_config: "configs/inference/inference-v3.yaml" 7 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 8 | 9 | seed: [7107114461349773341, 17169636352587613974, 9844335976427375435, 6372518434592560610] 10 | steps: 25 11 | guidance_scale: 8 12 | 13 | prompt: 14 | - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 15 | - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 16 | - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 17 | - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" 18 | 19 | n_prompt: 20 | - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" 21 | 22 | 23 | # motion module v1_15 24 | - dreambooth_path: "models/DreamBooth_LoRA/leosamsHelloworldXL_filmGrain20.safetensors" 25 | lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors" 26 | lora_alpha: 0.6 27 | 28 | inference_config: "configs/inference/inference-v2.yaml" 29 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 30 | 31 | seed: [8605999221232672724, 110148213803975296, 9191327304973552413, 174075196208604916] 32 | steps: 25 33 | guidance_scale: 8 34 | 35 | prompt: 36 | - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 37 | - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 38 | - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" 39 | - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" 40 | 41 | n_prompt: 42 | - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" 43 | -------------------------------------------------------------------------------- /configs/prompts/2_motionlora/2_motionlora_RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | # ZoomIn 2 | - inference_config: "configs/inference/inference-v2.yaml" 3 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 4 | 5 | motion_module_lora_configs: 6 | - path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt" 7 | alpha: 1.0 8 | 9 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 10 | lora_model_path: "" 11 | 12 | seed: 43242532350557906 13 | steps: 25 14 | guidance_scale: 7.5 15 | 16 | prompt: 17 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 18 | 19 | n_prompt: 20 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 21 | 22 | 23 | # ZoomOut 24 | - inference_config: "configs/inference/inference-v2.yaml" 25 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 26 | 27 | motion_module_lora_configs: 28 | - path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt" 29 | alpha: 1.0 30 | 31 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 32 | lora_model_path: "" 33 | 34 | seed: 43242532350557906 35 | steps: 25 36 | guidance_scale: 7.5 37 | 38 | prompt: 39 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 40 | 41 | n_prompt: 42 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 43 | 44 | 45 | # PanLeft 46 | - inference_config: "configs/inference/inference-v2.yaml" 47 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 48 | 49 | motion_module_lora_configs: 50 | - path: "models/MotionLoRA/v2_lora_PanLeft.ckpt" 51 | alpha: 1.0 52 | 53 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 54 | lora_model_path: "" 55 | 56 | seed: 43242532350557906 57 | steps: 25 58 | guidance_scale: 7.5 59 | 60 | prompt: 61 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 62 | 63 | n_prompt: 64 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 65 | 66 | 67 | # PanRight 68 | - inference_config: "configs/inference/inference-v2.yaml" 69 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 70 | 71 | motion_module_lora_configs: 72 | - path: "models/MotionLoRA/v2_lora_PanRight.ckpt" 73 | alpha: 1.0 74 | 75 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 76 | lora_model_path: "" 77 | 78 | seed: 43242532350557906 79 | steps: 25 80 | guidance_scale: 7.5 81 | 82 | prompt: 83 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 84 | 85 | n_prompt: 86 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 87 | 88 | 89 | # TiltUp 90 | - inference_config: "configs/inference/inference-v2.yaml" 91 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 92 | 93 | motion_module_lora_configs: 94 | - path: "models/MotionLoRA/v2_lora_TiltUp.ckpt" 95 | alpha: 1.0 96 | 97 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 98 | lora_model_path: "" 99 | 100 | seed: 43242532350557906 101 | steps: 25 102 | guidance_scale: 7.5 103 | 104 | prompt: 105 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 106 | 107 | n_prompt: 108 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 109 | 110 | 111 | # TiltDown 112 | - inference_config: "configs/inference/inference-v2.yaml" 113 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 114 | 115 | motion_module_lora_configs: 116 | - path: "models/MotionLoRA/v2_lora_TiltDown.ckpt" 117 | alpha: 1.0 118 | 119 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 120 | lora_model_path: "" 121 | 122 | seed: 43242532350557906 123 | steps: 25 124 | guidance_scale: 7.5 125 | 126 | prompt: 127 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 128 | 129 | n_prompt: 130 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 131 | 132 | 133 | # RollingAnticlockwise 134 | - inference_config: "configs/inference/inference-v2.yaml" 135 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 136 | 137 | motion_module_lora_configs: 138 | - path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt" 139 | alpha: 1.0 140 | 141 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 142 | lora_model_path: "" 143 | 144 | seed: 43242532350557906 145 | steps: 25 146 | guidance_scale: 7.5 147 | 148 | prompt: 149 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 150 | 151 | n_prompt: 152 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 153 | 154 | 155 | # RollingClockwise 156 | - inference_config: "configs/inference/inference-v2.yaml" 157 | motion_module: "models/Motion_Module/mm_sd_v15_v2.ckpt" 158 | 159 | motion_module_lora_configs: 160 | - path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt" 161 | alpha: 1.0 162 | 163 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 164 | lora_model_path: "" 165 | 166 | seed: 43242532350557906 167 | steps: 25 168 | guidance_scale: 7.5 169 | 170 | prompt: 171 | - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" 172 | 173 | n_prompt: 174 | - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" 175 | -------------------------------------------------------------------------------- /configs/prompts/3_sparsectrl/3_1_sparsectrl_i2v.yaml: -------------------------------------------------------------------------------- 1 | # 1-animation 2 | - adapter_lora_scale: 1.0 3 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 4 | dreambooth_path: "" 5 | 6 | inference_config: "configs/inference/inference-v3.yaml" 7 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 8 | 9 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 10 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 11 | 12 | H: 256 13 | W: 384 14 | seed: [123,234] 15 | steps: 25 16 | guidance_scale: 8.5 17 | 18 | controlnet_image_indexs: [0] 19 | controlnet_images: 20 | - "__assets__/demos/image/painting.png" 21 | 22 | prompt: 23 | - an oil painting of a sailboat in the ocean wave 24 | - an oil painting of a sailboat in the ocean wave 25 | n_prompt: 26 | - "worst quality, low quality, letterboxed" 27 | 28 | 29 | # 2-interpolation 30 | - adapter_lora_scale: 1.0 31 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 32 | dreambooth_path: "" 33 | 34 | inference_config: "configs/inference/inference-v3.yaml" 35 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 36 | 37 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 38 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 39 | 40 | H: 256 41 | W: 384 42 | seed: [123,234] 43 | steps: 25 44 | guidance_scale: 8.5 45 | 46 | controlnet_image_indexs: [0,-1] 47 | controlnet_images: 48 | - "__assets__/demos/image/interpolation_1.png" 49 | - "__assets__/demos/image/interpolation_2.png" 50 | 51 | prompt: 52 | - "aerial view, beautiful forest, autumn, 4k, high quality" 53 | - "aerial view, beautiful forest, autumn, 4k, high quality" 54 | n_prompt: 55 | - "worst quality, low quality, letterboxed" 56 | 57 | 58 | # 3-interpolation 59 | - adapter_lora_scale: 1.0 60 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 61 | dreambooth_path: "" 62 | 63 | inference_config: "configs/inference/inference-v3.yaml" 64 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 65 | 66 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 67 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 68 | 69 | H: 256 70 | W: 384 71 | seed: [123,234] 72 | steps: 25 73 | guidance_scale: 8.5 74 | 75 | controlnet_image_indexs: [0,5,10,15] 76 | controlnet_images: 77 | - "__assets__/demos/image/low_fps_1.png" 78 | - "__assets__/demos/image/low_fps_2.png" 79 | - "__assets__/demos/image/low_fps_3.png" 80 | - "__assets__/demos/image/low_fps_4.png" 81 | 82 | prompt: 83 | - "two people holding hands in a field with wind turbines in the background" 84 | - "two people holding hands in a field with wind turbines in the background" 85 | n_prompt: 86 | - "worst quality, low quality, letterboxed" 87 | 88 | 89 | # 3-prediction 90 | - adapter_lora_scale: 1.0 91 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 92 | dreambooth_path: "" 93 | 94 | inference_config: "configs/inference/inference-v3.yaml" 95 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 96 | 97 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 98 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 99 | 100 | H: 256 101 | W: 384 102 | seed: [123,234] 103 | steps: 25 104 | guidance_scale: 8.5 105 | 106 | controlnet_image_indexs: [0,1,2,3] 107 | controlnet_images: 108 | - "__assets__/demos/image/prediction_1.png" 109 | - "__assets__/demos/image/prediction_2.png" 110 | - "__assets__/demos/image/prediction_3.png" 111 | - "__assets__/demos/image/prediction_4.png" 112 | 113 | prompt: 114 | - "an astronaut is flying in the space, 4k, high resolution" 115 | - "an astronaut is flying in the space, 4k, high resolution" 116 | n_prompt: 117 | - "worst quality, low quality, letterboxed" 118 | -------------------------------------------------------------------------------- /configs/prompts/3_sparsectrl/3_2_sparsectrl_rgb_RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | # animation-1 2 | - adapter_lora_scale: 1.0 3 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 4 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 5 | 6 | inference_config: "configs/inference/inference-v3.yaml" 7 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 8 | 9 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 10 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 11 | 12 | seed: -1 13 | steps: 25 14 | guidance_scale: 8.5 15 | 16 | controlnet_image_indexs: [0] 17 | controlnet_images: 18 | - "__assets__/demos/image/RealisticVision_firework.png" 19 | 20 | prompt: 21 | - "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" 22 | - "closeup face photo of man in black clothes, night city street, bokeh, fireworks in background" 23 | n_prompt: 24 | - "worst quality, low quality, letterboxed" 25 | 26 | 27 | # animation-2 28 | - adapter_lora_scale: 1.0 29 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 30 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 31 | 32 | inference_config: "configs/inference/inference-v3.yaml" 33 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 34 | 35 | controlnet_config: "configs/inference/sparsectrl/latent_condition.yaml" 36 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_rgb.ckpt" 37 | 38 | seed: -1 39 | steps: 25 40 | guidance_scale: 8.5 41 | 42 | controlnet_image_indexs: [0] 43 | controlnet_images: 44 | - "__assets__/demos/image/RealisticVision_sunset.png" 45 | 46 | prompt: 47 | - "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, orange sky, warm lighting, fishing boats, ocean waves, seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, golden hour, coastal landscape, seaside scenery" 48 | - "masterpiece, bestquality, highlydetailed, ultradetailed, sunset, orange sky, warm lighting, fishing boats, ocean waves, seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, golden hour, coastal landscape, seaside scenery" 49 | n_prompt: 50 | - "worst quality, low quality, letterboxed" 51 | -------------------------------------------------------------------------------- /configs/prompts/3_sparsectrl/3_3_sparsectrl_sketch_RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | # 1-sketch-to-video 2 | - adapter_lora_scale: 1.0 3 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 4 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 5 | 6 | inference_config: "configs/inference/inference-v3.yaml" 7 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 8 | 9 | controlnet_config: "configs/inference/sparsectrl/image_condition.yaml" 10 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" 11 | 12 | seed: -1 13 | steps: 25 14 | guidance_scale: 8.5 15 | 16 | controlnet_image_indexs: [0] 17 | controlnet_images: 18 | - "__assets__/demos/scribble/scribble_1.png" 19 | 20 | prompt: 21 | - "a back view of a boy, standing on the ground, looking at the sky, sunlight, masterpieces" 22 | - "a back view of a boy, standing on the ground, looking at the sky, clouds, sunset, orange sky, beautiful sunlight, masterpieces" 23 | n_prompt: 24 | - "worst quality, low quality, letterboxed" 25 | 26 | 27 | # 2-storyboarding 28 | - adapter_lora_scale: 1.0 29 | adapter_lora_path: "models/Motion_Module/v3_sd15_adapter.ckpt" 30 | dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v51VAE.safetensors" 31 | 32 | inference_config: "configs/inference/inference-v3.yaml" 33 | motion_module: "models/Motion_Module/v3_sd15_mm.ckpt" 34 | 35 | controlnet_config: "configs/inference/sparsectrl/image_condition.yaml" 36 | controlnet_path: "models/SparseCtrl/v3_sd15_sparsectrl_scribble.ckpt" 37 | 38 | seed: -1 39 | steps: 25 40 | guidance_scale: 8.5 41 | 42 | controlnet_image_indexs: [0,8,15] 43 | controlnet_images: 44 | - "__assets__/demos/scribble/scribble_2_1.png" 45 | - "__assets__/demos/scribble/scribble_2_2.png" 46 | - "__assets__/demos/scribble/scribble_2_3.png" 47 | 48 | prompt: 49 | - "an aerial view of a modern city, sunlight, day time, masterpiece, high quality" 50 | - "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" 51 | n_prompt: 52 | - "worst quality, low quality, letterboxed" 53 | -------------------------------------------------------------------------------- /configs/training/v1/image_finetune.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: true 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" 5 | 6 | noise_scheduler_kwargs: 7 | num_train_timesteps: 1000 8 | beta_start: 0.00085 9 | beta_end: 0.012 10 | beta_schedule: "scaled_linear" 11 | steps_offset: 1 12 | clip_sample: false 13 | 14 | train_data: 15 | csv_path: "path_to_csv_file" 16 | video_folder: "path_to_video_foler" 17 | sample_size: 256 18 | 19 | validation_data: 20 | prompts: 21 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 22 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 23 | - "Robot dancing in times square." 24 | - "Pacific coast, carmel by the sea ocean and waves." 25 | num_inference_steps: 25 26 | guidance_scale: 8. 27 | 28 | trainable_modules: 29 | - "." 30 | 31 | unet_checkpoint_path: "" 32 | 33 | learning_rate: 1.e-5 34 | train_batch_size: 50 35 | 36 | max_train_epoch: -1 37 | max_train_steps: 100 38 | checkpointing_epochs: -1 39 | checkpointing_steps: 60 40 | 41 | validation_steps: 5000 42 | validation_steps_tuple: [2, 50] 43 | 44 | global_seed: 42 45 | mixed_precision_training: true 46 | enable_xformers_memory_efficient_attention: True 47 | 48 | is_debug: False 49 | -------------------------------------------------------------------------------- /configs/training/v1/training.yaml: -------------------------------------------------------------------------------- 1 | image_finetune: false 2 | 3 | output_dir: "outputs" 4 | pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" 5 | 6 | unet_additional_kwargs: 7 | use_motion_module : true 8 | motion_module_resolutions : [ 1,2,4,8 ] 9 | unet_use_cross_frame_attention : false 10 | unet_use_temporal_attention : false 11 | 12 | motion_module_type: Vanilla 13 | motion_module_kwargs: 14 | num_attention_heads : 8 15 | num_transformer_block : 1 16 | attention_block_types : [ "Temporal_Self", "Temporal_Self" ] 17 | temporal_position_encoding : true 18 | temporal_position_encoding_max_len : 24 19 | temporal_attention_dim_div : 1 20 | zero_initialize : true 21 | 22 | noise_scheduler_kwargs: 23 | num_train_timesteps: 1000 24 | beta_start: 0.00085 25 | beta_end: 0.012 26 | beta_schedule: "linear" 27 | steps_offset: 1 28 | clip_sample: false 29 | 30 | train_data: 31 | csv_path: "path_to_csv_file" 32 | video_folder: "path_to_video_foler" 33 | sample_size: 256 34 | sample_stride: 4 35 | sample_n_frames: 16 36 | 37 | validation_data: 38 | prompts: 39 | - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." 40 | - "A drone view of celebration with Christma tree and fireworks, starry sky - background." 41 | - "Robot dancing in times square." 42 | - "Pacific coast, carmel by the sea ocean and waves." 43 | num_inference_steps: 25 44 | guidance_scale: 8. 45 | 46 | trainable_modules: 47 | - "motion_modules." 48 | 49 | unet_checkpoint_path: "" 50 | 51 | learning_rate: 1.e-4 52 | train_batch_size: 4 53 | 54 | max_train_epoch: -1 55 | max_train_steps: 100 56 | checkpointing_epochs: -1 57 | checkpointing_steps: 60 58 | 59 | validation_steps: 5000 60 | validation_steps_tuple: [2, 50] 61 | 62 | global_seed: 42 63 | mixed_precision_training: true 64 | enable_xformers_memory_efficient_attention: True 65 | 66 | is_debug: False 67 | -------------------------------------------------------------------------------- /models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt -------------------------------------------------------------------------------- /models/MotionLoRA/Put MotionLoRA checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/models/MotionLoRA/Put MotionLoRA checkpoints here.txt -------------------------------------------------------------------------------- /models/Motion_Module/Put motion module checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/models/Motion_Module/Put motion module checkpoints here.txt -------------------------------------------------------------------------------- /models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guoyww/AnimateDiff/e92bd5671ba62c0d774a32951453e328018b7c5b/models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.1 2 | torchvision==0.18.1 3 | diffusers==0.11.1 4 | transformers==4.25.1 5 | xformers==0.0.27 6 | imageio==2.27.0 7 | imageio-ffmpeg==0.4.9 8 | decord==0.6.0 9 | omegaconf==2.3.0 10 | gradio==3.36.1 11 | safetensors 12 | einops 13 | wandb 14 | -------------------------------------------------------------------------------- /scripts/animate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import os 5 | from omegaconf import OmegaConf 6 | 7 | import torch 8 | import torchvision.transforms as transforms 9 | 10 | import diffusers 11 | from diffusers import AutoencoderKL, DDIMScheduler 12 | 13 | from tqdm.auto import tqdm 14 | from transformers import CLIPTextModel, CLIPTokenizer 15 | 16 | from animatediff.models.unet import UNet3DConditionModel 17 | from animatediff.models.sparse_controlnet import SparseControlNetModel 18 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 19 | from animatediff.utils.util import save_videos_grid 20 | from animatediff.utils.util import load_weights, auto_download 21 | from diffusers.utils.import_utils import is_xformers_available 22 | 23 | from einops import rearrange, repeat 24 | 25 | import csv, pdb, glob, math 26 | from pathlib import Path 27 | from PIL import Image 28 | import numpy as np 29 | 30 | 31 | @torch.no_grad() 32 | def main(args): 33 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 34 | func_args = dict(func_args) 35 | 36 | time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 37 | savedir = f"samples/{Path(args.config).stem}-{time_str}" 38 | os.makedirs(savedir) 39 | 40 | config = OmegaConf.load(args.config) 41 | samples = [] 42 | 43 | # create validation pipeline 44 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") 45 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").cuda() 46 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").cuda() 47 | 48 | sample_idx = 0 49 | for model_idx, model_config in enumerate(config): 50 | model_config.W = model_config.get("W", args.W) 51 | model_config.H = model_config.get("H", args.H) 52 | model_config.L = model_config.get("L", args.L) 53 | 54 | inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config)) 55 | unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda() 56 | 57 | # load controlnet model 58 | controlnet = controlnet_images = None 59 | if model_config.get("controlnet_path", "") != "": 60 | assert model_config.get("controlnet_images", "") != "" 61 | assert model_config.get("controlnet_config", "") != "" 62 | 63 | unet.config.num_attention_heads = 8 64 | unet.config.projection_class_embeddings_input_dim = None 65 | 66 | controlnet_config = OmegaConf.load(model_config.controlnet_config) 67 | controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {})) 68 | 69 | auto_download(model_config.controlnet_path, is_dreambooth_lora=False) 70 | print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...") 71 | controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu") 72 | controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict 73 | controlnet_state_dict = {name: param for name, param in controlnet_state_dict.items() if "pos_encoder.pe" not in name} 74 | controlnet_state_dict.pop("animatediff_config", "") 75 | controlnet.load_state_dict(controlnet_state_dict) 76 | controlnet.cuda() 77 | 78 | image_paths = model_config.controlnet_images 79 | if isinstance(image_paths, str): image_paths = [image_paths] 80 | 81 | print(f"controlnet image paths:") 82 | for path in image_paths: print(path) 83 | assert len(image_paths) <= model_config.L 84 | 85 | image_transforms = transforms.Compose([ 86 | transforms.RandomResizedCrop( 87 | (model_config.H, model_config.W), (1.0, 1.0), 88 | ratio=(model_config.W/model_config.H, model_config.W/model_config.H) 89 | ), 90 | transforms.ToTensor(), 91 | ]) 92 | 93 | if model_config.get("normalize_condition_images", False): 94 | def image_norm(image): 95 | image = image.mean(dim=0, keepdim=True).repeat(3,1,1) 96 | image -= image.min() 97 | image /= image.max() 98 | return image 99 | else: image_norm = lambda x: x 100 | 101 | controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths] 102 | 103 | os.makedirs(os.path.join(savedir, "control_images"), exist_ok=True) 104 | for i, image in enumerate(controlnet_images): 105 | Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png") 106 | 107 | controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda() 108 | controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") 109 | 110 | if controlnet.use_simplified_condition_embedding: 111 | num_controlnet_images = controlnet_images.shape[2] 112 | controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") 113 | controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 114 | controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) 115 | 116 | # set xformers 117 | if is_xformers_available() and (not args.without_xformers): 118 | unet.enable_xformers_memory_efficient_attention() 119 | if controlnet is not None: controlnet.enable_xformers_memory_efficient_attention() 120 | 121 | pipeline = AnimationPipeline( 122 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 123 | controlnet=controlnet, 124 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 125 | ).to("cuda") 126 | 127 | pipeline = load_weights( 128 | pipeline, 129 | # motion module 130 | motion_module_path = model_config.get("motion_module", ""), 131 | motion_module_lora_configs = model_config.get("motion_module_lora_configs", []), 132 | # domain adapter 133 | adapter_lora_path = model_config.get("adapter_lora_path", ""), 134 | adapter_lora_scale = model_config.get("adapter_lora_scale", 1.0), 135 | # image layers 136 | dreambooth_model_path = model_config.get("dreambooth_path", ""), 137 | lora_model_path = model_config.get("lora_model_path", ""), 138 | lora_alpha = model_config.get("lora_alpha", 0.8), 139 | ).to("cuda") 140 | 141 | prompts = model_config.prompt 142 | n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt 143 | 144 | random_seeds = model_config.get("seed", [-1]) 145 | random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) 146 | random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds 147 | 148 | config[model_idx].random_seed = [] 149 | for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): 150 | 151 | # manually set random seed for reproduction 152 | if random_seed != -1: torch.manual_seed(random_seed) 153 | else: torch.seed() 154 | config[model_idx].random_seed.append(torch.initial_seed()) 155 | 156 | print(f"current seed: {torch.initial_seed()}") 157 | print(f"sampling {prompt} ...") 158 | sample = pipeline( 159 | prompt, 160 | negative_prompt = n_prompt, 161 | num_inference_steps = model_config.steps, 162 | guidance_scale = model_config.guidance_scale, 163 | width = model_config.W, 164 | height = model_config.H, 165 | video_length = model_config.L, 166 | 167 | controlnet_images = controlnet_images, 168 | controlnet_image_index = model_config.get("controlnet_image_indexs", [0]), 169 | ).videos 170 | samples.append(sample) 171 | 172 | prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) 173 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") 174 | print(f"save to {savedir}/sample/{prompt}.gif") 175 | 176 | sample_idx += 1 177 | 178 | samples = torch.concat(samples) 179 | save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) 180 | 181 | OmegaConf.save(config, f"{savedir}/config.yaml") 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--pretrained-model-path", type=str, default="runwayml/stable-diffusion-v1-5") 187 | parser.add_argument("--inference-config", type=str, default="configs/inference/inference-v1.yaml") 188 | parser.add_argument("--config", type=str, required=True) 189 | 190 | parser.add_argument("--L", type=int, default=16 ) 191 | parser.add_argument("--W", type=int, default=512) 192 | parser.add_argument("--H", type=int, default=512) 193 | 194 | parser.add_argument("--without-xformers", action="store_true") 195 | 196 | args = parser.parse_args() 197 | main(args) 198 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import wandb 4 | import random 5 | import logging 6 | import inspect 7 | import argparse 8 | import datetime 9 | import subprocess 10 | 11 | from pathlib import Path 12 | from tqdm.auto import tqdm 13 | from einops import rearrange 14 | from omegaconf import OmegaConf 15 | from safetensors import safe_open 16 | from typing import Dict, Optional, Tuple 17 | 18 | import torch 19 | import torchvision 20 | import torch.nn.functional as F 21 | import torch.distributed as dist 22 | from torch.optim.swa_utils import AveragedModel 23 | from torch.utils.data.distributed import DistributedSampler 24 | from torch.nn.parallel import DistributedDataParallel as DDP 25 | 26 | import diffusers 27 | from diffusers import AutoencoderKL, DDIMScheduler 28 | from diffusers.models import UNet2DConditionModel 29 | from diffusers.pipelines import StableDiffusionPipeline 30 | from diffusers.optimization import get_scheduler 31 | from diffusers.utils import check_min_version 32 | from diffusers.utils.import_utils import is_xformers_available 33 | 34 | import transformers 35 | from transformers import CLIPTextModel, CLIPTokenizer 36 | 37 | from animatediff.data.dataset import WebVid10M 38 | from animatediff.models.unet import UNet3DConditionModel 39 | from animatediff.pipelines.pipeline_animation import AnimationPipeline 40 | from animatediff.utils.util import save_videos_grid, zero_rank_print 41 | 42 | 43 | 44 | def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): 45 | """Initializes distributed environment.""" 46 | if launcher == 'pytorch': 47 | rank = int(os.environ['RANK']) 48 | num_gpus = torch.cuda.device_count() 49 | local_rank = rank % num_gpus 50 | torch.cuda.set_device(local_rank) 51 | dist.init_process_group(backend=backend, **kwargs) 52 | 53 | elif launcher == 'slurm': 54 | proc_id = int(os.environ['SLURM_PROCID']) 55 | ntasks = int(os.environ['SLURM_NTASKS']) 56 | node_list = os.environ['SLURM_NODELIST'] 57 | num_gpus = torch.cuda.device_count() 58 | local_rank = proc_id % num_gpus 59 | torch.cuda.set_device(local_rank) 60 | addr = subprocess.getoutput( 61 | f'scontrol show hostname {node_list} | head -n1') 62 | os.environ['MASTER_ADDR'] = addr 63 | os.environ['WORLD_SIZE'] = str(ntasks) 64 | os.environ['RANK'] = str(proc_id) 65 | port = os.environ.get('PORT', port) 66 | os.environ['MASTER_PORT'] = str(port) 67 | dist.init_process_group(backend=backend) 68 | zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}") 69 | 70 | else: 71 | raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') 72 | 73 | return local_rank 74 | 75 | 76 | 77 | def main( 78 | image_finetune: bool, 79 | 80 | name: str, 81 | use_wandb: bool, 82 | launcher: str, 83 | 84 | output_dir: str, 85 | pretrained_model_path: str, 86 | 87 | train_data: Dict, 88 | validation_data: Dict, 89 | cfg_random_null_text: bool = True, 90 | cfg_random_null_text_ratio: float = 0.1, 91 | 92 | unet_checkpoint_path: str = "", 93 | unet_additional_kwargs: Dict = {}, 94 | ema_decay: float = 0.9999, 95 | noise_scheduler_kwargs = None, 96 | 97 | max_train_epoch: int = -1, 98 | max_train_steps: int = 100, 99 | validation_steps: int = 100, 100 | validation_steps_tuple: Tuple = (-1,), 101 | 102 | learning_rate: float = 3e-5, 103 | scale_lr: bool = False, 104 | lr_warmup_steps: int = 0, 105 | lr_scheduler: str = "constant", 106 | 107 | trainable_modules: Tuple[str] = (None, ), 108 | num_workers: int = 32, 109 | train_batch_size: int = 1, 110 | adam_beta1: float = 0.9, 111 | adam_beta2: float = 0.999, 112 | adam_weight_decay: float = 1e-2, 113 | adam_epsilon: float = 1e-08, 114 | max_grad_norm: float = 1.0, 115 | gradient_accumulation_steps: int = 1, 116 | gradient_checkpointing: bool = False, 117 | checkpointing_epochs: int = 5, 118 | checkpointing_steps: int = -1, 119 | 120 | mixed_precision_training: bool = True, 121 | enable_xformers_memory_efficient_attention: bool = True, 122 | 123 | global_seed: int = 42, 124 | is_debug: bool = False, 125 | ): 126 | check_min_version("0.10.0.dev0") 127 | 128 | # Initialize distributed training 129 | local_rank = init_dist(launcher=launcher) 130 | global_rank = dist.get_rank() 131 | num_processes = dist.get_world_size() 132 | is_main_process = global_rank == 0 133 | 134 | seed = global_seed + global_rank 135 | torch.manual_seed(seed) 136 | 137 | # Logging folder 138 | folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") 139 | output_dir = os.path.join(output_dir, folder_name) 140 | if is_debug and os.path.exists(output_dir): 141 | os.system(f"rm -rf {output_dir}") 142 | 143 | *_, config = inspect.getargvalues(inspect.currentframe()) 144 | 145 | # Make one log on every process with the configuration for debugging. 146 | logging.basicConfig( 147 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 148 | datefmt="%m/%d/%Y %H:%M:%S", 149 | level=logging.INFO, 150 | ) 151 | 152 | if is_main_process and (not is_debug) and use_wandb: 153 | run = wandb.init(project="animatediff", name=folder_name, config=config) 154 | 155 | # Handle the output folder creation 156 | if is_main_process: 157 | os.makedirs(output_dir, exist_ok=True) 158 | os.makedirs(f"{output_dir}/samples", exist_ok=True) 159 | os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) 160 | os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) 161 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 162 | 163 | # Load scheduler, tokenizer and models. 164 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 165 | 166 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 167 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 168 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 169 | if not image_finetune: 170 | unet = UNet3DConditionModel.from_pretrained_2d( 171 | pretrained_model_path, subfolder="unet", 172 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) 173 | ) 174 | else: 175 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") 176 | 177 | # Load pretrained unet weights 178 | if unet_checkpoint_path != "": 179 | zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") 180 | unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") 181 | if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") 182 | state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path 183 | 184 | m, u = unet.load_state_dict(state_dict, strict=False) 185 | zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") 186 | assert len(u) == 0 187 | 188 | # Freeze vae and text_encoder 189 | vae.requires_grad_(False) 190 | text_encoder.requires_grad_(False) 191 | 192 | # Set unet trainable parameters 193 | unet.requires_grad_(False) 194 | for name, param in unet.named_parameters(): 195 | for trainable_module_name in trainable_modules: 196 | if trainable_module_name in name: 197 | param.requires_grad = True 198 | break 199 | 200 | trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) 201 | optimizer = torch.optim.AdamW( 202 | trainable_params, 203 | lr=learning_rate, 204 | betas=(adam_beta1, adam_beta2), 205 | weight_decay=adam_weight_decay, 206 | eps=adam_epsilon, 207 | ) 208 | 209 | if is_main_process: 210 | zero_rank_print(f"trainable params number: {len(trainable_params)}") 211 | zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") 212 | 213 | # Enable xformers 214 | if enable_xformers_memory_efficient_attention: 215 | if is_xformers_available(): 216 | unet.enable_xformers_memory_efficient_attention() 217 | else: 218 | raise ValueError("xformers is not available. Make sure it is installed correctly") 219 | 220 | # Enable gradient checkpointing 221 | if gradient_checkpointing: 222 | unet.enable_gradient_checkpointing() 223 | 224 | # Move models to GPU 225 | vae.to(local_rank) 226 | text_encoder.to(local_rank) 227 | 228 | # Get the training dataset 229 | train_dataset = WebVid10M(**train_data, is_image=image_finetune) 230 | distributed_sampler = DistributedSampler( 231 | train_dataset, 232 | num_replicas=num_processes, 233 | rank=global_rank, 234 | shuffle=True, 235 | seed=global_seed, 236 | ) 237 | 238 | # DataLoaders creation: 239 | train_dataloader = torch.utils.data.DataLoader( 240 | train_dataset, 241 | batch_size=train_batch_size, 242 | shuffle=False, 243 | sampler=distributed_sampler, 244 | num_workers=num_workers, 245 | pin_memory=True, 246 | drop_last=True, 247 | ) 248 | 249 | # Get the training iteration 250 | if max_train_steps == -1: 251 | assert max_train_epoch != -1 252 | max_train_steps = max_train_epoch * len(train_dataloader) 253 | 254 | if checkpointing_steps == -1: 255 | assert checkpointing_epochs != -1 256 | checkpointing_steps = checkpointing_epochs * len(train_dataloader) 257 | 258 | if scale_lr: 259 | learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes) 260 | 261 | # Scheduler 262 | lr_scheduler = get_scheduler( 263 | lr_scheduler, 264 | optimizer=optimizer, 265 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, 266 | num_training_steps=max_train_steps * gradient_accumulation_steps, 267 | ) 268 | 269 | # Validation pipeline 270 | if not image_finetune: 271 | validation_pipeline = AnimationPipeline( 272 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, 273 | ).to("cuda") 274 | else: 275 | validation_pipeline = StableDiffusionPipeline.from_pretrained( 276 | pretrained_model_path, 277 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None, 278 | ) 279 | validation_pipeline.enable_vae_slicing() 280 | 281 | # DDP warpper 282 | unet.to(local_rank) 283 | unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) 284 | 285 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 286 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 287 | # Afterwards we recalculate our number of training epochs 288 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 289 | 290 | # Train! 291 | total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps 292 | 293 | if is_main_process: 294 | logging.info("***** Running training *****") 295 | logging.info(f" Num examples = {len(train_dataset)}") 296 | logging.info(f" Num Epochs = {num_train_epochs}") 297 | logging.info(f" Instantaneous batch size per device = {train_batch_size}") 298 | logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 299 | logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 300 | logging.info(f" Total optimization steps = {max_train_steps}") 301 | global_step = 0 302 | first_epoch = 0 303 | 304 | # Only show the progress bar once on each machine. 305 | progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) 306 | progress_bar.set_description("Steps") 307 | 308 | # Support mixed-precision training 309 | scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None 310 | 311 | for epoch in range(first_epoch, num_train_epochs): 312 | train_dataloader.sampler.set_epoch(epoch) 313 | unet.train() 314 | 315 | for step, batch in enumerate(train_dataloader): 316 | if cfg_random_null_text: 317 | batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] 318 | 319 | # Data batch sanity check 320 | if epoch == first_epoch and step == 0: 321 | pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] 322 | if not image_finetune: 323 | pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") 324 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): 325 | pixel_value = pixel_value[None, ...] 326 | save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True) 327 | else: 328 | for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): 329 | pixel_value = pixel_value / 2. + 0.5 330 | torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png") 331 | 332 | ### >>>> Training >>>> ### 333 | 334 | # Convert videos to latent space 335 | pixel_values = batch["pixel_values"].to(local_rank) 336 | video_length = pixel_values.shape[1] 337 | with torch.no_grad(): 338 | if not image_finetune: 339 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") 340 | latents = vae.encode(pixel_values).latent_dist 341 | latents = latents.sample() 342 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 343 | else: 344 | latents = vae.encode(pixel_values).latent_dist 345 | latents = latents.sample() 346 | 347 | latents = latents * 0.18215 348 | 349 | # Sample noise that we'll add to the latents 350 | noise = torch.randn_like(latents) 351 | bsz = latents.shape[0] 352 | 353 | # Sample a random timestep for each video 354 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 355 | timesteps = timesteps.long() 356 | 357 | # Add noise to the latents according to the noise magnitude at each timestep 358 | # (this is the forward diffusion process) 359 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 360 | 361 | # Get the text embedding for conditioning 362 | with torch.no_grad(): 363 | prompt_ids = tokenizer( 364 | batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 365 | ).input_ids.to(latents.device) 366 | encoder_hidden_states = text_encoder(prompt_ids)[0] 367 | 368 | # Get the target for loss depending on the prediction type 369 | if noise_scheduler.config.prediction_type == "epsilon": 370 | target = noise 371 | elif noise_scheduler.config.prediction_type == "v_prediction": 372 | raise NotImplementedError 373 | else: 374 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 375 | 376 | # Predict the noise residual and compute loss 377 | # Mixed-precision training 378 | with torch.cuda.amp.autocast(enabled=mixed_precision_training): 379 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 380 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 381 | 382 | optimizer.zero_grad() 383 | 384 | # Backpropagate 385 | if mixed_precision_training: 386 | scaler.scale(loss).backward() 387 | """ >>> gradient clipping >>> """ 388 | scaler.unscale_(optimizer) 389 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) 390 | """ <<< gradient clipping <<< """ 391 | scaler.step(optimizer) 392 | scaler.update() 393 | else: 394 | loss.backward() 395 | """ >>> gradient clipping >>> """ 396 | torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) 397 | """ <<< gradient clipping <<< """ 398 | optimizer.step() 399 | 400 | lr_scheduler.step() 401 | progress_bar.update(1) 402 | global_step += 1 403 | 404 | ### <<<< Training <<<< ### 405 | 406 | # Wandb logging 407 | if is_main_process and (not is_debug) and use_wandb: 408 | wandb.log({"train_loss": loss.item()}, step=global_step) 409 | 410 | # Save checkpoint 411 | if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1): 412 | save_path = os.path.join(output_dir, f"checkpoints") 413 | state_dict = { 414 | "epoch": epoch, 415 | "global_step": global_step, 416 | "state_dict": unet.state_dict(), 417 | } 418 | if step == len(train_dataloader) - 1: 419 | torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt")) 420 | else: 421 | torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt")) 422 | logging.info(f"Saved state to {save_path} (global_step: {global_step})") 423 | 424 | # Periodically validation 425 | if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): 426 | samples = [] 427 | 428 | generator = torch.Generator(device=latents.device) 429 | generator.manual_seed(global_seed) 430 | 431 | height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size 432 | width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size 433 | 434 | prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts 435 | 436 | for idx, prompt in enumerate(prompts): 437 | if not image_finetune: 438 | sample = validation_pipeline( 439 | prompt, 440 | generator = generator, 441 | video_length = train_data.sample_n_frames, 442 | height = height, 443 | width = width, 444 | **validation_data, 445 | ).videos 446 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") 447 | samples.append(sample) 448 | 449 | else: 450 | sample = validation_pipeline( 451 | prompt, 452 | generator = generator, 453 | height = height, 454 | width = width, 455 | num_inference_steps = validation_data.get("num_inference_steps", 25), 456 | guidance_scale = validation_data.get("guidance_scale", 8.), 457 | ).images[0] 458 | sample = torchvision.transforms.functional.to_tensor(sample) 459 | samples.append(sample) 460 | 461 | if not image_finetune: 462 | samples = torch.concat(samples) 463 | save_path = f"{output_dir}/samples/sample-{global_step}.gif" 464 | save_videos_grid(samples, save_path) 465 | 466 | else: 467 | samples = torch.stack(samples) 468 | save_path = f"{output_dir}/samples/sample-{global_step}.png" 469 | torchvision.utils.save_image(samples, save_path, nrow=4) 470 | 471 | logging.info(f"Saved samples to {save_path}") 472 | 473 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 474 | progress_bar.set_postfix(**logs) 475 | 476 | if global_step >= max_train_steps: 477 | break 478 | 479 | dist.destroy_process_group() 480 | 481 | 482 | 483 | if __name__ == "__main__": 484 | parser = argparse.ArgumentParser() 485 | parser.add_argument("--config", type=str, required=True) 486 | parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch") 487 | parser.add_argument("--wandb", action="store_true") 488 | args = parser.parse_args() 489 | 490 | name = Path(args.config).stem 491 | config = OmegaConf.load(args.config) 492 | 493 | main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config) 494 | --------------------------------------------------------------------------------