├── LICENSE
├── README.md
├── __assets__
├── magictime_logo.png
├── promtp_opensora.txt
└── promtp_unet.txt
├── app.py
├── ckpts
├── Base_Model
│ ├── base_model_path.txt
│ ├── motion_module
│ │ └── motion_module_path.txt
│ └── stable-diffusion-v1-5
│ │ └── sd_15_path.txt
├── DreamBooth
│ └── dreambooth_path.txt
└── Magic_Weights
│ └── magic_weights_path.txt
├── data_preprocess
├── README.md
├── run.sh
├── step0_extract_frame_resize.py
├── step2_1_GPT4V_frame_caption.py
├── step2_2_preprocess_frame_caption.py
├── step3_1_GPT4V_video_caption_concise.py
├── step3_1_GPT4V_video_caption_detail.py
├── step3_2_preprocess_video_caption.py
└── step4_1_create_webvid_format.py
├── inference.sh
├── inference_cli.sh
├── inference_magictime.py
├── requirements.txt
├── sample_configs
├── RcnzCartoon.yaml
├── RealisticVision.yaml
└── ToonYou.yaml
└── utils
├── dataset.py
├── pipeline_magictime.py
├── unet.py
├── unet_blocks.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
7 | If you like our project, please give us a star ⭐ on GitHub for the latest update.
8 |
9 |
10 |
11 |
12 | [](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build)
13 | [](https://replicate.com/camenduru/magictime)
14 | [](https://colab.research.google.com/github/camenduru/MagicTime-jupyter/blob/main/MagicTime_jupyter.ipynb)
15 | [](https://huggingface.co/papers/2404.05014)
16 | [](https://arxiv.org/abs/2404.05014)
17 | [](https://pku-yuangroup.github.io/MagicTime/)
18 | [](https://huggingface.co/datasets/BestWishYsh/ChronoMagic)
19 | [](https://twitter.com/_akhaliq/status/1777538468043792473)
20 | [](https://twitter.com/vhjf36495872/status/1777525817087553827?s=61&t=r2HzCsU2AnJKbR8yKSprKw)
21 | [](https://zenodo.org/doi/10.5281/zenodo.10960665)
22 | [](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE)
23 | [](https://github.com/PKU-YuanGroup/MagicTime)
24 |
25 |
26 |
27 |
28 | This repository is the official implementation of MagicTime, a metamorphic video generation pipeline based on the given prompts. The main idea is to enhance the capacity of video generation models to accurately depict the real world through our proposed methods and dataset.
29 |
30 |
31 |
32 |
33 | 💡 We also have other video generation projects that may interest you ✨.
34 |
35 |
36 |
37 | > [**Open-Sora Plan: Open-Source Large Video Generation Model**](https://arxiv.org/abs/2412.00131)
38 | > Bin Lin, Yunyang Ge and Xinhua Cheng etc.
39 | [](https://github.com/PKU-YuanGroup/Open-Sora-Plan) [](https://github.com/PKU-YuanGroup/Open-Sora-Plan) [](https://arxiv.org/abs/2412.00131)
40 | >
41 | > [**OpenS2V-Nexus: A Detailed Benchmark and Million-Scale Dataset for Subject-to-Video Generation**](https://arxiv.org/abs/2505.20292)
42 | > Shenghai Yuan, Xianyi He and Yufan Deng etc.
43 | > [](https://github.com/PKU-YuanGroup/OpenS2V-Nexus) [](https://github.com/PKU-YuanGroup/OpenS2V-Nexus) [](https://arxiv.org/abs/2505.20292)
44 | >
45 | > [**ConsisID: Identity-Preserving Text-to-Video Generation by Frequency Decomposition**](https://arxiv.org/abs/2411.17440)
46 | > Shenghai Yuan, Jinfa Huang and Xianyi He etc.
47 | > [](https://github.com/PKU-YuanGroup/ConsisID/) [](https://github.com/PKU-YuanGroup/ConsisID/) [](https://arxiv.org/abs/2411.17440)
48 | >
49 | > [**ChronoMagic-Bench: A Benchmark for Metamorphic Evaluation of Text-to-Time-lapse Video Generation**](https://arxiv.org/abs/2406.18522)
50 | > Shenghai Yuan, Jinfa Huang and Yongqi Xu etc.
51 | > [](https://github.com/PKU-YuanGroup/ChronoMagic-Bench/) [](https://github.com/PKU-YuanGroup/ChronoMagic-Bench/) [](https://arxiv.org/abs/2406.18522)
52 | >
53 |
54 | ## 📣 News
55 | * ⏳⏳⏳ Training a stronger model with the support of [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan).
56 | * ⏳⏳⏳ Release the training code of MagicTime.
57 | * `[2025.04.08]` 🔥 We have updated our technical report. Please click [here](https://arxiv.org/abs/2404.05014) to view it.
58 | * `[2025.03.28]` 🔥 MagicTime has been accepted by **TPAMI**, and we will update arXiv with more details soon, keep tuned!
59 | * `[2024.07.29]` We add *batch inference* to [inference_magictime.py](https://github.com/PKU-YuanGroup/MagicTime/blob/main/inference_magictime.py) for easier usage.
60 | * `[2024.06.27]` Excited to share our latest [ChronoMagic-Bench](https://github.com/PKU-YuanGroup/ChronoMagic-Bench), a benchmark for metamorphic evaluation of text-to-time-lapse video generation, and is fully open source! Please check out the [paper](https://arxiv.org/abs/2406.18522).
61 | * `[2024.05.27]` Excited to share our latest Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out the [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.1.0.md).
62 | * `[2024.04.14]` Thanks [@camenduru](https://twitter.com/camenduru) and [@ModelsLab](https://modelslab.com/) for providing [Jupyter Notebook](https://github.com/camenduru/MagicTime-jupyter) and [Replicate Demo](https://replicate.com/camenduru/magictime).
63 | * `[2024.04.13]` 🔥 We have compressed the size of repo with less than 1.0 MB, so that everyone can clone easier and faster. You can click [here](https://github.com/PKU-YuanGroup/MagicTime/archive/refs/heads/main.zip) to download, or use `git clone --depth=1` command to obtain this repo.
64 | * `[2024.04.12]` Thanks [@Kijai](https://github.com/kijai) and [@Baobao Wang](https://www.bilibili.com/video/BV1wx421U7Gn/?spm_id_from=333.1007.top_right_bar_window_history.content.click) for providing ComfyUI Extension [ComfyUI-MagicTimeWrapper](https://github.com/kijai/ComfyUI-MagicTimeWrapper). If you find related work, please let us know.
65 | * `[2024.04.11]` 🔥 We release the Hugging Face Space of MagicTime, you can click [here](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build) to have a try.
66 | * `[2024.04.10]` 🔥 We release the inference code and model weight of MagicTime.
67 | * `[2024.04.09]` 🔥 We release the arXiv paper for MagicTime, and you can click [here](https://arxiv.org/abs/2404.05014) to see more details.
68 | * `[2024.04.08]` 🔥 We release the subset of ChronoMagic dataset used to train MagicTime. The dataset includes 2,265 metamorphic video-text pairs and can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) or [Google Drive](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing).
69 | * `[2024.04.08]` 🔥 **All codes & datasets** are coming soon! Stay tuned 👀!
70 |
71 | ## 😮 Highlights
72 |
73 | MagicTime shows excellent performance in **metamorphic video generation**.
74 |
75 | ### Related Resources
76 | * [ChronoMagic](https://huggingface.co/datasets/BestWishYsh/ChronoMagic): including 2265 time-lapse video-text pairs. (captioned by GPT-4V)
77 | * [ChronoMagic-Bench](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Bench/tree/main): including 1649 time-lapse video-text pairs. (captioned by GPT-4o)
78 | * [ChronoMagic-Bench-150](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Bench/tree/main): including 150 time-lapse video-text pairs. (captioned by GPT-4o)
79 | * [ChronoMagic-Pro](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-Pro): including 460K time-lapse video-text pairs. (captioned by ShareGPT4Video)
80 | * [ChronoMagic-ProH](https://huggingface.co/datasets/BestWishYsh/ChronoMagic-ProH): including 150K time-lapse video-text pairs. (captioned by ShareGPT4Video)
81 |
82 | ### Metamorphic Videos vs. General Videos
83 |
84 | Compared to general videos, metamorphic videos contain physical knowledge, long persistence, and strong variation, making them difficult to generate. We show compressed .gif on github, which loses some quality. The general videos are generated by the [Animatediff](https://github.com/guoyww/AnimateDiff) and **MagicTime**.
85 |
86 |
87 |
88 | Type |
89 | "Bean sprouts grow and mature from seeds" |
90 | "[...] construction in a Minecraft virtual environment" |
91 | "Cupcakes baking in an oven [...]" |
92 | "[...] transitioning from a tightly closed bud to a fully bloomed state [...]" |
93 |
94 |
95 | General Videos |
96 |  |
97 |  |
98 |  |
99 |  |
100 |
101 |
102 | Metamorphic Videos |
103 |  |
104 |  |
105 |  |
106 |  |
107 |
108 |
109 |
110 | ### Gallery
111 |
112 | We showcase some metamorphic videos generated by **MagicTime**, [MakeLongVideo](https://github.com/xuduo35/MakeLongVideo), [ModelScopeT2V](https://github.com/modelscope), [VideoCrafter](https://github.com/AILab-CVC/VideoCrafter?tab=readme-ov-file), [ZeroScope](https://huggingface.co/cerspense/zeroscope_v2_576w), [LaVie](https://github.com/Vchitect/LaVie), [T2V-Zero](https://github.com/Picsart-AI-Research/Text2Video-Zero), [Latte](https://github.com/Vchitect/Latte) and [Animatediff](https://github.com/guoyww/AnimateDiff) below.
113 |
114 |
115 |
116 | Method |
117 | "cherry blossoms transitioning [...]" |
118 | "dough balls baking process [...]" |
119 | "an ice cube is melting [...]" |
120 | "a simple modern house's construction [...]" |
121 |
122 |
123 | MakeLongVideo |
124 |  |
125 |  |
126 |  |
127 |  |
128 |
129 |
130 | ModelScopeT2V |
131 |  |
132 |  |
133 |  |
134 |  |
135 |
136 |
137 | VideoCrafter |
138 |  |
139 |  |
140 |  |
141 |  |
142 |
143 |
144 | ZeroScope |
145 |  |
146 |  |
147 |  |
148 |  |
149 |
150 |
151 | LaVie |
152 |  |
153 |  |
154 |  |
155 |  |
156 |
157 |
158 | T2V-Zero |
159 |  |
160 |  |
161 |  |
162 |  |
163 |
164 |
165 | Latte |
166 |  |
167 |  |
168 |  |
169 |  |
170 |
171 |
172 | Animatediff |
173 |  |
174 |  |
175 |  |
176 |  |
177 |
178 |
179 | Ours |
180 |  |
181 |  |
182 |  |
183 |  |
184 |
185 |
186 |
187 |
188 | We show more metamorphic videos generated by **MagicTime** with the help of [Realistic](https://civitai.com/models/4201/realistic-vision-v20), [ToonYou](https://civitai.com/models/30240/toonyou) and [RcnzCartoon](https://civitai.com/models/66347/rcnz-cartoon-3d).
189 |
190 |
191 |
192 |  |
193 |  |
194 |  |
195 |
196 |
197 | "[...] bean sprouts grow and mature from seeds" |
198 | "dough [...] swells and browns in the oven [...]" |
199 | "the construction [...] in Minecraft [...]" |
200 |
201 |
202 |  |
203 |  |
204 |  |
205 |
206 |
207 | "a bud transforms into a yellow flower" |
208 | "time-lapse of a plant germinating [...]" |
209 | "[...] a modern house being constructed in Minecraft [...]" |
210 |
211 |
212 |  |
213 |  |
214 |  |
215 |
216 |
217 | "an ice cube is melting" |
218 | "bean plant sprouts grow and mature from the soil" |
219 | "time-lapse of delicate pink plum blossoms [...]" |
220 |
221 |
222 |
223 | Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_unet.txt) for full prompts.
224 | ### Integrate into DiT-based Architecture
225 |
226 | The mission of this project is to help reproduce Sora and provide high-quality video-text data and data annotation pipelines, to support [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) or other DiT-based T2V models. To this end, we take an initial step to integrate our MagicTime scheme into the DiT-based Framework. Specifically, our method supports the Open-Sora-Plan v1.0.0 for fine-tuning. We first scale up with additional metamorphic landscape time-lapse videos in the same annotation framework to get the ChronoMagic-Landscape dataset. Then, we fine-tune the Open-Sora-Plan v1.0.0 with the ChronoMagic-Landscape dataset to get the MagicTime-DiT model. The results are as follows (**257×512×512 (10s)**):
227 |
228 |
229 |
230 |
231 |
232 | |
233 |
234 |
235 | |
236 |
237 |
238 | |
239 |
240 |
241 | |
242 |
243 |
244 | "Time-lapse of a coastal landscape [...]" |
245 | "Display the serene beauty of twilight [...]" |
246 | "Sunrise Splendor: Capture the breathtaking moment [...]" |
247 | "Nightfall Elegance: Embrace the tranquil beauty [...]" |
248 |
249 |
250 |
251 |
252 | |
253 |
254 |
255 | |
256 |
257 |
258 | |
259 |
260 |
261 | |
262 |
263 |
264 | "The sun descending below the horizon [...]" |
265 | "[...] daylight fades into the embrace of the night [...]" |
266 | "Time-lapse of the dynamic formations of clouds [...]" |
267 | "Capture the dynamic formations of clouds [...]" |
268 |
269 |
270 |
271 | Prompts are trimmed for display, see [here](https://github.com/PKU-YuanGroup/MagicTime/blob/main/__assets__/promtp_opensora.txt) for full prompts.
272 |
273 | ## 🤗 Demo
274 |
275 | ### Gradio Web UI
276 |
277 | Highly recommend trying out our web demo by the following command, which incorporates all features currently supported by MagicTime. We also provide [online demo](https://huggingface.co/spaces/BestWishYsh/MagicTime?logs=build) in Hugging Face Spaces.
278 |
279 | ```bash
280 | python app.py
281 | ```
282 |
283 | ### CLI Inference
284 |
285 | ```bash
286 | # For Realistic
287 | python inference_magictime.py --config sample_configs/RealisticVision.yaml --human
288 |
289 | # or you can directly run the .sh
290 | sh inference_cli.sh
291 | ```
292 |
293 | warning: It is worth noting that even if we use the same seed and prompt but we change a machine, the results will be different.
294 |
295 | ## ⚙️ Requirements and Installation
296 |
297 | We recommend the requirements as follows.
298 |
299 | ### Environment
300 |
301 | ```bash
302 | git clone --depth=1 https://github.com/PKU-YuanGroup/MagicTime.git
303 | cd MagicTime
304 | conda create -n magictime python=3.10.13
305 | conda activate magictime
306 | pip install -r requirements.txt
307 | ```
308 |
309 | ### Download MagicTime
310 |
311 | The weights are available at [🤗HuggingFace](https://huggingface.co/BestWishYsh/MagicTime/tree/main) and [🟣WiseModel](https://wisemodel.cn/models/SHYuanBest/MagicTime/file), or you can download it with the following commands.
312 |
313 | ```bash
314 | # way 1
315 | # if you are in china mainland, run this first: export HF_ENDPOINT=https://hf-mirror.com
316 | huggingface-cli download --repo-type model \
317 | BestWishYsh/MagicTime \
318 | --local-dir ckpts
319 |
320 | # way 2
321 | git lfs install
322 | git clone https://www.wisemodel.cn/SHYuanBest/MagicTime.git
323 | ```
324 |
325 | Once ready, the weights will be organized in this format:
326 |
327 | ```
328 | 📦 ckpts/
329 | ├── 📂 Base_Model/
330 | │ ├── 📂 motion_module/
331 | │ ├── 📂 stable-diffusion-v1-5/
332 | ├── 📂 DreamBooth/
333 | ├── 📂 Magic_Weights/
334 | │ ├── 📂 magic_adapter_s/
335 | │ ├── 📂 magic_adapter_t/
336 | │ ├── 📂 magic_text_encoder/
337 | ```
338 |
339 | ## 🗝️ Training & Inference
340 |
341 | The training code is coming soon!
342 |
343 | For inference, some examples are shown below:
344 |
345 | ```bash
346 | # For Realistic
347 | python inference_magictime.py --config sample_configs/RealisticVision.yaml
348 | # For ToonYou
349 | python inference_magictime.py --config sample_configs/ToonYou.yaml
350 | # For RcnzCartoon
351 | python inference_magictime.py --config sample_configs/RcnzCartoon.yaml
352 | # or you can directly run the .sh
353 | sh inference.sh
354 | ```
355 |
356 | You can also put all your *custom prompts* in a .txt file and run:
357 |
358 | ```bash
359 | # For Realistic
360 | python inference_magictime.py --config sample_configs/RealisticVision.yaml --run-txt XXX.txt --batch-size 2
361 | # For ToonYou
362 | python inference_magictime.py --config sample_configs/ToonYou.yaml --run-txt XXX.txt --batch-size 2
363 | # For RcnzCartoon
364 | python inference_magictime.py --config sample_configs/RcnzCartoon.yaml --run-txt XXX.txt --batch-size 2
365 | ```
366 |
367 | ## Community Contributions
368 |
369 | We found some plugins created by community developers. Thanks for their efforts:
370 |
371 | - ComfyUI Extension. [ComfyUI-MagicTimeWrapper](https://github.com/kijai/ComfyUI-MagicTimeWrapper) (by [@Kijai](https://github.com/kijai)). And you can click [here](https://www.bilibili.com/video/BV1wx421U7Gn/?spm_id_from=333.1007.top_right_bar_window_history.content.click) to view the installation tutorial.
372 | - Replicate Demo & Cloud API. [Replicate-MagicTime](https://replicate.com/camenduru/magictime) (by [@camenduru](https://twitter.com/camenduru)).
373 | - Jupyter Notebook. [Jupyter-MagicTime](https://github.com/camenduru/MagicTime-jupyter) (by [@ModelsLab](https://modelslab.com/)).
374 |
375 | If you find related work, please let us know.
376 |
377 | ## 🐳 ChronoMagic Dataset
378 | ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [HuggingFace Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic), or you can download it with the following command. Some samples can be found on our [Project Page](https://pku-yuangroup.github.io/MagicTime/).
379 | ```bash
380 | huggingface-cli download --repo-type dataset \
381 | --resume-download BestWishYsh/ChronoMagic \
382 | --local-dir BestWishYsh/ChronoMagic \
383 | --local-dir-use-symlinks False
384 | ```
385 |
386 | ## 👍 Acknowledgement
387 | * [Animatediff](https://github.com/guoyww/AnimateDiff/tree/main) The codebase we built upon and it is a strong U-Net-based text-to-video generation model.
388 |
389 | * [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) The codebase we built upon and it is a simple and scalable DiT-based text-to-video generation repo, to reproduce [Sora](https://openai.com/sora).
390 |
391 | ## 🔒 License
392 | * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) file.
393 | * The service is a research preview. Please contact us if you find any potential violations.
394 |
395 | ## ✏️ Citation
396 | If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:.
397 |
398 | ```BibTeX
399 | @article{yuan2025magictime,
400 | title={Magictime: Time-lapse video generation models as metamorphic simulators},
401 | author={Yuan, Shenghai and Huang, Jinfa and Shi, Yujun and Xu, Yongqi and Zhu, Ruijie and Lin, Bin and Cheng, Xinhua and Yuan, Li and Luo, Jiebo},
402 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
403 | year={2025},
404 | publisher={IEEE}
405 | }
406 | ```
407 |
408 | ## 🤝 Contributors
409 |
410 |
411 |
412 |
413 |
414 |
--------------------------------------------------------------------------------
/__assets__/magictime_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/__assets__/magictime_logo.png
--------------------------------------------------------------------------------
/__assets__/promtp_opensora.txt:
--------------------------------------------------------------------------------
1 | 1. Time-lapse of a coastal landscape transitioning from sunrise to nightfall, with early morning light and soft shadows giving way to a clearer, bright midday sky, and later visible signs of sunset with orange hues and a dimming sky, culminating in a vibrant dusk.
2 | 2. Display the serene beauty of twilight, marking the transition from day to night with subtle changes in lighting.
3 | 3. Sunrise Splendor: Capture the breathtaking moment as the sun peeks over the horizon, casting its warm hues across the landscape in a mesmerizing time-lapse.
4 | 4. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle.
5 | 5. The sun descending below the horizon at dusk. The video is a time-lapse showcasing the gradual dimming of daylight, leading to the onset of twilight.
6 | 6. Nightfall Elegance: Embrace the tranquil beauty of dusk as daylight fades into the embrace of the night, unveiling the twinkling stars against the darkening sky in a mesmerizing time-lapse spectacle.
7 | 7. Time-lapse of the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video.
8 | 8. Capture the dynamic formations of clouds, showcasing their continuous motion and evolution over the course of the video.
--------------------------------------------------------------------------------
/__assets__/promtp_unet.txt:
--------------------------------------------------------------------------------
1 | 1. A time-lapse video of bean sprouts grow and mature from seeds.
2 | 2. Dough starts smooth, swells and browns in the oven, finishing as fully expanded, baked bread.
3 | 3. The construction of a simple modern house in Minecraft. As the construction progresses, the roof and walls are completed, and the area around the house is cleared and shaped.
4 | 4. A bud transforms into a yellow flower.
5 | 5. Time-lapse of a plant germinating and developing into a young plant with multiple true leaves in a container, showing progressive growth stages from bare soil to a full plant.
6 | 6. Time-lapse of a modern house being constructed in Minecraft, beginning with a basic structure and progressively adding roof details, and new sections.
7 | 7. An ice cube is melting.
8 | 8. Bean plant sprouts grow and mature from the soil.
9 | 9. Time-lapse of delicate pink plum blossoms transitioning from tightly closed buds to gently unfurling petals, revealing the intricate details of stamens and pistils within.
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | import torch
5 | import random
6 | import gradio as gr
7 | from glob import glob
8 | from omegaconf import OmegaConf
9 | from safetensors import safe_open
10 | from diffusers import AutoencoderKL
11 | from diffusers import DDIMScheduler
12 | from diffusers.utils.import_utils import is_xformers_available
13 | from transformers import CLIPTextModel, CLIPTokenizer
14 |
15 | from utils.unet import UNet3DConditionModel
16 | from utils.pipeline_magictime import MagicTimePipeline
17 | from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
18 | # import spaces
19 |
20 | from huggingface_hub import snapshot_download
21 |
22 | model_path = "ckpts"
23 |
24 | if not os.path.exists(model_path) or not os.path.exists(f"{model_path}/model_real_esran") or not os.path.exists(f"{model_path}/model_rife"):
25 | print("Model not found, downloading from Hugging Face...")
26 | snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir=f"{model_path}")
27 | else:
28 | print(f"Model already exists in {model_path}, skipping download.")
29 |
30 | pretrained_model_path = f"{model_path}/Base_Model/stable-diffusion-v1-5"
31 | inference_config_path = "sample_configs/RealisticVision.yaml"
32 | magic_adapter_s_path = f"{model_path}/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
33 | magic_adapter_t_path = f"{model_path}/Magic_Weights/magic_adapter_t"
34 | magic_text_encoder_path = f"{model_path}/Magic_Weights/magic_text_encoder"
35 |
36 | css = """
37 | .toolbutton {
38 | margin-buttom: 0em 0em 0em 0em;
39 | max-width: 2.5em;
40 | min-width: 2.5em !important;
41 | height: 2.5em;
42 | }
43 | """
44 |
45 | examples = [
46 | # 1-RealisticVision
47 | [
48 | "RealisticVisionV60B1_v51VAE.safetensors",
49 | "motion_module.ckpt",
50 | "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms.",
51 | "worst quality, low quality, letterboxed",
52 | 512, 512, "1534851746"
53 | ],
54 | # 2-RCNZ
55 | [
56 | "RcnzCartoon.safetensors",
57 | "motion_module.ckpt",
58 | "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney.",
59 | "worst quality, low quality, letterboxed",
60 | 512, 512, "3480796026"
61 | ],
62 | # 3-ToonYou
63 | [
64 | "ToonYou_beta6.safetensors",
65 | "motion_module.ckpt",
66 | "Bean sprouts grow and mature from seeds.",
67 | "worst quality, low quality, letterboxed",
68 | 512, 512, "1496541313"
69 | ]
70 | ]
71 |
72 | # clean Grdio cache
73 | print(f"### Cleaning cached examples ...")
74 | os.system(f"rm -rf gradio_cached_examples/")
75 |
76 | device = "cuda"
77 |
78 | def random_seed():
79 | return random.randint(1, 10**16)
80 |
81 | class MagicTimeController:
82 | def __init__(self):
83 | # config dirs
84 | self.basedir = os.getcwd()
85 | self.stable_diffusion_dir = os.path.join(self.basedir, model_path, "Base_Model")
86 | self.motion_module_dir = os.path.join(self.basedir, model_path, "Base_Model", "motion_module")
87 | self.personalized_model_dir = os.path.join(self.basedir, model_path, "DreamBooth")
88 | self.savedir = os.path.join(self.basedir, "outputs")
89 | os.makedirs(self.savedir, exist_ok=True)
90 |
91 | self.dreambooth_list = []
92 | self.motion_module_list = []
93 |
94 | self.selected_dreambooth = None
95 | self.selected_motion_module = None
96 |
97 | self.refresh_motion_module()
98 | self.refresh_personalized_model()
99 |
100 | # config models
101 | self.inference_config = OmegaConf.load(inference_config_path)[1]
102 |
103 | self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
104 | self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
105 | self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
106 | self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
107 | self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
108 | self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
109 |
110 | self.update_motion_module(self.motion_module_list[0])
111 | self.update_motion_module_2(self.motion_module_list[0])
112 | self.update_dreambooth(self.dreambooth_list[0])
113 |
114 | def refresh_motion_module(self):
115 | motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
116 | self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
117 |
118 | def refresh_personalized_model(self):
119 | dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
120 | self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list]
121 |
122 | def update_dreambooth(self, dreambooth_dropdown, motion_module_dropdown=None):
123 | self.selected_dreambooth = dreambooth_dropdown
124 |
125 | dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown)
126 | dreambooth_state_dict = {}
127 | with safe_open(dreambooth_dropdown, framework="pt", device="cpu") as f:
128 | for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key)
129 |
130 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
131 | self.vae.load_state_dict(converted_vae_checkpoint)
132 |
133 | del self.unet
134 | self.unet = None
135 | torch.cuda.empty_cache()
136 | time.sleep(1)
137 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config)
138 | self.unet = copy.deepcopy(self.unet_model)
139 | self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
140 |
141 | del self.text_encoder
142 | self.text_encoder = None
143 | torch.cuda.empty_cache()
144 | time.sleep(1)
145 | text_model = copy.deepcopy(self.text_model)
146 | self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
147 |
148 | from swift import Swift
149 | magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
150 | self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0)
151 | self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
152 | self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
153 |
154 | return gr.Dropdown()
155 |
156 | def update_motion_module(self, motion_module_dropdown):
157 | self.selected_motion_module = motion_module_dropdown
158 | motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
159 | motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
160 | _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
161 | assert len(unexpected) == 0
162 | return gr.Dropdown()
163 |
164 | def update_motion_module_2(self, motion_module_dropdown):
165 | self.selected_motion_module = motion_module_dropdown
166 | motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
167 | motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
168 | _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
169 | assert len(unexpected) == 0
170 | return gr.Dropdown()
171 |
172 | # @spaces.GPU(duration=300)
173 | def magictime(
174 | self,
175 | dreambooth_dropdown,
176 | motion_module_dropdown,
177 | prompt_textbox,
178 | negative_prompt_textbox,
179 | width_slider,
180 | height_slider,
181 | seed_textbox,
182 | ):
183 | torch.cuda.empty_cache()
184 | time.sleep(1)
185 |
186 | if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
187 | if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown)
188 | if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
189 |
190 | while self.text_encoder is None or self.unet is None:
191 | self.update_dreambooth(dreambooth_dropdown, motion_module_dropdown)
192 |
193 | if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
194 |
195 | pipeline = MagicTimePipeline(
196 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
197 | scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
198 | ).to(device)
199 |
200 | if int(seed_textbox) > 0: seed = int(seed_textbox)
201 | else: seed = random_seed()
202 | torch.manual_seed(int(seed))
203 |
204 | assert seed == torch.initial_seed()
205 | print(f"### seed: {seed}")
206 |
207 | generator = torch.Generator(device=device)
208 | generator.manual_seed(seed)
209 |
210 | sample = pipeline(
211 | prompt_textbox,
212 | negative_prompt = negative_prompt_textbox,
213 | num_inference_steps = 25,
214 | guidance_scale = 8.,
215 | width = width_slider,
216 | height = height_slider,
217 | video_length = 16,
218 | generator = generator,
219 | ).videos
220 |
221 | save_sample_path = os.path.join(self.savedir, f"sample.mp4")
222 | save_videos_grid(sample, save_sample_path)
223 |
224 | json_config = {
225 | "prompt": prompt_textbox,
226 | "n_prompt": negative_prompt_textbox,
227 | "width": width_slider,
228 | "height": height_slider,
229 | "seed": seed,
230 | "dreambooth": dreambooth_dropdown,
231 | }
232 |
233 | torch.cuda.empty_cache()
234 | time.sleep(1)
235 | return gr.Video(value=save_sample_path), gr.Json(value=json_config)
236 |
237 | controller = MagicTimeController()
238 |
239 | def ui():
240 | with gr.Blocks(css=css) as demo:
241 | gr.Markdown(
242 | """
243 |
244 |

245 |
246 |
247 |
248 | If you like our project, please give us a star ⭐ on GitHub for the latest update.
249 |
250 | [GitHub](https://github.com/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing)
251 | """
252 | )
253 | with gr.Row():
254 | with gr.Column():
255 | dreambooth_dropdown = gr.Dropdown(label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True)
256 | motion_module_dropdown = gr.Dropdown(label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True)
257 |
258 | prompt_textbox = gr.Textbox(label="Prompt", lines=3)
259 | negative_prompt_textbox = gr.Textbox(label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
260 |
261 | with gr.Accordion("Advance", open=False):
262 | with gr.Row():
263 | width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
264 | height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
265 | with gr.Row():
266 | seed_textbox = gr.Textbox(label="Seed (-1 means random)", value="-1")
267 | seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
268 | seed_button.click(fn=random_seed, inputs=[], outputs=[seed_textbox])
269 |
270 | generate_button = gr.Button(value="Generate", variant='primary')
271 |
272 | with gr.Column():
273 | result_video = gr.Video(label="Generated Animation", interactive=False)
274 | json_config = gr.Json(label="Config", value={})
275 |
276 | inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox]
277 | outputs = [result_video, json_config]
278 |
279 | generate_button.click(fn=controller.magictime, inputs=inputs, outputs=outputs)
280 |
281 | gr.Markdown("""
282 | ⚠ Warning: Even if you use the same seed and prompt, changing machines may produce different results.
283 | If you find a better seed and prompt, please submit an issue on GitHub.
284 | """)
285 |
286 | gr.Examples(fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True)
287 |
288 | return demo
289 |
290 | if __name__ == "__main__":
291 | demo = ui()
292 | demo.queue(max_size=20)
293 | demo.launch()
294 |
--------------------------------------------------------------------------------
/ckpts/Base_Model/base_model_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/base_model_path.txt
--------------------------------------------------------------------------------
/ckpts/Base_Model/motion_module/motion_module_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/motion_module/motion_module_path.txt
--------------------------------------------------------------------------------
/ckpts/Base_Model/stable-diffusion-v1-5/sd_15_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Base_Model/stable-diffusion-v1-5/sd_15_path.txt
--------------------------------------------------------------------------------
/ckpts/DreamBooth/dreambooth_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/DreamBooth/dreambooth_path.txt
--------------------------------------------------------------------------------
/ckpts/Magic_Weights/magic_weights_path.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PKU-YuanGroup/MagicTime/33307bc0ce0ef2d15ccef2ed623c99bae63bf35c/ckpts/Magic_Weights/magic_weights_path.txt
--------------------------------------------------------------------------------
/data_preprocess/README.md:
--------------------------------------------------------------------------------
1 | # Data Preprocessing Pipeline by *MagicTime*
2 | This repo describes how to process your own data like [ChronoMagic](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) datasets in the [MagicTime](https://arxiv.org/abs/2404.05014) paper.
3 |
4 | ## 🗝️ Usage
5 |
6 | ```bash
7 | #!/bin/bash
8 |
9 | # Global variables
10 | INPUT_FOLDER="./step_0"
11 | OUTPUT_FOLDER_STEP_1="./step_1"
12 | API_KEY="XXX"
13 | NUM_WORKERS=8
14 |
15 | # File paths
16 | FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json"
17 | GROUP_FRAMES_FILE="./2_1_temp_group_frames.json"
18 | UPDATED_FRAME_CAPTION_FILE="./2_2_updated_gpt_frames_caption.json"
19 | UNMATCHED_FRAME_CAPTION_FILE="./2_2_temp_unmatched_gpt_frames_caption.json"
20 | UNORDERED_FRAME_CAPTION_FILE="./2_2_temp_unordered_gpt_frames_caption.json"
21 | FINAL_USEFUL_FRAME_CAPTION_FILE="./2_2_final_useful_gpt_frames_caption.json"
22 | VIDEO_CAPTION_FILE="./3_1_gpt_video_caption.json"
23 | UNMATCHED_VIDEO_CAPTION_FILE="./3_2_temp_unmatched_gpt_video_caption.json"
24 | EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE="./3_2_temp_exclude_by_frame_gpt_video_caption.json"
25 | FINAL_USEFUL_VIDEO_CAPTION_FILE="./3_2_final_useful_gpt_video_caption.json"
26 | FINAL_CSV_FILE="./all_clean_data.csv"
27 |
28 | # Step 1: Extract and resize frames
29 | python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1"
30 |
31 | # Step 2.1: Generate frame captions using GPT-4V
32 | python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \
33 | --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1"
34 |
35 | # Step 2.2: Preprocess frame captions
36 | python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \
37 | --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \
38 | --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE"
39 |
40 | # Step 3.1: Generate concise video captions using GPT-4V
41 | python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \
42 | --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE"
43 |
44 | # Optional: Generate detailed video captions (uncomment to enable)
45 | # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \
46 | # --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE"
47 |
48 | # Step 3.2: Preprocess video captions
49 | python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \
50 | --updated_file_path "$VIDEO_CAPTION_FILE" --unmatched_data_path "$UNMATCHED_VIDEO_CAPTION_FILE" \
51 | --exclude_by_frame_data_path "$EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE" --final_useful_data_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE"
52 |
53 | # Step 4: Create the final dataset in WebVid format
54 | python step4_1_create_webvid_format.py --caption_file_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" \
55 | --output_csv_file_path "$FINAL_CSV_FILE"
56 | ```
57 |
--------------------------------------------------------------------------------
/data_preprocess/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Global variables
4 | INPUT_FOLDER="./step_0"
5 | OUTPUT_FOLDER_STEP_1="./step_1"
6 | API_KEY="XXX"
7 | NUM_WORKERS=8
8 |
9 | # File paths
10 | FRAME_CAPTION_FILE="./2_1_gpt_frames_caption.json"
11 | GROUP_FRAMES_FILE="./2_1_temp_group_frames.json"
12 | UPDATED_FRAME_CAPTION_FILE="./2_2_updated_gpt_frames_caption.json"
13 | UNMATCHED_FRAME_CAPTION_FILE="./2_2_temp_unmatched_gpt_frames_caption.json"
14 | UNORDERED_FRAME_CAPTION_FILE="./2_2_temp_unordered_gpt_frames_caption.json"
15 | FINAL_USEFUL_FRAME_CAPTION_FILE="./2_2_final_useful_gpt_frames_caption.json"
16 | VIDEO_CAPTION_FILE="./3_1_gpt_video_caption.json"
17 | UNMATCHED_VIDEO_CAPTION_FILE="./3_2_temp_unmatched_gpt_video_caption.json"
18 | EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE="./3_2_temp_exclude_by_frame_gpt_video_caption.json"
19 | FINAL_USEFUL_VIDEO_CAPTION_FILE="./3_2_final_useful_gpt_video_caption.json"
20 | FINAL_CSV_FILE="./all_clean_data.csv"
21 |
22 | # Step 1: Extract and resize frames
23 | python step0_extract_frame_resize.py --input_folder "$INPUT_FOLDER" --output_folder "$OUTPUT_FOLDER_STEP_1"
24 |
25 | # Step 2.1: Generate frame captions using GPT-4V
26 | python step2_1_GPT4V_frame_caption.py --api_key "$API_KEY" --num_workers "$NUM_WORKERS" \
27 | --output_file "$FRAME_CAPTION_FILE" --group_frames_file "$GROUP_FRAMES_FILE" --image_directories "$OUTPUT_FOLDER_STEP_1"
28 |
29 | # Step 2.2: Preprocess frame captions
30 | python step2_2_preprocess_frame_caption.py --file_path "$FRAME_CAPTION_FILE" \
31 | --updated_file_path "$UPDATED_FRAME_CAPTION_FILE" --unmatched_file_path "$UNMATCHED_FRAME_CAPTION_FILE" \
32 | --unordered_file_path "$UNORDERED_FRAME_CAPTION_FILE" --final_useful_data_file_path "$FINAL_USEFUL_FRAME_CAPTION_FILE"
33 |
34 | # Step 3.1: Generate concise video captions using GPT-4V
35 | python step3_1_GPT4V_video_caption_concise.py --num_workers "$NUM_WORKERS" \
36 | --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE"
37 |
38 | # Optional: Generate detailed video captions (uncomment to enable)
39 | # python step3_1_GPT4V_video_caption_detail.py --num_workers "$NUM_WORKERS" \
40 | # --input_file "$FINAL_USEFUL_FRAME_CAPTION_FILE" --output_file "$VIDEO_CAPTION_FILE"
41 |
42 | # Step 3.2: Preprocess video captions
43 | python step3_2_preprocess_video_caption.py --file_path "$VIDEO_CAPTION_FILE" \
44 | --updated_file_path "$VIDEO_CAPTION_FILE" --unmatched_data_path "$UNMATCHED_VIDEO_CAPTION_FILE" \
45 | --exclude_by_frame_data_path "$EXCLUDE_BY_FRAME_VIDEO_CAPTION_FILE" --final_useful_data_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE"
46 |
47 | # Step 4: Create the final dataset in WebVid format
48 | python step4_1_create_webvid_format.py --caption_file_path "$FINAL_USEFUL_VIDEO_CAPTION_FILE" \
49 | --output_csv_file_path "$FINAL_CSV_FILE"
--------------------------------------------------------------------------------
/data_preprocess/step0_extract_frame_resize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import argparse
5 |
6 |
7 | def resize_frame(frame, short_edge=256):
8 | height, width = frame.shape[:2]
9 | if min(height, width) <= short_edge:
10 | return frame
11 | else:
12 | scale = short_edge / width if height > width else short_edge / height
13 | new_width = int(width * scale)
14 | new_height = int(height * scale)
15 | resized_frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
16 | return resized_frame
17 |
18 | def extract_frames(video_path, output_folder, num_frames=8):
19 | cap = cv2.VideoCapture(video_path)
20 | if not cap.isOpened():
21 | print(f"Error opening video file {video_path}")
22 | return
23 |
24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25 | frames_to_capture = set([0, total_frames - 1])
26 | frames_interval = (total_frames - 1) // (num_frames - 1)
27 | for i in range(1, num_frames - 1):
28 | frames_to_capture.add(i * frames_interval)
29 |
30 | count = 0
31 | while True:
32 | ret, frame = cap.read()
33 | if not ret:
34 | break
35 |
36 | if count in frames_to_capture:
37 | resized_frame = resize_frame(frame)
38 | frame_name = f"{os.path.splitext(os.path.basename(video_path))[0]}_frame{count}.png"
39 | output_path = os.path.join(output_folder, frame_name)
40 | cv2.imwrite(output_path, resized_frame)
41 | print(f"Saved {output_path}")
42 |
43 | count += 1
44 |
45 | cap.release()
46 |
47 | def process_all_videos(folder_path, output_folder):
48 | if not os.path.exists(output_folder):
49 | os.makedirs(output_folder)
50 |
51 | video_files = [f for f in os.listdir(folder_path) if f.endswith((".mp4", ".avi", ".mov"))]
52 | total_videos = len(video_files)
53 | skipped_videos = 0
54 |
55 | print(f"Total videos to check: {total_videos}")
56 |
57 | for filename in video_files:
58 | video_name = os.path.splitext(filename)[0]
59 | video_related_images = glob.glob(os.path.join(output_folder, f"{video_name}_frame*.png"))
60 |
61 | if len(video_related_images) == 8:
62 | print(f"Skipping {filename}, already processed.")
63 | skipped_videos += 1
64 | continue
65 |
66 | # If not 8 images, delete existing ones
67 | for img in video_related_images:
68 | os.remove(img)
69 | print(f"Deleted {img}")
70 |
71 | video_path = os.path.join(folder_path, filename)
72 | print(f"Processing {filename}...")
73 | extract_frames(video_path, output_folder)
74 |
75 | print(f"Skipped {skipped_videos} videos that were already processed.")
76 | print(f"Processed {total_videos - skipped_videos} new or incomplete videos.")
77 |
78 | if __name__ == "__main__":
79 | # Set up argument parser
80 | parser = argparse.ArgumentParser(description="Batch process video files")
81 | parser.add_argument("--input_folder", type=str, default='./step_0', help="Path to the input folder containing videos")
82 | parser.add_argument("--output_folder", type=str, default='./step_1', help="Path to the output folder for processed videos")
83 |
84 | # Parse command-line arguments
85 | args = parser.parse_args()
86 |
87 | # Call the video processing function
88 | process_all_videos(args.input_folder, args.output_folder)
--------------------------------------------------------------------------------
/data_preprocess/step2_1_GPT4V_frame_caption.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import base64
5 | import argparse
6 | from tqdm import tqdm
7 | from openai import OpenAI
8 | from threading import Lock
9 | from concurrent.futures import ThreadPoolExecutor, as_completed
10 | from tenacity import retry, wait_exponential, stop_after_attempt
11 |
12 |
13 | txt_prompt = '''
14 | Suppose you are a data annotator, specialized in generating captions for time-lapse videos. You will be supplied with eight key frames extracted from a video, each with a filename labeled with its position in the video sequence. Your task is to generate a caption for each frame, focusing on the primary subject and integrating all discernible elements. Note: These captions should be brief and concise, avoiding redundancy.
15 |
16 | Your analysis should demonstrate a deep understanding of real-world physics, encompassing aspects such as gravity and elasticity, and align with the principles of perspective geometry in photography. Ensure object identification consistency across all frames, even if an object is temporarily out of sight. Employ logical deductions to bridge any informational gaps. Begin each caption with a brief reasoning statement, showcasing your analytical approach. For guidance on the expected format, refer to the provided examples:
17 |
18 | Brief Reasoning Statement: The images provided are sequential frames from a time-lapse video depicting the blooming stages of a yellow flower, likely a ranunculus. The sequence is forward, showing a natural progression from bud to full bloom. Time-related information is not included in these frames. I will describe each frame accordingly.
19 | "[_2p6vHyth14]": {
20 | "Reasoning": [
21 | "Frame 0: This is the first frame, starting the sequence. The flower is in its initial stages, with petals tightly closed.",
22 | "Frame 224: The petals appear slightly more open than in the first frame, indicating the progression of blooming.",
23 | "Frame 448: The bloom has progressed further; petals are more open than in the previous frame, suggesting the continuation of the blooming process.",
24 | "Frame 672: Continuity in the blooming process is evident, with petals unfurling more than in the last frame.",
25 | "Frame 896: The flower is more open than in frame 672, indicating an advanced stage of the blooming process.",
26 | "Frame 1120: The flower is nearing full bloom, with a majority of the petals open and the inner ones starting to loosen.",
27 | "Frame 1344: The blooming process is almost complete, with the flower more open than in frame 1120 and the center more visible.",
28 | "Frame 1570: This final frame likely represents the peak of the bloom, with the flower fully open and all petals relaxed."
29 | ],
30 | "Captioning": [
31 | "Frame 0: Closed yellow ranunculus bud amidst green foliage.",
32 | "Frame 224: Yellow ranunculus bud beginning to open, with green sepals visible.",
33 | "Frame 448: Opening yellow ranunculus with distinct petal layers.",
34 | "Frame 672: Further unfurled yellow ranunculus, petals spreading outward.",
35 | "Frame 896: Half-open yellow ranunculus, with inner petals still tightly clustered.",
36 | "Frame 1120: Nearly fully bloomed yellow ranunculus, with central petals loosening.",
37 | "Frame 1344: Yellow ranunculus in full bloom, center clearly visible amidst open petals.",
38 | "Frame 1570: Fully bloomed yellow ranunculus with a fully visible center and relaxed petals."
39 | ]
40 | }
41 |
42 | Brief Reasoning Statement: The images show the germination and growth process of a plant, identified as spinach, over a span of 46 days. This time-lapse video captures the transformation from soil to a fully developed plant in a forward sequence. Time-related information is present, indicating the duration of the captured growth process. I will describe each frame accordingly.
43 | "[pVmX1v1hDc]_0001": {
44 | "Reasoning": [
45 | "Frame 0: This is the initial stage where the soil is moist, likely right after sowing the seeds.",
46 | "Frame 69: The soil surface shows signs of disturbance, possibly from seeds beginning to germinate.",
47 | "Frame 138: Germination has occurred, evident from the emergence of seedlings breaking through the soil.",
48 | "Frame 207: The seedlings have elongated and the first true leaves are beginning to form.",
49 | "Frame 276: Growth is evident with larger true leaves, and the plant is entering the vegetative stage.",
50 | "Frame 345: The plants are more developed with a denser leaf canopy, indicating healthy vegetative growth.",
51 | "Frame 414: The spinach plants are fully developed with large leaves, ready for harvesting.",
52 | "Frame 485: The plants are at full maturity with a thick canopy of leaves, showing the complete growth cycle."
53 | ],
54 | "Captioning": [
55 | "Frame 0: Moist soil on Day 1 after sowing spinach seeds.",
56 | "Frame 69: Soil surface showing early signs of spinach seed germination on Day 6.",
57 | "Frame 138: Spinach seedlings emerging from soil on Day 10.",
58 | "Frame 207: Elongated spinach seedlings with first true leaves on Day 16.",
59 | "Frame 276: Spinach showing significant leaf growth on Day 24.",
60 | "Frame 345: Denser and larger spinach leaves visible on Day 31.",
61 | "Frame 414: Mature spinach plants with large leaves ready for harvest on Day 39.",
62 | "Frame 485: Thick canopy of mature spinach leaves on Day 46."
63 | ]
64 | }
65 |
66 | {Brief Reasoning Statement: Must include time-related information and description of forward processes}
67 | "{Enter the prefix of the image to represent the id}": {
68 | "Reasoning": [
69 | " ",
70 | " ",
71 | " ",
72 | " ",
73 | " ",
74 | " ",
75 | " ",
76 | " "
77 | ],
78 | "Captioning": [
79 | " ",
80 | " ",
81 | " ",
82 | " ",
83 | " ",
84 | " ",
85 | " ",
86 | " "
87 | ]
88 | }
89 |
90 | Attention: Do not reply outside the example template! Below are the video title and input frames:
91 | '''
92 |
93 | # Global lock for thread-safe file operations
94 | file_lock = Lock()
95 |
96 | # Function to get all image filenames in the specified directory
97 | def get_image_filenames(directory):
98 | image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
99 | return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) and os.path.splitext(f)[1].lower() in image_extensions]
100 |
101 | # Function to parse the video ID from the image file name
102 | def parse_video_id(filename):
103 | match = re.match(r'(.+)_frame\d+\.png', filename)
104 | return match.group(1) if match else None
105 |
106 | # Function to convert image to base64
107 | def image_b64(image_path):
108 | with open(image_path, "rb") as f:
109 | return base64.b64encode(f.read()).decode('utf-8')
110 |
111 | # Function to group images
112 | def group_images_by_video_id(filenames):
113 | images_by_video = {}
114 | for filename in tqdm(filenames, desc="Grouping images"):
115 | video_id = parse_video_id(filename)
116 | if video_id:
117 | if video_id not in images_by_video:
118 | images_by_video[video_id] = []
119 | images_by_video[video_id].append(filename)
120 |
121 | valid_groups = {video_id: images for video_id, images in images_by_video.items() if len(images) == 8}
122 | return valid_groups
123 |
124 | # Function to create prompts for the GPT-4 Vision API
125 | def create_prompts(grouped_images, image_directory, txt_prompt):
126 | prompts = {}
127 | for video_id, group in tqdm(grouped_images.items(), desc="Creating prompts"):
128 | # Initialize the prompt with the given text prompt
129 | prompt = [{"type": "text", "text": txt_prompt}]
130 |
131 | # Append information about each image in the group
132 | for image_name in group:
133 | image_path = os.path.join(image_directory, image_name.strip())
134 | b64_image = image_b64(image_path)
135 | prompt.append({"type": "text", "text": image_name.strip()})
136 | prompt.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_image}"}})
137 |
138 | prompts[video_id] = prompt
139 | return prompts
140 |
141 | def has_been_processed(video_id, output_file):
142 | with file_lock:
143 | if os.path.exists(output_file):
144 | with open(output_file, 'r') as f:
145 | data = json.load(f)
146 | if video_id in data:
147 | print(f"Video ID {video_id} has already been processed.")
148 | return True
149 | return False
150 |
151 | def extract_frame_number(filename):
152 | # Extract the number after 'frame' and convert to integer
153 | return int(filename.split('_frame')[-1].split('.')[0])
154 |
155 | def load_existing_results(file_path):
156 | if os.path.exists(file_path):
157 | with open(file_path, 'r') as file:
158 | print(f"Loading existing results from {file_path}")
159 | return json.load(file)
160 | else:
161 | print(f"No existing results file found at {file_path}. Creating a new file.")
162 | with open(file_path, 'w') as file:
163 | empty_data = {}
164 | json.dump(empty_data, file)
165 | return empty_data
166 |
167 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100))
168 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None):
169 | client = OpenAI(api_key=api_key)
170 | chat_completion = client.chat.completions.create(
171 | model=model_name,
172 | messages=[
173 | {
174 | "role": "user",
175 | "content": prompt,
176 | }
177 | ],
178 | max_tokens=2048,
179 | )
180 | print(chat_completion)
181 | return chat_completion.choices[0].message.content
182 |
183 | def save_output(video_id, prompt, output_file, api_key):
184 | if not has_been_processed(video_id, output_file):
185 | result = call_gpt(prompt, api_key=api_key)
186 | with file_lock:
187 | with open(output_file, 'r+') as f:
188 | # Read the current data and update it
189 | data = json.load(f)
190 | data[video_id] = result
191 | f.seek(0) # Rewind file to the beginning
192 | json.dump(data, f, indent=4)
193 | f.truncate() # Truncate file to new size
194 | print(f"Processed and saved output for Video ID {video_id}")
195 |
196 | def main(num_workers, all_prompts, output_file, api_key):
197 | # Load existing results
198 | existing_results = load_existing_results(output_file)
199 |
200 | # Filter prompts for video IDs that have not been processed
201 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results}
202 | if not unprocessed_prompts:
203 | print("No unprocessed video IDs found. All prompts have already been processed.")
204 | return
205 |
206 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.")
207 |
208 | progress_bar = tqdm(total=len(unprocessed_prompts))
209 |
210 | with ThreadPoolExecutor(max_workers=num_workers) as executor:
211 | future_to_index = {
212 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id
213 | for video_id, prompt in unprocessed_prompts.items()
214 | }
215 |
216 | for future in as_completed(future_to_index):
217 | progress_bar.update(1)
218 | try:
219 | future.result()
220 | except Exception as e:
221 | print(f"Error processing video ID {future_to_index[future]}: {e}")
222 |
223 | progress_bar.close()
224 |
225 | if __name__ == "__main__":
226 | # Set up argument parser
227 | parser = argparse.ArgumentParser(description="Process video frame captions.")
228 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.")
229 | parser.add_argument("--num_workers", type=int, default=6, help="Number of worker threads for processing.")
230 | parser.add_argument("--output_file", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the output JSON file.")
231 | parser.add_argument("--group_frames_file", type=str, default="./2_1_temp_group_frames.json", help="Path to save grouped frame metadata.")
232 | parser.add_argument("--image_directories", type=str, nargs="+", default=["./step_1"], help="List of directories containing images.")
233 |
234 | # Parse command-line arguments
235 | args = parser.parse_args()
236 |
237 | all_prompts = {}
238 | all_grouped_images = {}
239 |
240 | # Process each image directory
241 | for directory in args.image_directories:
242 | filenames = get_image_filenames(directory)
243 | grouped_images = group_images_by_video_id(filenames)
244 |
245 | # Sort images within each video group
246 | for video_id in grouped_images:
247 | grouped_images[video_id].sort(key=extract_frame_number)
248 |
249 | all_grouped_images.update(grouped_images) # Merge into a single dictionary
250 |
251 | # Generate prompts
252 | prompts = create_prompts(grouped_images, directory, txt_prompt)
253 | all_prompts.update(prompts)
254 |
255 | # Save grouped images metadata
256 | with open(args.group_frames_file, 'w') as file:
257 | json.dump(all_grouped_images, file, indent=4)
258 |
259 | # Execute main processing function
260 | main(args.num_workers, all_prompts, args.output_file, args.api_key)
--------------------------------------------------------------------------------
/data_preprocess/step2_2_preprocess_frame_caption.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import argparse
4 |
5 | def load_json(file_path):
6 | """Load and return the content of a JSON file."""
7 | with open(file_path, 'r') as file:
8 | return json.load(file)
9 |
10 | def save_json(data, file_path):
11 | """Save data to a JSON file."""
12 | with open(file_path, 'w') as file:
13 | json.dump(data, file, ensure_ascii=False, indent=4)
14 |
15 | def process_frame_caption(file_path):
16 | """Process frame captions and save matched data."""
17 | data = load_json(file_path)
18 | matched_data = {}
19 | unmatched_data = {}
20 | for key, value in data.items():
21 | brief_reasoning_match = re.search(r'Brief Reasoning Statement: (.*?)(?:\n\n|\n)', value, re.DOTALL)
22 | reasoning_match = re.search(r'"Reasoning": \[(.*?)\]', value, re.DOTALL)
23 | captioning_match = re.search(r'"Captioning": \[(.*?)\]', value, re.DOTALL)
24 | if brief_reasoning_match and reasoning_match and captioning_match:
25 | brief_reasoning = brief_reasoning_match.group(1).strip()
26 | reasoning_list = re.findall(r'"(.*?)"(?:,|$)', reasoning_match.group(1))
27 | captioning_list = re.findall(r'"(.*?)"(?:,|$)', captioning_match.group(1))
28 | matched_data[key] = {
29 | "Video_Reasoning": brief_reasoning,
30 | "Frame_Reasoning": reasoning_list,
31 | "Frame_Captioning": captioning_list
32 | }
33 | else:
34 | unmatched_data[key] = value
35 | return matched_data, unmatched_data
36 |
37 | def is_disordered(section):
38 | frames = []
39 | for entry in section:
40 | try:
41 | # Extracting the frame number
42 | frame_num = int(entry.split(':')[0].split(' ')[1])
43 | frames.append(frame_num)
44 | except ValueError:
45 | # If parsing fails, skip this entry
46 | continue
47 | return not all(earlier <= later for earlier, later in zip(frames, frames[1:]))
48 |
49 | def find_disorder(data):
50 | """Identify entries with unordered frames."""
51 | unordered_records = {}
52 | ordered_records = {}
53 |
54 | for key, value in data.items():
55 | for section_name in ['Frame_Reasoning', 'Frame_Captioning']:
56 | section = value.get(section_name, [])
57 | if is_disordered(section):
58 | unordered_records[key] = value
59 | break
60 | else:
61 | ordered_records[key] = value
62 | return ordered_records, unordered_records
63 |
64 | def remove_disorder(data, unordered_data):
65 | """Remove disordered entries from the dataset."""
66 | unordered_ids = set(unordered_data.keys())
67 | ordered_json = {k: v for k, v in data.items() if k not in unordered_ids}
68 | return ordered_json
69 |
70 | def remove_unmatch_records(data, unmatched_data):
71 | """
72 | Removes records from gpt_results if their ID exists in disordered_records.
73 | :param data: dict, the data from gpt_results.json
74 | :return: dict, the updated data with matching records removed
75 | """
76 | unmatch_ids = set(unmatched_data.keys())
77 | matched_json = {id_: value for id_, value in data.items() if id_ not in unmatch_ids}
78 | return matched_json
79 |
80 | def merge_json_files(info_data, caption_data):
81 | # Load info and caption data from JSON files
82 | # with open(info_file, 'r') as file:
83 | # info_data = json.load(file)
84 | # with open(caption_file, 'r') as file:
85 | # caption_data = json.load(file)
86 |
87 | # Merge info into caption data based on matching key prefixes
88 | for caption_key in caption_data:
89 | for info_key in info_data:
90 | if caption_key.startswith(info_key):
91 | # Update the caption entry with info data
92 |
93 | # caption_data[caption_key].update(info_data[info_key])
94 |
95 | selected_info = {key: info_data[info_key][key] for key in ['title'] if
96 | key in info_data[info_key]}
97 | caption_data[caption_key].update(selected_info)
98 |
99 | break
100 |
101 | # Save merged data to a new JSON file
102 | # with open(output_file, 'w') as file:
103 | # json.dump(caption_data, file)
104 | return caption_data
105 |
106 | if __name__ == "__main__":
107 | # Set up argument parser
108 | parser = argparse.ArgumentParser(description="Process GPT4V frame captions and clean up data.")
109 | parser.add_argument("--file_path", type=str, default="./2_1_gpt_frames_caption.json", help="Path to the input JSON file.")
110 | parser.add_argument("--updated_file_path", type=str, default="./2_2_updated_gpt_frames_caption.json", help="Path to save the updated JSON file.")
111 | parser.add_argument("--unmatched_file_path", type=str, default="./2_2_temp_unmatched_gpt_frames_caption.json", help="Path to save unmatched records.")
112 | parser.add_argument("--unordered_file_path", type=str, default="./2_2_temp_unordered_gpt_frames_caption.json", help="Path to save unordered records.")
113 | parser.add_argument("--final_useful_data_file_path", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to save the final cleaned data.")
114 |
115 | # Parse command-line arguments
116 | args = parser.parse_args()
117 |
118 | # Processing steps
119 | matched_data, unmatched_data = process_frame_caption(args.file_path)
120 | ordered_records, unordered_records = find_disorder(matched_data)
121 |
122 | # Clean JSON by removing unmatched and unordered records
123 | updated_json = remove_unmatch_records(remove_disorder(load_json(args.file_path), unordered_records), unmatched_data)
124 |
125 | # Final useful data (can be merged with additional info if needed)
126 | final_useful_data = ordered_records
127 |
128 | # Print stats
129 | print(f"Number of Unmatched Records (GPT4V_Frame): {len(unmatched_data)}")
130 | print(f"Number of Unordered Records (GPT4V_Frame): {len(unordered_records)}")
131 | print(f"Number of Final Useful Records (GPT4V_Frame): {len(final_useful_data)}")
132 |
133 | # Save the processed results
134 | if len(unmatched_data) != 0 or len(unordered_records) != 0:
135 | save_json(updated_json, args.updated_file_path)
136 | print(f"Found {len(unmatched_data)} unmatched records and {len(unordered_records)} unordered records!")
137 | print(f"Updated JSON file has been saved to {args.updated_file_path}. Please rerun GPT4V for captioning.")
138 | else:
139 | print(f"No unmatched/unordered records found! You can directly use {args.final_useful_data_file_path} for the next step.")
140 |
141 | # Save intermediate results
142 | save_json(unmatched_data, args.unmatched_file_path)
143 | save_json(unordered_records, args.unordered_file_path)
144 | save_json(final_useful_data, args.final_useful_data_file_path)
--------------------------------------------------------------------------------
/data_preprocess/step3_1_GPT4V_video_caption_concise.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 | from openai import OpenAI
6 | from threading import Lock
7 | from tenacity import retry, wait_exponential, stop_after_attempt
8 | from concurrent.futures import ThreadPoolExecutor, as_completed
9 |
10 |
11 | txt_prompt = '''
12 | Imagine you're an expert data annotator with a specialization in summarizing time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a concise summary for the given time-lapse video.
13 |
14 | Since only textual information is given, you can employ logical deductions to bridge any informational gaps if necessary. For guidance on the expected output format and content length (no more than 70 words), refer to the provided examples:
15 |
16 | "Video_Summary": Time-lapse of a ciplukan fruit growing from a small bud to a mature, rounded form among leaves, gradually enlarging and smoothing out by the video's end.
17 |
18 | "Video_Summary": Time-lapse of red onion bulbs sprouting and growing over 10 days: starting dormant, developing shoots and roots by Day 2, significant growth by Day 6, and full development by Day 10.
19 |
20 | "Video_Summary": "{Video Summary}"
21 |
22 | Attention: Do not reply outside the example template! The process of reasoning and thinking should not be included in the {Video Summary}! Do not use words similar to by frame or at frame! Below are the Video, Video_Reasoning, Frame_Reasoning and Frame_Captioning.
23 | '''
24 |
25 | # Global lock for thread-safe file operations
26 | file_lock = Lock()
27 |
28 | # Function to create prompts for the GPT-4 Vision API
29 | def create_prompts(txt_prompt, data):
30 | prompts = {}
31 | for video_id, value in tqdm(data.items(), desc="Creating prompts"):
32 | prompt = [{"type": "text", "text": txt_prompt}]
33 | prompt.append({"type": "text", "text": f'''The "Video_Reasoning" is: {value['Video_Reasoning']}'''})
34 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Reasoning" are: {value['Frame_Reasoning']}'''})
35 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Captioning" are: {value['Frame_Captioning']}'''})
36 | prompts[video_id] = prompt
37 | return prompts
38 |
39 | def has_been_processed(video_id, output_file):
40 | with file_lock:
41 | if os.path.exists(output_file):
42 | with open(output_file, 'r') as f:
43 | data = json.load(f)
44 | if video_id in data:
45 | print(f"Video ID {video_id} has already been processed.")
46 | return True
47 | return False
48 |
49 | def load_existing_results(file_path):
50 | if os.path.exists(file_path):
51 | with open(file_path, 'r') as file:
52 | print(f"Loading existing results from {file_path}")
53 | return json.load(file)
54 | else:
55 | print(f"No existing results file found at {file_path}. Creating a new file.")
56 | with open(file_path, 'w') as file:
57 | empty_data = {}
58 | json.dump(empty_data, file)
59 | return empty_data
60 |
61 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100))
62 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None):
63 | client = OpenAI(api_key=api_key)
64 | chat_completion = client.chat.completions.create(
65 | model=model_name,
66 | messages=[
67 | {
68 | "role": "user",
69 | "content": prompt,
70 | }
71 | ],
72 | max_tokens=1024,
73 | )
74 | return chat_completion.choices[0].message.content
75 |
76 | def save_output(video_id, prompt, output_file, api_key):
77 | if not has_been_processed(video_id, output_file):
78 | result = call_gpt(prompt, api_key=api_key)
79 | with file_lock:
80 | with open(output_file, 'r+') as f:
81 | # Read the current data and update it
82 | data = json.load(f)
83 | data[video_id] = result
84 | f.seek(0) # Rewind file to the beginning
85 | json.dump(data, f, indent=4)
86 | f.truncate() # Truncate file to new size
87 | print(f"Processed and saved output for Video ID {video_id}")
88 |
89 | def main(num_workers, all_prompts, output_file, api_key):
90 | # Load existing results
91 | existing_results = load_existing_results(output_file)
92 |
93 | # Filter prompts for video IDs that have not been processed
94 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results}
95 | if not unprocessed_prompts:
96 | print("No unprocessed video IDs found. All prompts have already been processed.")
97 | return
98 |
99 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.")
100 |
101 | progress_bar = tqdm(total=len(unprocessed_prompts))
102 |
103 | with ThreadPoolExecutor(max_workers=num_workers) as executor:
104 | future_to_index = {
105 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id
106 | for video_id, prompt in unprocessed_prompts.items()
107 | }
108 |
109 | for future in as_completed(future_to_index):
110 | progress_bar.update(1)
111 | try:
112 | future.result()
113 | except Exception as e:
114 | print(f"Error processing video ID {future_to_index[future]}: {e}")
115 |
116 | progress_bar.close()
117 |
118 | if __name__ == "__main__":
119 | # Set up argument parser
120 | parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.")
121 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.")
122 | parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.")
123 | parser.add_argument("--input_file", type=str, default="./2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.")
124 | parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.")
125 |
126 | # Parse command-line arguments
127 | args = parser.parse_args()
128 |
129 | # Load data from the input file
130 | with open(args.input_file, 'r') as file:
131 | data = json.load(file)
132 |
133 | # Generate prompts
134 | prompts = create_prompts(txt_prompt, data)
135 |
136 | # Execute main processing function
137 | main(args.num_workers, prompts, args.output_file, args.api_key)
--------------------------------------------------------------------------------
/data_preprocess/step3_1_GPT4V_video_caption_detail.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 | from openai import OpenAI
6 | from threading import Lock
7 | from concurrent.futures import ThreadPoolExecutor, as_completed
8 | from tenacity import retry, wait_exponential, stop_after_attempt
9 |
10 |
11 | txt_prompt = '''
12 | Imagine you are a data annotator, specialized in generating summaries for time-lapse videos. You will be supplied with "Video_Reasoning", "8_Key-Frames_Reasoning", and "8_Key-Frames_Captioning" from a video, your task is to craft a succinct and precise summary for the given time-lapse video. Note: The summary should efficiently encapsulate all discernible elements, particularly emphasizing the primary subject. It is important to indicate whether the video pertains to a forward or reverse sequence. Additionally, integrate any time-related aspects from the video into the summary.
13 |
14 | Since only textual information is given, you can employ logical deductions to bridge any informational gaps if necessary. For guidance on the expected output format, refer to the provided examples:
15 |
16 | "Video_Summary": "The time-lapse video showcases the growth and ripening process of strawberries in a forward sequence. The video starts with fully bloomed strawberry flowers, which then wilt slightly. As the video progresses, the yellow stamens recede, and the flowers continue to wilt. The white petals disappear, and the green immature strawberries become more prominent. The strawberries then grow in size, displaying a green color with some red hues. As the video continues, the strawberries gradually ripen, turning from green to a deep red color."
17 |
18 | "Video_Summary": "This time-lapse video succinctly documents the 50-day decomposition process of a pear in a forward sequence, from its fresh, ripe state on day 1 to a shrunken, moldy, and rotten form by day 50. Throughout the video, the pear's gradual deterioration is evident through increasing browning, the development of mold patches, and significant changes in color, texture, and structure."
19 |
20 | "Video_Summary": "The time-lapse video showcases a Halloween pumpkin's decomposition process in reverse. The video starts with a pumpkin in a highly decomposed state at 92 days post-carving and then counts down the days, reversing the process. The pumpkin gradually re-inflates, reducing the signs of wrinkling and drying, until it appears freshly carved at 1 day post-carving."
21 |
22 | "Video_Summary": "{Video Summary}"
23 |
24 | Attention: Do not reply outside the example template! The process of reasoning and thinking should not be included in the {Video Summary}! Do not use words similar to by frame or at frame! Below are the Video, Video_Reasoning, Frame_Reasoning and Frame_Captioning.
25 | '''
26 |
27 | # Global lock for thread-safe file operations
28 | file_lock = Lock()
29 |
30 | # Function to create prompts for the GPT-4 Vision API
31 | def create_prompts(txt_prompt, data):
32 | prompts = {}
33 | for video_id, value in tqdm(data.items(), desc="Creating prompts"):
34 | prompt = [{"type": "text", "text": txt_prompt}]
35 | prompt.append({"type": "text", "text": f'''The "Video_Reasoning" is: {value['Video_Reasoning']}'''})
36 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Reasoning" are: {value['Frame_Reasoning']}'''})
37 | prompt.append({"type": "text", "text": f'''The "8_Key-Frames_Captioning" are: {value['Frame_Captioning']}'''})
38 | prompts[video_id] = prompt
39 | return prompts
40 |
41 | def has_been_processed(video_id, output_file):
42 | with file_lock:
43 | if os.path.exists(output_file):
44 | with open(output_file, 'r') as f:
45 | data = json.load(f)
46 | if video_id in data:
47 | print(f"Video ID {video_id} has already been processed.")
48 | return True
49 | return False
50 |
51 | def load_existing_results(file_path):
52 | if os.path.exists(file_path):
53 | with open(file_path, 'r') as file:
54 | print(f"Loading existing results from {file_path}")
55 | return json.load(file)
56 | else:
57 | print(f"No existing results file found at {file_path}. Creating a new file.")
58 | with open(file_path, 'w') as file:
59 | empty_data = {}
60 | json.dump(empty_data, file)
61 | return empty_data
62 |
63 | @retry(wait=wait_exponential(multiplier=1, min=2, max=10), stop=stop_after_attempt(100))
64 | def call_gpt(prompt, model_name="gpt-4-vision-preview", api_key=None):
65 | client = OpenAI(api_key=api_key)
66 | chat_completion = client.chat.completions.create(
67 | model=model_name,
68 | messages=[
69 | {
70 | "role": "user",
71 | "content": prompt,
72 | }
73 | ],
74 | max_tokens=1024,
75 | )
76 | return chat_completion.choices[0].message.content
77 |
78 | def save_output(video_id, prompt, output_file, api_key):
79 | if not has_been_processed(video_id, output_file):
80 | result = call_gpt(prompt, api_key=api_key)
81 | with file_lock:
82 | with open(output_file, 'r+') as f:
83 | # Read the current data and update it
84 | data = json.load(f)
85 | data[video_id] = result
86 | f.seek(0) # Rewind file to the beginning
87 | json.dump(data, f, indent=4)
88 | f.truncate() # Truncate file to new size
89 | print(f"Processed and saved output for Video ID {video_id}")
90 |
91 | def main(num_workers, all_prompts, output_file, api_key):
92 | # Load existing results
93 | existing_results = load_existing_results(output_file)
94 |
95 | # Filter prompts for video IDs that have not been processed
96 | unprocessed_prompts = {vid: prompt for vid, prompt in all_prompts.items() if vid not in existing_results}
97 | if not unprocessed_prompts:
98 | print("No unprocessed video IDs found. All prompts have already been processed.")
99 | return
100 |
101 | print(f"Processing {len(unprocessed_prompts)} unprocessed video IDs.")
102 |
103 | progress_bar = tqdm(total=len(unprocessed_prompts))
104 |
105 | with ThreadPoolExecutor(max_workers=num_workers) as executor:
106 | future_to_index = {
107 | executor.submit(save_output, video_id, prompt, output_file, api_key): video_id
108 | for video_id, prompt in unprocessed_prompts.items()
109 | }
110 |
111 | for future in as_completed(future_to_index):
112 | progress_bar.update(1)
113 | try:
114 | future.result()
115 | except Exception as e:
116 | print(f"Error processing video ID {future_to_index[future]}: {e}")
117 |
118 | progress_bar.close()
119 |
120 | if __name__ == "__main__":
121 | # Set up argument parser
122 | parser = argparse.ArgumentParser(description="Generate video captions using GPT4V.")
123 | parser.add_argument("--api_key", type=int, default=None, help="OpenAI API key.")
124 | parser.add_argument("--num_workers", type=int, default=8, help="Number of worker threads for processing.")
125 | parser.add_argument("--input_file", type=str, default="2_2_final_useful_gpt_frames_caption.json", help="Path to the input JSON file.")
126 | parser.add_argument("--output_file", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the generated video captions.")
127 |
128 | # Parse command-line arguments
129 | args = parser.parse_args()
130 |
131 | # Load data from the input file
132 | with open(args.input_file, 'r') as file:
133 | data = json.load(file)
134 |
135 | # Generate prompts
136 | prompts = create_prompts(txt_prompt, data)
137 |
138 | # Execute main processing function
139 | main(args.num_workers, prompts, args.output_file, args.api_key)
--------------------------------------------------------------------------------
/data_preprocess/step3_2_preprocess_video_caption.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import argparse
4 |
5 |
6 | def process_json(file_path):
7 | with open(file_path, 'r') as file:
8 | data = json.load(file)
9 |
10 | matched_data = {}
11 | unmatched_data = {}
12 |
13 | for key, value in data.items():
14 | video_summary_match = re.search(r'"Video_Summary": (.*)', value)
15 |
16 | if video_summary_match:
17 | matched_data[key] = {
18 | "Video_GPT4_Caption": video_summary_match.group(1),
19 | }
20 | else:
21 | unmatched_data[key] = value
22 |
23 | return matched_data, unmatched_data
24 |
25 | def read_json_file(file_path):
26 | """Reads a JSON file and returns its content."""
27 | with open(file_path, 'r') as file:
28 | return json.load(file)
29 |
30 | def remove_by_Frame(data):
31 | # Initialize dictionaries for matched (to exclude) and unmatched data
32 | to_exclude = {}
33 | to_keep = {}
34 |
35 | # Pattern to identify "by Frame X" in the video summary
36 | pattern = re.compile(r'(by|at|in|on) Frame \d+', re.IGNORECASE)
37 |
38 | for key, value in data.items():
39 | # Assuming "Video_Summary" is a direct key in the value dictionary
40 | video_summary = value.get("Video_GPT4_Caption", "")
41 | # Check if "by Frame X" is in the video summary
42 | if pattern.search(video_summary):
43 | to_exclude[key] = value
44 | else:
45 | to_keep[key] = value
46 |
47 | return to_keep, to_exclude
48 |
49 | def remove_unmatch_records(gpt_data, unmatched_json_data):
50 | """
51 | Removes records from gpt_results if their ID exists in disordered_records.
52 | :param gpt_data: dict, the data from gpt_results.json
53 | :param disordered_ids: set, the set of IDs from disordered_records.json
54 | :return: dict, the updated gpt_data with matching records removed
55 | """
56 | disordered_ids = set(unmatched_json_data.keys())
57 | return {id_: value for id_, value in gpt_data.items() if id_ not in disordered_ids}
58 |
59 | def save_json_file(data, file_path):
60 | """Saves data to a JSON file."""
61 | with open(file_path, 'w') as file:
62 | json.dump(data, file, indent=4)
63 |
64 | def merge_json_files(info_data, caption_data):
65 | # Merge info into caption data based on matching key prefixes
66 | for caption_key in caption_data:
67 | for info_key in info_data:
68 | if caption_key.startswith(info_key):
69 | selected_info = {key: info_data[info_key][key] for key in ['title'] if
70 | key in info_data[info_key]}
71 | caption_data[caption_key].update(selected_info)
72 |
73 | break
74 | return caption_data
75 |
76 | if __name__ == "__main__":
77 | # Set up argument parser
78 | parser = argparse.ArgumentParser(description="Process GPT4V video captions and clean up data.")
79 | parser.add_argument("--file_path", type=str, default="./3_1_gpt_video_caption.json", help="Path to the input JSON file.")
80 | parser.add_argument("--updated_file_path", type=str, default="./3_1_gpt_video_caption.json", help="Path to save the updated JSON file.")
81 | parser.add_argument("--unmatched_data_path", type=str, default="./3_2_temp_unmatched_gpt_video_caption.json", help="Path to save unmatched records.")
82 | parser.add_argument("--exclude_by_frame_data_path", type=str, default="./3_2_temp_exclude_by_frame_gpt_video_caption.json", help="Path to save excluded records.")
83 | parser.add_argument("--final_useful_data_path", type=str, default="./3_2_final_useful_gpt_video_caption.json", help="Path to save the final cleaned data.")
84 |
85 | # Parse command-line arguments
86 | args = parser.parse_args()
87 |
88 | # Processing steps
89 | matched_data, unmatched_data = process_json(args.file_path)
90 | to_keep, to_exclude = remove_by_Frame(matched_data)
91 |
92 | # Clean JSON by removing unmatched and excluded records
93 | updated_json = remove_unmatch_records(remove_unmatch_records(read_json_file(args.file_path), unmatched_data), to_exclude)
94 |
95 | # Save intermediate results
96 | save_json_file(unmatched_data, args.unmatched_data_path)
97 | save_json_file(to_exclude, args.exclude_by_frame_data_path)
98 | save_json_file(to_keep, args.final_useful_data_path)
99 |
100 | # Print stats
101 | if len(unmatched_data) != 0 or len(to_exclude) != 0:
102 | save_json_file(updated_json, args.updated_file_path)
103 | print(f"Found {len(unmatched_data)} unmatched_data and {len(to_exclude)} exclude_by_frame_data!")
104 | print(f"Updated JSON file has been saved to {args.updated_file_path}. Please rerun GPT4V for captioning.")
105 | else:
106 | print(f"No unmatched_data and exclude_by_frame_data found! You can directly use {args.final_useful_data_path} for the next step.")
--------------------------------------------------------------------------------
/data_preprocess/step4_1_create_webvid_format.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import pandas as pd
4 |
5 |
6 | def merge_json_files_with_transmit_status(caption_file, output_file):
7 | # Load caption data from JSON file
8 | with open(caption_file, 'r', encoding='utf-8') as file:
9 | caption_data = json.load(file)
10 |
11 | # Extracting data and adding is_transmit status
12 | data = [{
13 | 'videoid': key,
14 | 'name': value['Video_GPT4_Caption'],
15 | 'is_transmit': '1' # N/A for videos not found in either category
16 | } for key, value in caption_data.items()]
17 |
18 | # Creating a DataFrame from the extracted data
19 | df = pd.DataFrame(data)
20 |
21 | # Saving the DataFrame as a CSV file
22 | df.to_csv(output_file, index=False)
23 |
24 | # Output the path to the saved CSV file
25 | return f"CSV file saved at: {output_file}"
26 |
27 | if __name__ == "__main__":
28 | # Set up argument parser
29 | parser = argparse.ArgumentParser(description="Convert GPT4V video captions JSON to CSV.")
30 | parser.add_argument("--caption_file_path", type=str, default="./3_2_final_useful_gpt_video_caption.json", help="Path to the input JSON caption file.")
31 | parser.add_argument("--output_csv_file_path", type=str, default="./all_clean_data.csv", help="Path to save the output CSV file.")
32 |
33 | # Parse command-line arguments
34 | args = parser.parse_args()
35 |
36 | # Process the JSON and convert it to CSV
37 | merge_json_files_with_transmit_status(args.caption_file_path, args.output_csv_file_path)
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python inference_magictime.py \
2 | --config sample_configs/RealisticVision.yaml
--------------------------------------------------------------------------------
/inference_cli.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python inference_magictime.py \
2 | --config sample_configs/RealisticVision.yaml \
3 | --human
--------------------------------------------------------------------------------
/inference_magictime.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import inspect
5 | import argparse
6 | import pandas as pd
7 | from omegaconf import OmegaConf
8 | from transformers import CLIPTextModel, CLIPTokenizer
9 | from diffusers import AutoencoderKL, DDIMScheduler
10 | from diffusers.utils.import_utils import is_xformers_available
11 | from huggingface_hub import snapshot_download
12 |
13 | from utils.unet import UNet3DConditionModel
14 | from utils.pipeline_magictime import MagicTimePipeline
15 | from utils.util import save_videos_grid
16 | from utils.util import load_weights
17 |
18 | @torch.no_grad()
19 | def main(args):
20 | *_, func_args = inspect.getargvalues(inspect.currentframe())
21 | func_args = dict(func_args)
22 |
23 | if 'counter' not in globals():
24 | globals()['counter'] = 0
25 | unique_id = globals()['counter']
26 | globals()['counter'] += 1
27 | savedir = None
28 | savedir = os.path.join(args.save_path, f"{unique_id}")
29 | while os.path.exists(savedir):
30 | unique_id = globals()['counter']
31 | globals()['counter'] += 1
32 | savedir = os.path.join(args.save_path, f"{unique_id}")
33 | os.makedirs(savedir, exist_ok=True)
34 | print(f"The results will be save to {savedir}")
35 |
36 | model_config = OmegaConf.load(args.config)[0]
37 | inference_config = OmegaConf.load(args.config)[1]
38 |
39 | if model_config.magic_adapter_s_path:
40 | print("Use MagicAdapter-S")
41 | if model_config.magic_adapter_t_path:
42 | print("Use MagicAdapter-T")
43 | if model_config.magic_text_encoder_path:
44 | print("Use Magic_Text_Encoder")
45 |
46 | tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer")
47 | text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda()
48 | vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda()
49 | unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet",
50 | unet_additional_kwargs=OmegaConf.to_container(
51 | inference_config.unet_additional_kwargs)).cuda()
52 |
53 | if is_xformers_available() and (not args.without_xformers):
54 | unet.enable_xformers_memory_efficient_attention()
55 |
56 | pipeline = MagicTimePipeline(
57 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
58 | scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
59 | ).to("cuda")
60 |
61 | pipeline = load_weights(
62 | pipeline,
63 | motion_module_path=model_config.get("motion_module", ""),
64 | dreambooth_model_path=model_config.get("dreambooth_path", ""),
65 | magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""),
66 | magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""),
67 | magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""),
68 | ).to("cuda")
69 |
70 | if args.human:
71 | sample_idx = 0
72 | while True:
73 | user_prompt = input("Enter your prompt (or type 'exit' to quit): ")
74 | if user_prompt.lower() == "exit":
75 | break
76 |
77 | random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
78 | torch.manual_seed(random_seed)
79 |
80 | print(f"current seed: {random_seed}")
81 | print(f"sampling {user_prompt} ...")
82 |
83 | sample = pipeline(
84 | user_prompt,
85 | negative_prompt = list(model_config.n_prompt),
86 | num_inference_steps = model_config.steps,
87 | guidance_scale = model_config.guidance_scale,
88 | width = model_config.W,
89 | height = model_config.H,
90 | video_length = model_config.L,
91 | ).videos
92 |
93 | prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10])
94 | save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.mp4")
95 | print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.mp4")
96 |
97 | sample_idx += 1
98 | else:
99 | default = True
100 | batch_size = args.batch_size
101 |
102 | if args.run_csv:
103 | print("run csv")
104 | default = False
105 | file_path = args.run_csv
106 | data = pd.read_csv(file_path)
107 | prompts = data['name'].tolist()
108 | videoids = data['videoid'].tolist()
109 | elif args.run_json:
110 | print("run json")
111 | default = False
112 | file_path = args.run_json
113 | with open(file_path, 'r') as file:
114 | data = json.load(file)
115 | prompts = []
116 | videoids = []
117 | senids = []
118 | for item in data['sentences']:
119 | prompts.append(item['caption'])
120 | videoids.append(item['video_id'])
121 | senids.append(item['sen_id'])
122 | elif args.run_txt:
123 | print("run txt")
124 | default = False
125 | file_path = args.run_txt
126 | with open(file_path, 'r') as file:
127 | prompts = [line.strip() for line in file.readlines()]
128 | videoids = [f"video_{i}" for i in range(len(prompts))]
129 | else:
130 | prompts = model_config.prompt
131 | videoids = [f"video_{i}" for i in range(len(prompts))]
132 |
133 | for i in range(0, len(prompts), batch_size):
134 | batch_prompts_raw = prompts[i : i + batch_size]
135 | batch_prompts = [prompt for prompt in batch_prompts_raw]
136 |
137 | if args.run_csv or args.run_json or args.run_txt or default:
138 | batch_videoids = videoids[i : i + batch_size]
139 | if args.run_json:
140 | batch_senids = senids[i : i + batch_size]
141 |
142 | flag = True
143 | for idx in range(len(batch_prompts)):
144 | if args.run_csv or args.run_txt or default:
145 | new_filename = f"{batch_videoids[idx]}.mp4"
146 | if args.run_json:
147 | new_filename = f"{batch_videoids[idx]}-{batch_senids[idx]}.mp4"
148 | if not os.path.exists(os.path.join(savedir, new_filename)):
149 | flag = False
150 | break
151 | if flag:
152 | print("skipping")
153 | continue
154 |
155 | n_prompts = list(model_config.n_prompt) * len(batch_prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
156 |
157 | random_seed = torch.randint(0, 2**32 - 1, (1,)).item()
158 | torch.manual_seed(random_seed)
159 |
160 | print(f"current seed: {random_seed}")
161 |
162 | results = pipeline(
163 | batch_prompts,
164 | negative_prompt = n_prompts,
165 | num_inference_steps = model_config.steps,
166 | guidance_scale = model_config.guidance_scale,
167 | width = model_config.W,
168 | height = model_config.H,
169 | video_length = model_config.L,
170 | ).videos
171 |
172 | for idx, sample in enumerate(results):
173 | if args.run_csv or args.run_txt or default:
174 | new_filename = f"{batch_videoids[idx]}.mp4"
175 | if args.run_json:
176 | new_filename = f"{batch_videoids[idx]}-{batch_senids[idx]}.mp4"
177 |
178 | save_videos_grid(sample.unsqueeze(0), f"{savedir}/{new_filename}")
179 | print(f"save to {savedir}/{new_filename}")
180 |
181 | OmegaConf.save(model_config, f"{savedir}/model_config.yaml")
182 |
183 | if __name__ == "__main__":
184 | parser = argparse.ArgumentParser()
185 | parser.add_argument("--config", type=str, required=True)
186 | parser.add_argument("--without-xformers", action="store_true")
187 | parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation")
188 | parser.add_argument("--run-csv", type=str, default=None)
189 | parser.add_argument("--run-json", type=str, default=None)
190 | parser.add_argument("--run-txt", type=str, default=None)
191 | parser.add_argument("--save-path", type=str, default="outputs")
192 | parser.add_argument("--batch-size", type=int, default=1)
193 |
194 | args = parser.parse_args()
195 | snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir="ckpts")
196 | main(args)
197 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.2
2 | torchvision==0.17.2
3 | torchaudio==2.2.2
4 | xformers==0.0.25.post1
5 | imageio==2.27.0
6 | imageio[ffmpeg]
7 | imageio[pyav]
8 | peft==0.9.0
9 | numpy==1.26.4
10 | ms-swift==2.0.0
11 | accelerate==0.28.0
12 | diffusers==0.11.1
13 | transformers==4.38.2
14 | huggingface_hub==0.25.2
15 | gradio==3.50.2
16 | gdown
17 | triton
18 | einops
19 | omegaconf
20 | safetensors
21 | spaces
--------------------------------------------------------------------------------
/sample_configs/RcnzCartoon.yaml:
--------------------------------------------------------------------------------
1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5"
2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt"
3 | dreambooth_path: "./ckpts/DreamBooth/RcnzCartoon.safetensors"
4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t"
6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder"
7 |
8 | H: 512
9 | W: 512
10 | L: 16
11 | seed: [1268480012, 3480796026, 3607977321, 1601344133]
12 | steps: 25
13 | guidance_scale: 8.5
14 |
15 | prompt:
16 | - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney."
17 | - "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney."
18 | - "Bean sprouts grow and mature from seeds."
19 | - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence."
20 |
21 | n_prompt:
22 | - "worst quality, low quality, letterboxed"
23 |
24 | - unet_additional_kwargs:
25 | use_inflated_groupnorm: true
26 | use_motion_module: true
27 | motion_module_resolutions:
28 | - 1
29 | - 2
30 | - 4
31 | - 8
32 | motion_module_mid_block: false
33 | motion_module_type: Vanilla
34 | motion_module_kwargs:
35 | num_attention_heads: 8
36 | num_transformer_block: 1
37 | attention_block_types:
38 | - Temporal_Self
39 | - Temporal_Self
40 | temporal_position_encoding: true
41 | temporal_position_encoding_max_len: 32
42 | temporal_attention_dim_div: 1
43 | zero_initialize: true
44 | noise_scheduler_kwargs:
45 | beta_start: 0.00085
46 | beta_end: 0.012
47 | beta_schedule: linear
48 | steps_offset: 1
49 | clip_sample: false
--------------------------------------------------------------------------------
/sample_configs/RealisticVision.yaml:
--------------------------------------------------------------------------------
1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5"
2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt"
3 | dreambooth_path: "./ckpts/DreamBooth/RealisticVisionV60B1_v51VAE.safetensors"
4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t"
6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder"
7 |
8 | H: 512
9 | W: 512
10 | L: 16
11 | seed: [1587796317, 2883629116, 3068368949, 2038801077]
12 | steps: 25
13 | guidance_scale: 8.5
14 |
15 | prompt:
16 | - "Time-lapse of dough balls transforming into bread rolls: Begins with smooth, proofed dough, gradually expands in early baking, becomes taut and voluminous, and finally browns and fully expands to signal the baking's completion."
17 | - "Time-lapse of cupcakes progressing through the baking process: starting from liquid batter in cupcake liners, gradually rising with the formation of domes, to fully baked cupcakes with golden, crackled domes."
18 | - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms."
19 | - "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms."
20 |
21 | n_prompt:
22 | - "worst quality, low quality, letterboxed"
23 |
24 | - unet_additional_kwargs:
25 | use_inflated_groupnorm: true
26 | use_motion_module: true
27 | motion_module_resolutions:
28 | - 1
29 | - 2
30 | - 4
31 | - 8
32 | motion_module_mid_block: false
33 | motion_module_type: Vanilla
34 | motion_module_kwargs:
35 | num_attention_heads: 8
36 | num_transformer_block: 1
37 | attention_block_types:
38 | - Temporal_Self
39 | - Temporal_Self
40 | temporal_position_encoding: true
41 | temporal_position_encoding_max_len: 32
42 | temporal_attention_dim_div: 1
43 | zero_initialize: true
44 | noise_scheduler_kwargs:
45 | beta_start: 0.00085
46 | beta_end: 0.012
47 | beta_schedule: linear
48 | steps_offset: 1
49 | clip_sample: false
--------------------------------------------------------------------------------
/sample_configs/ToonYou.yaml:
--------------------------------------------------------------------------------
1 | - pretrained_model_path: "./ckpts/Base_Model/stable-diffusion-v1-5"
2 | motion_module: "./ckpts/Base_Model/motion_module/motion_module.ckpt"
3 | dreambooth_path: "./ckpts/DreamBooth/ToonYou_beta6.safetensors"
4 | magic_adapter_s_path: "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
5 | magic_adapter_t_path: "./ckpts/Magic_Weights/magic_adapter_t"
6 | magic_text_encoder_path: "./ckpts/Magic_Weights/magic_text_encoder"
7 |
8 | H: 512
9 | W: 512
10 | L: 16
11 | seed: [3832738942, 153403692, 10789633, 1496541313]
12 | steps: 25
13 | guidance_scale: 8.5
14 |
15 | prompt:
16 | - "An ice cube is melting."
17 | - "A mesmerizing time-lapse showcasing the elegant unfolding of pink plum buds blossoms, capturing the gradual bloom from tightly sealed buds to fully open flowers."
18 | - "Time-lapse of a yellow ranunculus flower transitioning from a tightly closed bud to a fully bloomed state, with measured petal separation and unfurling observed across the sequence."
19 | - "Bean sprouts grow and mature from seeds."
20 |
21 | n_prompt:
22 | - "worst quality, low quality, letterboxed"
23 |
24 | - unet_additional_kwargs:
25 | use_inflated_groupnorm: true
26 | use_motion_module: true
27 | motion_module_resolutions:
28 | - 1
29 | - 2
30 | - 4
31 | - 8
32 | motion_module_mid_block: false
33 | motion_module_type: Vanilla
34 | motion_module_kwargs:
35 | num_attention_heads: 8
36 | num_transformer_block: 1
37 | attention_block_types:
38 | - Temporal_Self
39 | - Temporal_Self
40 | temporal_position_encoding: true
41 | temporal_position_encoding_max_len: 32
42 | temporal_attention_dim_div: 1
43 | zero_initialize: true
44 | noise_scheduler_kwargs:
45 | beta_start: 0.00085
46 | beta_end: 0.012
47 | beta_schedule: linear
48 | steps_offset: 1
49 | clip_sample: false
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os, csv, random
2 | import numpy as np
3 | from decord import VideoReader
4 | import torch
5 | import torchvision.transforms as transforms
6 | from torch.utils.data.dataset import Dataset
7 |
8 |
9 | class ChronoMagic(Dataset):
10 | def __init__(
11 | self,
12 | csv_path, video_folder,
13 | sample_size=512, sample_stride=4, sample_n_frames=16,
14 | is_image=False,
15 | is_uniform=True,
16 | ):
17 | with open(csv_path, 'r') as csvfile:
18 | self.dataset = list(csv.DictReader(csvfile))
19 | self.length = len(self.dataset)
20 |
21 | self.video_folder = video_folder
22 | self.sample_stride = sample_stride
23 | self.sample_n_frames = sample_n_frames
24 | self.is_image = is_image
25 | self.is_uniform = is_uniform
26 |
27 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
28 | self.pixel_transforms = transforms.Compose([
29 | transforms.RandomHorizontalFlip(),
30 | transforms.Resize(sample_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
31 | transforms.CenterCrop(sample_size),
32 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
33 | ])
34 |
35 | def _get_frame_indices_adjusted(self, video_length, n_frames):
36 | indices = list(range(video_length))
37 | additional_frames_needed = n_frames - video_length
38 |
39 | repeat_indices = []
40 | for i in range(additional_frames_needed):
41 | index_to_repeat = i % video_length
42 | repeat_indices.append(indices[index_to_repeat])
43 |
44 | all_indices = indices + repeat_indices
45 | all_indices.sort()
46 |
47 | return all_indices
48 |
49 | def _generate_frame_indices(self, video_length, n_frames, sample_stride, is_transmit):
50 | prob_execute_original = 1 if int(is_transmit) == 0 else 0
51 |
52 | # Generate a random number to decide which block of code to execute
53 | if random.random() < prob_execute_original:
54 | if video_length <= n_frames:
55 | return self._get_frame_indices_adjusted(video_length, n_frames)
56 | else:
57 | interval = (video_length - 1) / (n_frames - 1)
58 | indices = [int(round(i * interval)) for i in range(n_frames)]
59 | indices[-1] = video_length - 1
60 | return indices
61 | else:
62 | if video_length <= n_frames:
63 | return self._get_frame_indices_adjusted(video_length, n_frames)
64 | else:
65 | clip_length = min(video_length, (n_frames - 1) * sample_stride + 1)
66 | start_idx = random.randint(0, video_length - clip_length)
67 | return np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
68 |
69 | def get_batch(self, idx):
70 | video_dict = self.dataset[idx]
71 | videoid, name, is_transmit = video_dict['videoid'], video_dict['name'], video_dict['is_transmit']
72 |
73 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
74 | video_reader = VideoReader(video_dir, num_threads=0)
75 | video_length = len(video_reader)
76 |
77 | batch_index = self._generate_frame_indices(video_length, self.sample_n_frames, self.sample_stride, is_transmit) if not self.is_image else [random.randint(0, video_length - 1)]
78 |
79 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2) / 255.
80 | del video_reader
81 |
82 | if self.is_image:
83 | pixel_values = pixel_values[0]
84 |
85 | return pixel_values, name, videoid
86 |
87 | def __len__(self):
88 | return self.length
89 |
90 | def __getitem__(self, idx):
91 | while True:
92 | try:
93 | pixel_values, name, videoid = self.get_batch(idx)
94 | break
95 |
96 | except Exception as e:
97 | idx = random.randint(0, self.length-1)
98 |
99 | pixel_values = self.pixel_transforms(pixel_values)
100 | sample = dict(pixel_values=pixel_values, text=name, id=videoid)
101 | return sample
--------------------------------------------------------------------------------
/utils/pipeline_magictime.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/pipelines/pipeline_animation.py
2 |
3 | import torch
4 | import inspect
5 | import numpy as np
6 | from tqdm import tqdm
7 | from einops import rearrange
8 | from packaging import version
9 | from dataclasses import dataclass
10 | from typing import Callable, List, Optional, Union
11 | from transformers import CLIPTextModel, CLIPTokenizer
12 |
13 | from diffusers.utils import is_accelerate_available, deprecate, logging, BaseOutput
14 | from diffusers.configuration_utils import FrozenDict
15 | from diffusers.models import AutoencoderKL
16 | from diffusers.pipeline_utils import DiffusionPipeline
17 | from diffusers.schedulers import (
18 | DDIMScheduler,
19 | DPMSolverMultistepScheduler,
20 | EulerAncestralDiscreteScheduler,
21 | EulerDiscreteScheduler,
22 | LMSDiscreteScheduler,
23 | PNDMScheduler,
24 | )
25 |
26 | from .unet import UNet3DConditionModel
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 | @dataclass
31 | class MagicTimePipelineOutput(BaseOutput):
32 | videos: Union[torch.Tensor, np.ndarray]
33 |
34 | class MagicTimePipeline(DiffusionPipeline):
35 | _optional_components = []
36 |
37 | def __init__(
38 | self,
39 | vae: AutoencoderKL,
40 | text_encoder: CLIPTextModel,
41 | tokenizer: CLIPTokenizer,
42 | unet: UNet3DConditionModel,
43 | scheduler: Union[
44 | DDIMScheduler,
45 | PNDMScheduler,
46 | LMSDiscreteScheduler,
47 | EulerDiscreteScheduler,
48 | EulerAncestralDiscreteScheduler,
49 | DPMSolverMultistepScheduler,
50 | ],
51 | ):
52 | super().__init__()
53 |
54 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
55 | deprecation_message = (
56 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
57 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
58 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
59 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
60 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
61 | " file"
62 | )
63 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
64 | new_config = dict(scheduler.config)
65 | new_config["steps_offset"] = 1
66 | scheduler._internal_dict = FrozenDict(new_config)
67 |
68 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
69 | deprecation_message = (
70 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
71 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
72 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
73 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
74 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
75 | )
76 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
77 | new_config = dict(scheduler.config)
78 | new_config["clip_sample"] = False
79 | scheduler._internal_dict = FrozenDict(new_config)
80 |
81 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
82 | version.parse(unet.config._diffusers_version).base_version
83 | ) < version.parse("0.9.0.dev0")
84 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
85 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
86 | deprecation_message = (
87 | "The configuration file of the unet has set the default `sample_size` to smaller than"
88 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
89 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
90 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
91 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
92 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
93 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
94 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
95 | " the `unet/config.json` file"
96 | )
97 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
98 | new_config = dict(unet.config)
99 | new_config["sample_size"] = 64
100 | unet._internal_dict = FrozenDict(new_config)
101 |
102 | self.register_modules(
103 | vae=vae,
104 | text_encoder=text_encoder,
105 | tokenizer=tokenizer,
106 | unet=unet,
107 | scheduler=scheduler,
108 | )
109 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
110 |
111 | def enable_vae_slicing(self):
112 | self.vae.enable_slicing()
113 |
114 | def disable_vae_slicing(self):
115 | self.vae.disable_slicing()
116 |
117 | def enable_sequential_cpu_offload(self, gpu_id=0):
118 | if is_accelerate_available():
119 | from accelerate import cpu_offload
120 | else:
121 | raise ImportError("Please install accelerate via `pip install accelerate`")
122 |
123 | device = torch.device(f"cuda:{gpu_id}")
124 |
125 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
126 | if cpu_offloaded_model is not None:
127 | cpu_offload(cpu_offloaded_model, device)
128 |
129 |
130 | @property
131 | def _execution_device(self):
132 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
133 | return self.device
134 | for module in self.unet.modules():
135 | if (
136 | hasattr(module, "_hf_hook")
137 | and hasattr(module._hf_hook, "execution_device")
138 | and module._hf_hook.execution_device is not None
139 | ):
140 | return torch.device(module._hf_hook.execution_device)
141 | return self.device
142 |
143 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
144 | batch_size = len(prompt) if isinstance(prompt, list) else 1
145 |
146 | text_inputs = self.tokenizer(
147 | prompt,
148 | padding="max_length",
149 | max_length=self.tokenizer.model_max_length,
150 | truncation=True,
151 | return_tensors="pt",
152 | )
153 | text_input_ids = text_inputs.input_ids
154 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
155 |
156 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
157 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
158 | logger.warning(
159 | "The following part of your input was truncated because CLIP can only handle sequences up to"
160 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
161 | )
162 |
163 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
164 | attention_mask = text_inputs.attention_mask.to(device)
165 | else:
166 | attention_mask = None
167 |
168 | text_embeddings = self.text_encoder(
169 | text_input_ids.to(device),
170 | attention_mask=attention_mask,
171 | )
172 | text_embeddings = text_embeddings[0]
173 |
174 | # duplicate text embeddings for each generation per prompt, using mps friendly method
175 | bs_embed, seq_len, _ = text_embeddings.shape
176 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
177 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
178 |
179 | # get unconditional embeddings for classifier free guidance
180 | if do_classifier_free_guidance:
181 | uncond_tokens: List[str]
182 | if negative_prompt is None:
183 | uncond_tokens = [""] * batch_size
184 | elif type(prompt) is not type(negative_prompt):
185 | raise TypeError(
186 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
187 | f" {type(prompt)}."
188 | )
189 | elif isinstance(negative_prompt, str):
190 | uncond_tokens = [negative_prompt]
191 | elif batch_size != len(negative_prompt):
192 | raise ValueError(
193 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
194 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
195 | " the batch size of `prompt`."
196 | )
197 | else:
198 | uncond_tokens = negative_prompt
199 |
200 | max_length = text_input_ids.shape[-1]
201 | uncond_input = self.tokenizer(
202 | uncond_tokens,
203 | padding="max_length",
204 | max_length=max_length,
205 | truncation=True,
206 | return_tensors="pt",
207 | )
208 |
209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
210 | attention_mask = uncond_input.attention_mask.to(device)
211 | else:
212 | attention_mask = None
213 |
214 | uncond_embeddings = self.text_encoder(
215 | uncond_input.input_ids.to(device),
216 | attention_mask=attention_mask,
217 | )
218 | uncond_embeddings = uncond_embeddings[0]
219 |
220 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
221 | seq_len = uncond_embeddings.shape[1]
222 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
223 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
224 |
225 | # For classifier free guidance, we need to do two forward passes.
226 | # Here we concatenate the unconditional and text embeddings into a single batch
227 | # to avoid doing two forward passes
228 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
229 |
230 | return text_embeddings
231 |
232 | def decode_latents(self, latents):
233 | video_length = latents.shape[2]
234 | latents = 1 / 0.18215 * latents
235 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
236 | # video = self.vae.decode(latents).sample
237 | video = []
238 | for frame_idx in tqdm(range(latents.shape[0])):
239 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
240 | video = torch.cat(video)
241 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
242 | video = (video / 2 + 0.5).clamp(0, 1)
243 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
244 | video = video.cpu().float().numpy()
245 | return video
246 |
247 | def prepare_extra_step_kwargs(self, generator, eta):
248 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
249 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
250 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
251 | # and should be between [0, 1]
252 |
253 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
254 | extra_step_kwargs = {}
255 | if accepts_eta:
256 | extra_step_kwargs["eta"] = eta
257 |
258 | # check if the scheduler accepts generator
259 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
260 | if accepts_generator:
261 | extra_step_kwargs["generator"] = generator
262 | return extra_step_kwargs
263 |
264 | def check_inputs(self, prompt, height, width, callback_steps):
265 | if not isinstance(prompt, str) and not isinstance(prompt, list):
266 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
267 |
268 | if height % 8 != 0 or width % 8 != 0:
269 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
270 |
271 | if (callback_steps is None) or (
272 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
273 | ):
274 | raise ValueError(
275 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
276 | f" {type(callback_steps)}."
277 | )
278 |
279 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
280 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
281 | if isinstance(generator, list) and len(generator) != batch_size:
282 | raise ValueError(
283 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
284 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
285 | )
286 | if latents is None:
287 | rand_device = "cpu" if device.type == "mps" else device
288 |
289 | if isinstance(generator, list):
290 | shape = shape
291 | # shape = (1,) + shape[1:]
292 | latents = [
293 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
294 | for i in range(batch_size)
295 | ]
296 | latents = torch.cat(latents, dim=0).to(device)
297 | else:
298 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
299 | else:
300 | if latents.shape != shape:
301 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
302 | latents = latents.to(device)
303 |
304 | # scale the initial noise by the standard deviation required by the scheduler
305 | latents = latents * self.scheduler.init_noise_sigma
306 | return latents
307 |
308 | @torch.no_grad()
309 | def __call__(
310 | self,
311 | prompt: Union[str, List[str]],
312 | video_length: Optional[int],
313 | height: Optional[int] = None,
314 | width: Optional[int] = None,
315 | num_inference_steps: int = 50,
316 | guidance_scale: float = 7.5,
317 | negative_prompt: Optional[Union[str, List[str]]] = None,
318 | num_videos_per_prompt: Optional[int] = 1,
319 | eta: float = 0.0,
320 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
321 | latents: Optional[torch.FloatTensor] = None,
322 | output_type: Optional[str] = "tensor",
323 | return_dict: bool = True,
324 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
325 | callback_steps: Optional[int] = 1,
326 | **kwargs,
327 | ):
328 | # Default height and width to unet
329 | height = height or self.unet.config.sample_size * self.vae_scale_factor
330 | width = width or self.unet.config.sample_size * self.vae_scale_factor
331 |
332 | # Check inputs. Raise error if not correct
333 | self.check_inputs(prompt, height, width, callback_steps)
334 |
335 | # Define call parameters
336 | # batch_size = 1 if isinstance(prompt, str) else len(prompt)
337 | batch_size = 1
338 | if latents is not None:
339 | batch_size = latents.shape[0]
340 | if isinstance(prompt, list):
341 | batch_size = len(prompt)
342 |
343 | device = self._execution_device
344 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
345 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
346 | # corresponds to doing no classifier free guidance.
347 | do_classifier_free_guidance = guidance_scale > 1.0
348 |
349 | # Encode input prompt
350 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
351 | if negative_prompt is not None:
352 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
353 | text_embeddings = self._encode_prompt(
354 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
355 | )
356 |
357 | # Prepare timesteps
358 | self.scheduler.set_timesteps(num_inference_steps, device=device)
359 | timesteps = self.scheduler.timesteps
360 |
361 | # Prepare latent variables
362 | num_channels_latents = self.unet.in_channels
363 | latents = self.prepare_latents(
364 | batch_size * num_videos_per_prompt,
365 | num_channels_latents,
366 | video_length,
367 | height,
368 | width,
369 | text_embeddings.dtype,
370 | device,
371 | generator,
372 | latents,
373 | )
374 | latents_dtype = latents.dtype
375 |
376 | # Prepare extra step kwargs.
377 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
378 |
379 | # Denoising loop
380 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
381 | with self.progress_bar(total=num_inference_steps) as progress_bar:
382 | for i, t in enumerate(timesteps):
383 | # expand the latents if we are doing classifier free guidance
384 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
385 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
386 |
387 | down_block_additional_residuals = mid_block_additional_residual = None
388 |
389 | # predict the noise residual
390 | noise_pred = self.unet(
391 | latent_model_input, t,
392 | encoder_hidden_states=text_embeddings,
393 | down_block_additional_residuals = down_block_additional_residuals,
394 | mid_block_additional_residual = mid_block_additional_residual,
395 | ).sample.to(dtype=latents_dtype)
396 |
397 | # perform guidance
398 | if do_classifier_free_guidance:
399 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
400 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
401 |
402 | # compute the previous noisy sample x_t -> x_t-1
403 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
404 |
405 | # call the callback, if provided
406 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
407 | progress_bar.update()
408 | if callback is not None and i % callback_steps == 0:
409 | callback(i, t, latents)
410 |
411 | # Post-processing
412 | video = self.decode_latents(latents)
413 |
414 | # Convert to tensor
415 | if output_type == "tensor":
416 | video = torch.from_numpy(video)
417 |
418 | if not return_dict:
419 | return video
420 |
421 | return MagicTimePipelineOutput(videos=video)
422 |
--------------------------------------------------------------------------------
/utils/unet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/models/unet.py
2 | import os
3 | import json
4 | import pdb
5 | from dataclasses import dataclass
6 | from typing import List, Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.checkpoint
11 |
12 | from diffusers.configuration_utils import ConfigMixin, register_to_config
13 | from diffusers.modeling_utils import ModelMixin
14 | from diffusers.utils import BaseOutput, logging
15 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
16 | from .unet_blocks import (
17 | CrossAttnDownBlock3D,
18 | CrossAttnUpBlock3D,
19 | DownBlock3D,
20 | UNetMidBlock3DCrossAttn,
21 | UpBlock3D,
22 | get_down_block,
23 | get_up_block,
24 | InflatedConv3d,
25 | InflatedGroupNorm,
26 | )
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 |
31 | @dataclass
32 | class UNet3DConditionOutput(BaseOutput):
33 | sample: torch.FloatTensor
34 |
35 |
36 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
37 | _supports_gradient_checkpointing = True
38 |
39 | @register_to_config
40 | def __init__(
41 | self,
42 | sample_size: Optional[int] = None,
43 | in_channels: int = 4,
44 | out_channels: int = 4,
45 | center_input_sample: bool = False,
46 | flip_sin_to_cos: bool = True,
47 | freq_shift: int = 0,
48 | down_block_types: Tuple[str] = (
49 | "CrossAttnDownBlock3D",
50 | "CrossAttnDownBlock3D",
51 | "CrossAttnDownBlock3D",
52 | "DownBlock3D",
53 | ),
54 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
55 | up_block_types: Tuple[str] = (
56 | "UpBlock3D",
57 | "CrossAttnUpBlock3D",
58 | "CrossAttnUpBlock3D",
59 | "CrossAttnUpBlock3D"
60 | ),
61 | only_cross_attention: Union[bool, Tuple[bool]] = False,
62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
63 | layers_per_block: int = 2,
64 | downsample_padding: int = 1,
65 | mid_block_scale_factor: float = 1,
66 | act_fn: str = "silu",
67 | norm_num_groups: int = 32,
68 | norm_eps: float = 1e-5,
69 | cross_attention_dim: int = 1280,
70 | attention_head_dim: Union[int, Tuple[int]] = 8,
71 | dual_cross_attention: bool = False,
72 | use_linear_projection: bool = False,
73 | class_embed_type: Optional[str] = None,
74 | num_class_embeds: Optional[int] = None,
75 | upcast_attention: bool = False,
76 | resnet_time_scale_shift: str = "default",
77 |
78 | use_inflated_groupnorm=False,
79 |
80 | # Additional
81 | use_motion_module = False,
82 | motion_module_resolutions = ( 1,2,4,8 ),
83 | motion_module_mid_block = False,
84 | motion_module_decoder_only = False,
85 | motion_module_type = None,
86 | motion_module_kwargs = {},
87 | unet_use_cross_frame_attention = False,
88 | unet_use_temporal_attention = False,
89 | ):
90 | super().__init__()
91 |
92 | self.sample_size = sample_size
93 | time_embed_dim = block_out_channels[0] * 4
94 |
95 | # input
96 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
97 |
98 | # time
99 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
100 | timestep_input_dim = block_out_channels[0]
101 |
102 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103 |
104 | # class embedding
105 | if class_embed_type is None and num_class_embeds is not None:
106 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
107 | elif class_embed_type == "timestep":
108 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
109 | elif class_embed_type == "identity":
110 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
111 | else:
112 | self.class_embedding = None
113 |
114 | self.down_blocks = nn.ModuleList([])
115 | self.mid_block = None
116 | self.up_blocks = nn.ModuleList([])
117 |
118 | if isinstance(only_cross_attention, bool):
119 | only_cross_attention = [only_cross_attention] * len(down_block_types)
120 |
121 | if isinstance(attention_head_dim, int):
122 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
123 |
124 | # down
125 | output_channel = block_out_channels[0]
126 | for i, down_block_type in enumerate(down_block_types):
127 | res = 2 ** i
128 | input_channel = output_channel
129 | output_channel = block_out_channels[i]
130 | is_final_block = i == len(block_out_channels) - 1
131 |
132 | down_block = get_down_block(
133 | down_block_type,
134 | num_layers=layers_per_block,
135 | in_channels=input_channel,
136 | out_channels=output_channel,
137 | temb_channels=time_embed_dim,
138 | add_downsample=not is_final_block,
139 | resnet_eps=norm_eps,
140 | resnet_act_fn=act_fn,
141 | resnet_groups=norm_num_groups,
142 | cross_attention_dim=cross_attention_dim,
143 | attn_num_head_channels=attention_head_dim[i],
144 | downsample_padding=downsample_padding,
145 | dual_cross_attention=dual_cross_attention,
146 | use_linear_projection=use_linear_projection,
147 | only_cross_attention=only_cross_attention[i],
148 | upcast_attention=upcast_attention,
149 | resnet_time_scale_shift=resnet_time_scale_shift,
150 |
151 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
152 | unet_use_temporal_attention=unet_use_temporal_attention,
153 | use_inflated_groupnorm=use_inflated_groupnorm,
154 |
155 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
156 | motion_module_type=motion_module_type,
157 | motion_module_kwargs=motion_module_kwargs,
158 | )
159 | self.down_blocks.append(down_block)
160 |
161 | # mid
162 | if mid_block_type == "UNetMidBlock3DCrossAttn":
163 | self.mid_block = UNetMidBlock3DCrossAttn(
164 | in_channels=block_out_channels[-1],
165 | temb_channels=time_embed_dim,
166 | resnet_eps=norm_eps,
167 | resnet_act_fn=act_fn,
168 | output_scale_factor=mid_block_scale_factor,
169 | resnet_time_scale_shift=resnet_time_scale_shift,
170 | cross_attention_dim=cross_attention_dim,
171 | attn_num_head_channels=attention_head_dim[-1],
172 | resnet_groups=norm_num_groups,
173 | dual_cross_attention=dual_cross_attention,
174 | use_linear_projection=use_linear_projection,
175 | upcast_attention=upcast_attention,
176 |
177 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
178 | unet_use_temporal_attention=unet_use_temporal_attention,
179 | use_inflated_groupnorm=use_inflated_groupnorm,
180 |
181 | use_motion_module=use_motion_module and motion_module_mid_block,
182 | motion_module_type=motion_module_type,
183 | motion_module_kwargs=motion_module_kwargs,
184 | )
185 | else:
186 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
187 |
188 | # count how many layers upsample the videos
189 | self.num_upsamplers = 0
190 |
191 | # up
192 | reversed_block_out_channels = list(reversed(block_out_channels))
193 | reversed_attention_head_dim = list(reversed(attention_head_dim))
194 | only_cross_attention = list(reversed(only_cross_attention))
195 | output_channel = reversed_block_out_channels[0]
196 | for i, up_block_type in enumerate(up_block_types):
197 | res = 2 ** (3 - i)
198 | is_final_block = i == len(block_out_channels) - 1
199 |
200 | prev_output_channel = output_channel
201 | output_channel = reversed_block_out_channels[i]
202 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
203 |
204 | # add upsample block for all BUT final layer
205 | if not is_final_block:
206 | add_upsample = True
207 | self.num_upsamplers += 1
208 | else:
209 | add_upsample = False
210 |
211 | up_block = get_up_block(
212 | up_block_type,
213 | num_layers=layers_per_block + 1,
214 | in_channels=input_channel,
215 | out_channels=output_channel,
216 | prev_output_channel=prev_output_channel,
217 | temb_channels=time_embed_dim,
218 | add_upsample=add_upsample,
219 | resnet_eps=norm_eps,
220 | resnet_act_fn=act_fn,
221 | resnet_groups=norm_num_groups,
222 | cross_attention_dim=cross_attention_dim,
223 | attn_num_head_channels=reversed_attention_head_dim[i],
224 | dual_cross_attention=dual_cross_attention,
225 | use_linear_projection=use_linear_projection,
226 | only_cross_attention=only_cross_attention[i],
227 | upcast_attention=upcast_attention,
228 | resnet_time_scale_shift=resnet_time_scale_shift,
229 |
230 | unet_use_cross_frame_attention=unet_use_cross_frame_attention,
231 | unet_use_temporal_attention=unet_use_temporal_attention,
232 | use_inflated_groupnorm=use_inflated_groupnorm,
233 |
234 | use_motion_module=use_motion_module and (res in motion_module_resolutions),
235 | motion_module_type=motion_module_type,
236 | motion_module_kwargs=motion_module_kwargs,
237 | )
238 | self.up_blocks.append(up_block)
239 | prev_output_channel = output_channel
240 |
241 | # out
242 | if use_inflated_groupnorm:
243 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
244 | else:
245 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
246 | self.conv_act = nn.SiLU()
247 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
248 |
249 | def set_attention_slice(self, slice_size):
250 | r"""
251 | Enable sliced attention computation.
252 |
253 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
254 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
255 |
256 | Args:
257 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
258 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
259 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
260 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
261 | must be a multiple of `slice_size`.
262 | """
263 | sliceable_head_dims = []
264 |
265 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
266 | if hasattr(module, "set_attention_slice"):
267 | sliceable_head_dims.append(module.sliceable_head_dim)
268 |
269 | for child in module.children():
270 | fn_recursive_retrieve_slicable_dims(child)
271 |
272 | # retrieve number of attention layers
273 | for module in self.children():
274 | fn_recursive_retrieve_slicable_dims(module)
275 |
276 | num_slicable_layers = len(sliceable_head_dims)
277 |
278 | if slice_size == "auto":
279 | # half the attention head size is usually a good trade-off between
280 | # speed and memory
281 | slice_size = [dim // 2 for dim in sliceable_head_dims]
282 | elif slice_size == "max":
283 | # make smallest slice possible
284 | slice_size = num_slicable_layers * [1]
285 |
286 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
287 |
288 | if len(slice_size) != len(sliceable_head_dims):
289 | raise ValueError(
290 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
291 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
292 | )
293 |
294 | for i in range(len(slice_size)):
295 | size = slice_size[i]
296 | dim = sliceable_head_dims[i]
297 | if size is not None and size > dim:
298 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
299 |
300 | # Recursively walk through all the children.
301 | # Any children which exposes the set_attention_slice method
302 | # gets the message
303 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
304 | if hasattr(module, "set_attention_slice"):
305 | module.set_attention_slice(slice_size.pop())
306 |
307 | for child in module.children():
308 | fn_recursive_set_attention_slice(child, slice_size)
309 |
310 | reversed_slice_size = list(reversed(slice_size))
311 | for module in self.children():
312 | fn_recursive_set_attention_slice(module, reversed_slice_size)
313 |
314 | def _set_gradient_checkpointing(self, module, value=False):
315 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
316 | module.gradient_checkpointing = value
317 |
318 | def forward(
319 | self,
320 | sample: torch.FloatTensor,
321 | timestep: Union[torch.Tensor, float, int],
322 | encoder_hidden_states: torch.Tensor,
323 | class_labels: Optional[torch.Tensor] = None,
324 | attention_mask: Optional[torch.Tensor] = None,
325 |
326 | # support controlnet
327 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
328 | mid_block_additional_residual: Optional[torch.Tensor] = None,
329 |
330 | return_dict: bool = True,
331 | ) -> Union[UNet3DConditionOutput, Tuple]:
332 | r"""
333 | Args:
334 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
335 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
336 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
337 | return_dict (`bool`, *optional*, defaults to `True`):
338 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
339 |
340 | Returns:
341 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
342 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
343 | returning a tuple, the first element is the sample tensor.
344 | """
345 | # By default samples have to be AT least a multiple of the overall upsampling factor.
346 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
347 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
348 | # on the fly if necessary.
349 | default_overall_up_factor = 2**self.num_upsamplers
350 |
351 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
352 | forward_upsample_size = False
353 | upsample_size = None
354 |
355 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
356 | logger.info("Forward upsample size to force interpolation output size.")
357 | forward_upsample_size = True
358 |
359 | # prepare attention_mask
360 | if attention_mask is not None:
361 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
362 | attention_mask = attention_mask.unsqueeze(1)
363 |
364 | # center input if necessary
365 | if self.config.center_input_sample:
366 | sample = 2 * sample - 1.0
367 |
368 | # time
369 | timesteps = timestep
370 | if not torch.is_tensor(timesteps):
371 | # This would be a good case for the `match` statement (Python 3.10+)
372 | is_mps = sample.device.type == "mps"
373 | if isinstance(timestep, float):
374 | dtype = torch.float32 if is_mps else torch.float64
375 | else:
376 | dtype = torch.int32 if is_mps else torch.int64
377 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
378 | elif len(timesteps.shape) == 0:
379 | timesteps = timesteps[None].to(sample.device)
380 |
381 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
382 | timesteps = timesteps.expand(sample.shape[0])
383 |
384 | t_emb = self.time_proj(timesteps)
385 |
386 | # timesteps does not contain any weights and will always return f32 tensors
387 | # but time_embedding might actually be running in fp16. so we need to cast here.
388 | # there might be better ways to encapsulate this.
389 | t_emb = t_emb.to(dtype=self.dtype)
390 | emb = self.time_embedding(t_emb)
391 |
392 | if self.class_embedding is not None:
393 | if class_labels is None:
394 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
395 |
396 | if self.config.class_embed_type == "timestep":
397 | class_labels = self.time_proj(class_labels)
398 |
399 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
400 | emb = emb + class_emb
401 |
402 | # pre-process
403 | sample = self.conv_in(sample)
404 |
405 | # down
406 | down_block_res_samples = (sample,)
407 | for downsample_block in self.down_blocks:
408 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
409 | sample, res_samples = downsample_block(
410 | hidden_states=sample,
411 | temb=emb,
412 | encoder_hidden_states=encoder_hidden_states,
413 | attention_mask=attention_mask,
414 | )
415 | else:
416 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
417 |
418 | down_block_res_samples += res_samples
419 |
420 | # support controlnet
421 | down_block_res_samples = list(down_block_res_samples)
422 | if down_block_additional_residuals is not None:
423 | for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
424 | if down_block_additional_residual.dim() == 4: # boardcast
425 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
426 | down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
427 |
428 | # mid
429 | sample = self.mid_block(
430 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
431 | )
432 |
433 | # support controlnet
434 | if mid_block_additional_residual is not None:
435 | if mid_block_additional_residual.dim() == 4: # boardcast
436 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
437 | sample = sample + mid_block_additional_residual
438 |
439 | # up
440 | for i, upsample_block in enumerate(self.up_blocks):
441 | is_final_block = i == len(self.up_blocks) - 1
442 |
443 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
444 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
445 |
446 | # if we have not reached the final block and need to forward the
447 | # upsample size, we do it here
448 | if not is_final_block and forward_upsample_size:
449 | upsample_size = down_block_res_samples[-1].shape[2:]
450 |
451 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
452 | sample = upsample_block(
453 | hidden_states=sample,
454 | temb=emb,
455 | res_hidden_states_tuple=res_samples,
456 | encoder_hidden_states=encoder_hidden_states,
457 | upsample_size=upsample_size,
458 | attention_mask=attention_mask,
459 | )
460 | else:
461 | sample = upsample_block(
462 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
463 | )
464 |
465 | # post-process
466 | sample = self.conv_norm_out(sample)
467 | sample = self.conv_act(sample)
468 | sample = self.conv_out(sample)
469 |
470 | if not return_dict:
471 | return (sample,)
472 |
473 | return UNet3DConditionOutput(sample=sample)
474 |
475 | @classmethod
476 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
477 | if subfolder is not None:
478 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
479 | print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
480 |
481 | config_file = os.path.join(pretrained_model_path, 'config.json')
482 | if not os.path.isfile(config_file):
483 | raise RuntimeError(f"{config_file} does not exist")
484 | with open(config_file, "r") as f:
485 | config = json.load(f)
486 | config["_class_name"] = cls.__name__
487 | config["down_block_types"] = [
488 | "CrossAttnDownBlock3D",
489 | "CrossAttnDownBlock3D",
490 | "CrossAttnDownBlock3D",
491 | "DownBlock3D"
492 | ]
493 | config["up_block_types"] = [
494 | "UpBlock3D",
495 | "CrossAttnUpBlock3D",
496 | "CrossAttnUpBlock3D",
497 | "CrossAttnUpBlock3D"
498 | ]
499 |
500 | from diffusers.utils import WEIGHTS_NAME
501 | model = cls.from_config(config, **unet_additional_kwargs)
502 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
503 | if not os.path.isfile(model_file):
504 | raise RuntimeError(f"{model_file} does not exist")
505 | state_dict = torch.load(model_file, map_location="cpu")
506 |
507 | m, u = model.load_state_dict(state_dict, strict=False)
508 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
509 |
510 | params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
511 | print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
512 |
513 | return model
514 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from tqdm import tqdm
5 | from typing import Union
6 | from einops import rearrange
7 | from safetensors import safe_open
8 | from transformers import CLIPTextModel
9 | import torch
10 | import torchvision
11 | import torch.distributed as dist
12 |
13 | def zero_rank_print(s):
14 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
15 |
16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
17 | videos = rearrange(videos, "b c t h w -> t b c h w")
18 | outputs = []
19 | for x in videos:
20 | x = torchvision.utils.make_grid(x, nrow=n_rows)
21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
22 | if rescale:
23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
24 | x = (x * 255).numpy().astype(np.uint8)
25 | outputs.append(x)
26 |
27 | os.makedirs(os.path.dirname(path), exist_ok=True)
28 | imageio.mimsave(path, outputs, fps=fps)
29 |
30 | # DDIM Inversion
31 | @torch.no_grad()
32 | def init_prompt(prompt, pipeline):
33 | uncond_input = pipeline.tokenizer(
34 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
35 | return_tensors="pt"
36 | )
37 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
38 | text_input = pipeline.tokenizer(
39 | [prompt],
40 | padding="max_length",
41 | max_length=pipeline.tokenizer.model_max_length,
42 | truncation=True,
43 | return_tensors="pt",
44 | )
45 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
46 | context = torch.cat([uncond_embeddings, text_embeddings])
47 |
48 | return context
49 |
50 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
51 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
52 | timestep, next_timestep = min(
53 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
54 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
55 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
56 | beta_prod_t = 1 - alpha_prod_t
57 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
58 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
59 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
60 | return next_sample
61 |
62 | def get_noise_pred_single(latents, t, context, unet):
63 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
64 | return noise_pred
65 |
66 | @torch.no_grad()
67 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
68 | context = init_prompt(prompt, pipeline)
69 | uncond_embeddings, cond_embeddings = context.chunk(2)
70 | all_latent = [latent]
71 | latent = latent.clone().detach()
72 | for i in tqdm(range(num_inv_steps)):
73 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
74 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
75 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
76 | all_latent.append(latent)
77 | return all_latent
78 |
79 | @torch.no_grad()
80 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
81 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
82 | return ddim_latents
83 |
84 | def load_weights(
85 | magictime_pipeline,
86 | motion_module_path = "",
87 | dreambooth_model_path = "",
88 | magic_adapter_s_path = "",
89 | magic_adapter_t_path = "",
90 | magic_text_encoder_path = "",
91 | ):
92 | # motion module
93 | unet_state_dict = {}
94 | if motion_module_path != "":
95 | print(f"load motion module from {motion_module_path}")
96 | try:
97 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
98 | if "state_dict" in motion_module_state_dict:
99 | motion_module_state_dict = motion_module_state_dict["state_dict"]
100 | for name, param in motion_module_state_dict.items():
101 | if "motion_modules." in name:
102 | modified_name = name.removeprefix('module.') if name.startswith('module.') else name
103 | unet_state_dict[modified_name] = param
104 | except Exception as e:
105 | print(f"Error loading motion module: {e}")
106 | try:
107 | missing, unexpected = magictime_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
108 | assert len(unexpected) == 0, f"Unexpected keys in state_dict: {unexpected}"
109 | del unet_state_dict
110 | except Exception as e:
111 | print(f"Error loading state dict into UNet: {e}")
112 |
113 | # base model
114 | if dreambooth_model_path != "":
115 | print(f"load dreambooth model from {dreambooth_model_path}")
116 | if dreambooth_model_path.endswith(".safetensors"):
117 | dreambooth_state_dict = {}
118 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
119 | for key in f.keys():
120 | dreambooth_state_dict[key] = f.get_tensor(key)
121 | elif dreambooth_model_path.endswith(".ckpt"):
122 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
123 |
124 | # 1. vae
125 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, magictime_pipeline.vae.config)
126 | magictime_pipeline.vae.load_state_dict(converted_vae_checkpoint)
127 | # 2. unet
128 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, magictime_pipeline.unet.config)
129 | magictime_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
130 | # 3. text_model
131 | magictime_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
132 | del dreambooth_state_dict
133 |
134 | # MagicAdapter and MagicTextEncoder
135 | if magic_adapter_s_path != "":
136 | print(f"load domain lora from {magic_adapter_s_path}")
137 | magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
138 | magictime_pipeline = load_diffusers_lora(magictime_pipeline, magic_adapter_s_state_dict, alpha=1.0)
139 |
140 | if magic_adapter_t_path != "" or magic_text_encoder_path != "":
141 | from swift import Swift
142 |
143 | if magic_adapter_t_path != "":
144 | print("load lora from swift for Unet")
145 | Swift.from_pretrained(magictime_pipeline.unet, magic_adapter_t_path)
146 |
147 | if magic_text_encoder_path != "":
148 | print("load lora from swift for text encoder")
149 | Swift.from_pretrained(magictime_pipeline.text_encoder, magic_text_encoder_path)
150 |
151 | return magictime_pipeline
152 |
153 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0):
154 | # directly update weight in diffusers model
155 | for key in state_dict:
156 | # only process lora down key
157 | if "up." in key: continue
158 |
159 | up_key = key.replace(".down.", ".up.")
160 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
161 | model_key = model_key.replace("to_out.", "to_out.0.")
162 | layer_infos = model_key.split(".")[:-1]
163 |
164 | curr_layer = pipeline.unet
165 | while len(layer_infos) > 0:
166 | temp_name = layer_infos.pop(0)
167 | curr_layer = curr_layer.__getattr__(temp_name)
168 |
169 | weight_down = state_dict[key] * 2
170 | weight_up = state_dict[up_key] * 2
171 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
172 |
173 | return pipeline
174 |
175 | def load_diffusers_lora_unet(unet, state_dict, alpha=1.0):
176 | # directly update weight in diffusers model
177 | for key in state_dict:
178 | # only process lora down key
179 | if "up." in key: continue
180 |
181 | up_key = key.replace(".down.", ".up.")
182 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
183 | model_key = model_key.replace("to_out.", "to_out.0.")
184 | layer_infos = model_key.split(".")[:-1]
185 |
186 | curr_layer = unet
187 | while len(layer_infos) > 0:
188 | temp_name = layer_infos.pop(0)
189 | curr_layer = curr_layer.__getattr__(temp_name)
190 |
191 | weight_down = state_dict[key] * 2
192 | weight_up = state_dict[up_key] * 2
193 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
194 |
195 | return unet
196 |
197 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
198 | visited = []
199 |
200 | # directly update weight in diffusers model
201 | for key in state_dict:
202 | # it is suggested to print out the key, it usually will be something like below
203 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
204 |
205 | # as we have set the alpha beforehand, so just skip
206 | if ".alpha" in key or key in visited:
207 | continue
208 |
209 | if "text" in key:
210 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
211 | curr_layer = pipeline.text_encoder
212 | else:
213 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
214 | curr_layer = pipeline.unet
215 |
216 | # find the target layer
217 | temp_name = layer_infos.pop(0)
218 | while len(layer_infos) > -1:
219 | try:
220 | curr_layer = curr_layer.__getattr__(temp_name)
221 | if len(layer_infos) > 0:
222 | temp_name = layer_infos.pop(0)
223 | elif len(layer_infos) == 0:
224 | break
225 | except Exception:
226 | if len(temp_name) > 0:
227 | temp_name += "_" + layer_infos.pop(0)
228 | else:
229 | temp_name = layer_infos.pop(0)
230 |
231 | pair_keys = []
232 | if "lora_down" in key:
233 | pair_keys.append(key.replace("lora_down", "lora_up"))
234 | pair_keys.append(key)
235 | else:
236 | pair_keys.append(key)
237 | pair_keys.append(key.replace("lora_up", "lora_down"))
238 |
239 | # update weight
240 | if len(state_dict[pair_keys[0]].shape) == 4:
241 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
242 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
243 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
244 | else:
245 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
246 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
247 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
248 |
249 | # update visited list
250 | for item in pair_keys:
251 | visited.append(item)
252 |
253 | return pipeline
254 |
255 | def shave_segments(path, n_shave_prefix_segments=1):
256 | """
257 | Removes segments. Positive values shave the first segments, negative shave the last segments.
258 | """
259 | if n_shave_prefix_segments >= 0:
260 | return ".".join(path.split(".")[n_shave_prefix_segments:])
261 | else:
262 | return ".".join(path.split(".")[:n_shave_prefix_segments])
263 |
264 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
265 | """
266 | Updates paths inside resnets to the new naming scheme (local renaming)
267 | """
268 | mapping = []
269 | for old_item in old_list:
270 | new_item = old_item.replace("in_layers.0", "norm1")
271 | new_item = new_item.replace("in_layers.2", "conv1")
272 |
273 | new_item = new_item.replace("out_layers.0", "norm2")
274 | new_item = new_item.replace("out_layers.3", "conv2")
275 |
276 | new_item = new_item.replace("emb_layers.1", "time_emb_proj")
277 | new_item = new_item.replace("skip_connection", "conv_shortcut")
278 |
279 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
280 |
281 | mapping.append({"old": old_item, "new": new_item})
282 |
283 | return mapping
284 |
285 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
286 | """
287 | Updates paths inside resnets to the new naming scheme (local renaming)
288 | """
289 | mapping = []
290 | for old_item in old_list:
291 | new_item = old_item
292 |
293 | new_item = new_item.replace("nin_shortcut", "conv_shortcut")
294 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
295 |
296 | mapping.append({"old": old_item, "new": new_item})
297 |
298 | return mapping
299 |
300 | def renew_attention_paths(old_list, n_shave_prefix_segments=0):
301 | """
302 | Updates paths inside attentions to the new naming scheme (local renaming)
303 | """
304 | mapping = []
305 | for old_item in old_list:
306 | new_item = old_item
307 | mapping.append({"old": old_item, "new": new_item})
308 | return mapping
309 |
310 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
311 | """
312 | Updates paths inside attentions to the new naming scheme (local renaming)
313 | """
314 | mapping = []
315 | for old_item in old_list:
316 | new_item = old_item
317 |
318 | new_item = new_item.replace("norm.weight", "group_norm.weight")
319 | new_item = new_item.replace("norm.bias", "group_norm.bias")
320 |
321 | new_item = new_item.replace("q.weight", "query.weight")
322 | new_item = new_item.replace("q.bias", "query.bias")
323 |
324 | new_item = new_item.replace("k.weight", "key.weight")
325 | new_item = new_item.replace("k.bias", "key.bias")
326 |
327 | new_item = new_item.replace("v.weight", "value.weight")
328 | new_item = new_item.replace("v.bias", "value.bias")
329 |
330 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
331 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
332 |
333 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
334 |
335 | mapping.append({"old": old_item, "new": new_item})
336 |
337 | return mapping
338 |
339 | def assign_to_checkpoint(
340 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
341 | ):
342 | """
343 | This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
344 | attention layers, and takes into account additional replacements that may arise.
345 |
346 | Assigns the weights to the new checkpoint.
347 | """
348 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
349 |
350 | # Splits the attention layers into three variables.
351 | if attention_paths_to_split is not None:
352 | for path, path_map in attention_paths_to_split.items():
353 | old_tensor = old_checkpoint[path]
354 | channels = old_tensor.shape[0] // 3
355 |
356 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
357 |
358 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
359 |
360 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
361 | query, key, value = old_tensor.split(channels // num_heads, dim=1)
362 |
363 | checkpoint[path_map["query"]] = query.reshape(target_shape)
364 | checkpoint[path_map["key"]] = key.reshape(target_shape)
365 | checkpoint[path_map["value"]] = value.reshape(target_shape)
366 |
367 | for path in paths:
368 | new_path = path["new"]
369 |
370 | # These have already been assigned
371 | if attention_paths_to_split is not None and new_path in attention_paths_to_split:
372 | continue
373 |
374 | # Global renaming happens here
375 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
376 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
377 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
378 |
379 | if additional_replacements is not None:
380 | for replacement in additional_replacements:
381 | new_path = new_path.replace(replacement["old"], replacement["new"])
382 |
383 | # proj_attn.weight has to be converted from conv 1D to linear
384 | if "proj_attn.weight" in new_path:
385 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
386 | else:
387 | checkpoint[new_path] = old_checkpoint[path["old"]]
388 |
389 | def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
390 | """
391 | Takes a state dict and a config, and returns a converted checkpoint.
392 | """
393 |
394 | # extract state_dict for UNet
395 | unet_state_dict = {}
396 | keys = list(checkpoint.keys())
397 |
398 | unet_key = "model.diffusion_model."
399 |
400 | # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
401 | if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
402 | print(f"Checkpoint {path} has both EMA and non-EMA weights.")
403 | print(
404 | "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
405 | " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
406 | )
407 | for key in keys:
408 | if key.startswith("model.diffusion_model"):
409 | flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
410 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
411 | else:
412 | if sum(k.startswith("model_ema") for k in keys) > 100:
413 | print(
414 | "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
415 | " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
416 | )
417 |
418 | for key in keys:
419 | if key.startswith(unet_key):
420 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
421 |
422 | new_checkpoint = {}
423 |
424 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
425 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
426 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
427 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
428 |
429 | if config["class_embed_type"] is None:
430 | # No parameters to port
431 | ...
432 | elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
433 | new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
434 | new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
435 | new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
436 | new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
437 | else:
438 | raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
439 |
440 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
441 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
442 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
443 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
444 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
445 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
446 |
447 | # Retrieves the keys for the input blocks only
448 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
449 | input_blocks = {
450 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
451 | for layer_id in range(num_input_blocks)
452 | }
453 |
454 | # Retrieves the keys for the middle blocks only
455 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
456 | middle_blocks = {
457 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
458 | for layer_id in range(num_middle_blocks)
459 | }
460 |
461 | # Retrieves the keys for the output blocks only
462 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
463 | output_blocks = {
464 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
465 | for layer_id in range(num_output_blocks)
466 | }
467 |
468 | for i in range(1, num_input_blocks):
469 | block_id = (i - 1) // (config["layers_per_block"] + 1)
470 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
471 |
472 | resnets = [
473 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
474 | ]
475 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
476 |
477 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
478 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
479 | f"input_blocks.{i}.0.op.weight"
480 | )
481 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
482 | f"input_blocks.{i}.0.op.bias"
483 | )
484 |
485 | paths = renew_resnet_paths(resnets)
486 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
487 | assign_to_checkpoint(
488 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
489 | )
490 |
491 | if len(attentions):
492 | paths = renew_attention_paths(attentions)
493 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
494 | assign_to_checkpoint(
495 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
496 | )
497 |
498 | resnet_0 = middle_blocks[0]
499 | attentions = middle_blocks[1]
500 | resnet_1 = middle_blocks[2]
501 |
502 | resnet_0_paths = renew_resnet_paths(resnet_0)
503 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
504 |
505 | resnet_1_paths = renew_resnet_paths(resnet_1)
506 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
507 |
508 | attentions_paths = renew_attention_paths(attentions)
509 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
510 | assign_to_checkpoint(
511 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
512 | )
513 |
514 | for i in range(num_output_blocks):
515 | block_id = i // (config["layers_per_block"] + 1)
516 | layer_in_block_id = i % (config["layers_per_block"] + 1)
517 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
518 | output_block_list = {}
519 |
520 | for layer in output_block_layers:
521 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
522 | if layer_id in output_block_list:
523 | output_block_list[layer_id].append(layer_name)
524 | else:
525 | output_block_list[layer_id] = [layer_name]
526 |
527 | if len(output_block_list) > 1:
528 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
529 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
530 |
531 | resnet_0_paths = renew_resnet_paths(resnets)
532 | paths = renew_resnet_paths(resnets)
533 |
534 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
535 | assign_to_checkpoint(
536 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
537 | )
538 |
539 | output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
540 | if ["conv.bias", "conv.weight"] in output_block_list.values():
541 | index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
542 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
543 | f"output_blocks.{i}.{index}.conv.weight"
544 | ]
545 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
546 | f"output_blocks.{i}.{index}.conv.bias"
547 | ]
548 |
549 | # Clear attentions as they have been attributed above.
550 | if len(attentions) == 2:
551 | attentions = []
552 |
553 | if len(attentions):
554 | paths = renew_attention_paths(attentions)
555 | meta_path = {
556 | "old": f"output_blocks.{i}.1",
557 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
558 | }
559 | assign_to_checkpoint(
560 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
561 | )
562 | else:
563 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
564 | for path in resnet_0_paths:
565 | old_path = ".".join(["output_blocks", str(i), path["old"]])
566 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
567 |
568 | new_checkpoint[new_path] = unet_state_dict[old_path]
569 |
570 | return new_checkpoint
571 |
572 | def convert_ldm_clip_checkpoint(checkpoint):
573 | from transformers import CLIPTextModel
574 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
575 |
576 | keys = list(checkpoint.keys())
577 | keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids")
578 |
579 | text_model_dict = {}
580 |
581 | for key in keys:
582 | if key.startswith("cond_stage_model.transformer"):
583 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
584 | text_model.load_state_dict(text_model_dict)
585 |
586 | return text_model
587 |
588 | def convert_ldm_clip_text_model(text_model, checkpoint):
589 | keys = list(checkpoint.keys())
590 | keys.remove("cond_stage_model.transformer.text_model.embeddings.position_ids")
591 |
592 | text_model_dict = {}
593 |
594 | for key in keys:
595 | if key.startswith("cond_stage_model.transformer"):
596 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
597 | text_model.load_state_dict(text_model_dict)
598 |
599 | return text_model
600 |
601 | def conv_attn_to_linear(checkpoint):
602 | keys = list(checkpoint.keys())
603 | attn_keys = ["query.weight", "key.weight", "value.weight"]
604 | for key in keys:
605 | if ".".join(key.split(".")[-2:]) in attn_keys:
606 | if checkpoint[key].ndim > 2:
607 | checkpoint[key] = checkpoint[key][:, :, 0, 0]
608 | elif "proj_attn.weight" in key:
609 | if checkpoint[key].ndim > 2:
610 | checkpoint[key] = checkpoint[key][:, :, 0]
611 |
612 | def convert_ldm_vae_checkpoint(checkpoint, config):
613 | # extract state dict for VAE
614 | vae_state_dict = {}
615 | vae_key = "first_stage_model."
616 | keys = list(checkpoint.keys())
617 | for key in keys:
618 | if key.startswith(vae_key):
619 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
620 |
621 | new_checkpoint = {}
622 |
623 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
624 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
625 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
626 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
627 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
628 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
629 |
630 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
631 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
632 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
633 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
634 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
635 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
636 |
637 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
638 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
639 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
640 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
641 |
642 | # Retrieves the keys for the encoder down blocks only
643 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
644 | down_blocks = {
645 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
646 | }
647 |
648 | # Retrieves the keys for the decoder up blocks only
649 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
650 | up_blocks = {
651 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
652 | }
653 |
654 | for i in range(num_down_blocks):
655 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
656 |
657 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
658 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
659 | f"encoder.down.{i}.downsample.conv.weight"
660 | )
661 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
662 | f"encoder.down.{i}.downsample.conv.bias"
663 | )
664 |
665 | paths = renew_vae_resnet_paths(resnets)
666 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
667 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
668 |
669 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
670 | num_mid_res_blocks = 2
671 | for i in range(1, num_mid_res_blocks + 1):
672 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
673 |
674 | paths = renew_vae_resnet_paths(resnets)
675 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
676 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
677 |
678 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
679 | paths = renew_vae_attention_paths(mid_attentions)
680 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
681 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
682 | conv_attn_to_linear(new_checkpoint)
683 |
684 | for i in range(num_up_blocks):
685 | block_id = num_up_blocks - 1 - i
686 | resnets = [
687 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
688 | ]
689 |
690 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
691 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
692 | f"decoder.up.{block_id}.upsample.conv.weight"
693 | ]
694 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
695 | f"decoder.up.{block_id}.upsample.conv.bias"
696 | ]
697 |
698 | paths = renew_vae_resnet_paths(resnets)
699 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
700 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
701 |
702 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
703 | num_mid_res_blocks = 2
704 | for i in range(1, num_mid_res_blocks + 1):
705 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
706 |
707 | paths = renew_vae_resnet_paths(resnets)
708 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
709 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
710 |
711 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
712 | paths = renew_vae_attention_paths(mid_attentions)
713 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
714 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
715 | conv_attn_to_linear(new_checkpoint)
716 |
717 | return new_checkpoint
--------------------------------------------------------------------------------