├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 00.gif ├── 01.gif ├── 02.gif ├── 03.gif ├── 04.gif ├── 05.gif ├── 06.gif ├── 07.gif ├── 08.gif ├── 09.gif ├── 10.gif ├── 11.gif ├── 12.gif ├── 13.gif ├── 72105_388.mp4_00-00.png ├── 72105_388.mp4_00-01.png ├── 72109_125.mp4_00-00.png ├── 72109_125.mp4_00-01.png ├── 72110_255.mp4_00-00.png ├── 72110_255.mp4_00-01.png ├── 74302_1349_frame1.png ├── 74302_1349_frame3.png ├── Japan_v2_1_070321_s3_frame1.png ├── Japan_v2_1_070321_s3_frame3.png ├── Japan_v2_2_062266_s2_frame1.png ├── Japan_v2_2_062266_s2_frame3.png ├── frame0001_05.png ├── frame0001_09.png ├── frame0001_10.png ├── frame0001_11.png ├── frame0016_10.png ├── frame0016_11.png └── logo │ └── logo2.png ├── configs ├── inference_512_v1.0.yaml ├── training_1024_v1.0 │ ├── config.yaml │ └── run.sh └── training_512_v1.0 │ ├── config.yaml │ └── run.sh ├── gradio_app.py ├── lvdm ├── basics.py ├── common.py ├── data │ ├── base.py │ └── webvid.py ├── distributions.py ├── ema.py ├── models │ ├── autoencoder.py │ ├── autoencoder_dualref.py │ ├── ddpm3d.py │ ├── samplers │ │ ├── ddim.py │ │ └── ddim_multiplecond.py │ └── utils_diffusion.py └── modules │ ├── attention.py │ ├── attention_svd.py │ ├── encoders │ ├── condition.py │ └── resampler.py │ ├── networks │ ├── ae_modules.py │ └── openaimodel3d.py │ └── x_transformer.py ├── main ├── callbacks.py ├── trainer.py ├── utils_data.py └── utils_train.py ├── prompts └── 512_interp │ ├── 74906_1462_frame1.png │ ├── 74906_1462_frame3.png │ ├── Japan_v2_2_062266_s2_frame1.png │ ├── Japan_v2_2_062266_s2_frame3.png │ ├── Japan_v2_3_119235_s2_frame1.png │ ├── Japan_v2_3_119235_s2_frame3.png │ └── prompts.txt ├── requirements.txt ├── scripts ├── evaluation │ ├── ddp_wrapper.py │ ├── funcs.py │ └── inference.py ├── gradio │ ├── i2v_test.py │ └── i2v_test_application.py └── run.sh └── utils ├── save_video.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | 7 | checkpoints 8 | results 9 | backup 10 | LOG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright Tencent 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ___***ToonCrafter: Generative Cartoon Interpolation***___ 2 | 3 | 4 |
5 | 6 | 7 |   8 |   9 |
10 |    11 |   12 | 13 | 14 | 15 | _**[Jinbo Xing](https://doubiiu.github.io/), [Hanyuan Liu](https://github.com/hyliu), [Menghan Xia](https://menghanxia.github.io), [Yong Zhang](https://yzhang2016.github.io), [Xintao Wang](https://xinntao.github.io/), [Ying Shan](https://scholar.google.com/citations?hl=en&user=4oXBp9UAAAAJ&view_op=list_works&sortby=pubdate), [Tien-Tsin Wong](https://ttwong12.github.io/myself.html)**_ 16 |

17 | From CUHK and Tencent AI Lab. 18 | 19 | at SIGGRAPH Asia 2024, Journal Track 20 | 21 | 22 |
23 | 24 | ## 🔆 Introduction 25 | 26 | ⚠️ We have not set up any official profit-making projects or web applications. Please be cautious!!! 27 | 28 | 🤗 ToonCrafter can interpolate two cartoon images by leveraging the pre-trained image-to-video diffusion priors. Please check our project page and paper for more information.
29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | ### 1.1 Showcases (512x320) 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 47 | 50 | 53 | 54 | 55 | 56 | 57 | 60 | 63 | 66 | 67 | 68 | 71 | 74 | 77 | 78 | 79 | 82 | 85 | 88 | 89 |
Input starting frameInput ending frameGenerated video
45 | 46 | 48 | 49 | 51 | 52 |
58 | 59 | 61 | 62 | 64 | 65 |
69 | 70 | 72 | 73 | 75 | 76 |
80 | 81 | 83 | 84 | 86 | 87 |
90 | 91 | ### 1.2 Sparse sketch guidance 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 103 | 106 | 109 | 112 | 113 | 114 | 115 | 118 | 121 | 124 | 127 | 128 | 129 | 130 |
Input starting frameInput ending frameInput sketch guidanceGenerated video
101 | 102 | 104 | 105 | 107 | 108 | 110 | 111 |
116 | 117 | 119 | 120 | 122 | 123 | 125 | 126 |
131 | 132 | 133 | ### 2. Applications 134 | #### 2.1 Cartoon Sketch Interpolation (see project page for more details) 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 146 | 149 | 152 | 153 | 154 | 155 | 156 | 159 | 162 | 165 | 166 | 167 |
Input starting frameInput ending frameGenerated video
144 | 145 | 147 | 148 | 150 | 151 |
157 | 158 | 160 | 161 | 163 | 164 |
168 | 169 | 170 | #### 2.2 Reference-based Sketch Colorization 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 182 | 185 | 188 | 189 | 190 | 191 | 192 | 195 | 198 | 201 | 202 | 203 |
Input sketchInput referenceColorization results
180 | 181 | 183 | 184 | 186 | 187 |
193 | 194 | 196 | 197 | 199 | 200 |
204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | ## 📝 Changelog 212 | - [ ] Add sketch control and colorization function. 213 | - __[2024.05.29]__: 🔥🔥 Release code and model weights. 214 | - __[2024.05.28]__: Launch the project page and update the arXiv preprint. 215 |
216 | 217 | 218 | ## 🧰 Models 219 | 220 | |Model|Resolution|GPU Mem. & Inference Time (A100, ddim 50steps)|Checkpoint| 221 | |:---------|:---------|:--------|:--------| 222 | |ToonCrafter_512|320x512| ~24G & 24s (`perframe_ae=True`)|[Hugging Face](https://huggingface.co/Doubiiu/ToonCrafter/blob/main/model.ckpt)| 223 | 224 | We get the feedback from issues that the model may consume about 24G~27G GPU memory in this implementation, but the community has lowered the consumption to ~10GB. 225 | 226 | Currently, our ToonCrafter can support generating videos of up to 16 frames with a resolution of 512x320. The inference time can be reduced by using fewer DDIM steps. 227 | 228 | 229 | 230 | ## ⚙️ Setup 231 | 232 | ### Install Environment via Anaconda (Recommended) 233 | ```bash 234 | conda create -n tooncrafter python=3.8.5 235 | conda activate tooncrafter 236 | pip install -r requirements.txt 237 | ``` 238 | 239 | 240 | ## 💫 Inference 241 | ### 1. Command line 242 | 243 | Download pretrained ToonCrafter_512 and put the `model.ckpt` in `checkpoints/tooncrafter_512_interp_v1/model.ckpt`. 244 | ```bash 245 | sh scripts/run.sh 246 | ``` 247 | 248 | 249 | ### 2. Local Gradio demo 250 | 251 | Download the pretrained model and put it in the corresponding directory according to the previous guidelines. 252 | ```bash 253 | python gradio_app.py 254 | ``` 255 | 256 | 257 | 258 | 259 | 260 | 261 | ## 🤝 Community Support 262 | 1. ComfyUI and pruned models (fp16): [ComfyUI-DynamiCrafterWrapper](https://github.com/kijai/ComfyUI-DynamiCrafterWrapper) (Thanks to [kijai](https://twitter.com/kijaidesign)) 263 | 264 | |Model|Resolution|GPU Mem. |Checkpoint| 265 | |:---------|:---------|:--------|:--------| 266 | |ToonCrafter|512x320|12GB |[Hugging Face](https://huggingface.co/Kijai/DynamiCrafter_pruned/blob/main/tooncrafter_512_interp-fp16.safetensors)| 267 | 268 | 2. ComfyUI. [ComfyUI-ToonCrafter](https://github.com/AIGODLIKE/ComfyUI-ToonCrafter) (Thanks to [Yorha4D](https://github.com/Yorha4D)) 269 | 270 | 3. Colab. [Code](https://github.com/camenduru/ToonCrafter-jupyter) (Thanks to [camenduru](https://github.com/camenduru)), [Code](https://gist.github.com/0smboy/baef995b8f5974f19ac114ec20ac37d5) (Thanks to [0smboy](https://github.com/0smboy)) 271 | 272 | 4. Windows platform support: [ToonCrafter-for-windows](https://github.com/sdbds/ToonCrafter-for-windows) (Thanks to [sdbds](https://github.com/sdbds)) 273 | 274 | 5. Sketch-guidance implementation: [ToonCrafter_with_SketchGuidance](https://github.com/mattyamonaca/ToonCrafter_with_SketchGuidance) (Thanks to [mattyamonaca](https://github.com/mattyamonaca)) 275 | 276 | ## 😉 Citation 277 | Please consider citing our paper if our code is useful: 278 | ```bib 279 | @article{xing2024tooncrafter, 280 | title={Tooncrafter: Generative cartoon interpolation}, 281 | author={Xing, Jinbo and Liu, Hanyuan and Xia, Menghan and Zhang, Yong and Wang, Xintao and Shan, Ying and Wong, Tien-Tsin}, 282 | journal={ACM Transactions on Graphics (TOG)}, 283 | volume={43}, 284 | number={6}, 285 | pages={1--11}, 286 | year={2024} 287 | } 288 | ``` 289 | 290 | 291 | ## 🙏 Acknowledgements 292 | We would like to thank [Xiaoyu](https://engineering.purdue.edu/people/xiaoyu.xiang.1) for providing the [sketch extractor](https://github.com/Mukosame/Anime2Sketch), and [supraxylon](https://github.com/supraxylon) for the Windows batch script. 293 | 294 | 295 | ## 📢 Disclaimer 296 | Calm down. Our framework opens up the era of generative cartoon interpolation, but due to the variaity of generative video prior, the success rate is not guaranteed. 297 | 298 | ⚠️This is an open-source research exploration, instead of commercial products. It can't meet all your expectations. 299 | 300 | This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users. 301 | **** 302 | -------------------------------------------------------------------------------- /assets/00.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/00.gif -------------------------------------------------------------------------------- /assets/01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/01.gif -------------------------------------------------------------------------------- /assets/02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/02.gif -------------------------------------------------------------------------------- /assets/03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/03.gif -------------------------------------------------------------------------------- /assets/04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/04.gif -------------------------------------------------------------------------------- /assets/05.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/05.gif -------------------------------------------------------------------------------- /assets/06.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/06.gif -------------------------------------------------------------------------------- /assets/07.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/07.gif -------------------------------------------------------------------------------- /assets/08.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/08.gif -------------------------------------------------------------------------------- /assets/09.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/09.gif -------------------------------------------------------------------------------- /assets/10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/10.gif -------------------------------------------------------------------------------- /assets/11.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/11.gif -------------------------------------------------------------------------------- /assets/12.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/12.gif -------------------------------------------------------------------------------- /assets/13.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/13.gif -------------------------------------------------------------------------------- /assets/72105_388.mp4_00-00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72105_388.mp4_00-00.png -------------------------------------------------------------------------------- /assets/72105_388.mp4_00-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72105_388.mp4_00-01.png -------------------------------------------------------------------------------- /assets/72109_125.mp4_00-00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72109_125.mp4_00-00.png -------------------------------------------------------------------------------- /assets/72109_125.mp4_00-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72109_125.mp4_00-01.png -------------------------------------------------------------------------------- /assets/72110_255.mp4_00-00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72110_255.mp4_00-00.png -------------------------------------------------------------------------------- /assets/72110_255.mp4_00-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/72110_255.mp4_00-01.png -------------------------------------------------------------------------------- /assets/74302_1349_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/74302_1349_frame1.png -------------------------------------------------------------------------------- /assets/74302_1349_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/74302_1349_frame3.png -------------------------------------------------------------------------------- /assets/Japan_v2_1_070321_s3_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/Japan_v2_1_070321_s3_frame1.png -------------------------------------------------------------------------------- /assets/Japan_v2_1_070321_s3_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/Japan_v2_1_070321_s3_frame3.png -------------------------------------------------------------------------------- /assets/Japan_v2_2_062266_s2_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/Japan_v2_2_062266_s2_frame1.png -------------------------------------------------------------------------------- /assets/Japan_v2_2_062266_s2_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/Japan_v2_2_062266_s2_frame3.png -------------------------------------------------------------------------------- /assets/frame0001_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0001_05.png -------------------------------------------------------------------------------- /assets/frame0001_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0001_09.png -------------------------------------------------------------------------------- /assets/frame0001_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0001_10.png -------------------------------------------------------------------------------- /assets/frame0001_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0001_11.png -------------------------------------------------------------------------------- /assets/frame0016_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0016_10.png -------------------------------------------------------------------------------- /assets/frame0016_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/frame0016_11.png -------------------------------------------------------------------------------- /assets/logo/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/assets/logo/logo2.png -------------------------------------------------------------------------------- /configs/inference_512_v1.0.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 3 | params: 4 | rescale_betas_zero_snr: True 5 | parameterization: "v" 6 | linear_start: 0.00085 7 | linear_end: 0.012 8 | num_timesteps_cond: 1 9 | timesteps: 1000 10 | first_stage_key: video 11 | cond_stage_key: caption 12 | cond_stage_trainable: False 13 | conditioning_key: hybrid 14 | image_size: [40, 64] 15 | channels: 4 16 | scale_by_std: False 17 | scale_factor: 0.18215 18 | use_ema: False 19 | uncond_type: 'empty_seq' 20 | use_dynamic_rescale: true 21 | base_scale: 0.7 22 | fps_condition_type: 'fps' 23 | perframe_ae: True 24 | loop_video: true 25 | unet_config: 26 | target: lvdm.modules.networks.openaimodel3d.UNetModel 27 | params: 28 | in_channels: 8 29 | out_channels: 4 30 | model_channels: 320 31 | attention_resolutions: 32 | - 4 33 | - 2 34 | - 1 35 | num_res_blocks: 2 36 | channel_mult: 37 | - 1 38 | - 2 39 | - 4 40 | - 4 41 | dropout: 0.1 42 | num_head_channels: 64 43 | transformer_depth: 1 44 | context_dim: 1024 45 | use_linear: true 46 | use_checkpoint: True 47 | temporal_conv: True 48 | temporal_attention: True 49 | temporal_selfatt_only: true 50 | use_relative_position: false 51 | use_causal_attention: False 52 | temporal_length: 16 53 | addition_attention: true 54 | image_cross_attention: true 55 | default_fs: 24 56 | fs_condition: true 57 | 58 | first_stage_config: 59 | target: lvdm.models.autoencoder.AutoencoderKL_Dualref 60 | params: 61 | embed_dim: 4 62 | monitor: val/rec_loss 63 | ddconfig: 64 | double_z: True 65 | z_channels: 4 66 | resolution: 256 67 | in_channels: 3 68 | out_ch: 3 69 | ch: 128 70 | ch_mult: 71 | - 1 72 | - 2 73 | - 4 74 | - 4 75 | num_res_blocks: 2 76 | attn_resolutions: [] 77 | dropout: 0.0 78 | lossconfig: 79 | target: torch.nn.Identity 80 | 81 | cond_stage_config: 82 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 83 | params: 84 | freeze: true 85 | layer: "penultimate" 86 | 87 | img_cond_stage_config: 88 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 89 | params: 90 | freeze: true 91 | 92 | image_proj_stage_config: 93 | target: lvdm.modules.encoders.resampler.Resampler 94 | params: 95 | dim: 1024 96 | depth: 4 97 | dim_head: 64 98 | heads: 12 99 | num_queries: 16 100 | embedding_dim: 1280 101 | output_dim: 1024 102 | ff_mult: 4 103 | video_length: 16 104 | -------------------------------------------------------------------------------- /configs/training_1024_v1.0/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_checkpoint: checkpoints/dynamicrafter_1024_v1/model.ckpt 3 | base_learning_rate: 1.0e-05 4 | scale_lr: False 5 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 6 | params: 7 | rescale_betas_zero_snr: True 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.012 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: video 15 | cond_stage_key: caption 16 | cond_stage_trainable: False 17 | image_proj_model_trainable: True 18 | conditioning_key: hybrid 19 | image_size: [72, 128] 20 | channels: 4 21 | scale_by_std: False 22 | scale_factor: 0.18215 23 | use_ema: False 24 | uncond_prob: 0.05 25 | uncond_type: 'empty_seq' 26 | rand_cond_frame: true 27 | use_dynamic_rescale: true 28 | base_scale: 0.3 29 | fps_condition_type: 'fps' 30 | perframe_ae: True 31 | 32 | unet_config: 33 | target: lvdm.modules.networks.openaimodel3d.UNetModel 34 | params: 35 | in_channels: 8 36 | out_channels: 4 37 | model_channels: 320 38 | attention_resolutions: 39 | - 4 40 | - 2 41 | - 1 42 | num_res_blocks: 2 43 | channel_mult: 44 | - 1 45 | - 2 46 | - 4 47 | - 4 48 | dropout: 0.1 49 | num_head_channels: 64 50 | transformer_depth: 1 51 | context_dim: 1024 52 | use_linear: true 53 | use_checkpoint: True 54 | temporal_conv: True 55 | temporal_attention: True 56 | temporal_selfatt_only: true 57 | use_relative_position: false 58 | use_causal_attention: False 59 | temporal_length: 16 60 | addition_attention: true 61 | image_cross_attention: true 62 | default_fs: 10 63 | fs_condition: true 64 | 65 | first_stage_config: 66 | target: lvdm.models.autoencoder.AutoencoderKL 67 | params: 68 | embed_dim: 4 69 | monitor: val/rec_loss 70 | ddconfig: 71 | double_z: True 72 | z_channels: 4 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: 78 | - 1 79 | - 2 80 | - 4 81 | - 4 82 | num_res_blocks: 2 83 | attn_resolutions: [] 84 | dropout: 0.0 85 | lossconfig: 86 | target: torch.nn.Identity 87 | 88 | cond_stage_config: 89 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 90 | params: 91 | freeze: true 92 | layer: "penultimate" 93 | 94 | img_cond_stage_config: 95 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 96 | params: 97 | freeze: true 98 | 99 | image_proj_stage_config: 100 | target: lvdm.modules.encoders.resampler.Resampler 101 | params: 102 | dim: 1024 103 | depth: 4 104 | dim_head: 64 105 | heads: 12 106 | num_queries: 16 107 | embedding_dim: 1280 108 | output_dim: 1024 109 | ff_mult: 4 110 | video_length: 16 111 | 112 | data: 113 | target: utils_data.DataModuleFromConfig 114 | params: 115 | batch_size: 1 116 | num_workers: 12 117 | wrap: false 118 | train: 119 | target: lvdm.data.webvid.WebVid 120 | params: 121 | data_dir: 122 | meta_path: <.csv FILE> 123 | video_length: 16 124 | frame_stride: 6 125 | load_raw_resolution: true 126 | resolution: [576, 1024] 127 | spatial_transform: resize_center_crop 128 | random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above) 129 | 130 | lightning: 131 | precision: 16 132 | # strategy: deepspeed_stage_2 133 | trainer: 134 | benchmark: True 135 | accumulate_grad_batches: 2 136 | max_steps: 100000 137 | # logger 138 | log_every_n_steps: 50 139 | # val 140 | val_check_interval: 0.5 141 | gradient_clip_algorithm: 'norm' 142 | gradient_clip_val: 0.5 143 | callbacks: 144 | model_checkpoint: 145 | target: pytorch_lightning.callbacks.ModelCheckpoint 146 | params: 147 | every_n_train_steps: 9000 #1000 148 | filename: "{epoch}-{step}" 149 | save_weights_only: True 150 | metrics_over_trainsteps_checkpoint: 151 | target: pytorch_lightning.callbacks.ModelCheckpoint 152 | params: 153 | filename: '{epoch}-{step}' 154 | save_weights_only: True 155 | every_n_train_steps: 10000 #20000 # 3s/step*2w= 156 | batch_logger: 157 | target: callbacks.ImageLogger 158 | params: 159 | batch_frequency: 500 160 | to_local: False 161 | max_images: 8 162 | log_images_kwargs: 163 | ddim_steps: 50 164 | unconditional_guidance_scale: 7.5 165 | timestep_spacing: uniform_trailing 166 | guidance_rescale: 0.7 -------------------------------------------------------------------------------- /configs/training_1024_v1.0/run.sh: -------------------------------------------------------------------------------- 1 | # NCCL configuration 2 | # export NCCL_DEBUG=INFO 3 | # export NCCL_IB_DISABLE=0 4 | # export NCCL_IB_GID_INDEX=3 5 | # export NCCL_NET_GDR_LEVEL=3 6 | # export NCCL_TOPO_FILE=/tmp/topo.txt 7 | 8 | # args 9 | name="training_1024_v1.0" 10 | config_file=configs/${name}/config.yaml 11 | 12 | # save root dir for logs, checkpoints, tensorboard record, etc. 13 | save_root="" 14 | 15 | mkdir -p $save_root/$name 16 | 17 | ## run 18 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ 19 | --nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 20 | ./main/trainer.py \ 21 | --base $config_file \ 22 | --train \ 23 | --name $name \ 24 | --logdir $save_root \ 25 | --devices $HOST_GPU_NUM \ 26 | lightning.trainer.num_nodes=1 27 | 28 | ## debugging 29 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \ 30 | # --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 31 | # ./main/trainer.py \ 32 | # --base $config_file \ 33 | # --train \ 34 | # --name $name \ 35 | # --logdir $save_root \ 36 | # --devices 4 \ 37 | # lightning.trainer.num_nodes=1 -------------------------------------------------------------------------------- /configs/training_512_v1.0/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | pretrained_checkpoint: checkpoints/dynamicrafter_512_v1/model.ckpt 3 | base_learning_rate: 1.0e-05 4 | scale_lr: False 5 | target: lvdm.models.ddpm3d.LatentVisualDiffusion 6 | params: 7 | rescale_betas_zero_snr: True 8 | parameterization: "v" 9 | linear_start: 0.00085 10 | linear_end: 0.012 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: video 15 | cond_stage_key: caption 16 | cond_stage_trainable: False 17 | image_proj_model_trainable: True 18 | conditioning_key: hybrid 19 | image_size: [40, 64] 20 | channels: 4 21 | scale_by_std: False 22 | scale_factor: 0.18215 23 | use_ema: False 24 | uncond_prob: 0.05 25 | uncond_type: 'empty_seq' 26 | rand_cond_frame: true 27 | use_dynamic_rescale: true 28 | base_scale: 0.7 29 | fps_condition_type: 'fps' 30 | perframe_ae: True 31 | 32 | unet_config: 33 | target: lvdm.modules.networks.openaimodel3d.UNetModel 34 | params: 35 | in_channels: 8 36 | out_channels: 4 37 | model_channels: 320 38 | attention_resolutions: 39 | - 4 40 | - 2 41 | - 1 42 | num_res_blocks: 2 43 | channel_mult: 44 | - 1 45 | - 2 46 | - 4 47 | - 4 48 | dropout: 0.1 49 | num_head_channels: 64 50 | transformer_depth: 1 51 | context_dim: 1024 52 | use_linear: true 53 | use_checkpoint: True 54 | temporal_conv: True 55 | temporal_attention: True 56 | temporal_selfatt_only: true 57 | use_relative_position: false 58 | use_causal_attention: False 59 | temporal_length: 16 60 | addition_attention: true 61 | image_cross_attention: true 62 | default_fs: 10 63 | fs_condition: true 64 | 65 | first_stage_config: 66 | target: lvdm.models.autoencoder.AutoencoderKL 67 | params: 68 | embed_dim: 4 69 | monitor: val/rec_loss 70 | ddconfig: 71 | double_z: True 72 | z_channels: 4 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: 78 | - 1 79 | - 2 80 | - 4 81 | - 4 82 | num_res_blocks: 2 83 | attn_resolutions: [] 84 | dropout: 0.0 85 | lossconfig: 86 | target: torch.nn.Identity 87 | 88 | cond_stage_config: 89 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder 90 | params: 91 | freeze: true 92 | layer: "penultimate" 93 | 94 | img_cond_stage_config: 95 | target: lvdm.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2 96 | params: 97 | freeze: true 98 | 99 | image_proj_stage_config: 100 | target: lvdm.modules.encoders.resampler.Resampler 101 | params: 102 | dim: 1024 103 | depth: 4 104 | dim_head: 64 105 | heads: 12 106 | num_queries: 16 107 | embedding_dim: 1280 108 | output_dim: 1024 109 | ff_mult: 4 110 | video_length: 16 111 | 112 | data: 113 | target: utils_data.DataModuleFromConfig 114 | params: 115 | batch_size: 2 116 | num_workers: 12 117 | wrap: false 118 | train: 119 | target: lvdm.data.webvid.WebVid 120 | params: 121 | data_dir: 122 | meta_path: <.csv FILE> 123 | video_length: 16 124 | frame_stride: 6 125 | load_raw_resolution: true 126 | resolution: [320, 512] 127 | spatial_transform: resize_center_crop 128 | random_fs: true ## if true, we uniformly sample fs with max_fs=frame_stride (above) 129 | 130 | lightning: 131 | precision: 16 132 | # strategy: deepspeed_stage_2 133 | trainer: 134 | benchmark: True 135 | accumulate_grad_batches: 2 136 | max_steps: 100000 137 | # logger 138 | log_every_n_steps: 50 139 | # val 140 | val_check_interval: 0.5 141 | gradient_clip_algorithm: 'norm' 142 | gradient_clip_val: 0.5 143 | callbacks: 144 | model_checkpoint: 145 | target: pytorch_lightning.callbacks.ModelCheckpoint 146 | params: 147 | every_n_train_steps: 9000 #1000 148 | filename: "{epoch}-{step}" 149 | save_weights_only: True 150 | metrics_over_trainsteps_checkpoint: 151 | target: pytorch_lightning.callbacks.ModelCheckpoint 152 | params: 153 | filename: '{epoch}-{step}' 154 | save_weights_only: True 155 | every_n_train_steps: 10000 #20000 # 3s/step*2w= 156 | batch_logger: 157 | target: callbacks.ImageLogger 158 | params: 159 | batch_frequency: 500 160 | to_local: False 161 | max_images: 8 162 | log_images_kwargs: 163 | ddim_steps: 50 164 | unconditional_guidance_scale: 7.5 165 | timestep_spacing: uniform_trailing 166 | guidance_rescale: 0.7 -------------------------------------------------------------------------------- /configs/training_512_v1.0/run.sh: -------------------------------------------------------------------------------- 1 | # NCCL configuration 2 | # export NCCL_DEBUG=INFO 3 | # export NCCL_IB_DISABLE=0 4 | # export NCCL_IB_GID_INDEX=3 5 | # export NCCL_NET_GDR_LEVEL=3 6 | # export NCCL_TOPO_FILE=/tmp/topo.txt 7 | 8 | # args 9 | name="training_512_v1.0" 10 | config_file=configs/${name}/config.yaml 11 | 12 | # save root dir for logs, checkpoints, tensorboard record, etc. 13 | save_root="" 14 | 15 | mkdir -p $save_root/$name 16 | 17 | ## run 18 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \ 19 | --nproc_per_node=$HOST_GPU_NUM --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 20 | ./main/trainer.py \ 21 | --base $config_file \ 22 | --train \ 23 | --name $name \ 24 | --logdir $save_root \ 25 | --devices $HOST_GPU_NUM \ 26 | lightning.trainer.num_nodes=1 27 | 28 | ## debugging 29 | # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \ 30 | # --nproc_per_node=4 --nnodes=1 --master_addr=127.0.0.1 --master_port=12352 --node_rank=0 \ 31 | # ./main/trainer.py \ 32 | # --base $config_file \ 33 | # --train \ 34 | # --name $name \ 35 | # --logdir $save_root \ 36 | # --devices 4 \ 37 | # lightning.trainer.num_nodes=1 -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import sys 3 | import gradio as gr 4 | from scripts.gradio.i2v_test_application import Image2Video 5 | sys.path.insert(1, os.path.join(sys.path[0], 'lvdm')) 6 | 7 | 8 | i2v_examples_interp_512 = [ 9 | ['prompts/512_interp/74906_1462_frame1.png', 'walking man', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/74906_1462_frame3.png'], 10 | ['prompts/512_interp/Japan_v2_2_062266_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 789, 'prompts/512_interp/Japan_v2_2_062266_s2_frame3.png'], 11 | ['prompts/512_interp/Japan_v2_3_119235_s2_frame1.png', 'an anime scene', 50, 7.5, 1.0, 10, 123, 'prompts/512_interp/Japan_v2_3_119235_s2_frame3.png'], 12 | ] 13 | 14 | 15 | 16 | 17 | def dynamicrafter_demo(result_dir='./tmp/', res=512): 18 | if res == 1024: 19 | resolution = '576_1024' 20 | css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px}""" 21 | elif res == 512: 22 | resolution = '320_512' 23 | css = """#input_img {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px} #input_img2 {max-width: 512px !important} #output_vid {max-width: 512px; max-height: 320px}""" 24 | elif res == 256: 25 | resolution = '256_256' 26 | css = """#input_img {max-width: 256px !important} #output_vid {max-width: 256px; max-height: 256px}""" 27 | else: 28 | raise NotImplementedError(f"Unsupported resolution: {res}") 29 | image2video = Image2Video(result_dir, resolution=resolution) 30 | with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface: 31 | 32 | 33 | 34 | with gr.Tab(label='ToonCrafter_320x512'): 35 | with gr.Column(): 36 | with gr.Row(): 37 | with gr.Column(): 38 | with gr.Row(): 39 | i2v_input_image = gr.Image(label="Input Image1",elem_id="input_img") 40 | with gr.Row(): 41 | i2v_input_text = gr.Text(label='Prompts') 42 | with gr.Row(): 43 | i2v_seed = gr.Slider(label='Random Seed', minimum=0, maximum=50000, step=1, value=123) 44 | i2v_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label='ETA', value=1.0, elem_id="i2v_eta") 45 | i2v_cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.5, elem_id="i2v_cfg_scale") 46 | with gr.Row(): 47 | i2v_steps = gr.Slider(minimum=1, maximum=60, step=1, elem_id="i2v_steps", label="Sampling steps", value=50) 48 | i2v_motion = gr.Slider(minimum=5, maximum=30, step=1, elem_id="i2v_motion", label="FPS", value=10) 49 | i2v_end_btn = gr.Button("Generate") 50 | with gr.Column(): 51 | with gr.Row(): 52 | i2v_input_image2 = gr.Image(label="Input Image2",elem_id="input_img2") 53 | with gr.Row(): 54 | i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True) 55 | 56 | gr.Examples(examples=i2v_examples_interp_512, 57 | inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2], 58 | outputs=[i2v_output_video], 59 | fn = image2video.get_image, 60 | cache_examples=False, 61 | ) 62 | i2v_end_btn.click(inputs=[i2v_input_image, i2v_input_text, i2v_steps, i2v_cfg_scale, i2v_eta, i2v_motion, i2v_seed, i2v_input_image2], 63 | outputs=[i2v_output_video], 64 | fn = image2video.get_image 65 | ) 66 | 67 | 68 | return dynamicrafter_iface 69 | 70 | def get_parser(): 71 | parser = argparse.ArgumentParser() 72 | return parser 73 | 74 | if __name__ == "__main__": 75 | parser = get_parser() 76 | args = parser.parse_args() 77 | 78 | result_dir = os.path.join('./', 'results') 79 | dynamicrafter_iface = dynamicrafter_demo(result_dir) 80 | dynamicrafter_iface.queue(max_size=12) 81 | dynamicrafter_iface.launch(max_threads=1) 82 | # dynamicrafter_iface.launch(server_name='0.0.0.0', server_port=80, max_threads=1) -------------------------------------------------------------------------------- /lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from utils.utils import instantiate_from_config 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | def zero_module(module): 20 | """ 21 | Zero out the parameters of a module and return it. 22 | """ 23 | for p in module.parameters(): 24 | p.detach().zero_() 25 | return module 26 | 27 | def scale_module(module, scale): 28 | """ 29 | Scale the parameters of a module and return it. 30 | """ 31 | for p in module.parameters(): 32 | p.detach().mul_(scale) 33 | return module 34 | 35 | 36 | def conv_nd(dims, *args, **kwargs): 37 | """ 38 | Create a 1D, 2D, or 3D convolution module. 39 | """ 40 | if dims == 1: 41 | return nn.Conv1d(*args, **kwargs) 42 | elif dims == 2: 43 | return nn.Conv2d(*args, **kwargs) 44 | elif dims == 3: 45 | return nn.Conv3d(*args, **kwargs) 46 | raise ValueError(f"unsupported dimensions: {dims}") 47 | 48 | 49 | def linear(*args, **kwargs): 50 | """ 51 | Create a linear module. 52 | """ 53 | return nn.Linear(*args, **kwargs) 54 | 55 | 56 | def avg_pool_nd(dims, *args, **kwargs): 57 | """ 58 | Create a 1D, 2D, or 3D average pooling module. 59 | """ 60 | if dims == 1: 61 | return nn.AvgPool1d(*args, **kwargs) 62 | elif dims == 2: 63 | return nn.AvgPool2d(*args, **kwargs) 64 | elif dims == 3: 65 | return nn.AvgPool3d(*args, **kwargs) 66 | raise ValueError(f"unsupported dimensions: {dims}") 67 | 68 | 69 | def nonlinearity(type='silu'): 70 | if type == 'silu': 71 | return nn.SiLU() 72 | elif type == 'leaky_relu': 73 | return nn.LeakyReLU() 74 | 75 | 76 | class GroupNormSpecific(nn.GroupNorm): 77 | def forward(self, x): 78 | return super().forward(x.float()).type(x.dtype) 79 | 80 | 81 | def normalization(channels, num_groups=32): 82 | """ 83 | Make a standard normalization layer. 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNormSpecific(num_groups, channels) 88 | 89 | 90 | class HybridConditioner(nn.Module): 91 | 92 | def __init__(self, c_concat_config, c_crossattn_config): 93 | super().__init__() 94 | self.concat_conditioner = instantiate_from_config(c_concat_config) 95 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 96 | 97 | def forward(self, c_concat, c_crossattn): 98 | c_concat = self.concat_conditioner(c_concat) 99 | c_crossattn = self.crossattn_conditioner(c_crossattn) 100 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs, use_reentrant=False) 93 | else: 94 | return func(*inputs) -------------------------------------------------------------------------------- /lvdm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /lvdm/data/webvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from tqdm import tqdm 4 | import pandas as pd 5 | from decord import VideoReader, cpu 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | class WebVid(Dataset): 14 | """ 15 | WebVid Dataset. 16 | Assumes webvid data is structured as follows. 17 | Webvid/ 18 | videos/ 19 | 000001_000050/ ($page_dir) 20 | 1.mp4 (videoid.mp4) 21 | ... 22 | 5000.mp4 23 | ... 24 | """ 25 | def __init__(self, 26 | meta_path, 27 | data_dir, 28 | subsample=None, 29 | video_length=16, 30 | resolution=[256, 512], 31 | frame_stride=1, 32 | frame_stride_min=1, 33 | spatial_transform=None, 34 | crop_resolution=None, 35 | fps_max=None, 36 | load_raw_resolution=False, 37 | fixed_fps=None, 38 | random_fs=False, 39 | ): 40 | self.meta_path = meta_path 41 | self.data_dir = data_dir 42 | self.subsample = subsample 43 | self.video_length = video_length 44 | self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution 45 | self.fps_max = fps_max 46 | self.frame_stride = frame_stride 47 | self.frame_stride_min = frame_stride_min 48 | self.fixed_fps = fixed_fps 49 | self.load_raw_resolution = load_raw_resolution 50 | self.random_fs = random_fs 51 | self._load_metadata() 52 | if spatial_transform is not None: 53 | if spatial_transform == "random_crop": 54 | self.spatial_transform = transforms.RandomCrop(crop_resolution) 55 | elif spatial_transform == "center_crop": 56 | self.spatial_transform = transforms.Compose([ 57 | transforms.CenterCrop(resolution), 58 | ]) 59 | elif spatial_transform == "resize_center_crop": 60 | # assert(self.resolution[0] == self.resolution[1]) 61 | self.spatial_transform = transforms.Compose([ 62 | transforms.Resize(min(self.resolution)), 63 | transforms.CenterCrop(self.resolution), 64 | ]) 65 | elif spatial_transform == "resize": 66 | self.spatial_transform = transforms.Resize(self.resolution) 67 | else: 68 | raise NotImplementedError 69 | else: 70 | self.spatial_transform = None 71 | 72 | def _load_metadata(self): 73 | metadata = pd.read_csv(self.meta_path) 74 | print(f'>>> {len(metadata)} data samples loaded.') 75 | if self.subsample is not None: 76 | metadata = metadata.sample(self.subsample, random_state=0) 77 | 78 | metadata['caption'] = metadata['name'] 79 | del metadata['name'] 80 | self.metadata = metadata 81 | self.metadata.dropna(inplace=True) 82 | 83 | def _get_video_path(self, sample): 84 | rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') 85 | full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) 86 | return full_video_fp 87 | 88 | def __getitem__(self, index): 89 | if self.random_fs: 90 | frame_stride = random.randint(self.frame_stride_min, self.frame_stride) 91 | else: 92 | frame_stride = self.frame_stride 93 | 94 | ## get frames until success 95 | while True: 96 | index = index % len(self.metadata) 97 | sample = self.metadata.iloc[index] 98 | video_path = self._get_video_path(sample) 99 | ## video_path should be in the format of "....../WebVid/videos/$page_dir/$videoid.mp4" 100 | caption = sample['caption'] 101 | 102 | try: 103 | if self.load_raw_resolution: 104 | video_reader = VideoReader(video_path, ctx=cpu(0)) 105 | else: 106 | video_reader = VideoReader(video_path, ctx=cpu(0), width=530, height=300) 107 | if len(video_reader) < self.video_length: 108 | print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") 109 | index += 1 110 | continue 111 | else: 112 | pass 113 | except: 114 | index += 1 115 | print(f"Load video failed! path = {video_path}") 116 | continue 117 | 118 | fps_ori = video_reader.get_avg_fps() 119 | if self.fixed_fps is not None: 120 | frame_stride = int(frame_stride * (1.0 * fps_ori / self.fixed_fps)) 121 | 122 | ## to avoid extreme cases when fixed_fps is used 123 | frame_stride = max(frame_stride, 1) 124 | 125 | ## get valid range (adapting case by case) 126 | required_frame_num = frame_stride * (self.video_length-1) + 1 127 | frame_num = len(video_reader) 128 | if frame_num < required_frame_num: 129 | ## drop extra samples if fixed fps is required 130 | if self.fixed_fps is not None and frame_num < required_frame_num * 0.5: 131 | index += 1 132 | continue 133 | else: 134 | frame_stride = frame_num // self.video_length 135 | required_frame_num = frame_stride * (self.video_length-1) + 1 136 | 137 | ## select a random clip 138 | random_range = frame_num - required_frame_num 139 | start_idx = random.randint(0, random_range) if random_range > 0 else 0 140 | 141 | ## calculate frame indices 142 | frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)] 143 | try: 144 | frames = video_reader.get_batch(frame_indices) 145 | break 146 | except: 147 | print(f"Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]") 148 | index += 1 149 | continue 150 | 151 | ## process data 152 | assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' 153 | frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] 154 | 155 | if self.spatial_transform is not None: 156 | frames = self.spatial_transform(frames) 157 | 158 | if self.resolution is not None: 159 | assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' 160 | 161 | ## turn frames tensors to [-1,1] 162 | frames = (frames / 255 - 0.5) * 2 163 | fps_clip = fps_ori // frame_stride 164 | if self.fps_max is not None and fps_clip > self.fps_max: 165 | fps_clip = self.fps_max 166 | 167 | data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} 168 | return data 169 | 170 | def __len__(self): 171 | return len(self.metadata) 172 | 173 | 174 | if __name__== "__main__": 175 | meta_path = "" ## path to the meta file 176 | data_dir = "" ## path to the data directory 177 | save_dir = "" ## path to the save directory 178 | dataset = WebVid(meta_path, 179 | data_dir, 180 | subsample=None, 181 | video_length=16, 182 | resolution=[256,448], 183 | frame_stride=4, 184 | spatial_transform="resize_center_crop", 185 | crop_resolution=None, 186 | fps_max=None, 187 | load_raw_resolution=True 188 | ) 189 | dataloader = DataLoader(dataset, 190 | batch_size=1, 191 | num_workers=0, 192 | shuffle=False) 193 | 194 | 195 | import sys 196 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) 197 | from utils.save_video import tensor_to_mp4 198 | for i, batch in tqdm(enumerate(dataloader), desc="Data Batch"): 199 | video = batch['video'] 200 | name = batch['path'][0].split('videos/')[-1].replace('/','_') 201 | tensor_to_mp4(video, save_dir+'/'+name, fps=8) 202 | 203 | -------------------------------------------------------------------------------- /lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from lvdm.modules.networks.ae_modules import Encoder, Decoder 9 | from lvdm.distributions import DiagonalGaussianDistribution 10 | from utils.utils import instantiate_from_config 11 | 12 | TIMESTEPS=16 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | test=False, 24 | logdir=None, 25 | input_dim=4, 26 | test_args=None, 27 | additional_decode_keys=None, 28 | use_checkpoint=False, 29 | diff_boost_factor=3.0, 30 | ): 31 | super().__init__() 32 | self.image_key = image_key 33 | self.encoder = Encoder(**ddconfig) 34 | self.decoder = Decoder(**ddconfig) 35 | self.loss = instantiate_from_config(lossconfig) 36 | assert ddconfig["double_z"] 37 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 38 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 39 | self.embed_dim = embed_dim 40 | self.input_dim = input_dim 41 | self.test = test 42 | self.test_args = test_args 43 | self.logdir = logdir 44 | if colorize_nlabels is not None: 45 | assert type(colorize_nlabels)==int 46 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 47 | if monitor is not None: 48 | self.monitor = monitor 49 | if ckpt_path is not None: 50 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 51 | if self.test: 52 | self.init_test() 53 | 54 | def init_test(self,): 55 | self.test = True 56 | save_dir = os.path.join(self.logdir, "test") 57 | if 'ckpt' in self.test_args: 58 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 59 | self.root = os.path.join(save_dir, ckpt_name) 60 | else: 61 | self.root = save_dir 62 | if 'test_subdir' in self.test_args: 63 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 64 | 65 | self.root_zs = os.path.join(self.root, "zs") 66 | self.root_dec = os.path.join(self.root, "reconstructions") 67 | self.root_inputs = os.path.join(self.root, "inputs") 68 | os.makedirs(self.root, exist_ok=True) 69 | 70 | if self.test_args.save_z: 71 | os.makedirs(self.root_zs, exist_ok=True) 72 | if self.test_args.save_reconstruction: 73 | os.makedirs(self.root_dec, exist_ok=True) 74 | if self.test_args.save_input: 75 | os.makedirs(self.root_inputs, exist_ok=True) 76 | assert(self.test_args is not None) 77 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) 78 | self.count = 0 79 | self.eval_metrics = {} 80 | self.decodes = [] 81 | self.save_decode_samples = 2048 82 | 83 | def init_from_ckpt(self, path, ignore_keys=list()): 84 | sd = torch.load(path, map_location="cpu") 85 | try: 86 | self._cur_epoch = sd['epoch'] 87 | sd = sd["state_dict"] 88 | except: 89 | self._cur_epoch = 'null' 90 | keys = list(sd.keys()) 91 | for k in keys: 92 | for ik in ignore_keys: 93 | if k.startswith(ik): 94 | print("Deleting key {} from state_dict.".format(k)) 95 | del sd[k] 96 | self.load_state_dict(sd, strict=False) 97 | # self.load_state_dict(sd, strict=True) 98 | print(f"Restored from {path}") 99 | 100 | def encode(self, x, return_hidden_states=False, **kwargs): 101 | if return_hidden_states: 102 | h, hidden = self.encoder(x, return_hidden_states) 103 | moments = self.quant_conv(h) 104 | posterior = DiagonalGaussianDistribution(moments) 105 | return posterior, hidden 106 | else: 107 | h = self.encoder(x) 108 | moments = self.quant_conv(h) 109 | posterior = DiagonalGaussianDistribution(moments) 110 | return posterior 111 | 112 | def decode(self, z, **kwargs): 113 | if len(kwargs) == 0: ## use the original decoder in AutoencoderKL 114 | z = self.post_quant_conv(z) 115 | dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs 116 | return dec 117 | 118 | def forward(self, input, sample_posterior=True, **additional_decode_kwargs): 119 | input_tuple = (input, ) 120 | forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs) 121 | return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint) 122 | 123 | 124 | def _forward(self, input, sample_posterior=True, **additional_decode_kwargs): 125 | posterior = self.encode(input) 126 | if sample_posterior: 127 | z = posterior.sample() 128 | else: 129 | z = posterior.mode() 130 | dec = self.decode(z, **additional_decode_kwargs) 131 | ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) 132 | return dec, posterior 133 | 134 | def get_input(self, batch, k): 135 | x = batch[k] 136 | if x.dim() == 5 and self.input_dim == 4: 137 | b,c,t,h,w = x.shape 138 | self.b = b 139 | self.t = t 140 | x = rearrange(x, 'b c t h w -> (b t) c h w') 141 | 142 | return x 143 | 144 | def training_step(self, batch, batch_idx, optimizer_idx): 145 | inputs = self.get_input(batch, self.image_key) 146 | reconstructions, posterior = self(inputs) 147 | 148 | if optimizer_idx == 0: 149 | # train encoder+decoder+logvar 150 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 151 | last_layer=self.get_last_layer(), split="train") 152 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 153 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 154 | return aeloss 155 | 156 | if optimizer_idx == 1: 157 | # train the discriminator 158 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 159 | last_layer=self.get_last_layer(), split="train") 160 | 161 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 162 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 163 | return discloss 164 | 165 | def validation_step(self, batch, batch_idx): 166 | inputs = self.get_input(batch, self.image_key) 167 | reconstructions, posterior = self(inputs) 168 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 169 | last_layer=self.get_last_layer(), split="val") 170 | 171 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 172 | last_layer=self.get_last_layer(), split="val") 173 | 174 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 175 | self.log_dict(log_dict_ae) 176 | self.log_dict(log_dict_disc) 177 | return self.log_dict 178 | 179 | def configure_optimizers(self): 180 | lr = self.learning_rate 181 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 182 | list(self.decoder.parameters())+ 183 | list(self.quant_conv.parameters())+ 184 | list(self.post_quant_conv.parameters()), 185 | lr=lr, betas=(0.5, 0.9)) 186 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 187 | lr=lr, betas=(0.5, 0.9)) 188 | return [opt_ae, opt_disc], [] 189 | 190 | def get_last_layer(self): 191 | return self.decoder.conv_out.weight 192 | 193 | @torch.no_grad() 194 | def log_images(self, batch, only_inputs=False, **kwargs): 195 | log = dict() 196 | x = self.get_input(batch, self.image_key) 197 | x = x.to(self.device) 198 | if not only_inputs: 199 | xrec, posterior = self(x) 200 | if x.shape[1] > 3: 201 | # colorize with random projection 202 | assert xrec.shape[1] > 3 203 | x = self.to_rgb(x) 204 | xrec = self.to_rgb(xrec) 205 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 206 | log["reconstructions"] = xrec 207 | log["inputs"] = x 208 | return log 209 | 210 | def to_rgb(self, x): 211 | assert self.image_key == "segmentation" 212 | if not hasattr(self, "colorize"): 213 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 214 | x = F.conv2d(x, weight=self.colorize) 215 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 216 | return x 217 | 218 | class IdentityFirstStage(torch.nn.Module): 219 | def __init__(self, *args, vq_interface=False, **kwargs): 220 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 221 | super().__init__() 222 | 223 | def encode(self, x, *args, **kwargs): 224 | return x 225 | 226 | def decode(self, x, *args, **kwargs): 227 | return x 228 | 229 | def quantize(self, x, *args, **kwargs): 230 | if self.vq_interface: 231 | return x, None, [None, None, None] 232 | return x 233 | 234 | def forward(self, x, *args, **kwargs): 235 | return x 236 | 237 | from lvdm.models.autoencoder_dualref import VideoDecoder 238 | class AutoencoderKL_Dualref(AutoencoderKL): 239 | def __init__(self, 240 | ddconfig, 241 | lossconfig, 242 | embed_dim, 243 | ckpt_path=None, 244 | ignore_keys=[], 245 | image_key="image", 246 | colorize_nlabels=None, 247 | monitor=None, 248 | test=False, 249 | logdir=None, 250 | input_dim=4, 251 | test_args=None, 252 | additional_decode_keys=None, 253 | use_checkpoint=False, 254 | diff_boost_factor=3.0, 255 | ): 256 | super().__init__(ddconfig, lossconfig, embed_dim, ckpt_path, ignore_keys, image_key, colorize_nlabels, monitor, test, logdir, input_dim, test_args, additional_decode_keys, use_checkpoint, diff_boost_factor) 257 | self.decoder = VideoDecoder(**ddconfig) 258 | 259 | def _forward(self, input, sample_posterior=True, **additional_decode_kwargs): 260 | posterior, hidden_states = self.encode(input, return_hidden_states=True) 261 | 262 | hidden_states_first_last = [] 263 | ### use only the first and last hidden states 264 | for hid in hidden_states: 265 | hid = rearrange(hid, '(b t) c h w -> b c t h w', t=TIMESTEPS) 266 | hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2) 267 | hidden_states_first_last.append(hid_new) 268 | 269 | if sample_posterior: 270 | z = posterior.sample() 271 | else: 272 | z = posterior.mode() 273 | dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs) 274 | ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) 275 | return dec, posterior -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from lvdm.common import noise_like 6 | from lvdm.common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | **kwargs 87 | ): 88 | 89 | # check condition bs 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | try: 93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 94 | except: 95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 96 | 97 | if cbs != batch_size: 98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 104 | 105 | # make shape 106 | if len(shape) == 3: 107 | C, H, W = shape 108 | size = (batch_size, C, H, W) 109 | elif len(shape) == 4: 110 | C, T, H, W = shape 111 | size = (batch_size, C, T, H, W) 112 | 113 | samples, intermediates = self.ddim_sampling(conditioning, size, 114 | callback=callback, 115 | img_callback=img_callback, 116 | quantize_denoised=quantize_x0, 117 | mask=mask, x0=x0, 118 | ddim_use_original_steps=False, 119 | noise_dropout=noise_dropout, 120 | temperature=temperature, 121 | score_corrector=score_corrector, 122 | corrector_kwargs=corrector_kwargs, 123 | x_T=x_T, 124 | log_every_t=log_every_t, 125 | unconditional_guidance_scale=unconditional_guidance_scale, 126 | unconditional_conditioning=unconditional_conditioning, 127 | verbose=verbose, 128 | precision=precision, 129 | fs=fs, 130 | guidance_rescale=guidance_rescale, 131 | **kwargs) 132 | return samples, intermediates 133 | 134 | @torch.no_grad() 135 | def ddim_sampling(self, cond, shape, 136 | x_T=None, ddim_use_original_steps=False, 137 | callback=None, timesteps=None, quantize_denoised=False, 138 | mask=None, x0=None, img_callback=None, log_every_t=100, 139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 140 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 141 | **kwargs): 142 | device = self.model.betas.device 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | if precision is not None: 149 | if precision == 16: 150 | img = img.to(dtype=torch.float16) 151 | 152 | if timesteps is None: 153 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 154 | elif timesteps is not None and not ddim_use_original_steps: 155 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 156 | timesteps = self.ddim_timesteps[:subset_end] 157 | 158 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 161 | if verbose: 162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 163 | else: 164 | iterator = time_range 165 | 166 | clean_cond = kwargs.pop("clean_cond", False) 167 | 168 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 169 | for i, step in enumerate(iterator): 170 | index = total_steps - i - 1 171 | ts = torch.full((b,), step, device=device, dtype=torch.long) 172 | 173 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 174 | if mask is not None: 175 | assert x0 is not None 176 | if clean_cond: 177 | img_orig = x0 178 | else: 179 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 180 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 181 | 182 | 183 | 184 | 185 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 186 | quantize_denoised=quantize_denoised, temperature=temperature, 187 | noise_dropout=noise_dropout, score_corrector=score_corrector, 188 | corrector_kwargs=corrector_kwargs, 189 | unconditional_guidance_scale=unconditional_guidance_scale, 190 | unconditional_conditioning=unconditional_conditioning, 191 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 192 | **kwargs) 193 | 194 | 195 | img, pred_x0 = outs 196 | if callback: callback(i) 197 | if img_callback: img_callback(pred_x0, i) 198 | 199 | if index % log_every_t == 0 or index == total_steps - 1: 200 | intermediates['x_inter'].append(img) 201 | intermediates['pred_x0'].append(pred_x0) 202 | 203 | return img, intermediates 204 | 205 | @torch.no_grad() 206 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 207 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 208 | unconditional_guidance_scale=1., unconditional_conditioning=None, 209 | uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs): 210 | b, *_, device = *x.shape, x.device 211 | if x.dim() == 5: 212 | is_video = True 213 | else: 214 | is_video = False 215 | 216 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 217 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 218 | else: 219 | ### do_classifier_free_guidance 220 | if isinstance(c, torch.Tensor) or isinstance(c, dict): 221 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 222 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 223 | else: 224 | raise NotImplementedError 225 | 226 | model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond) 227 | 228 | if guidance_rescale > 0.0: 229 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 230 | 231 | if self.model.parameterization == "v": 232 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 233 | else: 234 | e_t = model_output 235 | 236 | if score_corrector is not None: 237 | assert self.model.parameterization == "eps", 'not implemented' 238 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 239 | 240 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 241 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 242 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 243 | # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 244 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 245 | # select parameters corresponding to the currently considered timestep 246 | 247 | if is_video: 248 | size = (b, 1, 1, 1, 1) 249 | else: 250 | size = (b, 1, 1, 1) 251 | a_t = torch.full(size, alphas[index], device=device) 252 | a_prev = torch.full(size, alphas_prev[index], device=device) 253 | sigma_t = torch.full(size, sigmas[index], device=device) 254 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 255 | 256 | # current prediction for x_0 257 | if self.model.parameterization != "v": 258 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 259 | else: 260 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 261 | 262 | if self.model.use_dynamic_rescale: 263 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 264 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 265 | rescale = (prev_scale_t / scale_t) 266 | pred_x0 *= rescale 267 | 268 | if quantize_denoised: 269 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 270 | # direction pointing to x_t 271 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 272 | 273 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 274 | if noise_dropout > 0.: 275 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 276 | 277 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 278 | 279 | return x_prev, pred_x0 280 | 281 | @torch.no_grad() 282 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 283 | use_original_steps=False, callback=None): 284 | 285 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 286 | timesteps = timesteps[:t_start] 287 | 288 | time_range = np.flip(timesteps) 289 | total_steps = timesteps.shape[0] 290 | print(f"Running DDIM Sampling with {total_steps} timesteps") 291 | 292 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 293 | x_dec = x_latent 294 | for i, step in enumerate(iterator): 295 | index = total_steps - i - 1 296 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 297 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 298 | unconditional_guidance_scale=unconditional_guidance_scale, 299 | unconditional_conditioning=unconditional_conditioning) 300 | if callback: callback(i) 301 | return x_dec 302 | 303 | @torch.no_grad() 304 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 305 | # fast, but does not allow for exact reconstruction 306 | # t serves as an index to gather the correct alphas 307 | if use_original_steps: 308 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 309 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 310 | else: 311 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 312 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 313 | 314 | if noise is None: 315 | noise = torch.randn_like(x0) 316 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 317 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 318 | -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim_multiplecond.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from lvdm.models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from lvdm.common import noise_like 6 | from lvdm.common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 87 | **kwargs 88 | ): 89 | 90 | # check condition bs 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | try: 94 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 95 | except: 96 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 97 | 98 | if cbs != batch_size: 99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 100 | else: 101 | if conditioning.shape[0] != batch_size: 102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 103 | 104 | # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale) 105 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 106 | 107 | # make shape 108 | if len(shape) == 3: 109 | C, H, W = shape 110 | size = (batch_size, C, H, W) 111 | elif len(shape) == 4: 112 | C, T, H, W = shape 113 | size = (batch_size, C, T, H, W) 114 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 115 | 116 | samples, intermediates = self.ddim_sampling(conditioning, size, 117 | callback=callback, 118 | img_callback=img_callback, 119 | quantize_denoised=quantize_x0, 120 | mask=mask, x0=x0, 121 | ddim_use_original_steps=False, 122 | noise_dropout=noise_dropout, 123 | temperature=temperature, 124 | score_corrector=score_corrector, 125 | corrector_kwargs=corrector_kwargs, 126 | x_T=x_T, 127 | log_every_t=log_every_t, 128 | unconditional_guidance_scale=unconditional_guidance_scale, 129 | unconditional_conditioning=unconditional_conditioning, 130 | verbose=verbose, 131 | precision=precision, 132 | fs=fs, 133 | guidance_rescale=guidance_rescale, 134 | **kwargs) 135 | return samples, intermediates 136 | 137 | @torch.no_grad() 138 | def ddim_sampling(self, cond, shape, 139 | x_T=None, ddim_use_original_steps=False, 140 | callback=None, timesteps=None, quantize_denoised=False, 141 | mask=None, x0=None, img_callback=None, log_every_t=100, 142 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 143 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 144 | **kwargs): 145 | device = self.model.betas.device 146 | b = shape[0] 147 | if x_T is None: 148 | img = torch.randn(shape, device=device) 149 | else: 150 | img = x_T 151 | if precision is not None: 152 | if precision == 16: 153 | img = img.to(dtype=torch.float16) 154 | 155 | 156 | if timesteps is None: 157 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 158 | elif timesteps is not None and not ddim_use_original_steps: 159 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 160 | timesteps = self.ddim_timesteps[:subset_end] 161 | 162 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 163 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 164 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 165 | if verbose: 166 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 167 | else: 168 | iterator = time_range 169 | 170 | clean_cond = kwargs.pop("clean_cond", False) 171 | 172 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 173 | for i, step in enumerate(iterator): 174 | index = total_steps - i - 1 175 | ts = torch.full((b,), step, device=device, dtype=torch.long) 176 | 177 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 178 | if mask is not None: 179 | assert x0 is not None 180 | if clean_cond: 181 | img_orig = x0 182 | else: 183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 184 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 185 | 186 | 187 | 188 | 189 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 190 | quantize_denoised=quantize_denoised, temperature=temperature, 191 | noise_dropout=noise_dropout, score_corrector=score_corrector, 192 | corrector_kwargs=corrector_kwargs, 193 | unconditional_guidance_scale=unconditional_guidance_scale, 194 | unconditional_conditioning=unconditional_conditioning, 195 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 196 | **kwargs) 197 | 198 | 199 | 200 | img, pred_x0 = outs 201 | if callback: callback(i) 202 | if img_callback: img_callback(pred_x0, i) 203 | 204 | if index % log_every_t == 0 or index == total_steps - 1: 205 | intermediates['x_inter'].append(img) 206 | intermediates['pred_x0'].append(pred_x0) 207 | 208 | return img, intermediates 209 | 210 | @torch.no_grad() 211 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 212 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 213 | unconditional_guidance_scale=1., unconditional_conditioning=None, 214 | uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs): 215 | b, *_, device = *x.shape, x.device 216 | if x.dim() == 5: 217 | is_video = True 218 | else: 219 | is_video = False 220 | if cfg_img is None: 221 | cfg_img = unconditional_guidance_scale 222 | 223 | unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext'] 224 | 225 | 226 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 227 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 228 | else: 229 | ### with unconditional condition 230 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 232 | e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs) 233 | # text cfg 234 | model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img) 235 | if guidance_rescale > 0.0: 236 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 237 | 238 | if self.model.parameterization == "v": 239 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 240 | else: 241 | e_t = model_output 242 | 243 | if score_corrector is not None: 244 | assert self.model.parameterization == "eps", 'not implemented' 245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 246 | 247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 250 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 251 | # select parameters corresponding to the currently considered timestep 252 | 253 | if is_video: 254 | size = (b, 1, 1, 1, 1) 255 | else: 256 | size = (b, 1, 1, 1) 257 | a_t = torch.full(size, alphas[index], device=device) 258 | a_prev = torch.full(size, alphas_prev[index], device=device) 259 | sigma_t = torch.full(size, sigmas[index], device=device) 260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 261 | 262 | # current prediction for x_0 263 | if self.model.parameterization != "v": 264 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 265 | else: 266 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 267 | 268 | if self.model.use_dynamic_rescale: 269 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 270 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 271 | rescale = (prev_scale_t / scale_t) 272 | pred_x0 *= rescale 273 | 274 | if quantize_denoised: 275 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 276 | # direction pointing to x_t 277 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 278 | 279 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 280 | if noise_dropout > 0.: 281 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 282 | 283 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 284 | 285 | return x_prev, pred_x0 286 | 287 | @torch.no_grad() 288 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 289 | use_original_steps=False, callback=None): 290 | 291 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 292 | timesteps = timesteps[:t_start] 293 | 294 | time_range = np.flip(timesteps) 295 | total_steps = timesteps.shape[0] 296 | print(f"Running DDIM Sampling with {total_steps} timesteps") 297 | 298 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 299 | x_dec = x_latent 300 | for i, step in enumerate(iterator): 301 | index = total_steps - i - 1 302 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 303 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 304 | unconditional_guidance_scale=unconditional_guidance_scale, 305 | unconditional_conditioning=unconditional_conditioning) 306 | if callback: callback(i) 307 | return x_dec 308 | 309 | @torch.no_grad() 310 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 311 | # fast, but does not allow for exact reconstruction 312 | # t serves as an index to gather the correct alphas 313 | if use_original_steps: 314 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 315 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 316 | else: 317 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 318 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 319 | 320 | if noise is None: 321 | noise = torch.randn_like(x0) 322 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 323 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) -------------------------------------------------------------------------------- /lvdm/models/utils_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import repeat 6 | 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 12 | These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | else: 27 | embedding = repeat(timesteps, 'b -> b d', d=dim) 28 | return embedding 29 | 30 | 31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 32 | if schedule == "linear": 33 | betas = ( 34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 35 | ) 36 | 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 40 | ) 41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 42 | alphas = torch.cos(alphas).pow(2) 43 | alphas = alphas / alphas[0] 44 | betas = 1 - alphas[1:] / alphas[:-1] 45 | betas = np.clip(betas, a_min=0, a_max=0.999) 46 | 47 | elif schedule == "sqrt_linear": 48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | elif schedule == "sqrt": 50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 57 | if ddim_discr_method == 'uniform': 58 | c = num_ddpm_timesteps // num_ddim_timesteps 59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 60 | steps_out = ddim_timesteps + 1 61 | elif ddim_discr_method == 'uniform_trailing': 62 | c = num_ddpm_timesteps / num_ddim_timesteps 63 | ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64) 64 | steps_out = ddim_timesteps - 1 65 | elif ddim_discr_method == 'quad': 66 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 67 | steps_out = ddim_timesteps + 1 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | # steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 87 | if verbose: 88 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 89 | print(f'For the chosen value of eta, which is {eta}, ' 90 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 91 | return sigmas, alphas, alphas_prev 92 | 93 | 94 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 95 | """ 96 | Create a beta schedule that discretizes the given alpha_t_bar function, 97 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 98 | :param num_diffusion_timesteps: the number of betas to produce. 99 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 100 | produces the cumulative product of (1-beta) up to that 101 | part of the diffusion process. 102 | :param max_beta: the maximum beta to use; use values lower than 1 to 103 | prevent singularities. 104 | """ 105 | betas = [] 106 | for i in range(num_diffusion_timesteps): 107 | t1 = i / num_diffusion_timesteps 108 | t2 = (i + 1) / num_diffusion_timesteps 109 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 110 | return np.array(betas) 111 | 112 | def rescale_zero_terminal_snr(betas): 113 | """ 114 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 115 | 116 | Args: 117 | betas (`numpy.ndarray`): 118 | the betas that the scheduler is being initialized with. 119 | 120 | Returns: 121 | `numpy.ndarray`: rescaled betas with zero terminal SNR 122 | """ 123 | # Convert betas to alphas_bar_sqrt 124 | alphas = 1.0 - betas 125 | alphas_cumprod = np.cumprod(alphas, axis=0) 126 | alphas_bar_sqrt = np.sqrt(alphas_cumprod) 127 | 128 | # Store old values. 129 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() 130 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() 131 | 132 | # Shift so the last timestep is zero. 133 | alphas_bar_sqrt -= alphas_bar_sqrt_T 134 | 135 | # Scale so the first timestep is back to the old value. 136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 137 | 138 | # Convert alphas_bar_sqrt to betas 139 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 140 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 141 | alphas = np.concatenate([alphas_bar[0:1], alphas]) 142 | betas = 1 - alphas 143 | 144 | return betas 145 | 146 | 147 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 148 | """ 149 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 150 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 151 | """ 152 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 153 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 154 | # rescale the results from guidance (fixes overexposure) 155 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 156 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 157 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 158 | return noise_cfg -------------------------------------------------------------------------------- /lvdm/modules/encoders/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import open_clip 5 | from torch.utils.checkpoint import checkpoint 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | from lvdm.common import autocast 8 | from utils.utils import count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | def encode(self, x): 21 | return x 22 | 23 | 24 | class ClassEmbedder(nn.Module): 25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 26 | super().__init__() 27 | self.key = key 28 | self.embedding = nn.Embedding(n_classes, embed_dim) 29 | self.n_classes = n_classes 30 | self.ucg_rate = ucg_rate 31 | 32 | def forward(self, batch, key=None, disable_dropout=False): 33 | if key is None: 34 | key = self.key 35 | # this is for use in crossattn 36 | c = batch[key][:, None] 37 | if self.ucg_rate > 0. and not disable_dropout: 38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 39 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 40 | c = c.long() 41 | c = self.embedding(c) 42 | return c 43 | 44 | def get_unconditional_conditioning(self, bs, device="cuda"): 45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 46 | uc = torch.ones((bs,), device=device) * uc_class 47 | uc = {self.key: uc} 48 | return uc 49 | 50 | 51 | def disabled_train(self, mode=True): 52 | """Overwrite model.train with this function to make sure train/eval mode 53 | does not change anymore.""" 54 | return self 55 | 56 | 57 | class FrozenT5Embedder(AbstractEncoder): 58 | """Uses the T5 transformer encoder for text""" 59 | 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 61 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 62 | super().__init__() 63 | self.tokenizer = T5Tokenizer.from_pretrained(version) 64 | self.transformer = T5EncoderModel.from_pretrained(version) 65 | self.device = device 66 | self.max_length = max_length # TODO: typical value? 67 | if freeze: 68 | self.freeze() 69 | 70 | def freeze(self): 71 | self.transformer = self.transformer.eval() 72 | # self.train = disabled_train 73 | for param in self.parameters(): 74 | param.requires_grad = False 75 | 76 | def forward(self, text): 77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 79 | tokens = batch_encoding["input_ids"].to(self.device) 80 | outputs = self.transformer(input_ids=tokens) 81 | 82 | z = outputs.last_hidden_state 83 | return z 84 | 85 | def encode(self, text): 86 | return self(text) 87 | 88 | 89 | class FrozenCLIPEmbedder(AbstractEncoder): 90 | """Uses the CLIP transformer encoder for text (from huggingface)""" 91 | LAYERS = [ 92 | "last", 93 | "pooled", 94 | "hidden" 95 | ] 96 | 97 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 98 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 99 | super().__init__() 100 | assert layer in self.LAYERS 101 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 102 | self.transformer = CLIPTextModel.from_pretrained(version) 103 | self.device = device 104 | self.max_length = max_length 105 | if freeze: 106 | self.freeze() 107 | self.layer = layer 108 | self.layer_idx = layer_idx 109 | if layer == "hidden": 110 | assert layer_idx is not None 111 | assert 0 <= abs(layer_idx) <= 12 112 | 113 | def freeze(self): 114 | self.transformer = self.transformer.eval() 115 | # self.train = disabled_train 116 | for param in self.parameters(): 117 | param.requires_grad = False 118 | 119 | def forward(self, text): 120 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 121 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 122 | tokens = batch_encoding["input_ids"].to(self.device) 123 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 124 | if self.layer == "last": 125 | z = outputs.last_hidden_state 126 | elif self.layer == "pooled": 127 | z = outputs.pooler_output[:, None, :] 128 | else: 129 | z = outputs.hidden_states[self.layer_idx] 130 | return z 131 | 132 | def encode(self, text): 133 | return self(text) 134 | 135 | 136 | class ClipImageEmbedder(nn.Module): 137 | def __init__( 138 | self, 139 | model, 140 | jit=False, 141 | device='cuda' if torch.cuda.is_available() else 'cpu', 142 | antialias=True, 143 | ucg_rate=0. 144 | ): 145 | super().__init__() 146 | from clip import load as load_clip 147 | self.model, _ = load_clip(name=model, device=device, jit=jit) 148 | 149 | self.antialias = antialias 150 | 151 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 152 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 153 | self.ucg_rate = ucg_rate 154 | 155 | def preprocess(self, x): 156 | # normalize to [0,1] 157 | x = kornia.geometry.resize(x, (224, 224), 158 | interpolation='bicubic', align_corners=True, 159 | antialias=self.antialias) 160 | x = (x + 1.) / 2. 161 | # re-normalize according to clip 162 | x = kornia.enhance.normalize(x, self.mean, self.std) 163 | return x 164 | 165 | def forward(self, x, no_dropout=False): 166 | # x is assumed to be in range [-1,1] 167 | out = self.model.encode_image(self.preprocess(x)) 168 | out = out.to(x.dtype) 169 | if self.ucg_rate > 0. and not no_dropout: 170 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 171 | return out 172 | 173 | 174 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 175 | """ 176 | Uses the OpenCLIP transformer encoder for text 177 | """ 178 | LAYERS = [ 179 | # "pooled", 180 | "last", 181 | "penultimate" 182 | ] 183 | 184 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 185 | freeze=True, layer="last"): 186 | super().__init__() 187 | assert layer in self.LAYERS 188 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 189 | del model.visual 190 | self.model = model 191 | 192 | self.device = device 193 | self.max_length = max_length 194 | if freeze: 195 | self.freeze() 196 | self.layer = layer 197 | if self.layer == "last": 198 | self.layer_idx = 0 199 | elif self.layer == "penultimate": 200 | self.layer_idx = 1 201 | else: 202 | raise NotImplementedError() 203 | 204 | def freeze(self): 205 | self.model = self.model.eval() 206 | for param in self.parameters(): 207 | param.requires_grad = False 208 | 209 | def forward(self, text): 210 | tokens = open_clip.tokenize(text) ## all clip models use 77 as context length 211 | z = self.encode_with_transformer(tokens.to(self.device)) 212 | return z 213 | 214 | def encode_with_transformer(self, text): 215 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 216 | x = x + self.model.positional_embedding 217 | x = x.permute(1, 0, 2) # NLD -> LND 218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 219 | x = x.permute(1, 0, 2) # LND -> NLD 220 | x = self.model.ln_final(x) 221 | return x 222 | 223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 224 | for i, r in enumerate(self.model.transformer.resblocks): 225 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 226 | break 227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 228 | x = checkpoint(r, x, attn_mask) 229 | else: 230 | x = r(x, attn_mask=attn_mask) 231 | return x 232 | 233 | def encode(self, text): 234 | return self(text) 235 | 236 | 237 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 238 | """ 239 | Uses the OpenCLIP vision transformer encoder for images 240 | """ 241 | 242 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 243 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 244 | super().__init__() 245 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 246 | pretrained=version, ) 247 | del model.transformer 248 | self.model = model 249 | # self.mapper = torch.nn.Linear(1280, 1024) 250 | self.device = device 251 | self.max_length = max_length 252 | if freeze: 253 | self.freeze() 254 | self.layer = layer 255 | if self.layer == "penultimate": 256 | raise NotImplementedError() 257 | self.layer_idx = 1 258 | 259 | self.antialias = antialias 260 | 261 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 262 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 263 | self.ucg_rate = ucg_rate 264 | 265 | def preprocess(self, x): 266 | # normalize to [0,1] 267 | x = kornia.geometry.resize(x, (224, 224), 268 | interpolation='bicubic', align_corners=True, 269 | antialias=self.antialias) 270 | x = (x + 1.) / 2. 271 | # renormalize according to clip 272 | x = kornia.enhance.normalize(x, self.mean, self.std) 273 | return x 274 | 275 | def freeze(self): 276 | self.model = self.model.eval() 277 | for param in self.model.parameters(): 278 | param.requires_grad = False 279 | 280 | @autocast 281 | def forward(self, image, no_dropout=False): 282 | z = self.encode_with_vision_transformer(image) 283 | if self.ucg_rate > 0. and not no_dropout: 284 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 285 | return z 286 | 287 | def encode_with_vision_transformer(self, img): 288 | img = self.preprocess(img) 289 | x = self.model.visual(img) 290 | return x 291 | 292 | def encode(self, text): 293 | return self(text) 294 | 295 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 296 | """ 297 | Uses the OpenCLIP vision transformer encoder for images 298 | """ 299 | 300 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 301 | freeze=True, layer="pooled", antialias=True): 302 | super().__init__() 303 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 304 | pretrained=version, ) 305 | del model.transformer 306 | self.model = model 307 | self.device = device 308 | 309 | if freeze: 310 | self.freeze() 311 | self.layer = layer 312 | if self.layer == "penultimate": 313 | raise NotImplementedError() 314 | self.layer_idx = 1 315 | 316 | self.antialias = antialias 317 | 318 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 319 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 320 | 321 | 322 | def preprocess(self, x): 323 | # normalize to [0,1] 324 | x = kornia.geometry.resize(x, (224, 224), 325 | interpolation='bicubic', align_corners=True, 326 | antialias=self.antialias) 327 | x = (x + 1.) / 2. 328 | # renormalize according to clip 329 | x = kornia.enhance.normalize(x, self.mean, self.std) 330 | return x 331 | 332 | def freeze(self): 333 | self.model = self.model.eval() 334 | for param in self.model.parameters(): 335 | param.requires_grad = False 336 | 337 | def forward(self, image, no_dropout=False): 338 | ## image: b c h w 339 | z = self.encode_with_vision_transformer(image) 340 | return z 341 | 342 | def encode_with_vision_transformer(self, x): 343 | x = self.preprocess(x) 344 | 345 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 346 | if self.model.visual.input_patchnorm: 347 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 348 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 349 | x = x.permute(0, 2, 4, 1, 3, 5) 350 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 351 | x = self.model.visual.patchnorm_pre_ln(x) 352 | x = self.model.visual.conv1(x) 353 | else: 354 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 355 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 356 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 357 | 358 | # class embeddings and positional embeddings 359 | x = torch.cat( 360 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 361 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 362 | x = x + self.model.visual.positional_embedding.to(x.dtype) 363 | 364 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 365 | x = self.model.visual.patch_dropout(x) 366 | x = self.model.visual.ln_pre(x) 367 | 368 | x = x.permute(1, 0, 2) # NLD -> LND 369 | x = self.model.visual.transformer(x) 370 | x = x.permute(1, 0, 2) # LND -> NLD 371 | 372 | return x 373 | 374 | class FrozenCLIPT5Encoder(AbstractEncoder): 375 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 376 | clip_max_length=77, t5_max_length=77): 377 | super().__init__() 378 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 379 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 380 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 381 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 382 | 383 | def encode(self, text): 384 | return self(text) 385 | 386 | def forward(self, text): 387 | clip_z = self.clip_encoder.encode(text) 388 | t5_z = self.t5_encoder.encode(text) 389 | return [clip_z, t5_z] 390 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /main/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | mainlogger = logging.getLogger('mainlogger') 5 | 6 | import torch 7 | import torchvision 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import Callback 10 | from pytorch_lightning.utilities import rank_zero_only 11 | from pytorch_lightning.utilities import rank_zero_info 12 | from utils.save_video import log_local, prepare_to_log 13 | 14 | 15 | class ImageLogger(Callback): 16 | def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \ 17 | to_local=False, log_images_kwargs=None): 18 | super().__init__() 19 | self.rescale = rescale 20 | self.batch_freq = batch_frequency 21 | self.max_images = max_images 22 | self.to_local = to_local 23 | self.clamp = clamp 24 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 25 | if self.to_local: 26 | ## default save dir 27 | self.save_dir = os.path.join(save_dir, "images") 28 | os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) 29 | os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) 30 | 31 | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=8): 32 | """ log images and videos to tensorboard """ 33 | global_step = pl_module.global_step 34 | for key in batch_logs: 35 | value = batch_logs[key] 36 | tag = "gs%d-%s/%s-%s"%(global_step, split, filename, key) 37 | if isinstance(value, list) and isinstance(value[0], str): 38 | captions = ' |------| '.join(value) 39 | pl_module.logger.experiment.add_text(tag, captions, global_step=global_step) 40 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 41 | video = value 42 | n = video.shape[0] 43 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 44 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, n*h, 1*w] 45 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 46 | grid = (grid + 1.0) / 2.0 47 | grid = grid.unsqueeze(dim=0) 48 | pl_module.logger.experiment.add_video(tag, grid, fps=save_fps, global_step=global_step) 49 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 50 | img = value 51 | grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0) 52 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 53 | pl_module.logger.experiment.add_image(tag, grid, global_step=global_step) 54 | else: 55 | pass 56 | 57 | @rank_zero_only 58 | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): 59 | """ generate images, then save and log to tensorboard """ 60 | skip_freq = self.batch_freq if split == "train" else 5 61 | if (batch_idx+1) % skip_freq == 0: 62 | is_train = pl_module.training 63 | if is_train: 64 | pl_module.eval() 65 | torch.cuda.empty_cache() 66 | with torch.no_grad(): 67 | log_func = pl_module.log_images 68 | batch_logs = log_func(batch, split=split, **self.log_images_kwargs) 69 | 70 | ## process: move to CPU and clamp 71 | batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp) 72 | torch.cuda.empty_cache() 73 | 74 | filename = "ep{}_idx{}_rank{}".format( 75 | pl_module.current_epoch, 76 | batch_idx, 77 | pl_module.global_rank) 78 | if self.to_local: 79 | mainlogger.info("Log [%s] batch <%s> to local ..."%(split, filename)) 80 | filename = "gs{}_".format(pl_module.global_step) + filename 81 | log_local(batch_logs, os.path.join(self.save_dir, split), filename, save_fps=10) 82 | else: 83 | mainlogger.info("Log [%s] batch <%s> to tensorboard ..."%(split, filename)) 84 | self.log_to_tensorboard(pl_module, batch_logs, filename, split, save_fps=10) 85 | mainlogger.info('Finish!') 86 | 87 | if is_train: 88 | pl_module.train() 89 | 90 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 91 | if self.batch_freq != -1 and pl_module.logdir: 92 | self.log_batch_imgs(pl_module, batch, batch_idx, split="train") 93 | 94 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 95 | ## different with validation_step() that saving the whole validation set and only keep the latest, 96 | ## it records the performance of every validation (without overwritten) by only keep a subset 97 | if self.batch_freq != -1 and pl_module.logdir: 98 | self.log_batch_imgs(pl_module, batch, batch_idx, split="val") 99 | if hasattr(pl_module, 'calibrate_grad_norm'): 100 | if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: 101 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 102 | 103 | 104 | class CUDACallback(Callback): 105 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 106 | def on_train_epoch_start(self, trainer, pl_module): 107 | # Reset the memory use counter 108 | # lightning update 109 | if int((pl.__version__).split('.')[1])>=7: 110 | gpu_index = trainer.strategy.root_device.index 111 | else: 112 | gpu_index = trainer.root_gpu 113 | torch.cuda.reset_peak_memory_stats(gpu_index) 114 | torch.cuda.synchronize(gpu_index) 115 | self.start_time = time.time() 116 | 117 | def on_train_epoch_end(self, trainer, pl_module): 118 | if int((pl.__version__).split('.')[1])>=7: 119 | gpu_index = trainer.strategy.root_device.index 120 | else: 121 | gpu_index = trainer.root_gpu 122 | torch.cuda.synchronize(gpu_index) 123 | max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2 ** 20 124 | epoch_time = time.time() - self.start_time 125 | 126 | try: 127 | max_memory = trainer.training_type_plugin.reduce(max_memory) 128 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 129 | 130 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 131 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 132 | except AttributeError: 133 | pass 134 | -------------------------------------------------------------------------------- /main/trainer.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime 2 | from omegaconf import OmegaConf 3 | from transformers import logging as transf_logging 4 | import pytorch_lightning as pl 5 | from pytorch_lightning import seed_everything 6 | from pytorch_lightning.trainer import Trainer 7 | import torch 8 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 9 | from utils.utils import instantiate_from_config 10 | from utils_train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy 11 | from utils_train import set_logger, init_workspace, load_checkpoints 12 | 13 | 14 | def get_parser(**parser_kwargs): 15 | parser = argparse.ArgumentParser(**parser_kwargs) 16 | parser.add_argument("--seed", "-s", type=int, default=20230211, help="seed for seed_everything") 17 | parser.add_argument("--name", "-n", type=str, default="", help="experiment name, as saving folder") 18 | 19 | parser.add_argument("--base", "-b", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " 20 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list()) 21 | 22 | parser.add_argument("--train", "-t", action='store_true', default=False, help='train') 23 | parser.add_argument("--val", "-v", action='store_true', default=False, help='val') 24 | parser.add_argument("--test", action='store_true', default=False, help='test') 25 | 26 | parser.add_argument("--logdir", "-l", type=str, default="logs", help="directory for logging dat shit") 27 | parser.add_argument("--auto_resume", action='store_true', default=False, help="resume from full-info checkpoint") 28 | parser.add_argument("--auto_resume_weight_only", action='store_true', default=False, help="resume from weight-only checkpoint") 29 | parser.add_argument("--debug", "-d", action='store_true', default=False, help="enable post-mortem debugging") 30 | 31 | return parser 32 | 33 | def get_nondefault_trainer_args(args): 34 | parser = argparse.ArgumentParser() 35 | parser = Trainer.add_argparse_args(parser) 36 | default_trainer_args = parser.parse_args([]) 37 | return sorted(k for k in vars(default_trainer_args) if getattr(args, k) != getattr(default_trainer_args, k)) 38 | 39 | 40 | if __name__ == "__main__": 41 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 42 | local_rank = int(os.environ.get('LOCAL_RANK')) 43 | global_rank = int(os.environ.get('RANK')) 44 | num_rank = int(os.environ.get('WORLD_SIZE')) 45 | 46 | parser = get_parser() 47 | ## Extends existing argparse by default Trainer attributes 48 | parser = Trainer.add_argparse_args(parser) 49 | args, unknown = parser.parse_known_args() 50 | ## disable transformer warning 51 | transf_logging.set_verbosity_error() 52 | seed_everything(args.seed) 53 | 54 | ## yaml configs: "model" | "data" | "lightning" 55 | configs = [OmegaConf.load(cfg) for cfg in args.base] 56 | cli = OmegaConf.from_dotlist(unknown) 57 | config = OmegaConf.merge(*configs, cli) 58 | lightning_config = config.pop("lightning", OmegaConf.create()) 59 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 60 | 61 | ## setup workspace directories 62 | workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir, config, lightning_config, global_rank) 63 | logger = set_logger(logfile=os.path.join(loginfo, 'log_%d:%s.txt'%(global_rank, now))) 64 | logger.info("@lightning version: %s [>=1.8 required]"%(pl.__version__)) 65 | 66 | ## MODEL CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 67 | logger.info("***** Configing Model *****") 68 | config.model.params.logdir = workdir 69 | model = instantiate_from_config(config.model) 70 | 71 | ## load checkpoints 72 | model = load_checkpoints(model, config.model) 73 | 74 | ## register_schedule again to make ZTSNR work 75 | if model.rescale_betas_zero_snr: 76 | model.register_schedule(given_betas=model.given_betas, beta_schedule=model.beta_schedule, timesteps=model.timesteps, 77 | linear_start=model.linear_start, linear_end=model.linear_end, cosine_s=model.cosine_s) 78 | 79 | ## update trainer config 80 | for k in get_nondefault_trainer_args(args): 81 | trainer_config[k] = getattr(args, k) 82 | 83 | num_nodes = trainer_config.num_nodes 84 | ngpu_per_node = trainer_config.devices 85 | logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs") 86 | 87 | ## setup learning rate 88 | base_lr = config.model.base_learning_rate 89 | bs = config.data.params.batch_size 90 | if getattr(config.model, 'scale_lr', True): 91 | model.learning_rate = num_rank * bs * base_lr 92 | else: 93 | model.learning_rate = base_lr 94 | 95 | 96 | ## DATA CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 97 | logger.info("***** Configing Data *****") 98 | data = instantiate_from_config(config.data) 99 | data.setup() 100 | for k in data.datasets: 101 | logger.info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") 102 | 103 | 104 | ## TRAINER CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 105 | logger.info("***** Configing Trainer *****") 106 | if "accelerator" not in trainer_config: 107 | trainer_config["accelerator"] = "gpu" 108 | 109 | ## setup trainer args: pl-logger and callbacks 110 | trainer_kwargs = dict() 111 | trainer_kwargs["num_sanity_val_steps"] = 0 112 | logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug) 113 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 114 | 115 | ## setup callbacks 116 | callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir, ckptdir, logger) 117 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] 118 | strategy_cfg = get_trainer_strategy(lightning_config) 119 | trainer_kwargs["strategy"] = strategy_cfg if type(strategy_cfg) == str else instantiate_from_config(strategy_cfg) 120 | trainer_kwargs['precision'] = lightning_config.get('precision', 32) 121 | trainer_kwargs["sync_batchnorm"] = False 122 | 123 | ## trainer config: others 124 | 125 | trainer_args = argparse.Namespace(**trainer_config) 126 | trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs) 127 | 128 | ## allow checkpointing via USR1 129 | def melk(*args, **kwargs): 130 | ## run all checkpoint hooks 131 | if trainer.global_rank == 0: 132 | print("Summoning checkpoint.") 133 | ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt") 134 | trainer.save_checkpoint(ckpt_path) 135 | 136 | def divein(*args, **kwargs): 137 | if trainer.global_rank == 0: 138 | import pudb; 139 | pudb.set_trace() 140 | 141 | import signal 142 | signal.signal(signal.SIGUSR1, melk) 143 | signal.signal(signal.SIGUSR2, divein) 144 | 145 | ## Running LOOP >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 146 | logger.info("***** Running the Loop *****") 147 | if args.train: 148 | try: 149 | if "strategy" in lightning_config and lightning_config['strategy'].startswith('deepspeed'): 150 | logger.info("") 151 | ## deepspeed 152 | if trainer_kwargs['precision'] == 16: 153 | with torch.cuda.amp.autocast(): 154 | trainer.fit(model, data) 155 | else: 156 | trainer.fit(model, data) 157 | else: 158 | logger.info("") ## this is default 159 | ## ddpsharded 160 | trainer.fit(model, data) 161 | except Exception: 162 | #melk() 163 | raise 164 | 165 | # if args.val: 166 | # trainer.validate(model, data) 167 | # if args.test or not trainer.interrupted: 168 | # trainer.test(model, data) -------------------------------------------------------------------------------- /main/utils_data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | import os, sys 9 | os.chdir(sys.path[0]) 10 | sys.path.append("..") 11 | from lvdm.data.base import Txt2ImgIterableBaseDataset 12 | from utils.utils import instantiate_from_config 13 | 14 | 15 | def worker_init_fn(_): 16 | worker_info = torch.utils.data.get_worker_info() 17 | 18 | dataset = worker_info.dataset 19 | worker_id = worker_info.id 20 | 21 | if isinstance(dataset, Txt2ImgIterableBaseDataset): 22 | split_size = dataset.num_records // worker_info.num_workers 23 | # reset num_records to the true number to retain reliable length information 24 | dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] 25 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 26 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 27 | else: 28 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 29 | 30 | 31 | class WrappedDataset(Dataset): 32 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 33 | 34 | def __init__(self, dataset): 35 | self.data = dataset 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, idx): 41 | return self.data[idx] 42 | 43 | 44 | class DataModuleFromConfig(pl.LightningDataModule): 45 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, 46 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 47 | shuffle_val_dataloader=False, train_img=None, 48 | test_max_n_samples=None): 49 | super().__init__() 50 | self.batch_size = batch_size 51 | self.dataset_configs = dict() 52 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 53 | self.use_worker_init_fn = use_worker_init_fn 54 | if train is not None: 55 | self.dataset_configs["train"] = train 56 | self.train_dataloader = self._train_dataloader 57 | if validation is not None: 58 | self.dataset_configs["validation"] = validation 59 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 60 | if test is not None: 61 | self.dataset_configs["test"] = test 62 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 63 | if predict is not None: 64 | self.dataset_configs["predict"] = predict 65 | self.predict_dataloader = self._predict_dataloader 66 | 67 | self.img_loader = None 68 | self.wrap = wrap 69 | self.test_max_n_samples = test_max_n_samples 70 | self.collate_fn = None 71 | 72 | def prepare_data(self): 73 | pass 74 | 75 | def setup(self, stage=None): 76 | self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) 77 | if self.wrap: 78 | for k in self.datasets: 79 | self.datasets[k] = WrappedDataset(self.datasets[k]) 80 | 81 | def _train_dataloader(self): 82 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 83 | if is_iterable_dataset or self.use_worker_init_fn: 84 | init_fn = worker_init_fn 85 | else: 86 | init_fn = None 87 | loader = DataLoader(self.datasets["train"], batch_size=self.batch_size, 88 | num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, 89 | worker_init_fn=init_fn, collate_fn=self.collate_fn, 90 | ) 91 | return loader 92 | 93 | def _val_dataloader(self, shuffle=False): 94 | if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 95 | init_fn = worker_init_fn 96 | else: 97 | init_fn = None 98 | return DataLoader(self.datasets["validation"], 99 | batch_size=self.batch_size, 100 | num_workers=self.num_workers, 101 | worker_init_fn=init_fn, 102 | shuffle=shuffle, 103 | collate_fn=self.collate_fn, 104 | ) 105 | 106 | def _test_dataloader(self, shuffle=False): 107 | try: 108 | is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) 109 | except: 110 | is_iterable_dataset = isinstance(self.datasets['test'], Txt2ImgIterableBaseDataset) 111 | 112 | if is_iterable_dataset or self.use_worker_init_fn: 113 | init_fn = worker_init_fn 114 | else: 115 | init_fn = None 116 | 117 | # do not shuffle dataloader for iterable dataset 118 | shuffle = shuffle and (not is_iterable_dataset) 119 | if self.test_max_n_samples is not None: 120 | dataset = torch.utils.data.Subset(self.datasets["test"], list(range(self.test_max_n_samples))) 121 | else: 122 | dataset = self.datasets["test"] 123 | return DataLoader(dataset, batch_size=self.batch_size, 124 | num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, 125 | collate_fn=self.collate_fn, 126 | ) 127 | 128 | def _predict_dataloader(self, shuffle=False): 129 | if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: 130 | init_fn = worker_init_fn 131 | else: 132 | init_fn = None 133 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 134 | num_workers=self.num_workers, worker_init_fn=init_fn, 135 | collate_fn=self.collate_fn, 136 | ) 137 | -------------------------------------------------------------------------------- /main/utils_train.py: -------------------------------------------------------------------------------- 1 | import os, re 2 | from omegaconf import OmegaConf 3 | import logging 4 | mainlogger = logging.getLogger('mainlogger') 5 | 6 | import torch 7 | from collections import OrderedDict 8 | 9 | def init_workspace(name, logdir, model_config, lightning_config, rank=0): 10 | workdir = os.path.join(logdir, name) 11 | ckptdir = os.path.join(workdir, "checkpoints") 12 | cfgdir = os.path.join(workdir, "configs") 13 | loginfo = os.path.join(workdir, "loginfo") 14 | 15 | # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower) 16 | os.makedirs(workdir, exist_ok=True) 17 | os.makedirs(ckptdir, exist_ok=True) 18 | os.makedirs(cfgdir, exist_ok=True) 19 | os.makedirs(loginfo, exist_ok=True) 20 | 21 | if rank == 0: 22 | if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: 23 | os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'), exist_ok=True) 24 | OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml")) 25 | OmegaConf.save(OmegaConf.create({"lightning": lightning_config}), os.path.join(cfgdir, "lightning.yaml")) 26 | return workdir, ckptdir, cfgdir, loginfo 27 | 28 | def check_config_attribute(config, name): 29 | if name in config: 30 | value = getattr(config, name) 31 | return value 32 | else: 33 | return None 34 | 35 | def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): 36 | default_callbacks_cfg = { 37 | "model_checkpoint": { 38 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 39 | "params": { 40 | "dirpath": ckptdir, 41 | "filename": "{epoch}", 42 | "verbose": True, 43 | "save_last": False, 44 | } 45 | }, 46 | "batch_logger": { 47 | "target": "callbacks.ImageLogger", 48 | "params": { 49 | "save_dir": logdir, 50 | "batch_frequency": 1000, 51 | "max_images": 4, 52 | "clamp": True, 53 | } 54 | }, 55 | "learning_rate_logger": { 56 | "target": "pytorch_lightning.callbacks.LearningRateMonitor", 57 | "params": { 58 | "logging_interval": "step", 59 | "log_momentum": False 60 | } 61 | }, 62 | "cuda_callback": { 63 | "target": "callbacks.CUDACallback" 64 | }, 65 | } 66 | 67 | ## optional setting for saving checkpoints 68 | monitor_metric = check_config_attribute(config.model.params, "monitor") 69 | if monitor_metric is not None: 70 | mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.") 71 | default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric 72 | default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3 73 | default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min" 74 | 75 | if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks: 76 | mainlogger.info('Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') 77 | default_metrics_over_trainsteps_ckpt_dict = { 78 | 'metrics_over_trainsteps_checkpoint': {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', 79 | 'params': { 80 | "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), 81 | "filename": "{epoch}-{step}", 82 | "verbose": True, 83 | 'save_top_k': -1, 84 | 'every_n_train_steps': 10000, 85 | 'save_weights_only': True 86 | } 87 | } 88 | } 89 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) 90 | 91 | if "callbacks" in lightning_config: 92 | callbacks_cfg = lightning_config.callbacks 93 | else: 94 | callbacks_cfg = OmegaConf.create() 95 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 96 | 97 | return callbacks_cfg 98 | 99 | def get_trainer_logger(lightning_config, logdir, on_debug): 100 | default_logger_cfgs = { 101 | "tensorboard": { 102 | "target": "pytorch_lightning.loggers.TensorBoardLogger", 103 | "params": { 104 | "save_dir": logdir, 105 | "name": "tensorboard", 106 | } 107 | }, 108 | "testtube": { 109 | "target": "pytorch_lightning.loggers.CSVLogger", 110 | "params": { 111 | "name": "testtube", 112 | "save_dir": logdir, 113 | } 114 | }, 115 | } 116 | os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True) 117 | default_logger_cfg = default_logger_cfgs["tensorboard"] 118 | if "logger" in lightning_config: 119 | logger_cfg = lightning_config.logger 120 | else: 121 | logger_cfg = OmegaConf.create() 122 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 123 | return logger_cfg 124 | 125 | def get_trainer_strategy(lightning_config): 126 | default_strategy_dict = { 127 | "target": "pytorch_lightning.strategies.DDPShardedStrategy" 128 | } 129 | if "strategy" in lightning_config: 130 | strategy_cfg = lightning_config.strategy 131 | return strategy_cfg 132 | else: 133 | strategy_cfg = OmegaConf.create() 134 | 135 | strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg) 136 | return strategy_cfg 137 | 138 | def load_checkpoints(model, model_cfg): 139 | if check_config_attribute(model_cfg, "pretrained_checkpoint"): 140 | pretrained_ckpt = model_cfg.pretrained_checkpoint 141 | assert os.path.exists(pretrained_ckpt), "Error: Pre-trained checkpoint NOT found at:%s"%pretrained_ckpt 142 | mainlogger.info(">>> Load weights from pretrained checkpoint") 143 | 144 | pl_sd = torch.load(pretrained_ckpt, map_location="cpu") 145 | try: 146 | if 'state_dict' in pl_sd.keys(): 147 | model.load_state_dict(pl_sd["state_dict"], strict=True) 148 | mainlogger.info(">>> Loaded weights from pretrained checkpoint: %s"%pretrained_ckpt) 149 | else: 150 | # deepspeed 151 | new_pl_sd = OrderedDict() 152 | for key in pl_sd['module'].keys(): 153 | new_pl_sd[key[16:]]=pl_sd['module'][key] 154 | model.load_state_dict(new_pl_sd, strict=True) 155 | except: 156 | model.load_state_dict(pl_sd) 157 | else: 158 | mainlogger.info(">>> Start training from scratch") 159 | 160 | return model 161 | 162 | def set_logger(logfile, name='mainlogger'): 163 | logger = logging.getLogger(name) 164 | logger.setLevel(logging.INFO) 165 | fh = logging.FileHandler(logfile, mode='w') 166 | fh.setLevel(logging.INFO) 167 | ch = logging.StreamHandler() 168 | ch.setLevel(logging.DEBUG) 169 | fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s")) 170 | ch.setFormatter(logging.Formatter("%(message)s")) 171 | logger.addHandler(fh) 172 | logger.addHandler(ch) 173 | return logger -------------------------------------------------------------------------------- /prompts/512_interp/74906_1462_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/74906_1462_frame1.png -------------------------------------------------------------------------------- /prompts/512_interp/74906_1462_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/74906_1462_frame3.png -------------------------------------------------------------------------------- /prompts/512_interp/Japan_v2_2_062266_s2_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/Japan_v2_2_062266_s2_frame1.png -------------------------------------------------------------------------------- /prompts/512_interp/Japan_v2_2_062266_s2_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/Japan_v2_2_062266_s2_frame3.png -------------------------------------------------------------------------------- /prompts/512_interp/Japan_v2_3_119235_s2_frame1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/Japan_v2_3_119235_s2_frame1.png -------------------------------------------------------------------------------- /prompts/512_interp/Japan_v2_3_119235_s2_frame3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doubiiu/ToonCrafter/b0c47ff339c5e5ec45b84d0c6587850f242d41ef/prompts/512_interp/Japan_v2_3_119235_s2_frame3.png -------------------------------------------------------------------------------- /prompts/512_interp/prompts.txt: -------------------------------------------------------------------------------- 1 | walking man 2 | an anime scene 3 | an anime scene -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | decord==0.6.0 2 | einops==0.3.0 3 | imageio==2.9.0 4 | numpy==1.24.2 5 | omegaconf==2.1.1 6 | opencv_python 7 | pandas==2.0.0 8 | Pillow==9.5.0 9 | pytorch_lightning==1.9.3 10 | PyYAML==6.0 11 | setuptools==65.6.3 12 | torch==2.0.0 13 | torchvision 14 | tqdm==4.65.0 15 | transformers==4.25.1 16 | moviepy 17 | av 18 | xformers 19 | gradio 20 | timm 21 | scikit-learn 22 | open_clip_torch==2.22.0 23 | kornia -------------------------------------------------------------------------------- /scripts/evaluation/ddp_wrapper.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse, importlib 3 | from pytorch_lightning import seed_everything 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | def setup_dist(local_rank): 9 | if dist.is_initialized(): 10 | return 11 | torch.cuda.set_device(local_rank) 12 | torch.distributed.init_process_group('nccl', init_method='env://') 13 | 14 | 15 | def get_dist_info(): 16 | if dist.is_available(): 17 | initialized = dist.is_initialized() 18 | else: 19 | initialized = False 20 | if initialized: 21 | rank = dist.get_rank() 22 | world_size = dist.get_world_size() 23 | else: 24 | rank = 0 25 | world_size = 1 26 | return rank, world_size 27 | 28 | 29 | if __name__ == '__main__': 30 | now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--module", type=str, help="module name", default="inference") 33 | parser.add_argument("--local_rank", type=int, nargs="?", help="for ddp", default=0) 34 | args, unknown = parser.parse_known_args() 35 | inference_api = importlib.import_module(args.module, package=None) 36 | 37 | inference_parser = inference_api.get_parser() 38 | inference_args, unknown = inference_parser.parse_known_args() 39 | 40 | seed_everything(inference_args.seed) 41 | setup_dist(args.local_rank) 42 | torch.backends.cudnn.benchmark = True 43 | rank, gpu_num = get_dist_info() 44 | 45 | # inference_args.savedir = inference_args.savedir+str('_seed')+str(inference_args.seed) 46 | print("@DynamiCrafter Inference [rank%d]: %s"%(rank, now)) 47 | inference_api.run_inference(inference_args, gpu_num, rank) -------------------------------------------------------------------------------- /scripts/evaluation/funcs.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | import numpy as np 3 | from collections import OrderedDict 4 | from decord import VideoReader, cpu 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) 10 | from lvdm.models.samplers.ddim import DDIMSampler 11 | from einops import rearrange 12 | 13 | 14 | def batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1.0,\ 15 | cfg_scale=1.0, hs=None, temporal_cfg_scale=None, **kwargs): 16 | ddim_sampler = DDIMSampler(model) 17 | uncond_type = model.uncond_type 18 | batch_size = noise_shape[0] 19 | fs = cond["fs"] 20 | del cond["fs"] 21 | if noise_shape[-1] == 32: 22 | timestep_spacing = "uniform" 23 | guidance_rescale = 0.0 24 | else: 25 | timestep_spacing = "uniform_trailing" 26 | guidance_rescale = 0.7 27 | ## construct unconditional guidance 28 | if cfg_scale != 1.0: 29 | if uncond_type == "empty_seq": 30 | prompts = batch_size * [""] 31 | #prompts = N * T * [""] ## if is_imgbatch=True 32 | uc_emb = model.get_learned_conditioning(prompts) 33 | elif uncond_type == "zero_embed": 34 | c_emb = cond["c_crossattn"][0] if isinstance(cond, dict) else cond 35 | uc_emb = torch.zeros_like(c_emb) 36 | 37 | ## process image embedding token 38 | if hasattr(model, 'embedder'): 39 | uc_img = torch.zeros(noise_shape[0],3,224,224).to(model.device) 40 | ## img: b c h w >> b l c 41 | uc_img = model.embedder(uc_img) 42 | uc_img = model.image_proj_model(uc_img) 43 | uc_emb = torch.cat([uc_emb, uc_img], dim=1) 44 | 45 | if isinstance(cond, dict): 46 | uc = {key:cond[key] for key in cond.keys()} 47 | uc.update({'c_crossattn': [uc_emb]}) 48 | else: 49 | uc = uc_emb 50 | else: 51 | uc = None 52 | 53 | 54 | additional_decode_kwargs = {'ref_context': hs} 55 | x_T = None 56 | batch_variants = [] 57 | 58 | for _ in range(n_samples): 59 | if ddim_sampler is not None: 60 | kwargs.update({"clean_cond": True}) 61 | samples, _ = ddim_sampler.sample(S=ddim_steps, 62 | conditioning=cond, 63 | batch_size=noise_shape[0], 64 | shape=noise_shape[1:], 65 | verbose=False, 66 | unconditional_guidance_scale=cfg_scale, 67 | unconditional_conditioning=uc, 68 | eta=ddim_eta, 69 | temporal_length=noise_shape[2], 70 | conditional_guidance_scale_temporal=temporal_cfg_scale, 71 | x_T=x_T, 72 | fs=fs, 73 | timestep_spacing=timestep_spacing, 74 | guidance_rescale=guidance_rescale, 75 | **kwargs 76 | ) 77 | ## reconstruct from latent to pixel space 78 | batch_images = model.decode_first_stage(samples, **additional_decode_kwargs) 79 | 80 | index = list(range(samples.shape[2])) 81 | del index[1] 82 | del index[-2] 83 | samples = samples[:,:,index,:,:] 84 | ## reconstruct from latent to pixel space 85 | batch_images_middle = model.decode_first_stage(samples, **additional_decode_kwargs) 86 | batch_images[:,:,batch_images.shape[2]//2-1:batch_images.shape[2]//2+1] = batch_images_middle[:,:,batch_images.shape[2]//2-2:batch_images.shape[2]//2] 87 | 88 | 89 | 90 | batch_variants.append(batch_images) 91 | ## batch, , c, t, h, w 92 | batch_variants = torch.stack(batch_variants, dim=1) 93 | return batch_variants 94 | 95 | 96 | def get_filelist(data_dir, ext='*'): 97 | file_list = glob.glob(os.path.join(data_dir, '*.%s'%ext)) 98 | file_list.sort() 99 | return file_list 100 | 101 | def get_dirlist(path): 102 | list = [] 103 | if (os.path.exists(path)): 104 | files = os.listdir(path) 105 | for file in files: 106 | m = os.path.join(path,file) 107 | if (os.path.isdir(m)): 108 | list.append(m) 109 | list.sort() 110 | return list 111 | 112 | 113 | def load_model_checkpoint(model, ckpt): 114 | def load_checkpoint(model, ckpt, full_strict): 115 | state_dict = torch.load(ckpt, map_location="cpu") 116 | if "state_dict" in list(state_dict.keys()): 117 | state_dict = state_dict["state_dict"] 118 | try: 119 | model.load_state_dict(state_dict, strict=full_strict) 120 | except: 121 | ## rename the keys for 256x256 model 122 | new_pl_sd = OrderedDict() 123 | for k,v in state_dict.items(): 124 | new_pl_sd[k] = v 125 | 126 | for k in list(new_pl_sd.keys()): 127 | if "framestride_embed" in k: 128 | new_key = k.replace("framestride_embed", "fps_embedding") 129 | new_pl_sd[new_key] = new_pl_sd[k] 130 | del new_pl_sd[k] 131 | model.load_state_dict(new_pl_sd, strict=full_strict) 132 | else: 133 | ## deepspeed 134 | new_pl_sd = OrderedDict() 135 | for key in state_dict['module'].keys(): 136 | new_pl_sd[key[16:]]=state_dict['module'][key] 137 | model.load_state_dict(new_pl_sd, strict=full_strict) 138 | 139 | return model 140 | load_checkpoint(model, ckpt, full_strict=True) 141 | print('>>> model checkpoint loaded.') 142 | return model 143 | 144 | 145 | def load_prompts(prompt_file): 146 | f = open(prompt_file, 'r') 147 | prompt_list = [] 148 | for idx, line in enumerate(f.readlines()): 149 | l = line.strip() 150 | if len(l) != 0: 151 | prompt_list.append(l) 152 | f.close() 153 | return prompt_list 154 | 155 | 156 | def load_video_batch(filepath_list, frame_stride, video_size=(256,256), video_frames=16): 157 | ''' 158 | Notice about some special cases: 159 | 1. video_frames=-1 means to take all the frames (with fs=1) 160 | 2. when the total video frames is less than required, padding strategy will be used (repeated last frame) 161 | ''' 162 | fps_list = [] 163 | batch_tensor = [] 164 | assert frame_stride > 0, "valid frame stride should be a positive interge!" 165 | for filepath in filepath_list: 166 | padding_num = 0 167 | vidreader = VideoReader(filepath, ctx=cpu(0), width=video_size[1], height=video_size[0]) 168 | fps = vidreader.get_avg_fps() 169 | total_frames = len(vidreader) 170 | max_valid_frames = (total_frames-1) // frame_stride + 1 171 | if video_frames < 0: 172 | ## all frames are collected: fs=1 is a must 173 | required_frames = total_frames 174 | frame_stride = 1 175 | else: 176 | required_frames = video_frames 177 | query_frames = min(required_frames, max_valid_frames) 178 | frame_indices = [frame_stride*i for i in range(query_frames)] 179 | 180 | ## [t,h,w,c] -> [c,t,h,w] 181 | frames = vidreader.get_batch(frame_indices) 182 | frame_tensor = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 183 | frame_tensor = (frame_tensor / 255. - 0.5) * 2 184 | if max_valid_frames < required_frames: 185 | padding_num = required_frames - max_valid_frames 186 | frame_tensor = torch.cat([frame_tensor, *([frame_tensor[:,-1:,:,:]]*padding_num)], dim=1) 187 | print(f'{os.path.split(filepath)[1]} is not long enough: {padding_num} frames padded.') 188 | batch_tensor.append(frame_tensor) 189 | sample_fps = int(fps/frame_stride) 190 | fps_list.append(sample_fps) 191 | 192 | return torch.stack(batch_tensor, dim=0) 193 | 194 | from PIL import Image 195 | def load_image_batch(filepath_list, image_size=(256,256)): 196 | batch_tensor = [] 197 | for filepath in filepath_list: 198 | _, filename = os.path.split(filepath) 199 | _, ext = os.path.splitext(filename) 200 | if ext == '.mp4': 201 | vidreader = VideoReader(filepath, ctx=cpu(0), width=image_size[1], height=image_size[0]) 202 | frame = vidreader.get_batch([0]) 203 | img_tensor = torch.tensor(frame.asnumpy()).squeeze(0).permute(2, 0, 1).float() 204 | elif ext == '.png' or ext == '.jpg': 205 | img = Image.open(filepath).convert("RGB") 206 | rgb_img = np.array(img, np.float32) 207 | #bgr_img = cv2.imread(filepath, cv2.IMREAD_COLOR) 208 | #bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) 209 | rgb_img = cv2.resize(rgb_img, (image_size[1],image_size[0]), interpolation=cv2.INTER_LINEAR) 210 | img_tensor = torch.from_numpy(rgb_img).permute(2, 0, 1).float() 211 | else: 212 | print(f'ERROR: <{ext}> image loading only support format: [mp4], [png], [jpg]') 213 | raise NotImplementedError 214 | img_tensor = (img_tensor / 255. - 0.5) * 2 215 | batch_tensor.append(img_tensor) 216 | return torch.stack(batch_tensor, dim=0) 217 | 218 | 219 | def save_videos(batch_tensors, savedir, filenames, fps=10): 220 | # b,samples,c,t,h,w 221 | n_samples = batch_tensors.shape[1] 222 | for idx, vid_tensor in enumerate(batch_tensors): 223 | video = vid_tensor.detach().cpu() 224 | video = torch.clamp(video.float(), -1., 1.) 225 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 226 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n_samples)) for framesheet in video] #[3, 1*h, n*w] 227 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 228 | grid = (grid + 1.0) / 2.0 229 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 230 | savepath = os.path.join(savedir, f"{filenames[idx]}.mp4") 231 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 232 | 233 | 234 | def get_latent_z(model, videos): 235 | b, c, t, h, w = videos.shape 236 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 237 | z = model.encode_first_stage(x) 238 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 239 | return z 240 | 241 | -------------------------------------------------------------------------------- /scripts/gradio/i2v_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z 6 | from utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | from einops import repeat 9 | import torchvision.transforms as transforms 10 | from pytorch_lightning import seed_everything 11 | 12 | 13 | class Image2Video(): 14 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None: 15 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw 16 | self.download_model() 17 | 18 | self.result_dir = result_dir 19 | if not os.path.exists(self.result_dir): 20 | os.mkdir(self.result_dir) 21 | ckpt_path='checkpoints/dynamicrafter_'+resolution.split('_')[1]+'_v1/model.ckpt' 22 | config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' 23 | config = OmegaConf.load(config_file) 24 | model_config = config.pop("model", OmegaConf.create()) 25 | model_config['params']['unet_config']['params']['use_checkpoint']=False 26 | model_list = [] 27 | for gpu_id in range(gpu_num): 28 | model = instantiate_from_config(model_config) 29 | # model = model.cuda(gpu_id) 30 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 31 | model = load_model_checkpoint(model, ckpt_path) 32 | model.eval() 33 | model_list.append(model) 34 | self.model_list = model_list 35 | self.save_fps = 8 36 | 37 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123): 38 | seed_everything(seed) 39 | transform = transforms.Compose([ 40 | transforms.Resize(min(self.resolution)), 41 | transforms.CenterCrop(self.resolution), 42 | ]) 43 | torch.cuda.empty_cache() 44 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 45 | start = time.time() 46 | gpu_id=0 47 | if steps > 60: 48 | steps = 60 49 | model = self.model_list[gpu_id] 50 | model = model.cuda() 51 | batch_size=1 52 | channels = model.model.diffusion_model.out_channels 53 | frames = model.temporal_length 54 | h, w = self.resolution[0] // 8, self.resolution[1] // 8 55 | noise_shape = [batch_size, channels, frames, h, w] 56 | 57 | # text cond 58 | with torch.no_grad(), torch.cuda.amp.autocast(): 59 | text_emb = model.get_learned_conditioning([prompt]) 60 | 61 | # img cond 62 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) 63 | img_tensor = (img_tensor / 255. - 0.5) * 2 64 | 65 | image_tensor_resized = transform(img_tensor) #3,h,w 66 | videos = image_tensor_resized.unsqueeze(0) # bchw 67 | 68 | z = get_latent_z(model, videos.unsqueeze(2)) #bc,1,hw 69 | 70 | img_tensor_repeat = repeat(z, 'b c t h w -> b c (repeat t) h w', repeat=frames) 71 | 72 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc 73 | img_emb = model.image_proj_model(cond_images) 74 | 75 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 76 | 77 | fs = torch.tensor([fs], dtype=torch.long, device=model.device) 78 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]} 79 | 80 | ## inference 81 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale) 82 | ## b,samples,c,t,h,w 83 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 84 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 85 | prompt_str=prompt_str[:40] 86 | if len(prompt_str) == 0: 87 | prompt_str = 'empty_prompt' 88 | 89 | save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 90 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 91 | model = model.cpu() 92 | return os.path.join(self.result_dir, f"{prompt_str}.mp4") 93 | 94 | def download_model(self): 95 | REPO_ID = 'Doubiiu/DynamiCrafter_'+str(self.resolution[1]) if self.resolution[1]!=256 else 'Doubiiu/DynamiCrafter' 96 | filename_list = ['model.ckpt'] 97 | if not os.path.exists('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/'): 98 | os.makedirs('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/') 99 | for filename in filename_list: 100 | local_file = os.path.join('./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', filename) 101 | if not os.path.exists(local_file): 102 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/dynamicrafter_'+str(self.resolution[1])+'_v1/', local_dir_use_symlinks=False) 103 | 104 | if __name__ == '__main__': 105 | i2v = Image2Video() 106 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset') 107 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/gradio/i2v_test_application.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from omegaconf import OmegaConf 4 | import torch 5 | from scripts.evaluation.funcs import load_model_checkpoint, save_videos, batch_ddim_sampling, get_latent_z 6 | from utils.utils import instantiate_from_config 7 | from huggingface_hub import hf_hub_download 8 | from einops import repeat 9 | import torchvision.transforms as transforms 10 | from pytorch_lightning import seed_everything 11 | from einops import rearrange 12 | 13 | class Image2Video(): 14 | def __init__(self,result_dir='./tmp/',gpu_num=1,resolution='256_256') -> None: 15 | self.resolution = (int(resolution.split('_')[0]), int(resolution.split('_')[1])) #hw 16 | self.download_model() 17 | 18 | self.result_dir = result_dir 19 | if not os.path.exists(self.result_dir): 20 | os.mkdir(self.result_dir) 21 | ckpt_path='checkpoints/tooncrafter_'+resolution.split('_')[1]+'_interp_v1/model.ckpt' 22 | config_file='configs/inference_'+resolution.split('_')[1]+'_v1.0.yaml' 23 | config = OmegaConf.load(config_file) 24 | model_config = config.pop("model", OmegaConf.create()) 25 | model_config['params']['unet_config']['params']['use_checkpoint']=False 26 | model_list = [] 27 | for gpu_id in range(gpu_num): 28 | model = instantiate_from_config(model_config) 29 | # model = model.cuda(gpu_id) 30 | print(ckpt_path) 31 | assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!" 32 | model = load_model_checkpoint(model, ckpt_path) 33 | model.eval() 34 | model_list.append(model) 35 | self.model_list = model_list 36 | self.save_fps = 8 37 | 38 | def get_image(self, image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123, image2=None): 39 | seed_everything(seed) 40 | transform = transforms.Compose([ 41 | transforms.Resize(min(self.resolution)), 42 | transforms.CenterCrop(self.resolution), 43 | ]) 44 | torch.cuda.empty_cache() 45 | print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))) 46 | start = time.time() 47 | gpu_id=0 48 | if steps > 60: 49 | steps = 60 50 | model = self.model_list[gpu_id] 51 | model = model.cuda() 52 | batch_size=1 53 | channels = model.model.diffusion_model.out_channels 54 | frames = model.temporal_length 55 | h, w = self.resolution[0] // 8, self.resolution[1] // 8 56 | noise_shape = [batch_size, channels, frames, h, w] 57 | 58 | # text cond 59 | with torch.no_grad(), torch.cuda.amp.autocast(): 60 | text_emb = model.get_learned_conditioning([prompt]) 61 | 62 | # img cond 63 | img_tensor = torch.from_numpy(image).permute(2, 0, 1).float().to(model.device) 64 | img_tensor = (img_tensor / 255. - 0.5) * 2 65 | 66 | image_tensor_resized = transform(img_tensor) #3,h,w 67 | videos = image_tensor_resized.unsqueeze(0).unsqueeze(2) # bc1hw 68 | 69 | # z = get_latent_z(model, videos) #bc,1,hw 70 | videos = repeat(videos, 'b c t h w -> b c (repeat t) h w', repeat=frames//2) 71 | 72 | 73 | 74 | 75 | img_tensor2 = torch.from_numpy(image2).permute(2, 0, 1).float().to(model.device) 76 | img_tensor2 = (img_tensor2 / 255. - 0.5) * 2 77 | image_tensor_resized2 = transform(img_tensor2) #3,h,w 78 | videos2 = image_tensor_resized2.unsqueeze(0).unsqueeze(2) # bchw 79 | videos2 = repeat(videos2, 'b c t h w -> b c (repeat t) h w', repeat=frames//2) 80 | 81 | 82 | videos = torch.cat([videos, videos2], dim=2) 83 | z, hs = self.get_latent_z_with_hidden_states(model, videos) 84 | 85 | img_tensor_repeat = torch.zeros_like(z) 86 | 87 | img_tensor_repeat[:,:,:1,:,:] = z[:,:,:1,:,:] 88 | img_tensor_repeat[:,:,-1:,:,:] = z[:,:,-1:,:,:] 89 | 90 | 91 | cond_images = model.embedder(img_tensor.unsqueeze(0)) ## blc 92 | img_emb = model.image_proj_model(cond_images) 93 | 94 | imtext_cond = torch.cat([text_emb, img_emb], dim=1) 95 | 96 | fs = torch.tensor([fs], dtype=torch.long, device=model.device) 97 | cond = {"c_crossattn": [imtext_cond], "fs": fs, "c_concat": [img_tensor_repeat]} 98 | 99 | ## inference 100 | batch_samples = batch_ddim_sampling(model, cond, noise_shape, n_samples=1, ddim_steps=steps, ddim_eta=eta, cfg_scale=cfg_scale, hs=hs) 101 | 102 | ## remove the last frame 103 | if image2 is None: 104 | batch_samples = batch_samples[:,:,:,:-1,...] 105 | ## b,samples,c,t,h,w 106 | prompt_str = prompt.replace("/", "_slash_") if "/" in prompt else prompt 107 | prompt_str = prompt_str.replace(" ", "_") if " " in prompt else prompt_str 108 | prompt_str=prompt_str[:40] 109 | if len(prompt_str) == 0: 110 | prompt_str = 'empty_prompt' 111 | 112 | save_videos(batch_samples, self.result_dir, filenames=[prompt_str], fps=self.save_fps) 113 | print(f"Saved in {prompt_str}. Time used: {(time.time() - start):.2f} seconds") 114 | model = model.cpu() 115 | return os.path.join(self.result_dir, f"{prompt_str}.mp4") 116 | 117 | def download_model(self): 118 | REPO_ID = 'Doubiiu/ToonCrafter' 119 | filename_list = ['model.ckpt'] 120 | if not os.path.exists('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/'): 121 | os.makedirs('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/') 122 | for filename in filename_list: 123 | local_file = os.path.join('./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/', filename) 124 | if not os.path.exists(local_file): 125 | hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir='./checkpoints/tooncrafter_'+str(self.resolution[1])+'_interp_v1/', local_dir_use_symlinks=False) 126 | 127 | def get_latent_z_with_hidden_states(self, model, videos): 128 | b, c, t, h, w = videos.shape 129 | x = rearrange(videos, 'b c t h w -> (b t) c h w') 130 | encoder_posterior, hidden_states = model.first_stage_model.encode(x, return_hidden_states=True) 131 | 132 | hidden_states_first_last = [] 133 | ### use only the first and last hidden states 134 | for hid in hidden_states: 135 | hid = rearrange(hid, '(b t) c h w -> b c t h w', t=t) 136 | hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2) 137 | hidden_states_first_last.append(hid_new) 138 | 139 | z = model.get_first_stage_encoding(encoder_posterior).detach() 140 | z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) 141 | return z, hidden_states_first_last 142 | if __name__ == '__main__': 143 | i2v = Image2Video() 144 | video_path = i2v.get_image('prompts/art.png','man fishing in a boat at sunset') 145 | print('done', video_path) -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | ckpt=checkpoints/tooncrafter_512_interp_v1/model.ckpt 3 | config=configs/inference_512_v1.0.yaml 4 | 5 | prompt_dir=prompts/512_interp/ 6 | res_dir="results" 7 | 8 | FS=10 ## This model adopts FPS=5, range recommended: 5-30 (smaller value -> larger motion) 9 | 10 | 11 | 12 | seed=123 13 | name=tooncrafter_512_interp_seed${seed} 14 | CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/inference.py \ 15 | --seed ${seed} \ 16 | --ckpt_path $ckpt \ 17 | --config $config \ 18 | --savedir $res_dir/$name \ 19 | --n_samples 1 \ 20 | --bs 1 --height 320 --width 512 \ 21 | --unconditional_guidance_scale 7.5 \ 22 | --ddim_steps 50 \ 23 | --ddim_eta 1.0 \ 24 | --prompt_dir $prompt_dir \ 25 | --text_input \ 26 | --video_length 16 \ 27 | --frame_stride ${FS} \ 28 | --timestep_spacing 'uniform_trailing' --guidance_rescale 0.7 --perframe_ae --interp 29 | -------------------------------------------------------------------------------- /utils/save_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from einops import rearrange 6 | 7 | import torch 8 | import torchvision 9 | from torch import Tensor 10 | from torchvision.utils import make_grid 11 | from torchvision.transforms.functional import to_tensor 12 | 13 | 14 | def frames_to_mp4(frame_dir,output_path,fps): 15 | def read_first_n_frames(d: os.PathLike, num_frames: int): 16 | if num_frames: 17 | images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))[:num_frames]] 18 | else: 19 | images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))] 20 | images = [to_tensor(x) for x in images] 21 | return torch.stack(images) 22 | videos = read_first_n_frames(frame_dir, num_frames=None) 23 | videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1) 24 | torchvision.io.write_video(output_path, videos, fps=fps, video_codec='h264', options={'crf': '10'}) 25 | 26 | 27 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): 28 | """ 29 | video: torch.Tensor, b,c,t,h,w, 0-1 30 | if -1~1, enable rescale=True 31 | """ 32 | n = video.shape[0] 33 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 34 | nrow = int(np.sqrt(n)) if nrow is None else nrow 35 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0) for framesheet in video] # [3, grid_h, grid_w] 36 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] 37 | grid = torch.clamp(grid.float(), -1., 1.) 38 | if rescale: 39 | grid = (grid + 1.0) / 2.0 40 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 41 | torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 42 | 43 | 44 | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True): 45 | assert(video.dim() == 5) # b,c,t,h,w 46 | assert(isinstance(video, torch.Tensor)) 47 | 48 | video = video.detach().cpu() 49 | if clamp: 50 | video = torch.clamp(video, -1., 1.) 51 | n = video.shape[0] 52 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 53 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n))) for framesheet in video] # [3, grid_h, grid_w] 54 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] 55 | if rescale: 56 | grid = (grid + 1.0) / 2.0 57 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 58 | path = os.path.join(root, filename) 59 | torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) 60 | 61 | 62 | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True): 63 | if batch_logs is None: 64 | return None 65 | """ save images and videos from images dict """ 66 | def save_img_grid(grid, path, rescale): 67 | if rescale: 68 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 69 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 70 | grid = grid.numpy() 71 | grid = (grid * 255).astype(np.uint8) 72 | os.makedirs(os.path.split(path)[0], exist_ok=True) 73 | Image.fromarray(grid).save(path) 74 | 75 | for key in batch_logs: 76 | value = batch_logs[key] 77 | if isinstance(value, list) and isinstance(value[0], str): 78 | ## a batch of captions 79 | path = os.path.join(save_dir, "%s-%s.txt"%(key, filename)) 80 | with open(path, 'w') as f: 81 | for i, txt in enumerate(value): 82 | f.write(f'idx={i}, txt={txt}\n') 83 | f.close() 84 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 85 | ## save video grids 86 | video = value # b,c,t,h,w 87 | ## only save grayscale or rgb mode 88 | if video.shape[1] != 1 and video.shape[1] != 3: 89 | continue 90 | n = video.shape[0] 91 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 92 | frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0) for framesheet in video] #[3, n*h, 1*w] 93 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 94 | if rescale: 95 | grid = (grid + 1.0) / 2.0 96 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 97 | path = os.path.join(save_dir, "%s-%s.mp4"%(key, filename)) 98 | torchvision.io.write_video(path, grid, fps=save_fps, video_codec='h264', options={'crf': '10'}) 99 | 100 | ## save frame sheet 101 | img = value 102 | video_frames = rearrange(img, 'b c t h w -> (b t) c h w') 103 | t = img.shape[2] 104 | grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0) 105 | path = os.path.join(save_dir, "%s-%s.jpg"%(key, filename)) 106 | #save_img_grid(grid, path, rescale) 107 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 108 | ## save image grids 109 | img = value 110 | ## only save grayscale or rgb mode 111 | if img.shape[1] != 1 and img.shape[1] != 3: 112 | continue 113 | n = img.shape[0] 114 | grid = torchvision.utils.make_grid(img, nrow=1, padding=0) 115 | path = os.path.join(save_dir, "%s-%s.jpg"%(key, filename)) 116 | save_img_grid(grid, path, rescale) 117 | else: 118 | pass 119 | 120 | def prepare_to_log(batch_logs, max_images=100000, clamp=True): 121 | if batch_logs is None: 122 | return None 123 | # process 124 | for key in batch_logs: 125 | N = batch_logs[key].shape[0] if hasattr(batch_logs[key], 'shape') else len(batch_logs[key]) 126 | N = min(N, max_images) 127 | batch_logs[key] = batch_logs[key][:N] 128 | ## in batch_logs: images & caption 129 | if isinstance(batch_logs[key], torch.Tensor): 130 | batch_logs[key] = batch_logs[key].detach().cpu() 131 | if clamp: 132 | try: 133 | batch_logs[key] = torch.clamp(batch_logs[key].float(), -1., 1.) 134 | except RuntimeError: 135 | print("clamp_scalar_cpu not implemented for Half") 136 | return batch_logs 137 | 138 | # ---------------------------------------------------------------------------------------------- 139 | 140 | def fill_with_black_squares(video, desired_len: int) -> Tensor: 141 | if len(video) >= desired_len: 142 | return video 143 | 144 | return torch.cat([ 145 | video, 146 | torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1), 147 | ], dim=0) 148 | 149 | # ---------------------------------------------------------------------------------------------- 150 | def load_num_videos(data_path, num_videos): 151 | # first argument can be either data_path of np array 152 | if isinstance(data_path, str): 153 | videos = np.load(data_path)['arr_0'] # NTHWC 154 | elif isinstance(data_path, np.ndarray): 155 | videos = data_path 156 | else: 157 | raise Exception 158 | 159 | if num_videos is not None: 160 | videos = videos[:num_videos, :, :, :, :] 161 | return videos 162 | 163 | def npz_to_video_grid(data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True): 164 | # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1 165 | if isinstance(data_path, str): 166 | videos = load_num_videos(data_path, num_videos) 167 | elif isinstance(data_path, np.ndarray): 168 | videos = data_path 169 | else: 170 | raise Exception 171 | n,t,h,w,c = videos.shape 172 | videos_th = [] 173 | for i in range(n): 174 | video = videos[i, :,:,:,:] 175 | images = [video[j, :,:,:] for j in range(t)] 176 | images = [to_tensor(img) for img in images] 177 | video = torch.stack(images) 178 | videos_th.append(video) 179 | if verbose: 180 | videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW 181 | else: 182 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 183 | 184 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 185 | if nrow is None: 186 | nrow = int(np.ceil(np.sqrt(n))) 187 | if verbose: 188 | frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] 189 | else: 190 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 191 | 192 | if os.path.dirname(out_path) != "": 193 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 194 | frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] 195 | torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) 196 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget=False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == '__is_first_stage__': 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 35 | 36 | 37 | def get_obj_from_str(string, reload=False): 38 | module, cls = string.rsplit(".", 1) 39 | if reload: 40 | module_imp = importlib.import_module(module) 41 | importlib.reload(module_imp) 42 | return getattr(importlib.import_module(module, package=None), cls) 43 | 44 | 45 | def load_npz_from_dir(data_dir): 46 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 47 | data = np.concatenate(data, axis=0) 48 | return data 49 | 50 | 51 | def load_npz_from_paths(data_paths): 52 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 53 | data = np.concatenate(data, axis=0) 54 | return data 55 | 56 | 57 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 58 | h, w = image.shape[:2] 59 | if resize_short_edge is not None: 60 | k = resize_short_edge / min(h, w) 61 | else: 62 | k = max_resolution / (h * w) 63 | k = k**0.5 64 | h = int(np.round(h * k / 64)) * 64 65 | w = int(np.round(w * k / 64)) * 64 66 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 67 | return image 68 | 69 | 70 | def setup_dist(args): 71 | if dist.is_initialized(): 72 | return 73 | torch.cuda.set_device(args.local_rank) 74 | torch.distributed.init_process_group( 75 | 'nccl', 76 | init_method='env://' 77 | ) --------------------------------------------------------------------------------