├── LICENSE ├── README.md ├── __assets__ ├── magictime_logo.png ├── promtp_opensora.txt └── promtp_unet.txt ├── app.py ├── ckpts ├── Base_Model │ ├── base_model_path.txt │ ├── motion_module │ │ └── motion_module_path.txt │ └── stable-diffusion-v1-5 │ │ └── sd_15_path.txt ├── DreamBooth │ └── dreambooth_path.txt └── Magic_Weights │ └── magic_weights_path.txt ├── data_preprocess ├── README.md ├── run.sh ├── step0_extract_frame_resize.py ├── step2_1_GPT4V_frame_caption.py ├── step2_2_preprocess_frame_caption.py ├── step3_1_GPT4V_video_caption_concise.py ├── step3_1_GPT4V_video_caption_detail.py ├── step3_2_preprocess_video_caption.py └── step4_1_create_webvid_format.py ├── inference.sh ├── inference_cli.sh ├── inference_magictime.py ├── requirements.txt ├── sample_configs ├── RcnzCartoon.yaml ├── RealisticVision.yaml └── ToonYou.yaml └── utils ├── dataset.py ├── pipeline_magictime.py ├── unet.py ├── unet_blocks.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 |

MagicTime: Time-lapse Video Generation Models 5 | 6 | as Metamorphic Simulators

7 |
If you like our project, please give us a star ⭐ on GitHub for the latest update.
8 | 9 |
10 | 11 | 12 | [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build) 13 | [![Replicate demo and cloud API](https://replicate.com/camenduru/magictime/badge)](https://replicate.com/camenduru/magictime) 14 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/MagicTime-jupyter/blob/main/MagicTime_jupyter.ipynb) 15 | [![hf_space](https://img.shields.io/badge/🤗-Paper%20In%20HF-red.svg)](https://huggingface.co/papers/2404.05014) 16 | [![arXiv](https://img.shields.io/badge/Arxiv-2404.05014-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2404.05014) 17 | [![Home Page](https://img.shields.io/badge/Project--blue.svg)](https://pku-yuangroup.github.io/MagicTime/) 18 | [![Dataset](https://img.shields.io/badge/Dataset--green)](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) 19 | [![zhihu](https://img.shields.io/badge/-Twitter@AK%20-black?logo=twitter&logoColor=1D9BF0)](https://twitter.com/_akhaliq/status/1777538468043792473) 20 | [![zhihu](https://img.shields.io/badge/-Twitter@Jinfa%20Huang%20-black?logo=twitter&logoColor=1D9BF0)](https://twitter.com/vhjf36495872/status/1777525817087553827?s=61&t=r2HzCsU2AnJKbR8yKSprKw) 21 | [![DOI](https://zenodo.org/badge/783303222.svg)](https://zenodo.org/doi/10.5281/zenodo.10960665) 22 | [![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) 23 | [![github](https://img.shields.io/github/stars/PKU-YuanGroup/MagicTime.svg?style=social)](https://github.com/PKU-YuanGroup/MagicTime) 24 | 25 |
26 | 27 |
28 | This repository is the official implementation of MagicTime, a metamorphic video generation pipeline based on the given prompts. The main idea is to enhance the capacity of video generation models to accurately depict the real world through our proposed methods and dataset. 29 |
30 | 31 | 32 |
33 |
💡 We also have other video generation projects that may interest you ✨.

34 | 35 | 36 | 37 | > [**Open-Sora Plan: Open-Source Large Video Generation Model**](https://arxiv.org/abs/2412.00131)
38 | > Bin Lin, Yunyang Ge and Xinhua Cheng etc.
39 | [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/Open-Sora-Plan) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/Open-Sora-Plan.svg?style=social)](https://github.com/PKU-YuanGroup/Open-Sora-Plan) [![arXiv](https://img.shields.io/badge/Arxiv-2412.00131-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2412.00131)
40 | > 41 | > [**OpenS2V-Nexus: A Detailed Benchmark and Million-Scale Dataset for Subject-to-Video Generation**](https://arxiv.org/abs/2505.20292)
42 | > Shenghai Yuan, Xianyi He and Yufan Deng etc.
43 | > [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/OpenS2V-Nexus) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/OpenS2V-Nexus.svg?style=social)](https://github.com/PKU-YuanGroup/OpenS2V-Nexus) [![arXiv](https://img.shields.io/badge/Arxiv-2505.20292-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2505.20292)
44 | > 45 | > [**ConsisID: Identity-Preserving Text-to-Video Generation by Frequency Decomposition**](https://arxiv.org/abs/2411.17440)
46 | > Shenghai Yuan, Jinfa Huang and Xianyi He etc.
47 | > [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/ConsisID/) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/ConsisID.svg?style=social)](https://github.com/PKU-YuanGroup/ConsisID/) [![arXiv](https://img.shields.io/badge/Arxiv-2411.17440-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2411.17440)
48 | > 49 | > [**ChronoMagic-Bench: A Benchmark for Metamorphic Evaluation of Text-to-Time-lapse Video Generation**](https://arxiv.org/abs/2406.18522)
50 | > Shenghai Yuan, Jinfa Huang and Yongqi Xu etc.
51 | > [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/PKU-YuanGroup/ChronoMagic-Bench/) [![github](https://img.shields.io/github/stars/PKU-YuanGroup/ChronoMagic-Bench.svg?style=social)](https://github.com/PKU-YuanGroup/ChronoMagic-Bench/) [![arXiv](https://img.shields.io/badge/Arxiv-2406.18522-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.18522)
52 | >

53 | 54 | ## 📣 News 55 | * ⏳⏳⏳ Training a stronger model with the support of [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan). 56 | * ⏳⏳⏳ Release the training code of MagicTime. 57 | * `[2025.04.08]` 🔥 We have updated our technical report. Please click [here](https://arxiv.org/abs/2404.05014) to view it. 58 | * `[2025.03.28]` 🔥 MagicTime has been accepted by **TPAMI**, and we will update arXiv with more details soon, keep tuned! 59 | * `[2024.07.29]` We add *batch inference* to [inference_magictime.py](https://github.com/PKU-YuanGroup/MagicTime/blob/main/inference_magictime.py) for easier usage. 60 | * `[2024.06.27]` Excited to share our latest [ChronoMagic-Bench](https://github.com/PKU-YuanGroup/ChronoMagic-Bench), a benchmark for metamorphic evaluation of text-to-time-lapse video generation, and is fully open source! Please check out the [paper](https://arxiv.org/abs/2406.18522). 61 | * `[2024.05.27]` Excited to share our latest Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out the [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md). 62 | * `[2024.04.14]` Thanks [@camenduru](https://twitter.com/camenduru) and [@ModelsLab](https://modelslab.com/) for providing [Jupyter Notebook](https://github.com/camenduru/MagicTime-jupyter) and [Replicate Demo](https://replicate.com/camenduru/magictime). 63 | * `[2024.04.13]` 🔥 We have compressed the size of repo with less than 1.0 MB, so that everyone can clone easier and faster. You can click [here](https://github.com/PKU-YuanGroup/MagicTime/archive/refs/heads/main.zip) to download, or use `git clone --depth=1` command to obtain this repo. 64 | * `[2024.04.12]` Thanks [@Kijai](https://github.com/kijai) and [@Baobao Wang](https://www.bilibili.com/video/BV1wx421U7Gn/?spm_id_from=333.1007.top_right_bar_window_history.content.click) for providing ComfyUI Extension [ComfyUI-MagicTimeWrapper](https://github.com/kijai/ComfyUI-MagicTimeWrapper). If you find related work, please let us know. 65 | * `[2024.04.11]` 🔥 We release the Hugging Face Space of MagicTime, you can click [here](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build) to have a try. 66 | * `[2024.04.10]` 🔥 We release the inference code and model weight of MagicTime. 67 | * `[2024.04.09]` 🔥 We release the arXiv paper for MagicTime, and you can click [here](https://arxiv.org/abs/2404.05014) to see more details. 68 | * `[2024.04.08]` 🔥 We release the subset of ChronoMagic dataset used to train MagicTime. The dataset includes 2,265 metamorphic video-text pairs and can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) or [Google Drive](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing). 69 | * `[2024.04.08]` 🔥 **All codes & datasets** are coming soon! Stay tuned 👀! 70 | 71 | ## 😮 Highlights 72 | 73 | MagicTime shows excellent performance in **metamorphic video generation**. 74 | 75 | ### Related Resources 76 | * [ChronoMagic](https://huggingface.co/datasets/BestWishYsh/ChronoMagic): including 2265 time-lapse video-text pairs. (captioned by GPT-4V) 77 | * [ChronoMagic-Bench](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Bench/tree/main): including 1649 time-lapse video-text pairs. (captioned by GPT-4o) 78 | * [ChronoMagic-Bench-150](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Bench/tree/main): including 150 time-lapse video-text pairs. (captioned by GPT-4o) 79 | * [ChronoMagic-Pro](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Pro): including 460K time-lapse video-text pairs. (captioned by ShareGPT4Video) 80 | * [ChronoMagic-ProH](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-ProH): including 150K time-lapse video-text pairs. (captioned by ShareGPT4Video) 81 | 82 | ### Metamorphic Videos vs. General Videos 83 | 84 | Compared to general videos, metamorphic videos contain physical knowledge, long persistence, and strong variation, making them difficult to generate. We show compressed .gif on github, which loses some quality. The general videos are generated by the [Animatediff](https://github.com/guoyww/AnimateDiff) and **MagicTime**. 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 |
Type
"Bean sprouts grow and mature from seeds"
"[...] construction in a Minecraft virtual environment"
"Cupcakes baking in an oven [...]"
"[...] transitioning from a tightly closed bud to a fully bloomed state [...]"
General VideosMakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideo
Metamorphic VideosModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2V
109 | 110 | ### Gallery 111 | 112 | We showcase some metamorphic videos generated by **MagicTime**, [MakeLongVideo](https://github.com/xuduo35/MakeLongVideo), [ModelScopeT2V](https://github.com/modelscope), [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter?tab=readme-ov-file), [ZeroScope](https://huggingface.co/cerspense/zeroscope_v2_576w), [LaVie](https://github.com/Vchitect/LaVie), [T2V-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero), [Latte](https://github.com/Vchitect/Latte) and [Animatediff](https://github.com/guoyww/AnimateDiff) below. 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 |
Method
"cherry blossoms transitioning [...]"
"dough balls baking process [...]"
"an ice cube is melting [...]"
"a simple modern house's construction [...]"
MakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideoMakeLongVideo
ModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2VModelScopeT2V
VideoCrafterVideoCrafterVideoCrafterVideoCrafterVideoCrafter
ZeroScopeZeroScopeZeroScopeZeroScopeZeroScope
LaVieLaVieLaVieLaVieLaVie
T2V-ZeroT2V-ZeroT2V-ZeroT2V-ZeroT2V-Zero
LatteLatteLatteLatteLatte
AnimatediffAnimatediffAnimatediffAnimatediffAnimatediff
OursOursOursOursOurs
186 | 187 | 188 | We show more metamorphic videos generated by **MagicTime** with the help of [Realistic](https://civitai.com/models/4201/realistic-vision-v20), [ToonYou](https://civitai.com/models/30240/toonyou) and [RcnzCartoon](https://civitai.com/models/66347/rcnz-cartoon-3d). 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 |
RealisticRealisticRealistic
"[...] bean sprouts grow and mature from seeds"
"dough [...] swells and browns in the oven [...]"
"the construction [...] in Minecraft [...]"
RcnzCartoonRcnzCartoonRcnzCartoon
"a bud transforms into a yellow flower"
"time-lapse of a plant germinating [...]"
"[...] a modern house being constructed in Minecraft [...]"
ToonYouToonYouToonYou
"an ice cube is melting"
"bean plant sprouts grow and mature from the soil"
"time-lapse of delicate pink plum blossoms [...]"
222 | 223 | Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_unet.txt) for full prompts. 224 | ### Integrate into DiT-based Architecture 225 | 226 | The mission of this project is to help reproduce Sora and provide high-quality video-text data and data annotation pipelines, to support [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) or other DiT-based T2V models. To this end, we take an initial step to integrate our MagicTime scheme into the DiT-based Framework. Specifically, our method supports the Open-Sora-Plan v1.0.0 for fine-tuning. We first scale up with additional metamorphic landscape time-lapse videos in the same annotation framework to get the ChronoMagic-Landscape dataset. Then, we fine-tune the Open-Sora-Plan v1.0.0 with the ChronoMagic-Landscape dataset to get the MagicTime-DiT model. The results are as follows (**257×512×512 (10s)**): 227 | 228 | 229 | 230 | 233 | 236 | 239 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 253 | 256 | 259 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 |
231 | 232 | 234 | 235 | 237 | 238 | 240 | 241 |
"Time-lapse of a coastal landscape [...]"
"Display the serene beauty of twilight [...]"
"Sunrise Splendor: Capture the breathtaking moment [...]"
"Nightfall Elegance: Embrace the tranquil beauty [...]"
251 | 252 | 254 | 255 | 257 | 258 | 260 | 261 |
"The sun descending below the horizon [...]"
"[...] daylight fades into the embrace of the night [...]"
"Time-lapse of the dynamic formations of clouds [...]"
"Capture the dynamic formations of clouds [...]"
270 | 271 | Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_opensora.txt) for full prompts. 272 | 273 | ## 🤗 Demo 274 | 275 | ### Gradio Web UI 276 | 277 | Highly recommend trying out our web demo by the following command, which incorporates all features currently supported by MagicTime. We also provide [online demo](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build) in Hugging Face Spaces. 278 | 279 | ```bash 280 | python app.py 281 | ``` 282 | 283 | ### CLI Inference 284 | 285 | ```bash 286 | # For Realistic 287 | python inference_magictime.py --config sample_configs/RealisticVision.yaml --human 288 | 289 | # or you can directly run the .sh 290 | sh inference_cli.sh 291 | ``` 292 | 293 | warning: It is worth noting that even if we use the same seed and prompt but we change a machine, the results will be different. 294 | 295 | ## ⚙️ Requirements and Installation 296 | 297 | We recommend the requirements as follows. 298 | 299 | ### Environment 300 | 301 | ```bash 302 | git clone --depth=1 https://github.com/PKU-YuanGroup/MagicTime.git 303 | cd MagicTime 304 | conda create -n magictime python=3.10.13 305 | conda activate magictime 306 | pip install -r requirements.txt 307 | ``` 308 | 309 | ### Download MagicTime 310 | 311 | The weights are available at [🤗HuggingFace](https://huggingface.co/BestWishYsh/MagicTime/tree/main) and [🟣WiseModel](https://wisemodel.cn/models/SHYuanBest/MagicTime/file), or you can download it with the following commands. 312 | 313 | ```bash 314 | # way 1 315 | # if you are in china mainland, run this first: export HF_ENDPOINT=https://hf-mirror.com 316 | huggingface-cli download --repo-type model \ 317 | BestWishYsh/MagicTime \ 318 | --local-dir ckpts 319 | 320 | # way 2 321 | git lfs install 322 | git clone https://www.wisemodel.cn/SHYuanBest/MagicTime.git 323 | ``` 324 | 325 | Once ready, the weights will be organized in this format: 326 | 327 | ``` 328 | 📦 ckpts/ 329 | ├── 📂 Base_Model/ 330 | │ ├── 📂 motion_module/ 331 | │ ├── 📂 stable-diffusion-v1-5/ 332 | ├── 📂 DreamBooth/ 333 | ├── 📂 Magic_Weights/ 334 | │ ├── 📂 magic_adapter_s/ 335 | │ ├── 📂 magic_adapter_t/ 336 | │ ├── 📂 magic_text_encoder/ 337 | ``` 338 | 339 | ## 🗝️ Training & Inference 340 | 341 | The training code is coming soon! 342 | 343 | For inference, some examples are shown below: 344 | 345 | ```bash 346 | # For Realistic 347 | python inference_magictime.py --config sample_configs/RealisticVision.yaml 348 | # For ToonYou 349 | python inference_magictime.py --config sample_configs/ToonYou.yaml 350 | # For RcnzCartoon 351 | python inference_magictime.py --config sample_configs/RcnzCartoon.yaml 352 | # or you can directly run the .sh 353 | sh inference.sh 354 | ``` 355 | 356 | You can also put all your *custom prompts* in a .txt file and run: 357 | 358 | ```bash 359 | # For Realistic 360 | python inference_magictime.py --config sample_configs/RealisticVision.yaml --run-txt XXX.txt --batch-size 2 361 | # For ToonYou 362 | python inference_magictime.py --config sample_configs/ToonYou.yaml --run-txt XXX.txt --batch-size 2 363 | # For RcnzCartoon 364 | python inference_magictime.py --config sample_configs/RcnzCartoon.yaml --run-txt XXX.txt --batch-size 2 365 | ``` 366 | 367 | ## Community Contributions 368 | 369 | We found some plugins created by community developers. Thanks for their efforts: 370 | 371 | - ComfyUI Extension. [ComfyUI-MagicTimeWrapper](https://github.com/kijai/ComfyUI-MagicTimeWrapper) (by [@Kijai](https://github.com/kijai)). And you can click [here](https://www.bilibili.com/video/BV1wx421U7Gn/?spm_id_from=333.1007.top_right_bar_window_history.content.click) to view the installation tutorial. 372 | - Replicate Demo & Cloud API. [Replicate-MagicTime](https://replicate.com/camenduru/magictime) (by [@camenduru](https://twitter.com/camenduru)). 373 | - Jupyter Notebook. [Jupyter-MagicTime](https://github.com/camenduru/MagicTime-jupyter) (by [@ModelsLab](https://modelslab.com/)). 374 | 375 | If you find related work, please let us know. 376 | 377 | ## 🐳 ChronoMagic Dataset 378 | ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic), or you can download it with the following command. Some samples can be found on our [Project Page](https://pku-yuangroup.github.io/MagicTime/). 379 | ```bash 380 | huggingface-cli download --repo-type dataset \ 381 | --resume-download BestWishYsh/ChronoMagic \ 382 | --local-dir BestWishYsh/ChronoMagic \ 383 | --local-dir-use-symlinks False 384 | ``` 385 | 386 | ## 👍 Acknowledgement 387 | * [Animatediff](https://github.com/guoyww/AnimateDiff/tree/main) The codebase we built upon and it is a strong U-Net-based text-to-video generation model. 388 | 389 | * [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) The codebase we built upon and it is a simple and scalable DiT-based text-to-video generation repo, to reproduce [Sora](https://openai.com/sora). 390 | 391 | ## 🔒 License 392 | * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) file. 393 | * The service is a research preview. Please contact us if you find any potential violations. 394 | 395 | ## ✏️ Citation 396 | If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. 397 | 398 | ```BibTeX 399 | @article{yuan2025magictime, 400 | title={Magictime: Time-lapse video generation models as metamorphic simulators}, 401 | author={Yuan, Shenghai and Huang, Jinfa and Shi, Yujun and Xu, Yongqi and Zhu, Ruijie and Lin, Bin and Cheng, Xinhua and Yuan, Li and Luo, Jiebo}, 402 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 403 | year={2025}, 404 | publisher={IEEE} 405 | } 406 | ``` 407 | 408 | ## 🤝 Contributors 409 | 410 | 411 | 412 | 413 | 414 | -------------------------------------------------------------------------------- /__assets__/magictime_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/__assets__/magictime_logo.png -------------------------------------------------------------------------------- /__assets__/promtp_opensora.txt: -------------------------------------------------------------------------------- 1 | 1. Time-lapse of a coastal landscape transitioning from sunrise to nightfall, with early morning light and soft shadows giving way to a clearer, bright midday sky, and later visible signs of sunset with orange hues and a dimming sky, culminating in a vibrant dusk. 2 | 2. Display the serene beauty of twilight, marking the transition from day to night with subtle changes in lighting. 3 | 3. Sunrise Splendor: Capture the breathtaking moment as the sun peeks over the horizon, casting its warm hues across the landscape in a mesmerizing time-lapse. 4 | 4. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle. 5 | 5. The sun descending below the horizon at dusk. The video is a time-lapse showcasing the gradual dimming of daylight, leading to the onset of twilight. 6 | 6. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle. 7 | 7. Time-lapse of the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video. 8 | 8. Capture the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video. -------------------------------------------------------------------------------- /__assets__/promtp_unet.txt: -------------------------------------------------------------------------------- 1 | 1. A time-lapse video of bean sprouts grow and mature from seeds. 2 | 2. Dough starts smooth, swells and browns in the oven, finishing as fully expanded, baked bread. 3 | 3. The construction of a simple modern house in Minecraft. As the construction progresses, the roof and walls are completed, and the area around the house is cleared and shaped. 4 | 4. A bud transforms into a yellow flower. 5 | 5. Time-lapse of a plant germinating and developing into a young plant with multiple true leaves in a container, showing progressive growth stages from bare soil to a full plant. 6 | 6. Time-lapse of a modern house being constructed in Minecraft, beginning with a basic structure and progressively adding roof details, and new sections. 7 | 7. An ice cube is melting. 8 | 8. Bean plant sprouts grow and mature from the soil. 9 | 9. Time-lapse of delicate pink plum blossoms transitioning from tightly closed buds to gently unfurling petals, revealing the intricate details of stamens and pistils within. -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import torch 5 | import random 6 | import gradio as gr 7 | from glob import glob 8 | from omegaconf import OmegaConf 9 | from safetensors import safe_open 10 | from diffusers import AutoencoderKL 11 | from diffusers import DDIMScheduler 12 | from diffusers.utils.import_utils import is_xformers_available 13 | from transformers import CLIPTextModel, CLIPTokenizer 14 | 15 | from utils.unet import UNet3DConditionModel 16 | from utils.pipeline_magictime import MagicTimePipeline 17 | from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model 18 | # import spaces 19 | 20 | from huggingface_hub import snapshot_download 21 | 22 | model_path = "ckpts" 23 | 24 | if not os.path.exists(model_path) or not os.path.exists(f"{model_path}/model_real_esran") or not os.path.exists(f"{model_path}/model_rife"): 25 | print("Model not found, downloading from Hugging Face...") 26 | snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir=f"{model_path}") 27 | else: 28 | print(f"Model already exists in {model_path}, skipping download.") 29 | 30 | pretrained_model_path = f"{model_path}/Base_Model/stable-diffusion-v1-5" 31 | inference_config_path = "sample_configs/RealisticVision.yaml" 32 | magic_adapter_s_path = f"{model_path}/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" 33 | magic_adapter_t_path = f"{model_path}/Magic_Weights/magic_adapter_t" 34 | magic_text_encoder_path = f"{model_path}/Magic_Weights/magic_text_encoder" 35 | 36 | css = """ 37 | .toolbutton { 38 | margin-buttom: 0em 0em 0em 0em; 39 | max-width: 2.5em; 40 | min-width: 2.5em !important; 41 | height: 2.5em; 42 | } 43 | """ 44 | 45 | examples = [ 46 | # 1-RealisticVision 47 | [ 48 | "RealisticVisionV60B1_v51VAE.safetensors", 49 | "motion_module.ckpt", 50 | "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms.", 51 | "worst quality, low quality, letterboxed", 52 | 512, 512, "1534851746" 53 | ], 54 | # 2-RCNZ 55 | [ 56 | "RcnzCartoon.safetensors", 57 | "motion_module.ckpt", 58 | "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney.", 59 | "worst quality, low quality, letterboxed", 60 | 512, 512, "3480796026" 61 | ], 62 | # 3-ToonYou 63 | [ 64 | "ToonYou_beta6.safetensors", 65 | "motion_module.ckpt", 66 | "Bean sprouts grow and mature from seeds.", 67 | "worst quality, low quality, letterboxed", 68 | 512, 512, "1496541313" 69 | ] 70 | ] 71 | 72 | # clean Grdio cache 73 | print(f"### Cleaning cached examples ...") 74 | os.system(f"rm -rf gradio_cached_examples/") 75 | 76 | device = "cuda" 77 | 78 | def random_seed(): 79 | return random.randint(1, 10**16) 80 | 81 | class MagicTimeController: 82 | def __init__(self): 83 | # config dirs 84 | self.basedir = os.getcwd() 85 | self.stable_diffusion_dir = os.path.join(self.basedir, model_path, "Base_Model") 86 | self.motion_module_dir = os.path.join(self.basedir, model_path, "Base_Model", "motion_module") 87 | self.personalized_model_dir = os.path.join(self.basedir, model_path, "DreamBooth") 88 | self.savedir = os.path.join(self.basedir, "outputs") 89 | os.makedirs(self.savedir, exist_ok=True) 90 | 91 | self.dreambooth_list = [] 92 | self.motion_module_list = [] 93 | 94 | self.selected_dreambooth = None 95 | self.selected_motion_module = None 96 | 97 | self.refresh_motion_module() 98 | self.refresh_personalized_model() 99 | 100 | # config models 101 | self.inference_config = OmegaConf.load(inference_config_path)[1] 102 | 103 | self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 104 | self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) 105 | self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) 106 | self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device) 107 | self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 108 | self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)) 109 | 110 | self.update_motion_module(self.motion_module_list[0]) 111 | self.update_motion_module_2(self.motion_module_list[0]) 112 | self.update_dreambooth(self.dreambooth_list[0]) 113 | 114 | def refresh_motion_module(self): 115 | motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) 116 | self.motion_module_list = [os.path.basename(p) for p in motion_module_list] 117 | 118 | def refresh_personalized_model(self): 119 | dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) 120 | self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list] 121 | 122 | def update_dreambooth(self, dreambooth_dropdown, motion_module_dropdown=None): 123 | self.selected_dreambooth = dreambooth_dropdown 124 | 125 | dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown) 126 | dreambooth_state_dict = {} 127 | with safe_open(dreambooth_dropdown, framework="pt", device="cpu") as f: 128 | for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) 129 | 130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config) 131 | self.vae.load_state_dict(converted_vae_checkpoint) 132 | 133 | del self.unet 134 | self.unet = None 135 | torch.cuda.empty_cache() 136 | time.sleep(1) 137 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config) 138 | self.unet = copy.deepcopy(self.unet_model) 139 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False) 140 | 141 | del self.text_encoder 142 | self.text_encoder = None 143 | torch.cuda.empty_cache() 144 | time.sleep(1) 145 | text_model = copy.deepcopy(self.text_model) 146 | self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict) 147 | 148 | from swift import Swift 149 | magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") 150 | self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0) 151 | self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path) 152 | self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path) 153 | 154 | return gr.Dropdown() 155 | 156 | def update_motion_module(self, motion_module_dropdown): 157 | self.selected_motion_module = motion_module_dropdown 158 | motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) 159 | motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") 160 | _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) 161 | assert len(unexpected) == 0 162 | return gr.Dropdown() 163 | 164 | def update_motion_module_2(self, motion_module_dropdown): 165 | self.selected_motion_module = motion_module_dropdown 166 | motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) 167 | motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") 168 | _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False) 169 | assert len(unexpected) == 0 170 | return gr.Dropdown() 171 | 172 | # @spaces.GPU(duration=300) 173 | def magictime( 174 | self, 175 | dreambooth_dropdown, 176 | motion_module_dropdown, 177 | prompt_textbox, 178 | negative_prompt_textbox, 179 | width_slider, 180 | height_slider, 181 | seed_textbox, 182 | ): 183 | torch.cuda.empty_cache() 184 | time.sleep(1) 185 | 186 | if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown) 187 | if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown) 188 | if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown) 189 | 190 | while self.text_encoder is None or self.unet is None: 191 | self.update_dreambooth(dreambooth_dropdown, motion_module_dropdown) 192 | 193 | if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() 194 | 195 | pipeline = MagicTimePipeline( 196 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, 197 | scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) 198 | ).to(device) 199 | 200 | if int(seed_textbox) > 0: seed = int(seed_textbox) 201 | else: seed = random_seed() 202 | torch.manual_seed(int(seed)) 203 | 204 | assert seed == torch.initial_seed() 205 | print(f"### seed: {seed}") 206 | 207 | generator = torch.Generator(device=device) 208 | generator.manual_seed(seed) 209 | 210 | sample = pipeline( 211 | prompt_textbox, 212 | negative_prompt = negative_prompt_textbox, 213 | num_inference_steps = 25, 214 | guidance_scale = 8., 215 | width = width_slider, 216 | height = height_slider, 217 | video_length = 16, 218 | generator = generator, 219 | ).videos 220 | 221 | save_sample_path = os.path.join(self.savedir, f"sample.mp4") 222 | save_videos_grid(sample, save_sample_path) 223 | 224 | json_config = { 225 | "prompt": prompt_textbox, 226 | "n_prompt": negative_prompt_textbox, 227 | "width": width_slider, 228 | "height": height_slider, 229 | "seed": seed, 230 | "dreambooth": dreambooth_dropdown, 231 | } 232 | 233 | torch.cuda.empty_cache() 234 | time.sleep(1) 235 | return gr.Video(value=save_sample_path), gr.Json(value=json_config) 236 | 237 | controller = MagicTimeController() 238 | 239 | def ui(): 240 | with gr.Blocks(css=css) as demo: 241 | gr.Markdown( 242 | """ 243 |
244 | 245 |
246 | 247 |

MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators

248 |
If you like our project, please give us a star ⭐ on GitHub for the latest update.
249 | 250 | [GitHub](https://github.com/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing) 251 | """ 252 | ) 253 | with gr.Row(): 254 | with gr.Column(): 255 | dreambooth_dropdown = gr.Dropdown(label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True) 256 | motion_module_dropdown = gr.Dropdown(label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True) 257 | 258 | prompt_textbox = gr.Textbox(label="Prompt", lines=3) 259 | negative_prompt_textbox = gr.Textbox(label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo") 260 | 261 | with gr.Accordion("Advance", open=False): 262 | with gr.Row(): 263 | width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) 264 | height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) 265 | with gr.Row(): 266 | seed_textbox = gr.Textbox(label="Seed (-1 means random)", value="-1") 267 | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") 268 | seed_button.click(fn=random_seed, inputs=[], outputs=[seed_textbox]) 269 | 270 | generate_button = gr.Button(value="Generate", variant='primary') 271 | 272 | with gr.Column(): 273 | result_video = gr.Video(label="Generated Animation", interactive=False) 274 | json_config = gr.Json(label="Config", value={}) 275 | 276 | inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox] 277 | outputs = [result_video, json_config] 278 | 279 | generate_button.click(fn=controller.magictime, inputs=inputs, outputs=outputs) 280 | 281 | gr.Markdown(""" 282 |
⚠ Warning: Even if you use the same seed and prompt, changing machines may produce different results. 283 | If you find a better seed and prompt, please submit an issue on GitHub.
284 | """) 285 | 286 | gr.Examples(fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True) 287 | 288 | return demo 289 | 290 | if __name__ == "__main__": 291 | demo = ui() 292 | demo.queue(max_size=20) 293 | demo.launch() 294 | -------------------------------------------------------------------------------- /ckpts/Base_Model/base_model_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/base_model_path.txt -------------------------------------------------------------------------------- /ckpts/Base_Model/motion_module/motion_module_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/motion_module/motion_module_path.txt -------------------------------------------------------------------------------- /ckpts/Base_Model/stable-diffusion-v1-5/sd_15_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/stable-diffusion-v1-5/sd_15_path.txt -------------------------------------------------------------------------------- /ckpts/DreamBooth/dreambooth_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/DreamBooth/dreambooth_path.txt -------------------------------------------------------------------------------- /ckpts/Magic_Weights/magic_weights_path.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Magic_Weights/magic_weights_path.txt -------------------------------------------------------------------------------- /data_preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Data Preprocessing Pipeline by *MagicTime* 2 | This repo describes how to process your own data like [ChronoMagic](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) datasets in the [MagicTime](https://arxiv.org/abs/2404.05014) paper. 3 | 4 | ## 🗝️ Usage 5 | 6 | ```bash 7 | #!/bin/bash 8 | 9 | # Global variables 10 | INPUT_FOLDER="./step_0" 11 | OUTPUT_FOLDER_STEP_1="./step_1" 12 | API_KEY="XXX" 13 | NUM_WORKERS=8 14 | 15 | # File paths 16 | FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json" 17 | GROUP_FRAMES_FILE="./2_1_temp_group_frames.json" 18 | UPDATED_FRAME_CAPTION_FILE="./2_2_updated_gpt_frames_caption.json" 19 | UNMATCHED_FRAME_CAPTION_FILE="./2_2_temp_unmatched_gpt_frames_caption.json" 20 | UNORDERED_FRAME_CAPTION_FILE="./2_2_temp_unordered_gpt_frames_caption.json" 21 | FINAL_USEFUL_FRAME_CAPTION_FILE="./2_2_final_useful_gpt_frames_caption.json" 22 | VIDEO_CAPTION_FILE="./3_1_gpt_video_caption.json" 23 | UNMATCHED_VIDEO_CAPTION_FILE="./3_2_temp_unmatched_gpt_video_caption.json" 24 | EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE="./3_2_temp_exclude_by_frame_gpt_video_caption.json" 25 | FINAL_USEFUL_VIDEO_CAPTION_FILE="./3_2_final_useful_gpt_video_caption.json" 26 | FINAL_CSV_FILE="./all_clean_data.csv" 27 | 28 | # Step 1: Extract and resize frames 29 | python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1" 30 | 31 | # Step 2.1: Generate frame captions using GPT-4V 32 | python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \ 33 | --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" 34 | 35 | # Step 2.2: Preprocess frame captions 36 | python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \ 37 | --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \ 38 | --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE" 39 | 40 | # Step 3.1: Generate concise video captions using GPT-4V 41 | python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \ 42 | --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" 43 | 44 | # Optional: Generate detailed video captions (uncomment to enable) 45 | # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \ 46 | # --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" 47 | 48 | # Step 3.2: Preprocess video captions 49 | python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \ 50 | --updated_file_path "$VIDEO_CAPTION_FILE" --unmatched_data_path "$UNMATCHED_VIDEO_CAPTION_FILE" \ 51 | --exclude_by_frame_data_path "$EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE" --final_useful_data_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" 52 | 53 | # Step 4: Create the final dataset in WebVid format 54 | python step4_1_create_webvid_format.py --caption_file_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" \ 55 | --output_csv_file_path "$FINAL_CSV_FILE" 56 | ``` 57 | -------------------------------------------------------------------------------- /data_preprocess/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Global variables 4 | INPUT_FOLDER="./step_0" 5 | OUTPUT_FOLDER_STEP_1="./step_1" 6 | API_KEY="XXX" 7 | NUM_WORKERS=8 8 | 9 | # File paths 10 | FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json" 11 | GROUP_FRAMES_FILE="./2_1_temp_group_frames.json" 12 | UPDATED_FRAME_CAPTION_FILE="./2_2_updated_gpt_frames_caption.json" 13 | UNMATCHED_FRAME_CAPTION_FILE="./2_2_temp_unmatched_gpt_frames_caption.json" 14 | UNORDERED_FRAME_CAPTION_FILE="./2_2_temp_unordered_gpt_frames_caption.json" 15 | FINAL_USEFUL_FRAME_CAPTION_FILE="./2_2_final_useful_gpt_frames_caption.json" 16 | VIDEO_CAPTION_FILE="./3_1_gpt_video_caption.json" 17 | UNMATCHED_VIDEO_CAPTION_FILE="./3_2_temp_unmatched_gpt_video_caption.json" 18 | EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE="./3_2_temp_exclude_by_frame_gpt_video_caption.json" 19 | FINAL_USEFUL_VIDEO_CAPTION_FILE="./3_2_final_useful_gpt_video_caption.json" 20 | FINAL_CSV_FILE="./all_clean_data.csv" 21 | 22 | # Step 1: Extract and resize frames 23 | python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1" 24 | 25 | # Step 2.1: Generate frame captions using GPT-4V 26 | python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \ 27 | --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1" 28 | 29 | # Step 2.2: Preprocess frame captions 30 | python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \ 31 | --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \ 32 | --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE" 33 | 34 | # Step 3.1: Generate concise video captions using GPT-4V 35 | python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \ 36 | --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" 37 | 38 | # Optional: Generate detailed video captions (uncomment to enable) 39 | # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \ 40 | # --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE" 41 | 42 | # Step 3.2: Preprocess video captions 43 | python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \ 44 | --updated_file_path "$VIDEO_CAPTION_FILE" --unmatched_data_path "$UNMATCHED_VIDEO_CAPTION_FILE" \ 45 | --exclude_by_frame_data_path "$EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE" --final_useful_data_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" 46 | 47 | # Step 4: Create the final dataset in WebVid format 48 | python step4_1_create_webvid_format.py --caption_file_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" \ 49 | --output_csv_file_path "$FINAL_CSV_FILE" -------------------------------------------------------------------------------- /data_preprocess/step0_extract_frame_resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import argparse 5 | 6 | 7 | def resize_frame(frame, short_edge=256): 8 | height, width = frame.shape[:2] 9 | if min(height, width) <= short_edge: 10 | return frame 11 | else: 12 | scale = short_edge / width if height > width else short_edge / height 13 | new_width = int(width * scale) 14 | new_height = int(height * scale) 15 | resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) 16 | return resized_frame 17 | 18 | def extract_frames(video_path, output_folder, num_frames=8): 19 | cap = cv2.VideoCapture(video_path) 20 | if not cap.isOpened(): 21 | print(f"Error opening video file {video_path}") 22 | return 23 | 24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 25 | frames_to_capture = set([0, total_frames - 1]) 26 | frames_interval = (total_frames - 1) // (num_frames - 1) 27 | for i in range(1, num_frames - 1): 28 | frames_to_capture.add(i * frames_interval) 29 | 30 | count = 0 31 | while True: 32 | ret, frame = cap.read() 33 | if not ret: 34 | break 35 | 36 | if count in frames_to_capture: 37 | resized_frame = resize_frame(frame) 38 | frame_name = f"{os.path.splitext(os.path.basename(video_path))[0]}_frame{count}.png" 39 | output_path = os.path.join(output_folder, frame_name) 40 | cv2.imwrite(output_path, resized_frame) 41 | print(f"Saved {output_path}") 42 | 43 | count += 1 44 | 45 | cap.release() 46 | 47 | def process_all_videos(folder_path, output_folder): 48 | if not os.path.exists(output_folder): 49 | os.makedirs(output_folder) 50 | 51 | video_files = [f for f in os.listdir(folder_path) if f.endswith((".mp4", ".avi", ".mov"))] 52 | total_videos = len(video_files) 53 | skipped_videos = 0 54 | 55 | print(f"Total videos to check: {total_videos}") 56 | 57 | for filename in video_files: 58 | video_name = os.path.splitext(filename)[0] 59 | video_related_images = glob.glob(os.path.join(output_folder, f"{video_name}_frame*.png")) 60 | 61 | if len(video_related_images) == 8: 62 | print(f"Skipping {filename}, already processed.") 63 | skipped_videos += 1 64 | continue 65 | 66 | # If not 8 images, delete existing ones 67 | for img in video_related_images: 68 | os.remove(img) 69 | print(f"Deleted {img}") 70 | 71 | video_path = os.path.join(folder_path, filename) 72 | print(f"Processing {filename}...") 73 | extract_frames(video_path, output_folder) 74 | 75 | print(f"Skipped {skipped_videos} videos that were already processed.") 76 | print(f"Processed {total_videos - skipped_videos} new or incomplete videos.") 77 | 78 | if __name__ == "__main__": 79 | # Set up argument parser 80 | parser = argparse.ArgumentParser(description="Batch process video files") 81 | parser.add_argument("--input_folder", type=str, default='./step_0', help="Path to the input folder containing videos") 82 | parser.add_argument("--output_folder", type=str, default='./step_1', help="Path to the output folder for processed videos") 83 | 84 | # Parse command-line arguments 85 | args = parser.parse_args() 86 | 87 | # Call the video processing function 88 | process_all_videos(args.input_folder, args.output_folder) -------------------------------------------------------------------------------- /data_preprocess/step2_1_GPT4V_frame_caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import base64 5 | import argparse 6 | from tqdm import tqdm 7 | from openai import OpenAI 8 | from threading import Lock 9 | from concurrent.futures import ThreadPoolExecutor, as_completed 10 | from tenacity import retry, wait_exponential, stop_after_attempt 11 | 12 | 13 | txt_prompt = ''' 14 | Suppose you are a data annotator, specialized in generating captions for time-lapse videos. You will be supplied with eight key frames extracted from a video, each with a filename labeled with its position in the video sequence. Your task is to generate a caption for each frame, focusing on the primary subject and integrating all discernible elements. Note: These captions should be brief and concise, avoiding redundancy. 15 | 16 | Your analysis should demonstrate a deep understanding of real-world physics, encompassing aspects such as gravity and elasticity, and align with the principles of perspective geometry in photography. Ensure object identification consistency across all frames, even if an object is temporarily out of sight. Employ logical deductions to bridge any informational gaps. Begin each caption with a brief reasoning statement, showcasing your analytical approach. For guidance on the expected format, refer to the provided examples: 17 | 18 | Brief Reasoning Statement: The images provided are sequential frames from a time-lapse video depicting the blooming stages of a yellow flower, likely a ranunculus. The sequence is forward, showing a natural progression from bud to full bloom. Time-related information is not included in these frames. I will describe each frame accordingly. 19 | "[_2p6vHyth14]": { 20 | "Reasoning": [ 21 | "Frame 0: This is the first frame, starting the sequence. The flower is in its initial stages, with petals tightly closed.", 22 | "Frame 224: The petals appear slightly more open than in the first frame, indicating the progression of blooming.", 23 | "Frame 448: The bloom has progressed further; petals are more open than in the previous frame, suggesting the continuation of the blooming process.", 24 | "Frame 672: Continuity in the blooming process is evident, with petals unfurling more than in the last frame.", 25 | "Frame 896: The flower is more open than in frame 672, indicating an advanced stage of the blooming process.", 26 | "Frame 1120: The flower is nearing full bloom, with a majority of the petals open and the inner ones starting to loosen.", 27 | "Frame 1344: The blooming process is almost complete, with the flower more open than in frame 1120 and the center more visible.", 28 | "Frame 1570: This final frame likely represents the peak of the bloom, with the flower fully open and all petals relaxed." 29 | ], 30 | "Captioning": [ 31 | "Frame 0: Closed yellow ranunculus bud amidst green foliage.", 32 | "Frame 224: Yellow ranunculus bud beginning to open, with green sepals visible.", 33 | "Frame 448: Opening yellow ranunculus with distinct petal layers.", 34 | "Frame 672: Further unfurled yellow ranunculus, petals spreading outward.", 35 | "Frame 896: Half-open yellow ranunculus, with inner petals still tightly clustered.", 36 | "Frame 1120: Nearly fully bloomed yellow ranunculus, with central petals loosening.", 37 | "Frame 1344: Yellow ranunculus in full bloom, center clearly visible amidst open petals.", 38 | "Frame 1570: Fully bloomed yellow ranunculus with a fully visible center and relaxed petals." 39 | ] 40 | } 41 | 42 | Brief Reasoning Statement: The images show the germination and growth process of a plant, identified as spinach, over a span of 46 days. This time-lapse video captures the transformation from soil to a fully developed plant in a forward sequence. Time-related information is present, indicating the duration of the captured growth process. I will describe each frame accordingly. 43 | "[pVmX1v1hDc]_0001": { 44 | "Reasoning": [ 45 | "Frame 0: This is the initial stage where the soil is moist, likely right after sowing the seeds.", 46 | "Frame 69: The soil surface shows signs of disturbance, possibly from seeds beginning to germinate.", 47 | "Frame 138: Germination has occurred, evident from the emergence of seedlings breaking through the soil.", 48 | "Frame 207: The seedlings have elongated and the first true leaves are beginning to form.", 49 | "Frame 276: Growth is evident with larger true leaves, and the plant is entering the vegetative stage.", 50 | "Frame 345: The plants are more developed with a denser leaf canopy, indicating healthy vegetative growth.", 51 | "Frame 414: The spinach plants are fully developed with large leaves, ready for harvesting.", 52 | "Frame 485: The plants are at full maturity with a thick canopy of leaves, showing the complete growth cycle." 53 | ], 54 | "Captioning": [ 55 | "Frame 0: Moist soil on Day 1 after sowing spinach seeds.", 56 | "Frame 69: Soil surface showing early signs of spinach seed germination on Day 6.", 57 | "Frame 138: Spinach seedlings emerging from soil on Day 10.", 58 | "Frame 207: Elongated spinach seedlings with first true leaves on Day 16.", 59 | "Frame 276: Spinach showing significant leaf growth on Day 24.", 60 | "Frame 345: Denser and larger spinach leaves visible on Day 31.", 61 | "Frame 414: Mature spinach plants with large leaves ready for harvest on Day 39.", 62 | "Frame 485: Thick canopy of mature spinach leaves on Day 46." 63 | ] 64 | } 65 | 66 | {Brief Reasoning Statement: Must include time-related information and description of forward processes} 67 | "{Enter the prefix of the image to represent the id}": { 68 | "Reasoning": [ 69 | " ", 70 | " ", 71 | " ", 72 | " ", 73 | " ", 74 | " ", 75 | " ", 76 | " " 77 | ], 78 | "Captioning": [ 79 | " ", 80 | " ", 81 | " ", 82 | " ", 83 | " ", 84 | " ", 85 | " ", 86 | " " 87 | ] 88 | } 89 | 90 | Attention: Do not reply outside the example template! Below are the video title and input frames: 91 | ''' 92 | 93 | # Global lock for thread-safe file operations 94 | file_lock = Lock() 95 | 96 | # Function to get all image filenames in the specified directory 97 | def get_image_filenames(directory): 98 | image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif'] 99 | return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and os.path.splitext(f)[1].lower() in image_extensions] 100 | 101 | # Function to parse the video ID from the image file name 102 | def parse_video_id(filename): 103 | match = re.match(r'(.+)_frame\d+\.png', filename) 104 | return match.group(1) if match else None 105 | 106 | # Function to convert image to base64 107 | def image_b64(image_path): 108 | with open(image_path, "rb") as f: 109 | return base64.b64encode(f.read()).decode('utf-8') 110 | 111 | # Function to group images 112 | def group_images_by_video_id(filenames): 113 | images_by_video = {} 114 | for filename in tqdm(filenames, desc="Grouping images"): 115 | video_id = parse_video_id(filename) 116 | if video_id: 117 | if video_id not in images_by_video: 118 | images_by_video[video_id] = [] 119 | images_by_video[video_id].append(filename) 120 | 121 | valid_groups = {video_id: images for video_id, images in images_by_video.items() if len(images) == 8} 122 | return valid_groups 123 | 124 | # Function to create prompts for the GPT-4 Vision API 125 | def create_prompts(grouped_images, image_directory, txt_prompt): 126 | prompts = {} 127 | for video_id, group in tqdm(grouped_images.items(), desc="Creating prompts"): 128 | # Initialize the prompt with the given text prompt 129 | prompt = [{"type": "text", "text": txt_prompt}] 130 | 131 | # Append information about each image in the group 132 | for image_name in group: 133 | image_path = os.path.join(image_directory, image_name.strip()) 134 | b64_image = image_b64(image_path) 135 | prompt.append({"type": "text", "text": image_name.strip()}) 136 | prompt.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_image}"}}) 137 | 138 | prompts[video_id] = prompt 139 | return prompts 140 | 141 | def has_been_processed(video_id, output_file): 142 | with file_lock: 143 | if os.path.exists(output_file): 144 | with open(output_file, 'r') as f: 145 | data = json.load(f) 146 | if video_id in data: 147 | print(f"Video ID {video_id} has already been processed.") 148 | return True 149 | return False 150 | 151 | def extract_frame_number(filename): 152 | # Extract the number after 'frame' and convert to integer 153 | return int(filename.split('_frame')[-1].split('.')[0]) 154 | 155 | def load_existing_results(file_path): 156 | if os.path.exists(file_path): 157 | with open(file_path, 'r') as file: 158 | print(f"Loading existing results from {file_path}") 159 | return json.load(file) 160 | else: 161 | print(f"No existing results file found at {file_path}. Creating a new file.") 162 | with open(file_path, 'w') as file: 163 | empty_data = {} 164 | json.dump(empty_data, file) 165 | return empty_data 166 | 167 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) 168 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): 169 | client = OpenAI(api_key=api_key) 170 | chat_completion = client.chat.completions.create( 171 | model=model_name, 172 | messages=[ 173 | { 174 | "role": "user", 175 | "content": prompt, 176 | } 177 | ], 178 | max_tokens=2048, 179 | ) 180 | print(chat_completion) 181 | return chat_completion.choices[0].message.content 182 | 183 | def save_output(video_id, prompt, output_file, api_key): 184 | if not has_been_processed(video_id, output_file): 185 | result = call_gpt(prompt, api_key=api_key) 186 | with file_lock: 187 | with open(output_file, 'r+') as f: 188 | # Read the current data and update it 189 | data = json.load(f) 190 | data[video_id] = result 191 | f.seek(0) # Rewind file to the beginning 192 | json.dump(data, f, indent=4) 193 | f.truncate() # Truncate file to new size 194 | print(f"Processed and saved output for Video ID {video_id}") 195 | 196 | def main(num_workers, all_prompts, output_file, api_key): 197 | # Load existing results 198 | existing_results = load_existing_results(output_file) 199 | 200 | # Filter prompts for video IDs that have not been processed 201 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results} 202 | if not unprocessed_prompts: 203 | print("No unprocessed video IDs found. All prompts have already been processed.") 204 | return 205 | 206 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.") 207 | 208 | progress_bar = tqdm(total=len(unprocessed_prompts)) 209 | 210 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 211 | future_to_index = { 212 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id 213 | for video_id, prompt in unprocessed_prompts.items() 214 | } 215 | 216 | for future in as_completed(future_to_index): 217 | progress_bar.update(1) 218 | try: 219 | future.result() 220 | except Exception as e: 221 | print(f"Error processing video ID {future_to_index[future]}: {e}") 222 | 223 | progress_bar.close() 224 | 225 | if __name__ == "__main__": 226 | # Set up argument parser 227 | parser = argparse.ArgumentParser(description="Process video frame captions.") 228 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") 229 | parser.add_argument("--num_workers", type=int, default=6, help="Number of worker threads for processing.") 230 | parser.add_argument("--output_file", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the output JSON file.") 231 | parser.add_argument("--group_frames_file", type=str, default="./2_1_temp_group_frames.json", help="Path to save grouped frame metadata.") 232 | parser.add_argument("--image_directories", type=str, nargs="+", default=["./step_1"], help="List of directories containing images.") 233 | 234 | # Parse command-line arguments 235 | args = parser.parse_args() 236 | 237 | all_prompts = {} 238 | all_grouped_images = {} 239 | 240 | # Process each image directory 241 | for directory in args.image_directories: 242 | filenames = get_image_filenames(directory) 243 | grouped_images = group_images_by_video_id(filenames) 244 | 245 | # Sort images within each video group 246 | for video_id in grouped_images: 247 | grouped_images[video_id].sort(key=extract_frame_number) 248 | 249 | all_grouped_images.update(grouped_images) # Merge into a single dictionary 250 | 251 | # Generate prompts 252 | prompts = create_prompts(grouped_images, directory, txt_prompt) 253 | all_prompts.update(prompts) 254 | 255 | # Save grouped images metadata 256 | with open(args.group_frames_file, 'w') as file: 257 | json.dump(all_grouped_images, file, indent=4) 258 | 259 | # Execute main processing function 260 | main(args.num_workers, all_prompts, args.output_file, args.api_key) -------------------------------------------------------------------------------- /data_preprocess/step2_2_preprocess_frame_caption.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | 5 | def load_json(file_path): 6 | """Load and return the content of a JSON file.""" 7 | with open(file_path, 'r') as file: 8 | return json.load(file) 9 | 10 | def save_json(data, file_path): 11 | """Save data to a JSON file.""" 12 | with open(file_path, 'w') as file: 13 | json.dump(data, file, ensure_ascii=False, indent=4) 14 | 15 | def process_frame_caption(file_path): 16 | """Process frame captions and save matched data.""" 17 | data = load_json(file_path) 18 | matched_data = {} 19 | unmatched_data = {} 20 | for key, value in data.items(): 21 | brief_reasoning_match = re.search(r'Brief Reasoning Statement: (.*?)(?:\n\n|\n)', value, re.DOTALL) 22 | reasoning_match = re.search(r'"Reasoning": \[(.*?)\]', value, re.DOTALL) 23 | captioning_match = re.search(r'"Captioning": \[(.*?)\]', value, re.DOTALL) 24 | if brief_reasoning_match and reasoning_match and captioning_match: 25 | brief_reasoning = brief_reasoning_match.group(1).strip() 26 | reasoning_list = re.findall(r'"(.*?)"(?:,|$)', reasoning_match.group(1)) 27 | captioning_list = re.findall(r'"(.*?)"(?:,|$)', captioning_match.group(1)) 28 | matched_data[key] = { 29 | "Video_Reasoning": brief_reasoning, 30 | "Frame_Reasoning": reasoning_list, 31 | "Frame_Captioning": captioning_list 32 | } 33 | else: 34 | unmatched_data[key] = value 35 | return matched_data, unmatched_data 36 | 37 | def is_disordered(section): 38 | frames = [] 39 | for entry in section: 40 | try: 41 | # Extracting the frame number 42 | frame_num = int(entry.split(':')[0].split(' ')[1]) 43 | frames.append(frame_num) 44 | except ValueError: 45 | # If parsing fails, skip this entry 46 | continue 47 | return not all(earlier <= later for earlier, later in zip(frames, frames[1:])) 48 | 49 | def find_disorder(data): 50 | """Identify entries with unordered frames.""" 51 | unordered_records = {} 52 | ordered_records = {} 53 | 54 | for key, value in data.items(): 55 | for section_name in ['Frame_Reasoning', 'Frame_Captioning']: 56 | section = value.get(section_name, []) 57 | if is_disordered(section): 58 | unordered_records[key] = value 59 | break 60 | else: 61 | ordered_records[key] = value 62 | return ordered_records, unordered_records 63 | 64 | def remove_disorder(data, unordered_data): 65 | """Remove disordered entries from the dataset.""" 66 | unordered_ids = set(unordered_data.keys()) 67 | ordered_json = {k: v for k, v in data.items() if k not in unordered_ids} 68 | return ordered_json 69 | 70 | def remove_unmatch_records(data, unmatched_data): 71 | """ 72 | Removes records from gpt_results if their ID exists in disordered_records. 73 | :param data: dict, the data from gpt_results.json 74 | :return: dict, the updated data with matching records removed 75 | """ 76 | unmatch_ids = set(unmatched_data.keys()) 77 | matched_json = {id_: value for id_, value in data.items() if id_ not in unmatch_ids} 78 | return matched_json 79 | 80 | def merge_json_files(info_data, caption_data): 81 | # Load info and caption data from JSON files 82 | # with open(info_file, 'r') as file: 83 | # info_data = json.load(file) 84 | # with open(caption_file, 'r') as file: 85 | # caption_data = json.load(file) 86 | 87 | # Merge info into caption data based on matching key prefixes 88 | for caption_key in caption_data: 89 | for info_key in info_data: 90 | if caption_key.startswith(info_key): 91 | # Update the caption entry with info data 92 | 93 | # caption_data[caption_key].update(info_data[info_key]) 94 | 95 | selected_info = {key: info_data[info_key][key] for key in ['title'] if 96 | key in info_data[info_key]} 97 | caption_data[caption_key].update(selected_info) 98 | 99 | break 100 | 101 | # Save merged data to a new JSON file 102 | # with open(output_file, 'w') as file: 103 | # json.dump(caption_data, file) 104 | return caption_data 105 | 106 | if __name__ == "__main__": 107 | # Set up argument parser 108 | parser = argparse.ArgumentParser(description="Process GPT4V frame captions and clean up data.") 109 | parser.add_argument("--file_path", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the input JSON file.") 110 | parser.add_argument("--updated_file_path", type=str, default="./2_2_updated_gpt_frames_caption.json", help="Path to save the updated JSON file.") 111 | parser.add_argument("--unmatched_file_path", type=str, default="./2_2_temp_unmatched_gpt_frames_caption.json", help="Path to save unmatched records.") 112 | parser.add_argument("--unordered_file_path", type=str, default="./2_2_temp_unordered_gpt_frames_caption.json", help="Path to save unordered records.") 113 | parser.add_argument("--final_useful_data_file_path", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to save the final cleaned data.") 114 | 115 | # Parse command-line arguments 116 | args = parser.parse_args() 117 | 118 | # Processing steps 119 | matched_data, unmatched_data = process_frame_caption(args.file_path) 120 | ordered_records, unordered_records = find_disorder(matched_data) 121 | 122 | # Clean JSON by removing unmatched and unordered records 123 | updated_json = remove_unmatch_records(remove_disorder(load_json(args.file_path), unordered_records), unmatched_data) 124 | 125 | # Final useful data (can be merged with additional info if needed) 126 | final_useful_data = ordered_records 127 | 128 | # Print stats 129 | print(f"Number of Unmatched Records (GPT4V_Frame): {len(unmatched_data)}") 130 | print(f"Number of Unordered Records (GPT4V_Frame): {len(unordered_records)}") 131 | print(f"Number of Final Useful Records (GPT4V_Frame): {len(final_useful_data)}") 132 | 133 | # Save the processed results 134 | if len(unmatched_data) != 0 or len(unordered_records) != 0: 135 | save_json(updated_json, args.updated_file_path) 136 | print(f"Found {len(unmatched_data)} unmatched records and {len(unordered_records)} unordered records!") 137 | print(f"Updated JSON file has been saved to {args.updated_file_path}. Please rerun GPT4V for captioning.") 138 | else: 139 | print(f"No unmatched/unordered records found! You can directly use {args.final_useful_data_file_path} for the next step.") 140 | 141 | # Save intermediate results 142 | save_json(unmatched_data, args.unmatched_file_path) 143 | save_json(unordered_records, args.unordered_file_path) 144 | save_json(final_useful_data, args.final_useful_data_file_path) -------------------------------------------------------------------------------- /data_preprocess/step3_1_GPT4V_video_caption_concise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | from openai import OpenAI 6 | from threading import Lock 7 | from tenacity import retry, wait_exponential, stop_after_attempt 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | 10 | 11 | txt_prompt = ''' 12 | Imagine you're an expert data annotator with a specialization in summarizing time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a concise summary for the given time-lapse video. 13 | 14 | Since only textual information is given, you can employ logical deductions to bridge any informational gaps if necessary. For guidance on the expected output format and content length (no more than 70 words), refer to the provided examples: 15 | 16 | "Video_Summary": Time-lapse of a ciplukan fruit growing from a small bud to a mature, rounded form among leaves, gradually enlarging and smoothing out by the video's end. 17 | 18 | "Video_Summary": Time-lapse of red onion bulbs sprouting and growing over 10 days: starting dormant, developing shoots and roots by Day 2, significant growth by Day 6, and full development by Day 10. 19 | 20 | "Video_Summary": "{Video Summary}" 21 | 22 | Attention: Do not reply outside the example template! The process of reasoning and thinking should not be included in the {Video Summary}! Do not use words similar to by frame or at frame! Below are the Video, Video_Reasoning, Frame_Reasoning and Frame_Captioning. 23 | ''' 24 | 25 | # Global lock for thread-safe file operations 26 | file_lock = Lock() 27 | 28 | # Function to create prompts for the GPT-4 Vision API 29 | def create_prompts(txt_prompt, data): 30 | prompts = {} 31 | for video_id, value in tqdm(data.items(), desc="Creating prompts"): 32 | prompt = [{"type": "text", "text": txt_prompt}] 33 | prompt.append({"type": "text", "text": f'''The "Video_Reasoning" is: {value['Video_Reasoning']}'''}) 34 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Reasoning" are: {value['Frame_Reasoning']}'''}) 35 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Captioning" are: {value['Frame_Captioning']}'''}) 36 | prompts[video_id] = prompt 37 | return prompts 38 | 39 | def has_been_processed(video_id, output_file): 40 | with file_lock: 41 | if os.path.exists(output_file): 42 | with open(output_file, 'r') as f: 43 | data = json.load(f) 44 | if video_id in data: 45 | print(f"Video ID {video_id} has already been processed.") 46 | return True 47 | return False 48 | 49 | def load_existing_results(file_path): 50 | if os.path.exists(file_path): 51 | with open(file_path, 'r') as file: 52 | print(f"Loading existing results from {file_path}") 53 | return json.load(file) 54 | else: 55 | print(f"No existing results file found at {file_path}. Creating a new file.") 56 | with open(file_path, 'w') as file: 57 | empty_data = {} 58 | json.dump(empty_data, file) 59 | return empty_data 60 | 61 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) 62 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): 63 | client = OpenAI(api_key=api_key) 64 | chat_completion = client.chat.completions.create( 65 | model=model_name, 66 | messages=[ 67 | { 68 | "role": "user", 69 | "content": prompt, 70 | } 71 | ], 72 | max_tokens=1024, 73 | ) 74 | return chat_completion.choices[0].message.content 75 | 76 | def save_output(video_id, prompt, output_file, api_key): 77 | if not has_been_processed(video_id, output_file): 78 | result = call_gpt(prompt, api_key=api_key) 79 | with file_lock: 80 | with open(output_file, 'r+') as f: 81 | # Read the current data and update it 82 | data = json.load(f) 83 | data[video_id] = result 84 | f.seek(0) # Rewind file to the beginning 85 | json.dump(data, f, indent=4) 86 | f.truncate() # Truncate file to new size 87 | print(f"Processed and saved output for Video ID {video_id}") 88 | 89 | def main(num_workers, all_prompts, output_file, api_key): 90 | # Load existing results 91 | existing_results = load_existing_results(output_file) 92 | 93 | # Filter prompts for video IDs that have not been processed 94 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results} 95 | if not unprocessed_prompts: 96 | print("No unprocessed video IDs found. All prompts have already been processed.") 97 | return 98 | 99 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.") 100 | 101 | progress_bar = tqdm(total=len(unprocessed_prompts)) 102 | 103 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 104 | future_to_index = { 105 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id 106 | for video_id, prompt in unprocessed_prompts.items() 107 | } 108 | 109 | for future in as_completed(future_to_index): 110 | progress_bar.update(1) 111 | try: 112 | future.result() 113 | except Exception as e: 114 | print(f"Error processing video ID {future_to_index[future]}: {e}") 115 | 116 | progress_bar.close() 117 | 118 | if __name__ == "__main__": 119 | # Set up argument parser 120 | parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") 121 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") 122 | parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") 123 | parser.add_argument("--input_file", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") 124 | parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") 125 | 126 | # Parse command-line arguments 127 | args = parser.parse_args() 128 | 129 | # Load data from the input file 130 | with open(args.input_file, 'r') as file: 131 | data = json.load(file) 132 | 133 | # Generate prompts 134 | prompts = create_prompts(txt_prompt, data) 135 | 136 | # Execute main processing function 137 | main(args.num_workers, prompts, args.output_file, args.api_key) -------------------------------------------------------------------------------- /data_preprocess/step3_1_GPT4V_video_caption_detail.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | from openai import OpenAI 6 | from threading import Lock 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | from tenacity import retry, wait_exponential, stop_after_attempt 9 | 10 | 11 | txt_prompt = ''' 12 | Imagine you are a data annotator, specialized in generating summaries for time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a succinct and precise summary for the given time-lapse video. Note: The summary should efficiently encapsulate all discernible elements, particularly emphasizing the primary subject. It is important to indicate whether the video pertains to a forward or reverse sequence. Additionally, integrate any time-related aspects from the video into the summary. 13 | 14 | Since only textual information is given, you can employ logical deductions to bridge any informational gaps if necessary. For guidance on the expected output format, refer to the provided examples: 15 | 16 | "Video_Summary": "The time-lapse video showcases the growth and ripening process of strawberries in a forward sequence. The video starts with fully bloomed strawberry flowers, which then wilt slightly. As the video progresses, the yellow stamens recede, and the flowers continue to wilt. The white petals disappear, and the green immature strawberries become more prominent. The strawberries then grow in size, displaying a green color with some red hues. As the video continues, the strawberries gradually ripen, turning from green to a deep red color." 17 | 18 | "Video_Summary": "This time-lapse video succinctly documents the 50-day decomposition process of a pear in a forward sequence, from its fresh, ripe state on day 1 to a shrunken, moldy, and rotten form by day 50. Throughout the video, the pear's gradual deterioration is evident through increasing browning, the development of mold patches, and significant changes in color, texture, and structure." 19 | 20 | "Video_Summary": "The time-lapse video showcases a Halloween pumpkin's decomposition process in reverse. The video starts with a pumpkin in a highly decomposed state at 92 days post-carving and then counts down the days, reversing the process. The pumpkin gradually re-inflates, reducing the signs of wrinkling and drying, until it appears freshly carved at 1 day post-carving." 21 | 22 | "Video_Summary": "{Video Summary}" 23 | 24 | Attention: Do not reply outside the example template! The process of reasoning and thinking should not be included in the {Video Summary}! Do not use words similar to by frame or at frame! Below are the Video, Video_Reasoning, Frame_Reasoning and Frame_Captioning. 25 | ''' 26 | 27 | # Global lock for thread-safe file operations 28 | file_lock = Lock() 29 | 30 | # Function to create prompts for the GPT-4 Vision API 31 | def create_prompts(txt_prompt, data): 32 | prompts = {} 33 | for video_id, value in tqdm(data.items(), desc="Creating prompts"): 34 | prompt = [{"type": "text", "text": txt_prompt}] 35 | prompt.append({"type": "text", "text": f'''The "Video_Reasoning" is: {value['Video_Reasoning']}'''}) 36 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Reasoning" are: {value['Frame_Reasoning']}'''}) 37 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Captioning" are: {value['Frame_Captioning']}'''}) 38 | prompts[video_id] = prompt 39 | return prompts 40 | 41 | def has_been_processed(video_id, output_file): 42 | with file_lock: 43 | if os.path.exists(output_file): 44 | with open(output_file, 'r') as f: 45 | data = json.load(f) 46 | if video_id in data: 47 | print(f"Video ID {video_id} has already been processed.") 48 | return True 49 | return False 50 | 51 | def load_existing_results(file_path): 52 | if os.path.exists(file_path): 53 | with open(file_path, 'r') as file: 54 | print(f"Loading existing results from {file_path}") 55 | return json.load(file) 56 | else: 57 | print(f"No existing results file found at {file_path}. Creating a new file.") 58 | with open(file_path, 'w') as file: 59 | empty_data = {} 60 | json.dump(empty_data, file) 61 | return empty_data 62 | 63 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100)) 64 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None): 65 | client = OpenAI(api_key=api_key) 66 | chat_completion = client.chat.completions.create( 67 | model=model_name, 68 | messages=[ 69 | { 70 | "role": "user", 71 | "content": prompt, 72 | } 73 | ], 74 | max_tokens=1024, 75 | ) 76 | return chat_completion.choices[0].message.content 77 | 78 | def save_output(video_id, prompt, output_file, api_key): 79 | if not has_been_processed(video_id, output_file): 80 | result = call_gpt(prompt, api_key=api_key) 81 | with file_lock: 82 | with open(output_file, 'r+') as f: 83 | # Read the current data and update it 84 | data = json.load(f) 85 | data[video_id] = result 86 | f.seek(0) # Rewind file to the beginning 87 | json.dump(data, f, indent=4) 88 | f.truncate() # Truncate file to new size 89 | print(f"Processed and saved output for Video ID {video_id}") 90 | 91 | def main(num_workers, all_prompts, output_file, api_key): 92 | # Load existing results 93 | existing_results = load_existing_results(output_file) 94 | 95 | # Filter prompts for video IDs that have not been processed 96 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results} 97 | if not unprocessed_prompts: 98 | print("No unprocessed video IDs found. All prompts have already been processed.") 99 | return 100 | 101 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.") 102 | 103 | progress_bar = tqdm(total=len(unprocessed_prompts)) 104 | 105 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 106 | future_to_index = { 107 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id 108 | for video_id, prompt in unprocessed_prompts.items() 109 | } 110 | 111 | for future in as_completed(future_to_index): 112 | progress_bar.update(1) 113 | try: 114 | future.result() 115 | except Exception as e: 116 | print(f"Error processing video ID {future_to_index[future]}: {e}") 117 | 118 | progress_bar.close() 119 | 120 | if __name__ == "__main__": 121 | # Set up argument parser 122 | parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.") 123 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.") 124 | parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.") 125 | parser.add_argument("--input_file", type=str, default="2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.") 126 | parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.") 127 | 128 | # Parse command-line arguments 129 | args = parser.parse_args() 130 | 131 | # Load data from the input file 132 | with open(args.input_file, 'r') as file: 133 | data = json.load(file) 134 | 135 | # Generate prompts 136 | prompts = create_prompts(txt_prompt, data) 137 | 138 | # Execute main processing function 139 | main(args.num_workers, prompts, args.output_file, args.api_key) -------------------------------------------------------------------------------- /data_preprocess/step3_2_preprocess_video_caption.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import argparse 4 | 5 | 6 | def process_json(file_path): 7 | with open(file_path, 'r') as file: 8 | data = json.load(file) 9 | 10 | matched_data = {} 11 | unmatched_data = {} 12 | 13 | for key, value in data.items(): 14 | video_summary_match = re.search(r'"Video_Summary": (.*)', value) 15 | 16 | if video_summary_match: 17 | matched_data[key] = { 18 | "Video_GPT4_Caption": video_summary_match.group(1), 19 | } 20 | else: 21 | unmatched_data[key] = value 22 | 23 | return matched_data, unmatched_data 24 | 25 | def read_json_file(file_path): 26 | """Reads a JSON file and returns its content.""" 27 | with open(file_path, 'r') as file: 28 | return json.load(file) 29 | 30 | def remove_by_Frame(data): 31 | # Initialize dictionaries for matched (to exclude) and unmatched data 32 | to_exclude = {} 33 | to_keep = {} 34 | 35 | # Pattern to identify "by Frame X" in the video summary 36 | pattern = re.compile(r'(by|at|in|on) Frame \d+', re.IGNORECASE) 37 | 38 | for key, value in data.items(): 39 | # Assuming "Video_Summary" is a direct key in the value dictionary 40 | video_summary = value.get("Video_GPT4_Caption", "") 41 | # Check if "by Frame X" is in the video summary 42 | if pattern.search(video_summary): 43 | to_exclude[key] = value 44 | else: 45 | to_keep[key] = value 46 | 47 | return to_keep, to_exclude 48 | 49 | def remove_unmatch_records(gpt_data, unmatched_json_data): 50 | """ 51 | Removes records from gpt_results if their ID exists in disordered_records. 52 | :param gpt_data: dict, the data from gpt_results.json 53 | :param disordered_ids: set, the set of IDs from disordered_records.json 54 | :return: dict, the updated gpt_data with matching records removed 55 | """ 56 | disordered_ids = set(unmatched_json_data.keys()) 57 | return {id_: value for id_, value in gpt_data.items() if id_ not in disordered_ids} 58 | 59 | def save_json_file(data, file_path): 60 | """Saves data to a JSON file.""" 61 | with open(file_path, 'w') as file: 62 | json.dump(data, file, indent=4) 63 | 64 | def merge_json_files(info_data, caption_data): 65 | # Merge info into caption data based on matching key prefixes 66 | for caption_key in caption_data: 67 | for info_key in info_data: 68 | if caption_key.startswith(info_key): 69 | selected_info = {key: info_data[info_key][key] for key in ['title'] if 70 | key in info_data[info_key]} 71 | caption_data[caption_key].update(selected_info) 72 | 73 | break 74 | return caption_data 75 | 76 | if __name__ == "__main__": 77 | # Set up argument parser 78 | parser = argparse.ArgumentParser(description="Process GPT4V video captions and clean up data.") 79 | parser.add_argument("--file_path", type=str, default="./3_1_gpt_video_caption.json", help="Path to the input JSON file.") 80 | parser.add_argument("--updated_file_path", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the updated JSON file.") 81 | parser.add_argument("--unmatched_data_path", type=str, default="./3_2_temp_unmatched_gpt_video_caption.json", help="Path to save unmatched records.") 82 | parser.add_argument("--exclude_by_frame_data_path", type=str, default="./3_2_temp_exclude_by_frame_gpt_video_caption.json", help="Path to save excluded records.") 83 | parser.add_argument("--final_useful_data_path", type=str, default="./3_2_final_useful_gpt_video_caption.json", help="Path to save the final cleaned data.") 84 | 85 | # Parse command-line arguments 86 | args = parser.parse_args() 87 | 88 | # Processing steps 89 | matched_data, unmatched_data = process_json(args.file_path) 90 | to_keep, to_exclude = remove_by_Frame(matched_data) 91 | 92 | # Clean JSON by removing unmatched and excluded records 93 | updated_json = remove_unmatch_records(remove_unmatch_records(read_json_file(args.file_path), unmatched_data), to_exclude) 94 | 95 | # Save intermediate results 96 | save_json_file(unmatched_data, args.unmatched_data_path) 97 | save_json_file(to_exclude, args.exclude_by_frame_data_path) 98 | save_json_file(to_keep, args.final_useful_data_path) 99 | 100 | # Print stats 101 | if len(unmatched_data) != 0 or len(to_exclude) != 0: 102 | save_json_file(updated_json, args.updated_file_path) 103 | print(f"Found {len(unmatched_data)} unmatched_data and {len(to_exclude)} exclude_by_frame_data!") 104 | print(f"Updated JSON file has been saved to {args.updated_file_path}. Please rerun GPT4V for captioning.") 105 | else: 106 | print(f"No unmatched_data and exclude_by_frame_data found! You can directly use {args.final_useful_data_path} for the next step.") -------------------------------------------------------------------------------- /data_preprocess/step4_1_create_webvid_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import pandas as pd 4 | 5 | 6 | def merge_json_files_with_transmit_status(caption_file, output_file): 7 | # Load caption data from JSON file 8 | with open(caption_file, 'r', encoding='utf-8') as file: 9 | caption_data = json.load(file) 10 | 11 | # Extracting data and adding is_transmit status 12 | data = [{ 13 | 'videoid': key, 14 | 'name': value['Video_GPT4_Caption'], 15 | 'is_transmit': '1' # N/A for videos not found in either category 16 | } for key, value in caption_data.items()] 17 | 18 | # Creating a DataFrame from the extracted data 19 | df = pd.DataFrame(data) 20 | 21 | # Saving the DataFrame as a CSV file 22 | df.to_csv(output_file, index=False) 23 | 24 | # Output the path to the saved CSV file 25 | return f"CSV file saved at: {output_file}" 26 | 27 | if __name__ == "__main__": 28 | # Set up argument parser 29 | parser = argparse.ArgumentParser(description="Convert GPT4V video captions JSON to CSV.") 30 | parser.add_argument("--caption_file_path", type=str, default="./3_2_final_useful_gpt_video_caption.json", help="Path to the input JSON caption file.") 31 | parser.add_argument("--output_csv_file_path", type=str, default="./all_clean_data.csv", help="Path to save the output CSV file.") 32 | 33 | # Parse command-line arguments 34 | args = parser.parse_args() 35 | 36 | # Process the JSON and convert it to CSV 37 | merge_json_files_with_transmit_status(args.caption_file_path, args.output_csv_file_path) -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python inference_magictime.py \ 2 | --config sample_configs/RealisticVision.yaml -------------------------------------------------------------------------------- /inference_cli.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python inference_magictime.py \ 2 | --config sample_configs/RealisticVision.yaml \ 3 | --human -------------------------------------------------------------------------------- /inference_magictime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import inspect 5 | import argparse 6 | import pandas as pd 7 | from omegaconf import OmegaConf 8 | from transformers import CLIPTextModel, CLIPTokenizer 9 | from diffusers import AutoencoderKL, DDIMScheduler 10 | from diffusers.utils.import_utils import is_xformers_available 11 | from huggingface_hub import snapshot_download 12 | 13 | from utils.unet import UNet3DConditionModel 14 | from utils.pipeline_magictime import MagicTimePipeline 15 | from utils.util import save_videos_grid 16 | from utils.util import load_weights 17 | 18 | @torch.no_grad() 19 | def main(args): 20 | *_, func_args = inspect.getargvalues(inspect.currentframe()) 21 | func_args = dict(func_args) 22 | 23 | if 'counter' not in globals(): 24 | globals()['counter'] = 0 25 | unique_id = globals()['counter'] 26 | globals()['counter'] += 1 27 | savedir = None 28 | savedir = os.path.join(args.save_path, f"{unique_id}") 29 | while os.path.exists(savedir): 30 | unique_id = globals()['counter'] 31 | globals()['counter'] += 1 32 | savedir = os.path.join(args.save_path, f"{unique_id}") 33 | os.makedirs(savedir, exist_ok=True) 34 | print(f"The results will be save to {savedir}") 35 | 36 | model_config = OmegaConf.load(args.config)[0] 37 | inference_config = OmegaConf.load(args.config)[1] 38 | 39 | if model_config.magic_adapter_s_path: 40 | print("Use MagicAdapter-S") 41 | if model_config.magic_adapter_t_path: 42 | print("Use MagicAdapter-T") 43 | if model_config.magic_text_encoder_path: 44 | print("Use Magic_Text_Encoder") 45 | 46 | tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer") 47 | text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda() 48 | vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda() 49 | unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet", 50 | unet_additional_kwargs=OmegaConf.to_container( 51 | inference_config.unet_additional_kwargs)).cuda() 52 | 53 | if is_xformers_available() and (not args.without_xformers): 54 | unet.enable_xformers_memory_efficient_attention() 55 | 56 | pipeline = MagicTimePipeline( 57 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 58 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), 59 | ).to("cuda") 60 | 61 | pipeline = load_weights( 62 | pipeline, 63 | motion_module_path=model_config.get("motion_module", ""), 64 | dreambooth_model_path=model_config.get("dreambooth_path", ""), 65 | magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""), 66 | magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""), 67 | magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""), 68 | ).to("cuda") 69 | 70 | if args.human: 71 | sample_idx = 0 72 | while True: 73 | user_prompt = input("Enter your prompt (or type 'exit' to quit): ") 74 | if user_prompt.lower() == "exit": 75 | break 76 | 77 | random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() 78 | torch.manual_seed(random_seed) 79 | 80 | print(f"current seed: {random_seed}") 81 | print(f"sampling {user_prompt} ...") 82 | 83 | sample = pipeline( 84 | user_prompt, 85 | negative_prompt = list(model_config.n_prompt), 86 | num_inference_steps = model_config.steps, 87 | guidance_scale = model_config.guidance_scale, 88 | width = model_config.W, 89 | height = model_config.H, 90 | video_length = model_config.L, 91 | ).videos 92 | 93 | prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10]) 94 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.mp4") 95 | print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.mp4") 96 | 97 | sample_idx += 1 98 | else: 99 | default = True 100 | batch_size = args.batch_size 101 | 102 | if args.run_csv: 103 | print("run csv") 104 | default = False 105 | file_path = args.run_csv 106 | data = pd.read_csv(file_path) 107 | prompts = data['name'].tolist() 108 | videoids = data['videoid'].tolist() 109 | elif args.run_json: 110 | print("run json") 111 | default = False 112 | file_path = args.run_json 113 | with open(file_path, 'r') as file: 114 | data = json.load(file) 115 | prompts = [] 116 | videoids = [] 117 | senids = [] 118 | for item in data['sentences']: 119 | prompts.append(item['caption']) 120 | videoids.append(item['video_id']) 121 | senids.append(item['sen_id']) 122 | elif args.run_txt: 123 | print("run txt") 124 | default = False 125 | file_path = args.run_txt 126 | with open(file_path, 'r') as file: 127 | prompts = [line.strip() for line in file.readlines()] 128 | videoids = [f"video_{i}" for i in range(len(prompts))] 129 | else: 130 | prompts = model_config.prompt 131 | videoids = [f"video_{i}" for i in range(len(prompts))] 132 | 133 | for i in range(0, len(prompts), batch_size): 134 | batch_prompts_raw = prompts[i : i + batch_size] 135 | batch_prompts = [prompt for prompt in batch_prompts_raw] 136 | 137 | if args.run_csv or args.run_json or args.run_txt or default: 138 | batch_videoids = videoids[i : i + batch_size] 139 | if args.run_json: 140 | batch_senids = senids[i : i + batch_size] 141 | 142 | flag = True 143 | for idx in range(len(batch_prompts)): 144 | if args.run_csv or args.run_txt or default: 145 | new_filename = f"{batch_videoids[idx]}.mp4" 146 | if args.run_json: 147 | new_filename = f"{batch_videoids[idx]}-{batch_senids[idx]}.mp4" 148 | if not os.path.exists(os.path.join(savedir, new_filename)): 149 | flag = False 150 | break 151 | if flag: 152 | print("skipping") 153 | continue 154 | 155 | n_prompts = list(model_config.n_prompt) * len(batch_prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt 156 | 157 | random_seed = torch.randint(0, 2**32 - 1, (1,)).item() 158 | torch.manual_seed(random_seed) 159 | 160 | print(f"current seed: {random_seed}") 161 | 162 | results = pipeline( 163 | batch_prompts, 164 | negative_prompt = n_prompts, 165 | num_inference_steps = model_config.steps, 166 | guidance_scale = model_config.guidance_scale, 167 | width = model_config.W, 168 | height = model_config.H, 169 | video_length = model_config.L, 170 | ).videos 171 | 172 | for idx, sample in enumerate(results): 173 | if args.run_csv or args.run_txt or default: 174 | new_filename = f"{batch_videoids[idx]}.mp4" 175 | if args.run_json: 176 | new_filename = f"{batch_videoids[idx]}-{batch_senids[idx]}.mp4" 177 | 178 | save_videos_grid(sample.unsqueeze(0), f"{savedir}/{new_filename}") 179 | print(f"save to {savedir}/{new_filename}") 180 | 181 | OmegaConf.save(model_config, f"{savedir}/model_config.yaml") 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--config", type=str, required=True) 186 | parser.add_argument("--without-xformers", action="store_true") 187 | parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation") 188 | parser.add_argument("--run-csv", type=str, default=None) 189 | parser.add_argument("--run-json", type=str, default=None) 190 | parser.add_argument("--run-txt", type=str, default=None) 191 | parser.add_argument("--save-path", type=str, default="outputs") 192 | parser.add_argument("--batch-size", type=int, default=1) 193 | 194 | args = parser.parse_args() 195 | snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir="ckpts") 196 | main(args) 197 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.2 2 | torchvision==0.17.2 3 | torchaudio==2.2.2 4 | xformers==0.0.25.post1 5 | imageio==2.27.0 6 | imageio[ffmpeg] 7 | imageio[pyav] 8 | peft==0.9.0 9 | numpy==1.26.4 10 | ms-swift==2.0.0 11 | accelerate==0.28.0 12 | diffusers==0.11.1 13 | transformers==4.38.2 14 | huggingface_hub==0.25.2 15 | gradio==3.50.2 16 | gdown 17 | triton 18 | einops 19 | omegaconf 20 | safetensors 21 | spaces -------------------------------------------------------------------------------- /sample_configs/RcnzCartoon.yaml: -------------------------------------------------------------------------------- 1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" 2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" 3 | dreambooth_path: "./ckpts/DreamBooth/RcnzCartoon.safetensors" 4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" 5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" 6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" 7 | 8 | H: 512 9 | W: 512 10 | L: 16 11 | seed: [1268480012, 3480796026, 3607977321, 1601344133] 12 | steps: 25 13 | guidance_scale: 8.5 14 | 15 | prompt: 16 | - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney." 17 | - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney." 18 | - "Bean sprouts grow and mature from seeds." 19 | - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence." 20 | 21 | n_prompt: 22 | - "worst quality, low quality, letterboxed" 23 | 24 | - unet_additional_kwargs: 25 | use_inflated_groupnorm: true 26 | use_motion_module: true 27 | motion_module_resolutions: 28 | - 1 29 | - 2 30 | - 4 31 | - 8 32 | motion_module_mid_block: false 33 | motion_module_type: Vanilla 34 | motion_module_kwargs: 35 | num_attention_heads: 8 36 | num_transformer_block: 1 37 | attention_block_types: 38 | - Temporal_Self 39 | - Temporal_Self 40 | temporal_position_encoding: true 41 | temporal_position_encoding_max_len: 32 42 | temporal_attention_dim_div: 1 43 | zero_initialize: true 44 | noise_scheduler_kwargs: 45 | beta_start: 0.00085 46 | beta_end: 0.012 47 | beta_schedule: linear 48 | steps_offset: 1 49 | clip_sample: false -------------------------------------------------------------------------------- /sample_configs/RealisticVision.yaml: -------------------------------------------------------------------------------- 1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" 2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" 3 | dreambooth_path: "./ckpts/DreamBooth/RealisticVisionV60B1_v51VAE.safetensors" 4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" 5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" 6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" 7 | 8 | H: 512 9 | W: 512 10 | L: 16 11 | seed: [1587796317, 2883629116, 3068368949, 2038801077] 12 | steps: 25 13 | guidance_scale: 8.5 14 | 15 | prompt: 16 | - "Time-lapse of dough balls transforming into bread rolls: Begins with smooth, proofed dough, gradually expands in early baking, becomes taut and voluminous, and finally browns and fully expands to signal the baking's completion." 17 | - "Time-lapse of cupcakes progressing through the baking process: starting from liquid batter in cupcake liners, gradually rising with the formation of domes, to fully baked cupcakes with golden, crackled domes." 18 | - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms." 19 | - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms." 20 | 21 | n_prompt: 22 | - "worst quality, low quality, letterboxed" 23 | 24 | - unet_additional_kwargs: 25 | use_inflated_groupnorm: true 26 | use_motion_module: true 27 | motion_module_resolutions: 28 | - 1 29 | - 2 30 | - 4 31 | - 8 32 | motion_module_mid_block: false 33 | motion_module_type: Vanilla 34 | motion_module_kwargs: 35 | num_attention_heads: 8 36 | num_transformer_block: 1 37 | attention_block_types: 38 | - Temporal_Self 39 | - Temporal_Self 40 | temporal_position_encoding: true 41 | temporal_position_encoding_max_len: 32 42 | temporal_attention_dim_div: 1 43 | zero_initialize: true 44 | noise_scheduler_kwargs: 45 | beta_start: 0.00085 46 | beta_end: 0.012 47 | beta_schedule: linear 48 | steps_offset: 1 49 | clip_sample: false -------------------------------------------------------------------------------- /sample_configs/ToonYou.yaml: -------------------------------------------------------------------------------- 1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5" 2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt" 3 | dreambooth_path: "./ckpts/DreamBooth/ToonYou_beta6.safetensors" 4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" 5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t" 6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder" 7 | 8 | H: 512 9 | W: 512 10 | L: 16 11 | seed: [3832738942, 153403692, 10789633, 1496541313] 12 | steps: 25 13 | guidance_scale: 8.5 14 | 15 | prompt: 16 | - "An ice cube is melting." 17 | - "A mesmerizing time-lapse showcasing the elegant unfolding of pink plum buds blossoms, capturing the gradual bloom from tightly sealed buds to fully open flowers." 18 | - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence." 19 | - "Bean sprouts grow and mature from seeds." 20 | 21 | n_prompt: 22 | - "worst quality, low quality, letterboxed" 23 | 24 | - unet_additional_kwargs: 25 | use_inflated_groupnorm: true 26 | use_motion_module: true 27 | motion_module_resolutions: 28 | - 1 29 | - 2 30 | - 4 31 | - 8 32 | motion_module_mid_block: false 33 | motion_module_type: Vanilla 34 | motion_module_kwargs: 35 | num_attention_heads: 8 36 | num_transformer_block: 1 37 | attention_block_types: 38 | - Temporal_Self 39 | - Temporal_Self 40 | temporal_position_encoding: true 41 | temporal_position_encoding_max_len: 32 42 | temporal_attention_dim_div: 1 43 | zero_initialize: true 44 | noise_scheduler_kwargs: 45 | beta_start: 0.00085 46 | beta_end: 0.012 47 | beta_schedule: linear 48 | steps_offset: 1 49 | clip_sample: false -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os, csv, random 2 | import numpy as np 3 | from decord import VideoReader 4 | import torch 5 | import torchvision.transforms as transforms 6 | from torch.utils.data.dataset import Dataset 7 | 8 | 9 | class ChronoMagic(Dataset): 10 | def __init__( 11 | self, 12 | csv_path, video_folder, 13 | sample_size=512, sample_stride=4, sample_n_frames=16, 14 | is_image=False, 15 | is_uniform=True, 16 | ): 17 | with open(csv_path, 'r') as csvfile: 18 | self.dataset = list(csv.DictReader(csvfile)) 19 | self.length = len(self.dataset) 20 | 21 | self.video_folder = video_folder 22 | self.sample_stride = sample_stride 23 | self.sample_n_frames = sample_n_frames 24 | self.is_image = is_image 25 | self.is_uniform = is_uniform 26 | 27 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 28 | self.pixel_transforms = transforms.Compose([ 29 | transforms.RandomHorizontalFlip(), 30 | transforms.Resize(sample_size[0], interpolation=transforms.InterpolationMode.BICUBIC), 31 | transforms.CenterCrop(sample_size), 32 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 33 | ]) 34 | 35 | def _get_frame_indices_adjusted(self, video_length, n_frames): 36 | indices = list(range(video_length)) 37 | additional_frames_needed = n_frames - video_length 38 | 39 | repeat_indices = [] 40 | for i in range(additional_frames_needed): 41 | index_to_repeat = i % video_length 42 | repeat_indices.append(indices[index_to_repeat]) 43 | 44 | all_indices = indices + repeat_indices 45 | all_indices.sort() 46 | 47 | return all_indices 48 | 49 | def _generate_frame_indices(self, video_length, n_frames, sample_stride, is_transmit): 50 | prob_execute_original = 1 if int(is_transmit) == 0 else 0 51 | 52 | # Generate a random number to decide which block of code to execute 53 | if random.random() < prob_execute_original: 54 | if video_length <= n_frames: 55 | return self._get_frame_indices_adjusted(video_length, n_frames) 56 | else: 57 | interval = (video_length - 1) / (n_frames - 1) 58 | indices = [int(round(i * interval)) for i in range(n_frames)] 59 | indices[-1] = video_length - 1 60 | return indices 61 | else: 62 | if video_length <= n_frames: 63 | return self._get_frame_indices_adjusted(video_length, n_frames) 64 | else: 65 | clip_length = min(video_length, (n_frames - 1) * sample_stride + 1) 66 | start_idx = random.randint(0, video_length - clip_length) 67 | return np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist() 68 | 69 | def get_batch(self, idx): 70 | video_dict = self.dataset[idx] 71 | videoid, name, is_transmit = video_dict['videoid'], video_dict['name'], video_dict['is_transmit'] 72 | 73 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 74 | video_reader = VideoReader(video_dir, num_threads=0) 75 | video_length = len(video_reader) 76 | 77 | batch_index = self._generate_frame_indices(video_length, self.sample_n_frames, self.sample_stride, is_transmit) if not self.is_image else [random.randint(0, video_length - 1)] 78 | 79 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2) / 255. 80 | del video_reader 81 | 82 | if self.is_image: 83 | pixel_values = pixel_values[0] 84 | 85 | return pixel_values, name, videoid 86 | 87 | def __len__(self): 88 | return self.length 89 | 90 | def __getitem__(self, idx): 91 | while True: 92 | try: 93 | pixel_values, name, videoid = self.get_batch(idx) 94 | break 95 | 96 | except Exception as e: 97 | idx = random.randint(0, self.length-1) 98 | 99 | pixel_values = self.pixel_transforms(pixel_values) 100 | sample = dict(pixel_values=pixel_values, text=name, id=videoid) 101 | return sample -------------------------------------------------------------------------------- /utils/pipeline_magictime.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/pipelines/pipeline_animation.py 2 | 3 | import torch 4 | import inspect 5 | import numpy as np 6 | from tqdm import tqdm 7 | from einops import rearrange 8 | from packaging import version 9 | from dataclasses import dataclass 10 | from typing import Callable, List, Optional, Union 11 | from transformers import CLIPTextModel, CLIPTokenizer 12 | 13 | from diffusers.utils import is_accelerate_available, deprecate, logging, BaseOutput 14 | from diffusers.configuration_utils import FrozenDict 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.pipeline_utils import DiffusionPipeline 17 | from diffusers.schedulers import ( 18 | DDIMScheduler, 19 | DPMSolverMultistepScheduler, 20 | EulerAncestralDiscreteScheduler, 21 | EulerDiscreteScheduler, 22 | LMSDiscreteScheduler, 23 | PNDMScheduler, 24 | ) 25 | 26 | from .unet import UNet3DConditionModel 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | @dataclass 31 | class MagicTimePipelineOutput(BaseOutput): 32 | videos: Union[torch.Tensor, np.ndarray] 33 | 34 | class MagicTimePipeline(DiffusionPipeline): 35 | _optional_components = [] 36 | 37 | def __init__( 38 | self, 39 | vae: AutoencoderKL, 40 | text_encoder: CLIPTextModel, 41 | tokenizer: CLIPTokenizer, 42 | unet: UNet3DConditionModel, 43 | scheduler: Union[ 44 | DDIMScheduler, 45 | PNDMScheduler, 46 | LMSDiscreteScheduler, 47 | EulerDiscreteScheduler, 48 | EulerAncestralDiscreteScheduler, 49 | DPMSolverMultistepScheduler, 50 | ], 51 | ): 52 | super().__init__() 53 | 54 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 55 | deprecation_message = ( 56 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 57 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 58 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 59 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 60 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 61 | " file" 62 | ) 63 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 64 | new_config = dict(scheduler.config) 65 | new_config["steps_offset"] = 1 66 | scheduler._internal_dict = FrozenDict(new_config) 67 | 68 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 69 | deprecation_message = ( 70 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 71 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 72 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 73 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 74 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 75 | ) 76 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 77 | new_config = dict(scheduler.config) 78 | new_config["clip_sample"] = False 79 | scheduler._internal_dict = FrozenDict(new_config) 80 | 81 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 82 | version.parse(unet.config._diffusers_version).base_version 83 | ) < version.parse("0.9.0.dev0") 84 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 85 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 86 | deprecation_message = ( 87 | "The configuration file of the unet has set the default `sample_size` to smaller than" 88 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 89 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 90 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 91 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 92 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 93 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 94 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 95 | " the `unet/config.json` file" 96 | ) 97 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 98 | new_config = dict(unet.config) 99 | new_config["sample_size"] = 64 100 | unet._internal_dict = FrozenDict(new_config) 101 | 102 | self.register_modules( 103 | vae=vae, 104 | text_encoder=text_encoder, 105 | tokenizer=tokenizer, 106 | unet=unet, 107 | scheduler=scheduler, 108 | ) 109 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 110 | 111 | def enable_vae_slicing(self): 112 | self.vae.enable_slicing() 113 | 114 | def disable_vae_slicing(self): 115 | self.vae.disable_slicing() 116 | 117 | def enable_sequential_cpu_offload(self, gpu_id=0): 118 | if is_accelerate_available(): 119 | from accelerate import cpu_offload 120 | else: 121 | raise ImportError("Please install accelerate via `pip install accelerate`") 122 | 123 | device = torch.device(f"cuda:{gpu_id}") 124 | 125 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 126 | if cpu_offloaded_model is not None: 127 | cpu_offload(cpu_offloaded_model, device) 128 | 129 | 130 | @property 131 | def _execution_device(self): 132 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 133 | return self.device 134 | for module in self.unet.modules(): 135 | if ( 136 | hasattr(module, "_hf_hook") 137 | and hasattr(module._hf_hook, "execution_device") 138 | and module._hf_hook.execution_device is not None 139 | ): 140 | return torch.device(module._hf_hook.execution_device) 141 | return self.device 142 | 143 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 144 | batch_size = len(prompt) if isinstance(prompt, list) else 1 145 | 146 | text_inputs = self.tokenizer( 147 | prompt, 148 | padding="max_length", 149 | max_length=self.tokenizer.model_max_length, 150 | truncation=True, 151 | return_tensors="pt", 152 | ) 153 | text_input_ids = text_inputs.input_ids 154 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 155 | 156 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 157 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 158 | logger.warning( 159 | "The following part of your input was truncated because CLIP can only handle sequences up to" 160 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 161 | ) 162 | 163 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 164 | attention_mask = text_inputs.attention_mask.to(device) 165 | else: 166 | attention_mask = None 167 | 168 | text_embeddings = self.text_encoder( 169 | text_input_ids.to(device), 170 | attention_mask=attention_mask, 171 | ) 172 | text_embeddings = text_embeddings[0] 173 | 174 | # duplicate text embeddings for each generation per prompt, using mps friendly method 175 | bs_embed, seq_len, _ = text_embeddings.shape 176 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 177 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 178 | 179 | # get unconditional embeddings for classifier free guidance 180 | if do_classifier_free_guidance: 181 | uncond_tokens: List[str] 182 | if negative_prompt is None: 183 | uncond_tokens = [""] * batch_size 184 | elif type(prompt) is not type(negative_prompt): 185 | raise TypeError( 186 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 187 | f" {type(prompt)}." 188 | ) 189 | elif isinstance(negative_prompt, str): 190 | uncond_tokens = [negative_prompt] 191 | elif batch_size != len(negative_prompt): 192 | raise ValueError( 193 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 194 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 195 | " the batch size of `prompt`." 196 | ) 197 | else: 198 | uncond_tokens = negative_prompt 199 | 200 | max_length = text_input_ids.shape[-1] 201 | uncond_input = self.tokenizer( 202 | uncond_tokens, 203 | padding="max_length", 204 | max_length=max_length, 205 | truncation=True, 206 | return_tensors="pt", 207 | ) 208 | 209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 210 | attention_mask = uncond_input.attention_mask.to(device) 211 | else: 212 | attention_mask = None 213 | 214 | uncond_embeddings = self.text_encoder( 215 | uncond_input.input_ids.to(device), 216 | attention_mask=attention_mask, 217 | ) 218 | uncond_embeddings = uncond_embeddings[0] 219 | 220 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 221 | seq_len = uncond_embeddings.shape[1] 222 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 223 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 224 | 225 | # For classifier free guidance, we need to do two forward passes. 226 | # Here we concatenate the unconditional and text embeddings into a single batch 227 | # to avoid doing two forward passes 228 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 229 | 230 | return text_embeddings 231 | 232 | def decode_latents(self, latents): 233 | video_length = latents.shape[2] 234 | latents = 1 / 0.18215 * latents 235 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 236 | # video = self.vae.decode(latents).sample 237 | video = [] 238 | for frame_idx in tqdm(range(latents.shape[0])): 239 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) 240 | video = torch.cat(video) 241 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 242 | video = (video / 2 + 0.5).clamp(0, 1) 243 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 244 | video = video.cpu().float().numpy() 245 | return video 246 | 247 | def prepare_extra_step_kwargs(self, generator, eta): 248 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 249 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 250 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 251 | # and should be between [0, 1] 252 | 253 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 254 | extra_step_kwargs = {} 255 | if accepts_eta: 256 | extra_step_kwargs["eta"] = eta 257 | 258 | # check if the scheduler accepts generator 259 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 260 | if accepts_generator: 261 | extra_step_kwargs["generator"] = generator 262 | return extra_step_kwargs 263 | 264 | def check_inputs(self, prompt, height, width, callback_steps): 265 | if not isinstance(prompt, str) and not isinstance(prompt, list): 266 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 267 | 268 | if height % 8 != 0 or width % 8 != 0: 269 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 270 | 271 | if (callback_steps is None) or ( 272 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 273 | ): 274 | raise ValueError( 275 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 276 | f" {type(callback_steps)}." 277 | ) 278 | 279 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): 280 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 281 | if isinstance(generator, list) and len(generator) != batch_size: 282 | raise ValueError( 283 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 284 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 285 | ) 286 | if latents is None: 287 | rand_device = "cpu" if device.type == "mps" else device 288 | 289 | if isinstance(generator, list): 290 | shape = shape 291 | # shape = (1,) + shape[1:] 292 | latents = [ 293 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 294 | for i in range(batch_size) 295 | ] 296 | latents = torch.cat(latents, dim=0).to(device) 297 | else: 298 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 299 | else: 300 | if latents.shape != shape: 301 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 302 | latents = latents.to(device) 303 | 304 | # scale the initial noise by the standard deviation required by the scheduler 305 | latents = latents * self.scheduler.init_noise_sigma 306 | return latents 307 | 308 | @torch.no_grad() 309 | def __call__( 310 | self, 311 | prompt: Union[str, List[str]], 312 | video_length: Optional[int], 313 | height: Optional[int] = None, 314 | width: Optional[int] = None, 315 | num_inference_steps: int = 50, 316 | guidance_scale: float = 7.5, 317 | negative_prompt: Optional[Union[str, List[str]]] = None, 318 | num_videos_per_prompt: Optional[int] = 1, 319 | eta: float = 0.0, 320 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 321 | latents: Optional[torch.FloatTensor] = None, 322 | output_type: Optional[str] = "tensor", 323 | return_dict: bool = True, 324 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 325 | callback_steps: Optional[int] = 1, 326 | **kwargs, 327 | ): 328 | # Default height and width to unet 329 | height = height or self.unet.config.sample_size * self.vae_scale_factor 330 | width = width or self.unet.config.sample_size * self.vae_scale_factor 331 | 332 | # Check inputs. Raise error if not correct 333 | self.check_inputs(prompt, height, width, callback_steps) 334 | 335 | # Define call parameters 336 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 337 | batch_size = 1 338 | if latents is not None: 339 | batch_size = latents.shape[0] 340 | if isinstance(prompt, list): 341 | batch_size = len(prompt) 342 | 343 | device = self._execution_device 344 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 345 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 346 | # corresponds to doing no classifier free guidance. 347 | do_classifier_free_guidance = guidance_scale > 1.0 348 | 349 | # Encode input prompt 350 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 351 | if negative_prompt is not None: 352 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 353 | text_embeddings = self._encode_prompt( 354 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 355 | ) 356 | 357 | # Prepare timesteps 358 | self.scheduler.set_timesteps(num_inference_steps, device=device) 359 | timesteps = self.scheduler.timesteps 360 | 361 | # Prepare latent variables 362 | num_channels_latents = self.unet.in_channels 363 | latents = self.prepare_latents( 364 | batch_size * num_videos_per_prompt, 365 | num_channels_latents, 366 | video_length, 367 | height, 368 | width, 369 | text_embeddings.dtype, 370 | device, 371 | generator, 372 | latents, 373 | ) 374 | latents_dtype = latents.dtype 375 | 376 | # Prepare extra step kwargs. 377 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 378 | 379 | # Denoising loop 380 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 381 | with self.progress_bar(total=num_inference_steps) as progress_bar: 382 | for i, t in enumerate(timesteps): 383 | # expand the latents if we are doing classifier free guidance 384 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 385 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 386 | 387 | down_block_additional_residuals = mid_block_additional_residual = None 388 | 389 | # predict the noise residual 390 | noise_pred = self.unet( 391 | latent_model_input, t, 392 | encoder_hidden_states=text_embeddings, 393 | down_block_additional_residuals = down_block_additional_residuals, 394 | mid_block_additional_residual = mid_block_additional_residual, 395 | ).sample.to(dtype=latents_dtype) 396 | 397 | # perform guidance 398 | if do_classifier_free_guidance: 399 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 400 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 401 | 402 | # compute the previous noisy sample x_t -> x_t-1 403 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 404 | 405 | # call the callback, if provided 406 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 407 | progress_bar.update() 408 | if callback is not None and i % callback_steps == 0: 409 | callback(i, t, latents) 410 | 411 | # Post-processing 412 | video = self.decode_latents(latents) 413 | 414 | # Convert to tensor 415 | if output_type == "tensor": 416 | video = torch.from_numpy(video) 417 | 418 | if not return_dict: 419 | return video 420 | 421 | return MagicTimePipelineOutput(videos=video) 422 | -------------------------------------------------------------------------------- /utils/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/models/unet.py 2 | import os 3 | import json 4 | import pdb 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers.modeling_utils import ModelMixin 14 | from diffusers.utils import BaseOutput, logging 15 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 16 | from .unet_blocks import ( 17 | CrossAttnDownBlock3D, 18 | CrossAttnUpBlock3D, 19 | DownBlock3D, 20 | UNetMidBlock3DCrossAttn, 21 | UpBlock3D, 22 | get_down_block, 23 | get_up_block, 24 | InflatedConv3d, 25 | InflatedGroupNorm, 26 | ) 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | class UNet3DConditionOutput(BaseOutput): 33 | sample: torch.FloatTensor 34 | 35 | 36 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 37 | _supports_gradient_checkpointing = True 38 | 39 | @register_to_config 40 | def __init__( 41 | self, 42 | sample_size: Optional[int] = None, 43 | in_channels: int = 4, 44 | out_channels: int = 4, 45 | center_input_sample: bool = False, 46 | flip_sin_to_cos: bool = True, 47 | freq_shift: int = 0, 48 | down_block_types: Tuple[str] = ( 49 | "CrossAttnDownBlock3D", 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "DownBlock3D", 53 | ), 54 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 55 | up_block_types: Tuple[str] = ( 56 | "UpBlock3D", 57 | "CrossAttnUpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D" 60 | ), 61 | only_cross_attention: Union[bool, Tuple[bool]] = False, 62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 63 | layers_per_block: int = 2, 64 | downsample_padding: int = 1, 65 | mid_block_scale_factor: float = 1, 66 | act_fn: str = "silu", 67 | norm_num_groups: int = 32, 68 | norm_eps: float = 1e-5, 69 | cross_attention_dim: int = 1280, 70 | attention_head_dim: Union[int, Tuple[int]] = 8, 71 | dual_cross_attention: bool = False, 72 | use_linear_projection: bool = False, 73 | class_embed_type: Optional[str] = None, 74 | num_class_embeds: Optional[int] = None, 75 | upcast_attention: bool = False, 76 | resnet_time_scale_shift: str = "default", 77 | 78 | use_inflated_groupnorm=False, 79 | 80 | # Additional 81 | use_motion_module = False, 82 | motion_module_resolutions = ( 1,2,4,8 ), 83 | motion_module_mid_block = False, 84 | motion_module_decoder_only = False, 85 | motion_module_type = None, 86 | motion_module_kwargs = {}, 87 | unet_use_cross_frame_attention = False, 88 | unet_use_temporal_attention = False, 89 | ): 90 | super().__init__() 91 | 92 | self.sample_size = sample_size 93 | time_embed_dim = block_out_channels[0] * 4 94 | 95 | # input 96 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 97 | 98 | # time 99 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 100 | timestep_input_dim = block_out_channels[0] 101 | 102 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 103 | 104 | # class embedding 105 | if class_embed_type is None and num_class_embeds is not None: 106 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 107 | elif class_embed_type == "timestep": 108 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 109 | elif class_embed_type == "identity": 110 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 111 | else: 112 | self.class_embedding = None 113 | 114 | self.down_blocks = nn.ModuleList([]) 115 | self.mid_block = None 116 | self.up_blocks = nn.ModuleList([]) 117 | 118 | if isinstance(only_cross_attention, bool): 119 | only_cross_attention = [only_cross_attention] * len(down_block_types) 120 | 121 | if isinstance(attention_head_dim, int): 122 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 123 | 124 | # down 125 | output_channel = block_out_channels[0] 126 | for i, down_block_type in enumerate(down_block_types): 127 | res = 2 ** i 128 | input_channel = output_channel 129 | output_channel = block_out_channels[i] 130 | is_final_block = i == len(block_out_channels) - 1 131 | 132 | down_block = get_down_block( 133 | down_block_type, 134 | num_layers=layers_per_block, 135 | in_channels=input_channel, 136 | out_channels=output_channel, 137 | temb_channels=time_embed_dim, 138 | add_downsample=not is_final_block, 139 | resnet_eps=norm_eps, 140 | resnet_act_fn=act_fn, 141 | resnet_groups=norm_num_groups, 142 | cross_attention_dim=cross_attention_dim, 143 | attn_num_head_channels=attention_head_dim[i], 144 | downsample_padding=downsample_padding, 145 | dual_cross_attention=dual_cross_attention, 146 | use_linear_projection=use_linear_projection, 147 | only_cross_attention=only_cross_attention[i], 148 | upcast_attention=upcast_attention, 149 | resnet_time_scale_shift=resnet_time_scale_shift, 150 | 151 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 152 | unet_use_temporal_attention=unet_use_temporal_attention, 153 | use_inflated_groupnorm=use_inflated_groupnorm, 154 | 155 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 156 | motion_module_type=motion_module_type, 157 | motion_module_kwargs=motion_module_kwargs, 158 | ) 159 | self.down_blocks.append(down_block) 160 | 161 | # mid 162 | if mid_block_type == "UNetMidBlock3DCrossAttn": 163 | self.mid_block = UNetMidBlock3DCrossAttn( 164 | in_channels=block_out_channels[-1], 165 | temb_channels=time_embed_dim, 166 | resnet_eps=norm_eps, 167 | resnet_act_fn=act_fn, 168 | output_scale_factor=mid_block_scale_factor, 169 | resnet_time_scale_shift=resnet_time_scale_shift, 170 | cross_attention_dim=cross_attention_dim, 171 | attn_num_head_channels=attention_head_dim[-1], 172 | resnet_groups=norm_num_groups, 173 | dual_cross_attention=dual_cross_attention, 174 | use_linear_projection=use_linear_projection, 175 | upcast_attention=upcast_attention, 176 | 177 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 178 | unet_use_temporal_attention=unet_use_temporal_attention, 179 | use_inflated_groupnorm=use_inflated_groupnorm, 180 | 181 | use_motion_module=use_motion_module and motion_module_mid_block, 182 | motion_module_type=motion_module_type, 183 | motion_module_kwargs=motion_module_kwargs, 184 | ) 185 | else: 186 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 187 | 188 | # count how many layers upsample the videos 189 | self.num_upsamplers = 0 190 | 191 | # up 192 | reversed_block_out_channels = list(reversed(block_out_channels)) 193 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 194 | only_cross_attention = list(reversed(only_cross_attention)) 195 | output_channel = reversed_block_out_channels[0] 196 | for i, up_block_type in enumerate(up_block_types): 197 | res = 2 ** (3 - i) 198 | is_final_block = i == len(block_out_channels) - 1 199 | 200 | prev_output_channel = output_channel 201 | output_channel = reversed_block_out_channels[i] 202 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 203 | 204 | # add upsample block for all BUT final layer 205 | if not is_final_block: 206 | add_upsample = True 207 | self.num_upsamplers += 1 208 | else: 209 | add_upsample = False 210 | 211 | up_block = get_up_block( 212 | up_block_type, 213 | num_layers=layers_per_block + 1, 214 | in_channels=input_channel, 215 | out_channels=output_channel, 216 | prev_output_channel=prev_output_channel, 217 | temb_channels=time_embed_dim, 218 | add_upsample=add_upsample, 219 | resnet_eps=norm_eps, 220 | resnet_act_fn=act_fn, 221 | resnet_groups=norm_num_groups, 222 | cross_attention_dim=cross_attention_dim, 223 | attn_num_head_channels=reversed_attention_head_dim[i], 224 | dual_cross_attention=dual_cross_attention, 225 | use_linear_projection=use_linear_projection, 226 | only_cross_attention=only_cross_attention[i], 227 | upcast_attention=upcast_attention, 228 | resnet_time_scale_shift=resnet_time_scale_shift, 229 | 230 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 231 | unet_use_temporal_attention=unet_use_temporal_attention, 232 | use_inflated_groupnorm=use_inflated_groupnorm, 233 | 234 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 235 | motion_module_type=motion_module_type, 236 | motion_module_kwargs=motion_module_kwargs, 237 | ) 238 | self.up_blocks.append(up_block) 239 | prev_output_channel = output_channel 240 | 241 | # out 242 | if use_inflated_groupnorm: 243 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 244 | else: 245 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 246 | self.conv_act = nn.SiLU() 247 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 248 | 249 | def set_attention_slice(self, slice_size): 250 | r""" 251 | Enable sliced attention computation. 252 | 253 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 254 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 255 | 256 | Args: 257 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 258 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 259 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 260 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 261 | must be a multiple of `slice_size`. 262 | """ 263 | sliceable_head_dims = [] 264 | 265 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 266 | if hasattr(module, "set_attention_slice"): 267 | sliceable_head_dims.append(module.sliceable_head_dim) 268 | 269 | for child in module.children(): 270 | fn_recursive_retrieve_slicable_dims(child) 271 | 272 | # retrieve number of attention layers 273 | for module in self.children(): 274 | fn_recursive_retrieve_slicable_dims(module) 275 | 276 | num_slicable_layers = len(sliceable_head_dims) 277 | 278 | if slice_size == "auto": 279 | # half the attention head size is usually a good trade-off between 280 | # speed and memory 281 | slice_size = [dim // 2 for dim in sliceable_head_dims] 282 | elif slice_size == "max": 283 | # make smallest slice possible 284 | slice_size = num_slicable_layers * [1] 285 | 286 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 287 | 288 | if len(slice_size) != len(sliceable_head_dims): 289 | raise ValueError( 290 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 291 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 292 | ) 293 | 294 | for i in range(len(slice_size)): 295 | size = slice_size[i] 296 | dim = sliceable_head_dims[i] 297 | if size is not None and size > dim: 298 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 299 | 300 | # Recursively walk through all the children. 301 | # Any children which exposes the set_attention_slice method 302 | # gets the message 303 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 304 | if hasattr(module, "set_attention_slice"): 305 | module.set_attention_slice(slice_size.pop()) 306 | 307 | for child in module.children(): 308 | fn_recursive_set_attention_slice(child, slice_size) 309 | 310 | reversed_slice_size = list(reversed(slice_size)) 311 | for module in self.children(): 312 | fn_recursive_set_attention_slice(module, reversed_slice_size) 313 | 314 | def _set_gradient_checkpointing(self, module, value=False): 315 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 316 | module.gradient_checkpointing = value 317 | 318 | def forward( 319 | self, 320 | sample: torch.FloatTensor, 321 | timestep: Union[torch.Tensor, float, int], 322 | encoder_hidden_states: torch.Tensor, 323 | class_labels: Optional[torch.Tensor] = None, 324 | attention_mask: Optional[torch.Tensor] = None, 325 | 326 | # support controlnet 327 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 328 | mid_block_additional_residual: Optional[torch.Tensor] = None, 329 | 330 | return_dict: bool = True, 331 | ) -> Union[UNet3DConditionOutput, Tuple]: 332 | r""" 333 | Args: 334 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 335 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 336 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 337 | return_dict (`bool`, *optional*, defaults to `True`): 338 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 339 | 340 | Returns: 341 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 342 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 343 | returning a tuple, the first element is the sample tensor. 344 | """ 345 | # By default samples have to be AT least a multiple of the overall upsampling factor. 346 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 347 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 348 | # on the fly if necessary. 349 | default_overall_up_factor = 2**self.num_upsamplers 350 | 351 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 352 | forward_upsample_size = False 353 | upsample_size = None 354 | 355 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 356 | logger.info("Forward upsample size to force interpolation output size.") 357 | forward_upsample_size = True 358 | 359 | # prepare attention_mask 360 | if attention_mask is not None: 361 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 362 | attention_mask = attention_mask.unsqueeze(1) 363 | 364 | # center input if necessary 365 | if self.config.center_input_sample: 366 | sample = 2 * sample - 1.0 367 | 368 | # time 369 | timesteps = timestep 370 | if not torch.is_tensor(timesteps): 371 | # This would be a good case for the `match` statement (Python 3.10+) 372 | is_mps = sample.device.type == "mps" 373 | if isinstance(timestep, float): 374 | dtype = torch.float32 if is_mps else torch.float64 375 | else: 376 | dtype = torch.int32 if is_mps else torch.int64 377 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 378 | elif len(timesteps.shape) == 0: 379 | timesteps = timesteps[None].to(sample.device) 380 | 381 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 382 | timesteps = timesteps.expand(sample.shape[0]) 383 | 384 | t_emb = self.time_proj(timesteps) 385 | 386 | # timesteps does not contain any weights and will always return f32 tensors 387 | # but time_embedding might actually be running in fp16. so we need to cast here. 388 | # there might be better ways to encapsulate this. 389 | t_emb = t_emb.to(dtype=self.dtype) 390 | emb = self.time_embedding(t_emb) 391 | 392 | if self.class_embedding is not None: 393 | if class_labels is None: 394 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 395 | 396 | if self.config.class_embed_type == "timestep": 397 | class_labels = self.time_proj(class_labels) 398 | 399 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 400 | emb = emb + class_emb 401 | 402 | # pre-process 403 | sample = self.conv_in(sample) 404 | 405 | # down 406 | down_block_res_samples = (sample,) 407 | for downsample_block in self.down_blocks: 408 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 409 | sample, res_samples = downsample_block( 410 | hidden_states=sample, 411 | temb=emb, 412 | encoder_hidden_states=encoder_hidden_states, 413 | attention_mask=attention_mask, 414 | ) 415 | else: 416 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 417 | 418 | down_block_res_samples += res_samples 419 | 420 | # support controlnet 421 | down_block_res_samples = list(down_block_res_samples) 422 | if down_block_additional_residuals is not None: 423 | for i, down_block_additional_residual in enumerate(down_block_additional_residuals): 424 | if down_block_additional_residual.dim() == 4: # boardcast 425 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2) 426 | down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual 427 | 428 | # mid 429 | sample = self.mid_block( 430 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 431 | ) 432 | 433 | # support controlnet 434 | if mid_block_additional_residual is not None: 435 | if mid_block_additional_residual.dim() == 4: # boardcast 436 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) 437 | sample = sample + mid_block_additional_residual 438 | 439 | # up 440 | for i, upsample_block in enumerate(self.up_blocks): 441 | is_final_block = i == len(self.up_blocks) - 1 442 | 443 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 444 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 445 | 446 | # if we have not reached the final block and need to forward the 447 | # upsample size, we do it here 448 | if not is_final_block and forward_upsample_size: 449 | upsample_size = down_block_res_samples[-1].shape[2:] 450 | 451 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 452 | sample = upsample_block( 453 | hidden_states=sample, 454 | temb=emb, 455 | res_hidden_states_tuple=res_samples, 456 | encoder_hidden_states=encoder_hidden_states, 457 | upsample_size=upsample_size, 458 | attention_mask=attention_mask, 459 | ) 460 | else: 461 | sample = upsample_block( 462 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, 463 | ) 464 | 465 | # post-process 466 | sample = self.conv_norm_out(sample) 467 | sample = self.conv_act(sample) 468 | sample = self.conv_out(sample) 469 | 470 | if not return_dict: 471 | return (sample,) 472 | 473 | return UNet3DConditionOutput(sample=sample) 474 | 475 | @classmethod 476 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 477 | if subfolder is not None: 478 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 479 | print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") 480 | 481 | config_file = os.path.join(pretrained_model_path, 'config.json') 482 | if not os.path.isfile(config_file): 483 | raise RuntimeError(f"{config_file} does not exist") 484 | with open(config_file, "r") as f: 485 | config = json.load(f) 486 | config["_class_name"] = cls.__name__ 487 | config["down_block_types"] = [ 488 | "CrossAttnDownBlock3D", 489 | "CrossAttnDownBlock3D", 490 | "CrossAttnDownBlock3D", 491 | "DownBlock3D" 492 | ] 493 | config["up_block_types"] = [ 494 | "UpBlock3D", 495 | "CrossAttnUpBlock3D", 496 | "CrossAttnUpBlock3D", 497 | "CrossAttnUpBlock3D" 498 | ] 499 | 500 | from diffusers.utils import WEIGHTS_NAME 501 | model = cls.from_config(config, **unet_additional_kwargs) 502 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 503 | if not os.path.isfile(model_file): 504 | raise RuntimeError(f"{model_file} does not exist") 505 | state_dict = torch.load(model_file, map_location="cpu") 506 | 507 | m, u = model.load_state_dict(state_dict, strict=False) 508 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 509 | 510 | params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] 511 | print(f"### Motion Module Parameters: {sum(params) / 1e6} M") 512 | 513 | return model 514 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from tqdm import tqdm 5 | from typing import Union 6 | from einops import rearrange 7 | from safetensors import safe_open 8 | from transformers import CLIPTextModel 9 | import torch 10 | import torchvision 11 | import torch.distributed as dist 12 | 13 | def zero_rank_print(s): 14 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 15 | 16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 17 | videos = rearrange(videos, "b c t h w -> t b c h w") 18 | outputs = [] 19 | for x in videos: 20 | x = torchvision.utils.make_grid(x, nrow=n_rows) 21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 22 | if rescale: 23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 24 | x = (x * 255).numpy().astype(np.uint8) 25 | outputs.append(x) 26 | 27 | os.makedirs(os.path.dirname(path), exist_ok=True) 28 | imageio.mimsave(path, outputs, fps=fps) 29 | 30 | # DDIM Inversion 31 | @torch.no_grad() 32 | def init_prompt(prompt, pipeline): 33 | uncond_input = pipeline.tokenizer( 34 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 35 | return_tensors="pt" 36 | ) 37 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 38 | text_input = pipeline.tokenizer( 39 | [prompt], 40 | padding="max_length", 41 | max_length=pipeline.tokenizer.model_max_length, 42 | truncation=True, 43 | return_tensors="pt", 44 | ) 45 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 46 | context = torch.cat([uncond_embeddings, text_embeddings]) 47 | 48 | return context 49 | 50 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 51 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 52 | timestep, next_timestep = min( 53 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 54 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 55 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 56 | beta_prod_t = 1 - alpha_prod_t 57 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 58 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 59 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 60 | return next_sample 61 | 62 | def get_noise_pred_single(latents, t, context, unet): 63 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 64 | return noise_pred 65 | 66 | @torch.no_grad() 67 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 68 | context = init_prompt(prompt, pipeline) 69 | uncond_embeddings, cond_embeddings = context.chunk(2) 70 | all_latent = [latent] 71 | latent = latent.clone().detach() 72 | for i in tqdm(range(num_inv_steps)): 73 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 74 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 75 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 76 | all_latent.append(latent) 77 | return all_latent 78 | 79 | @torch.no_grad() 80 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 81 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 82 | return ddim_latents 83 | 84 | def load_weights( 85 | magictime_pipeline, 86 | motion_module_path = "", 87 | dreambooth_model_path = "", 88 | magic_adapter_s_path = "", 89 | magic_adapter_t_path = "", 90 | magic_text_encoder_path = "", 91 | ): 92 | # motion module 93 | unet_state_dict = {} 94 | if motion_module_path != "": 95 | print(f"load motion module from {motion_module_path}") 96 | try: 97 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 98 | if "state_dict" in motion_module_state_dict: 99 | motion_module_state_dict = motion_module_state_dict["state_dict"] 100 | for name, param in motion_module_state_dict.items(): 101 | if "motion_modules." in name: 102 | modified_name = name.removeprefix('module.') if name.startswith('module.') else name 103 | unet_state_dict[modified_name] = param 104 | except Exception as e: 105 | print(f"Error loading motion module: {e}") 106 | try: 107 | missing, unexpected = magictime_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 108 | assert len(unexpected) == 0, f"Unexpected keys in state_dict: {unexpected}" 109 | del unet_state_dict 110 | except Exception as e: 111 | print(f"Error loading state dict into UNet: {e}") 112 | 113 | # base model 114 | if dreambooth_model_path != "": 115 | print(f"load dreambooth model from {dreambooth_model_path}") 116 | if dreambooth_model_path.endswith(".safetensors"): 117 | dreambooth_state_dict = {} 118 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 119 | for key in f.keys(): 120 | dreambooth_state_dict[key] = f.get_tensor(key) 121 | elif dreambooth_model_path.endswith(".ckpt"): 122 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 123 | 124 | # 1. vae 125 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, magictime_pipeline.vae.config) 126 | magictime_pipeline.vae.load_state_dict(converted_vae_checkpoint) 127 | # 2. unet 128 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, magictime_pipeline.unet.config) 129 | magictime_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 130 | # 3. text_model 131 | magictime_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 132 | del dreambooth_state_dict 133 | 134 | # MagicAdapter and MagicTextEncoder 135 | if magic_adapter_s_path != "": 136 | print(f"load domain lora from {magic_adapter_s_path}") 137 | magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") 138 | magictime_pipeline = load_diffusers_lora(magictime_pipeline, magic_adapter_s_state_dict, alpha=1.0) 139 | 140 | if magic_adapter_t_path != "" or magic_text_encoder_path != "": 141 | from swift import Swift 142 | 143 | if magic_adapter_t_path != "": 144 | print("load lora from swift for Unet") 145 | Swift.from_pretrained(magictime_pipeline.unet, magic_adapter_t_path) 146 | 147 | if magic_text_encoder_path != "": 148 | print("load lora from swift for text encoder") 149 | Swift.from_pretrained(magictime_pipeline.text_encoder, magic_text_encoder_path) 150 | 151 | return magictime_pipeline 152 | 153 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 154 | # directly update weight in diffusers model 155 | for key in state_dict: 156 | # only process lora down key 157 | if "up." in key: continue 158 | 159 | up_key = key.replace(".down.", ".up.") 160 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 161 | model_key = model_key.replace("to_out.", "to_out.0.") 162 | layer_infos = model_key.split(".")[:-1] 163 | 164 | curr_layer = pipeline.unet 165 | while len(layer_infos) > 0: 166 | temp_name = layer_infos.pop(0) 167 | curr_layer = curr_layer.__getattr__(temp_name) 168 | 169 | weight_down = state_dict[key] * 2 170 | weight_up = state_dict[up_key] * 2 171 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 172 | 173 | return pipeline 174 | 175 | def load_diffusers_lora_unet(unet, state_dict, alpha=1.0): 176 | # directly update weight in diffusers model 177 | for key in state_dict: 178 | # only process lora down key 179 | if "up." in key: continue 180 | 181 | up_key = key.replace(".down.", ".up.") 182 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 183 | model_key = model_key.replace("to_out.", "to_out.0.") 184 | layer_infos = model_key.split(".")[:-1] 185 | 186 | curr_layer = unet 187 | while len(layer_infos) > 0: 188 | temp_name = layer_infos.pop(0) 189 | curr_layer = curr_layer.__getattr__(temp_name) 190 | 191 | weight_down = state_dict[key] * 2 192 | weight_up = state_dict[up_key] * 2 193 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 194 | 195 | return unet 196 | 197 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 198 | visited = [] 199 | 200 | # directly update weight in diffusers model 201 | for key in state_dict: 202 | # it is suggested to print out the key, it usually will be something like below 203 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 204 | 205 | # as we have set the alpha beforehand, so just skip 206 | if ".alpha" in key or key in visited: 207 | continue 208 | 209 | if "text" in key: 210 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 211 | curr_layer = pipeline.text_encoder 212 | else: 213 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 214 | curr_layer = pipeline.unet 215 | 216 | # find the target layer 217 | temp_name = layer_infos.pop(0) 218 | while len(layer_infos) > -1: 219 | try: 220 | curr_layer = curr_layer.__getattr__(temp_name) 221 | if len(layer_infos) > 0: 222 | temp_name = layer_infos.pop(0) 223 | elif len(layer_infos) == 0: 224 | break 225 | except Exception: 226 | if len(temp_name) > 0: 227 | temp_name += "_" + layer_infos.pop(0) 228 | else: 229 | temp_name = layer_infos.pop(0) 230 | 231 | pair_keys = [] 232 | if "lora_down" in key: 233 | pair_keys.append(key.replace("lora_down", "lora_up")) 234 | pair_keys.append(key) 235 | else: 236 | pair_keys.append(key) 237 | pair_keys.append(key.replace("lora_up", "lora_down")) 238 | 239 | # update weight 240 | if len(state_dict[pair_keys[0]].shape) == 4: 241 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 242 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 243 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 244 | else: 245 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 246 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 247 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 248 | 249 | # update visited list 250 | for item in pair_keys: 251 | visited.append(item) 252 | 253 | return pipeline 254 | 255 | def shave_segments(path, n_shave_prefix_segments=1): 256 | """ 257 | Removes segments. Positive values shave the first segments, negative shave the last segments. 258 | """ 259 | if n_shave_prefix_segments >= 0: 260 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 261 | else: 262 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 263 | 264 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 265 | """ 266 | Updates paths inside resnets to the new naming scheme (local renaming) 267 | """ 268 | mapping = [] 269 | for old_item in old_list: 270 | new_item = old_item.replace("in_layers.0", "norm1") 271 | new_item = new_item.replace("in_layers.2", "conv1") 272 | 273 | new_item = new_item.replace("out_layers.0", "norm2") 274 | new_item = new_item.replace("out_layers.3", "conv2") 275 | 276 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 277 | new_item = new_item.replace("skip_connection", "conv_shortcut") 278 | 279 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 280 | 281 | mapping.append({"old": old_item, "new": new_item}) 282 | 283 | return mapping 284 | 285 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 286 | """ 287 | Updates paths inside resnets to the new naming scheme (local renaming) 288 | """ 289 | mapping = [] 290 | for old_item in old_list: 291 | new_item = old_item 292 | 293 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 294 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 295 | 296 | mapping.append({"old": old_item, "new": new_item}) 297 | 298 | return mapping 299 | 300 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 301 | """ 302 | Updates paths inside attentions to the new naming scheme (local renaming) 303 | """ 304 | mapping = [] 305 | for old_item in old_list: 306 | new_item = old_item 307 | mapping.append({"old": old_item, "new": new_item}) 308 | return mapping 309 | 310 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 311 | """ 312 | Updates paths inside attentions to the new naming scheme (local renaming) 313 | """ 314 | mapping = [] 315 | for old_item in old_list: 316 | new_item = old_item 317 | 318 | new_item = new_item.replace("norm.weight", "group_norm.weight") 319 | new_item = new_item.replace("norm.bias", "group_norm.bias") 320 | 321 | new_item = new_item.replace("q.weight", "query.weight") 322 | new_item = new_item.replace("q.bias", "query.bias") 323 | 324 | new_item = new_item.replace("k.weight", "key.weight") 325 | new_item = new_item.replace("k.bias", "key.bias") 326 | 327 | new_item = new_item.replace("v.weight", "value.weight") 328 | new_item = new_item.replace("v.bias", "value.bias") 329 | 330 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 331 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 332 | 333 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 334 | 335 | mapping.append({"old": old_item, "new": new_item}) 336 | 337 | return mapping 338 | 339 | def assign_to_checkpoint( 340 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 341 | ): 342 | """ 343 | This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits 344 | attention layers, and takes into account additional replacements that may arise. 345 | 346 | Assigns the weights to the new checkpoint. 347 | """ 348 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 349 | 350 | # Splits the attention layers into three variables. 351 | if attention_paths_to_split is not None: 352 | for path, path_map in attention_paths_to_split.items(): 353 | old_tensor = old_checkpoint[path] 354 | channels = old_tensor.shape[0] // 3 355 | 356 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 357 | 358 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 359 | 360 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 361 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 362 | 363 | checkpoint[path_map["query"]] = query.reshape(target_shape) 364 | checkpoint[path_map["key"]] = key.reshape(target_shape) 365 | checkpoint[path_map["value"]] = value.reshape(target_shape) 366 | 367 | for path in paths: 368 | new_path = path["new"] 369 | 370 | # These have already been assigned 371 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 372 | continue 373 | 374 | # Global renaming happens here 375 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 376 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 377 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 378 | 379 | if additional_replacements is not None: 380 | for replacement in additional_replacements: 381 | new_path = new_path.replace(replacement["old"], replacement["new"]) 382 | 383 | # proj_attn.weight has to be converted from conv 1D to linear 384 | if "proj_attn.weight" in new_path: 385 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 386 | else: 387 | checkpoint[new_path] = old_checkpoint[path["old"]] 388 | 389 | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False): 390 | """ 391 | Takes a state dict and a config, and returns a converted checkpoint. 392 | """ 393 | 394 | # extract state_dict for UNet 395 | unet_state_dict = {} 396 | keys = list(checkpoint.keys()) 397 | 398 | unet_key = "model.diffusion_model." 399 | 400 | # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA 401 | if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: 402 | print(f"Checkpoint {path} has both EMA and non-EMA weights.") 403 | print( 404 | "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" 405 | " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." 406 | ) 407 | for key in keys: 408 | if key.startswith("model.diffusion_model"): 409 | flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) 410 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) 411 | else: 412 | if sum(k.startswith("model_ema") for k in keys) > 100: 413 | print( 414 | "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" 415 | " weights (usually better for inference), please make sure to add the `--extract_ema` flag." 416 | ) 417 | 418 | for key in keys: 419 | if key.startswith(unet_key): 420 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 421 | 422 | new_checkpoint = {} 423 | 424 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 425 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 426 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 427 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 428 | 429 | if config["class_embed_type"] is None: 430 | # No parameters to port 431 | ... 432 | elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": 433 | new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] 434 | new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] 435 | new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] 436 | new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] 437 | else: 438 | raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") 439 | 440 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 441 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 442 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 443 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 444 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 445 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 446 | 447 | # Retrieves the keys for the input blocks only 448 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 449 | input_blocks = { 450 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 451 | for layer_id in range(num_input_blocks) 452 | } 453 | 454 | # Retrieves the keys for the middle blocks only 455 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 456 | middle_blocks = { 457 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 458 | for layer_id in range(num_middle_blocks) 459 | } 460 | 461 | # Retrieves the keys for the output blocks only 462 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 463 | output_blocks = { 464 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 465 | for layer_id in range(num_output_blocks) 466 | } 467 | 468 | for i in range(1, num_input_blocks): 469 | block_id = (i - 1) // (config["layers_per_block"] + 1) 470 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 471 | 472 | resnets = [ 473 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 474 | ] 475 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 476 | 477 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 478 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 479 | f"input_blocks.{i}.0.op.weight" 480 | ) 481 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 482 | f"input_blocks.{i}.0.op.bias" 483 | ) 484 | 485 | paths = renew_resnet_paths(resnets) 486 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 487 | assign_to_checkpoint( 488 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 489 | ) 490 | 491 | if len(attentions): 492 | paths = renew_attention_paths(attentions) 493 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 494 | assign_to_checkpoint( 495 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 496 | ) 497 | 498 | resnet_0 = middle_blocks[0] 499 | attentions = middle_blocks[1] 500 | resnet_1 = middle_blocks[2] 501 | 502 | resnet_0_paths = renew_resnet_paths(resnet_0) 503 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 504 | 505 | resnet_1_paths = renew_resnet_paths(resnet_1) 506 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 507 | 508 | attentions_paths = renew_attention_paths(attentions) 509 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 510 | assign_to_checkpoint( 511 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 512 | ) 513 | 514 | for i in range(num_output_blocks): 515 | block_id = i // (config["layers_per_block"] + 1) 516 | layer_in_block_id = i % (config["layers_per_block"] + 1) 517 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 518 | output_block_list = {} 519 | 520 | for layer in output_block_layers: 521 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 522 | if layer_id in output_block_list: 523 | output_block_list[layer_id].append(layer_name) 524 | else: 525 | output_block_list[layer_id] = [layer_name] 526 | 527 | if len(output_block_list) > 1: 528 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 529 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 530 | 531 | resnet_0_paths = renew_resnet_paths(resnets) 532 | paths = renew_resnet_paths(resnets) 533 | 534 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 535 | assign_to_checkpoint( 536 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 537 | ) 538 | 539 | output_block_list = {k: sorted(v) for k, v in output_block_list.items()} 540 | if ["conv.bias", "conv.weight"] in output_block_list.values(): 541 | index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) 542 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 543 | f"output_blocks.{i}.{index}.conv.weight" 544 | ] 545 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 546 | f"output_blocks.{i}.{index}.conv.bias" 547 | ] 548 | 549 | # Clear attentions as they have been attributed above. 550 | if len(attentions) == 2: 551 | attentions = [] 552 | 553 | if len(attentions): 554 | paths = renew_attention_paths(attentions) 555 | meta_path = { 556 | "old": f"output_blocks.{i}.1", 557 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 558 | } 559 | assign_to_checkpoint( 560 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 561 | ) 562 | else: 563 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 564 | for path in resnet_0_paths: 565 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 566 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 567 | 568 | new_checkpoint[new_path] = unet_state_dict[old_path] 569 | 570 | return new_checkpoint 571 | 572 | def convert_ldm_clip_checkpoint(checkpoint): 573 | from transformers import CLIPTextModel 574 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 575 | 576 | keys = list(checkpoint.keys()) 577 | keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids") 578 | 579 | text_model_dict = {} 580 | 581 | for key in keys: 582 | if key.startswith("cond_stage_model.transformer"): 583 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] 584 | text_model.load_state_dict(text_model_dict) 585 | 586 | return text_model 587 | 588 | def convert_ldm_clip_text_model(text_model, checkpoint): 589 | keys = list(checkpoint.keys()) 590 | keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids") 591 | 592 | text_model_dict = {} 593 | 594 | for key in keys: 595 | if key.startswith("cond_stage_model.transformer"): 596 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] 597 | text_model.load_state_dict(text_model_dict) 598 | 599 | return text_model 600 | 601 | def conv_attn_to_linear(checkpoint): 602 | keys = list(checkpoint.keys()) 603 | attn_keys = ["query.weight", "key.weight", "value.weight"] 604 | for key in keys: 605 | if ".".join(key.split(".")[-2:]) in attn_keys: 606 | if checkpoint[key].ndim > 2: 607 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 608 | elif "proj_attn.weight" in key: 609 | if checkpoint[key].ndim > 2: 610 | checkpoint[key] = checkpoint[key][:, :, 0] 611 | 612 | def convert_ldm_vae_checkpoint(checkpoint, config): 613 | # extract state dict for VAE 614 | vae_state_dict = {} 615 | vae_key = "first_stage_model." 616 | keys = list(checkpoint.keys()) 617 | for key in keys: 618 | if key.startswith(vae_key): 619 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 620 | 621 | new_checkpoint = {} 622 | 623 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 624 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 625 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 626 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 627 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 628 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 629 | 630 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 631 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 632 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 633 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 634 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 635 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 636 | 637 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 638 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 639 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 640 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 641 | 642 | # Retrieves the keys for the encoder down blocks only 643 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 644 | down_blocks = { 645 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 646 | } 647 | 648 | # Retrieves the keys for the decoder up blocks only 649 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 650 | up_blocks = { 651 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 652 | } 653 | 654 | for i in range(num_down_blocks): 655 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 656 | 657 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 658 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 659 | f"encoder.down.{i}.downsample.conv.weight" 660 | ) 661 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 662 | f"encoder.down.{i}.downsample.conv.bias" 663 | ) 664 | 665 | paths = renew_vae_resnet_paths(resnets) 666 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 667 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 668 | 669 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 670 | num_mid_res_blocks = 2 671 | for i in range(1, num_mid_res_blocks + 1): 672 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 673 | 674 | paths = renew_vae_resnet_paths(resnets) 675 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 676 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 677 | 678 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 679 | paths = renew_vae_attention_paths(mid_attentions) 680 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 681 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 682 | conv_attn_to_linear(new_checkpoint) 683 | 684 | for i in range(num_up_blocks): 685 | block_id = num_up_blocks - 1 - i 686 | resnets = [ 687 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 688 | ] 689 | 690 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 691 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 692 | f"decoder.up.{block_id}.upsample.conv.weight" 693 | ] 694 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 695 | f"decoder.up.{block_id}.upsample.conv.bias" 696 | ] 697 | 698 | paths = renew_vae_resnet_paths(resnets) 699 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 700 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 701 | 702 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 703 | num_mid_res_blocks = 2 704 | for i in range(1, num_mid_res_blocks + 1): 705 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 706 | 707 | paths = renew_vae_resnet_paths(resnets) 708 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 709 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 710 | 711 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 712 | paths = renew_vae_attention_paths(mid_attentions) 713 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 714 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 715 | conv_attn_to_linear(new_checkpoint) 716 | 717 | return new_checkpoint --------------------------------------------------------------------------------