├── .gitignore
├── LICENSE
├── README.md
├── assets
├── controlvideo
│ ├── building1-a wooden building, at night-80_merge.gif
│ ├── building1-a wooden building, at night-80_merge_0.gif
│ ├── building1-a wooden building, at night-80_merge_1.gif
│ ├── building1-a wooden building, at night-80_merge_2.gif
│ ├── building1-a wooden building, at night-80_merge_3.gif
│ ├── dance26-Michael Jackson is dancing-1500_merge.gif
│ ├── dance26-Michael Jackson is dancing-1500_merge_0.gif
│ ├── dance26-Michael Jackson is dancing-1500_merge_1.gif
│ ├── dance26-Michael Jackson is dancing-1500_merge_2.gif
│ ├── dance26-Michael Jackson is dancing-1500_merge_3.gif
│ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif
│ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif
│ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif
│ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif
│ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif
│ ├── girl8-a girl, Krenz Cushart style-300_merge.gif
│ ├── girl8-a girl, Krenz Cushart style-300_merge_0.gif
│ ├── girl8-a girl, Krenz Cushart style-300_merge_1.gif
│ ├── girl8-a girl, Krenz Cushart style-300_merge_2.gif
│ ├── girl8-a girl, Krenz Cushart style-300_merge_3.gif
│ ├── girlface9_6-a girl with rich makeup-300_merge.gif
│ ├── girlface9_6-a girl with rich makeup-300_merge_0.gif
│ ├── girlface9_6-a girl with rich makeup-300_merge_1.gif
│ ├── girlface9_6-a girl with rich makeup-300_merge_2.gif
│ ├── girlface9_6-a girl with rich makeup-300_merge_3.gif
│ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif
│ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif
│ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif
│ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif
│ └── ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif
├── make-a-protagonist
│ ├── 0-A man is playing a basketball on the beach, anime style-org.gif
│ ├── 0-A man is playing a basketball on the beach, anime style.gif
│ ├── 0-a jeep driving down a mountain road in the rain-org.gif
│ ├── 0-a jeep driving down a mountain road in the rain.gif
│ ├── 0-a panda walking down the snowy street-org.gif
│ ├── 0-a panda walking down the snowy street.gif
│ ├── 0-elon musk walking down the street-org.gif
│ ├── 0-elon musk walking down the street.gif
│ ├── car-turn.gif
│ ├── huaqiang.gif
│ ├── ikun.gif
│ └── yanzi.gif
├── tune-a-video
│ ├── car-turn.gif
│ ├── car-turn_a jeep car is moving on the beach.gif
│ ├── car-turn_a jeep car is moving on the road, cartoon style.gif
│ ├── car-turn_a jeep car is moving on the snow.gif
│ ├── car-turn_a sports car is moving on the road.gif
│ ├── car-turn_smooth_a jeep car is moving on the beach.gif
│ ├── car-turn_smooth_a jeep car is moving on the road, cartoon style.gif
│ ├── car-turn_smooth_a jeep car is moving on the snow.gif
│ ├── car-turn_smooth_a sports car is moving on the road.gif
│ ├── man-skiing.gif
│ ├── man-skiing_a man, wearing pink clothes, is skiing at sunset.gif
│ ├── man-skiing_mickey mouse is skiing on the snow.gif
│ ├── man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif
│ ├── man-skiing_smooth_mickey mouse is skiing on the snow.gif
│ ├── man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif
│ ├── man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif
│ ├── man-skiing_spider man is skiing on the beach, cartoon style.gif
│ ├── man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif
│ ├── rabbit-watermelon.gif
│ ├── rabbit-watermelon_a puppy is eating an orange.gif
│ ├── rabbit-watermelon_a rabbit is eating a pizza.gif
│ ├── rabbit-watermelon_a rabbit is eating an orange.gif
│ ├── rabbit-watermelon_a tiger is eating a watermelon.gif
│ ├── rabbit-watermelon_smooth_a puppy is eating an orange.gif
│ ├── rabbit-watermelon_smooth_a rabbit is eating a pizza.gif
│ ├── rabbit-watermelon_smooth_a rabbit is eating an orange.gif
│ └── rabbit-watermelon_smooth_a tiger is eating a watermelon.gif
└── video2video-zero
│ ├── mini-cooper.gif
│ ├── mini-cooper_make it animation_InstructVideo2Video-zero.gif
│ ├── mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif
│ ├── mini-cooper_make it animation_Video-InstructPix2Pix.gif
│ └── mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif
├── configs
├── car-turn.yaml
├── man-skiing.yaml
├── man-surfing.yaml
└── rabbit-watermelon.yaml
├── data
├── car-turn.mp4
├── man-skiing.mp4
├── man-surfing.mp4
└── rabbit-watermelon.mp4
├── requirements.txt
├── train_tuneavideo.py
└── tuneavideo
├── data
└── dataset.py
├── models
├── attention.py
├── resnet.py
├── unet.py
└── unet_blocks.py
├── pipelines
└── pipeline_tuneavideo.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | led / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | env/
9 | build/
10 | dist/
11 | *.log
12 |
13 | # pyenv
14 | .python-version
15 |
16 | # dotenv
17 | .env
18 |
19 | # virtualenv
20 | .venv/
21 | venv/
22 | ENV/
23 |
24 | # VSCode settings
25 | .vscode
26 |
27 | # IDEA files
28 | .idea
29 |
30 | # OSX dir files
31 | .DS_Store
32 |
33 | # Sublime Text settings
34 | *.sublime-workspace
35 | *.sublime-project
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SmoothVideo
2 |
3 | This repository is the official implementation of **Smooth Video Synthesis with Noise Constraints on Diffusion Models for One-shot Video Tuning**.
4 |
5 |
6 |
7 |
8 | ## Setup
9 | This implementation is based on [Tune-A-Video](https://github.com/showlab/Tune-A-Video).
10 |
11 | ### Requirements
12 |
13 | ```shell
14 | pip install -r requirements.txt
15 | ```
16 |
17 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for more efficiency and speed on GPUs.
18 | To enable xformers, set `enable_xformers_memory_efficient_attention=True` (default).
19 |
20 | ### Weights
21 |
22 | **[Stable Diffusion]** [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The pre-trained Stable Diffusion models can be downloaded from Hugging Face (e.g., [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5))).
23 |
24 |
25 | ## Usage
26 |
27 | ### Training
28 |
29 | To fine-tune the text-to-image diffusion models for text-to-video generation, run this command for the baseline model:
30 | ```bash
31 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml"
32 | ```
33 |
34 | Run this command for the baseline model with the proposed smooth loss:
35 | ```bash
36 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml" --smooth_loss
37 | ```
38 |
39 | Run this command for the baseline model with the proposed simple smooth loss:
40 | ```bash
41 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml" --smooth_loss --simple_manner
42 | ```
43 |
44 |
45 |
46 | Note: Tuning a 24-frame video usually takes `300~500` steps, about `10~15` minutes using one A100 GPU.
47 | Reduce `n_sample_frames` if your GPU memory is limited.
48 |
49 | ### Inference
50 |
51 | Once the training is done, run inference:
52 |
53 | ```python
54 | from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
55 | from tuneavideo.models.unet import UNet3DConditionModel
56 | from tuneavideo.util import save_videos_grid
57 | import torch
58 |
59 | pretrained_model_path = "./checkpoints/stable-diffusion-v1-5"
60 | my_model_path = "./outputs/man-skiing"
61 | unet = UNet3DConditionModel.from_pretrained(my_model_path, subfolder='unet', torch_dtype=torch.float16).to('cuda')
62 | pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16).to("cuda")
63 | pipe.enable_xformers_memory_efficient_attention()
64 | pipe.enable_vae_slicing()
65 |
66 | prompt = "spider man is skiing"
67 | ddim_inv_latent = torch.load(f"{my_model_path}/inv_latents/ddim_latent-500.pt").to(torch.float16)
68 | video = pipe(prompt, latents=ddim_inv_latent, video_length=24, height=512, width=512, num_inference_steps=50, guidance_scale=7.5).videos
69 |
70 | save_videos_grid(video, f"./{prompt}.gif")
71 | ```
72 |
73 |
74 |
75 | **We provide comparisons with different baselines, as follows:**
76 |
77 |
78 |
79 | ## Results
80 |
81 |
82 |
83 | ### Tune-A-Video
84 |
85 | Comparisons to [Tune-A-Video](https://github.com/showlab/Tune-A-Video).
86 |
87 |
88 |
89 | Input video |
90 | Tune-A-Video |
91 |
92 |
93 |  |
94 |  |
95 |  |
96 |  |
97 |  |
98 |
99 |
100 | Input video |
101 | Tune-A-Video + smooth loss |
102 |
103 |
104 |  |
105 |  |
106 |  |
107 |  |
108 |  |
109 |
110 |
111 | A jeep car is moving on the road |
112 | A jeep car is moving on the beach |
113 | A jeep car is moving on the snow |
114 | A jeep car is moving on the road, cartoon style |
115 | A sports car is moving on the road |
116 |
117 |
118 | Input video |
119 | Tune-A-Video |
120 |
121 |
122 |  |
123 |  |
124 |  |
125 |  |
126 |  |
127 |
128 |
129 | Input video |
130 | Tune-A-Video + smooth loss |
131 |
132 |
133 |  |
134 |  |
135 |  |
136 |  |
137 |  |
138 |
139 |
140 | A rabbit is eating a watermelon |
141 | A tiger is eating a watermelon |
142 | A rabbit is eating an orange |
143 | A rabbit is eating a pizza |
144 | A puppy is eating an orange |
145 |
146 |
147 | Input video |
148 | Tune-A-Video |
149 |
150 |
151 |  |
152 |  |
153 |  |
154 |  |
155 |  |
156 |
157 |
158 | Input video |
159 | Tune-A-Video + smooth loss |
160 |
161 |
162 |  |
163 |  |
164 |  |
165 |  |
166 |  |
167 |
168 |
169 | A man is skiing |
170 | Mickey mouse is skiing on the snow |
171 | Spider man is skiing on the beach, cartoon style |
172 | Wonder woman, wearing a cowboy hat, is skiing |
173 | A man, wearing pink clothes, is skiing at sunset |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 | ### Make-A-Protagonist
184 |
185 | Comparisons to [Make-A-Protagonist](https://github.com/HeliosZhao/Make-A-Protagonist).
186 |
187 |
188 |
189 | Input video |
190 | Make-A-Protagonist |
191 | Make-A-Protagonist + smooth loss |
192 |
193 |
194 |  |
195 |  |
196 |  |
197 |
198 |
199 | A jeep driving down a mountain road |
200 | A jeep driving down a mountain road in the rain |
201 |
202 |
203 |  |
204 |  |
205 |  |
206 |
207 |
208 | A man is playing basketball |
209 | A man is playing a basketball on the beach, anime style |
210 |
211 |
212 |  |
213 |  |
214 |  |
215 |
216 |
217 | A man walking down the street at night |
218 | A panda walking down the snowy street |
219 |
220 |
221 |  |
222 |  |
223 |  |
224 |
225 |
226 | A man waling down the street |
227 | Elon musk walking down the street |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 | ### ControlVideo
238 |
239 | Comparisons to [ControlVideo](https://github.com/thu-ml/controlvideo).
240 |
241 |
242 |
243 | Input video |
244 | Condition |
245 | ControlVideo |
246 | ControlVideo + smooth loss |
247 |
248 |
249 |  |
250 |  |
251 |  |
252 |  |
253 |
254 |
255 | A person is dancing |
256 | Pose condition |
257 | Michael Jackson is dancing |
258 |
259 |
260 |  |
261 |  |
262 |  |
263 |  |
264 |
265 |
266 | A person is dancing |
267 | Pose condition |
268 | A person is dancing, Makoto Shinkai style |
269 |
270 |
271 |  |
272 |  |
273 |  |
274 |  |
275 |
276 |
277 | A building |
278 | Canny edge condition |
279 | A wooden building, at night |
280 |
281 |
282 |  |
283 |  |
284 |  |
285 |  |
286 |
287 |
288 | A girl |
289 | Hed edge condition |
290 | A girl, Krenz Cushart style |
291 |
292 |
293 |  |
294 |  |
295 |  |
296 |  |
297 |
298 |
299 | A girl |
300 | Hed edge condition |
301 | A girl with rich makeup |
302 |
303 |
304 |  |
305 |  |
306 |  |
307 |  |
308 |
309 |
310 | Ink diffuses in water |
311 | Depth condition |
312 | Gentle green ink diffuses in water, beautiful light |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 | ### Video2Video-zero
323 |
324 | Comparisons to [Training-free methods](https://github.com/Picsart-AI-Research/Text2Video-Zero).
325 |
326 |
327 |
328 | Input video |
329 | Instruct Video2Video-zero |
330 | Instruct Video2Video-zero + noise constraint |
331 | Video InstructPix2Pix |
332 | Video InstructPix2Pix + noise constraint |
333 |
334 |
335 |  |
336 |  |
337 |  |
338 |  |
339 |  |
340 |
341 |
342 | |
343 | Instruct: Make it animation |
344 |
345 |
346 |
347 |
348 |
--------------------------------------------------------------------------------
/assets/controlvideo/building1-a wooden building, at night-80_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/building1-a wooden building, at night-80_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/building1-a wooden building, at night-80_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/building1-a wooden building, at night-80_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/building1-a wooden building, at night-80_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_3.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_3.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_3.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_3.gif
--------------------------------------------------------------------------------
/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif
--------------------------------------------------------------------------------
/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif
--------------------------------------------------------------------------------
/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif
--------------------------------------------------------------------------------
/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif
--------------------------------------------------------------------------------
/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style-org.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style-org.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain-org.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain-org.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-a panda walking down the snowy street-org.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a panda walking down the snowy street-org.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-a panda walking down the snowy street.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a panda walking down the snowy street.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-elon musk walking down the street-org.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-elon musk walking down the street-org.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/0-elon musk walking down the street.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-elon musk walking down the street.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/car-turn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/car-turn.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/huaqiang.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/huaqiang.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/ikun.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/ikun.gif
--------------------------------------------------------------------------------
/assets/make-a-protagonist/yanzi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/yanzi.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_a jeep car is moving on the beach.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the beach.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_a jeep car is moving on the road, cartoon style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the road, cartoon style.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_a jeep car is moving on the snow.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the snow.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_a sports car is moving on the road.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a sports car is moving on the road.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the beach.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the beach.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the road, cartoon style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the road, cartoon style.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the snow.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the snow.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/car-turn_smooth_a sports car is moving on the road.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a sports car is moving on the road.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_a man, wearing pink clothes, is skiing at sunset.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_a man, wearing pink clothes, is skiing at sunset.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_mickey mouse is skiing on the snow.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_mickey mouse is skiing on the snow.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_smooth_mickey mouse is skiing on the snow.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_mickey mouse is skiing on the snow.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_spider man is skiing on the beach, cartoon style.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_spider man is skiing on the beach, cartoon style.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_a puppy is eating an orange.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a puppy is eating an orange.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_a rabbit is eating a pizza.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a rabbit is eating a pizza.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_a rabbit is eating an orange.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a rabbit is eating an orange.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_a tiger is eating a watermelon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a tiger is eating a watermelon.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_smooth_a puppy is eating an orange.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a puppy is eating an orange.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating a pizza.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating a pizza.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating an orange.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating an orange.gif
--------------------------------------------------------------------------------
/assets/tune-a-video/rabbit-watermelon_smooth_a tiger is eating a watermelon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a tiger is eating a watermelon.gif
--------------------------------------------------------------------------------
/assets/video2video-zero/mini-cooper.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper.gif
--------------------------------------------------------------------------------
/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero.gif
--------------------------------------------------------------------------------
/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif
--------------------------------------------------------------------------------
/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix.gif
--------------------------------------------------------------------------------
/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif
--------------------------------------------------------------------------------
/configs/car-turn.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5"
2 | output_dir: "./outputs/car-turn"
3 |
4 | train_data:
5 | video_path: "data/car-turn.mp4"
6 | prompt: "a jeep car is moving on the road"
7 | n_sample_frames: 24
8 | width: 512
9 | height: 512
10 | sample_start_idx: 0
11 | sample_frame_rate: 2
12 |
13 | validation_data:
14 | prompts:
15 | - "a jeep car is moving on the beach"
16 | - "a jeep car is moving on the snow"
17 | - "a jeep car is moving on the road, cartoon style"
18 | - "a sports car is moving on the road"
19 | video_length: 24
20 | width: 512
21 | height: 512
22 | num_inference_steps: 50
23 | guidance_scale: 7.5
24 | use_inv_latent: True
25 | num_inv_steps: 50
26 |
27 | learning_rate: 3e-5
28 | train_batch_size: 1
29 | max_train_steps: 500
30 | checkpointing_steps: 1000
31 | validation_steps: 100
32 | trainable_modules:
33 | - "attn1.to_q"
34 | - "attn2.to_q"
35 | - "attn_temp"
36 |
37 | seed: 33
38 | mixed_precision: fp16
39 | use_8bit_adam: False
40 | gradient_checkpointing: True
41 | enable_xformers_memory_efficient_attention: True
42 |
--------------------------------------------------------------------------------
/configs/man-skiing.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5"
2 | output_dir: "./outputs/man-skiing"
3 |
4 | train_data:
5 | video_path: "data/man-skiing.mp4"
6 | prompt: "a man is skiing"
7 | n_sample_frames: 24
8 | width: 512
9 | height: 512
10 | sample_start_idx: 0
11 | sample_frame_rate: 2
12 |
13 | validation_data:
14 | prompts:
15 | - "mickey mouse is skiing on the snow"
16 | - "spider man is skiing on the beach, cartoon style"
17 | - "wonder woman, wearing a cowboy hat, is skiing"
18 | - "a man, wearing pink clothes, is skiing at sunset"
19 | video_length: 24
20 | width: 512
21 | height: 512
22 | num_inference_steps: 50
23 | guidance_scale: 7.5
24 | use_inv_latent: True
25 | num_inv_steps: 50
26 |
27 | learning_rate: 3e-5
28 | train_batch_size: 1
29 | max_train_steps: 500
30 | checkpointing_steps: 1000
31 | validation_steps: 100
32 | trainable_modules:
33 | - "attn1.to_q"
34 | - "attn2.to_q"
35 | - "attn_temp"
36 |
37 | seed: 33
38 | mixed_precision: fp16
39 | use_8bit_adam: False
40 | gradient_checkpointing: True
41 | enable_xformers_memory_efficient_attention: True
42 |
--------------------------------------------------------------------------------
/configs/man-surfing.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5"
2 | output_dir: "./outputs/man-surfing"
3 |
4 | train_data:
5 | video_path: "data/man-surfing.mp4"
6 | prompt: "a man is surfing"
7 | n_sample_frames: 24
8 | width: 512
9 | height: 512
10 | sample_start_idx: 0
11 | sample_frame_rate: 1
12 |
13 | validation_data:
14 | prompts:
15 | - "a panda is surfing"
16 | - "a boy, wearing a birthday hat, is surfing"
17 | - "a raccoon is surfing, cartoon style"
18 | - "Iron Man is surfing in the desert"
19 | video_length: 24
20 | width: 512
21 | height: 512
22 | num_inference_steps: 50
23 | guidance_scale: 7.5
24 | use_inv_latent: True
25 | num_inv_steps: 50
26 |
27 | learning_rate: 3e-5
28 | train_batch_size: 1
29 | max_train_steps: 500
30 | checkpointing_steps: 1000
31 | validation_steps: 100
32 | trainable_modules:
33 | - "attn1.to_q"
34 | - "attn2.to_q"
35 | - "attn_temp"
36 |
37 | seed: 33
38 | mixed_precision: fp16
39 | use_8bit_adam: False
40 | gradient_checkpointing: True
41 | enable_xformers_memory_efficient_attention: True
42 |
--------------------------------------------------------------------------------
/configs/rabbit-watermelon.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5"
2 | output_dir: "./outputs/rabbit-watermelon"
3 |
4 | train_data:
5 | video_path: "data/rabbit-watermelon.mp4"
6 | prompt: "a rabbit is eating a watermelon"
7 | n_sample_frames: 24
8 | width: 512
9 | height: 512
10 | sample_start_idx: 0
11 | sample_frame_rate: 2
12 |
13 | validation_data:
14 | prompts:
15 | - "a tiger is eating a watermelon"
16 | - "a rabbit is eating an orange"
17 | - "a rabbit is eating a pizza"
18 | - "a puppy is eating an orange"
19 | video_length: 24
20 | width: 512
21 | height: 512
22 | num_inference_steps: 50
23 | guidance_scale: 7.5
24 | use_inv_latent: True
25 | num_inv_steps: 50
26 |
27 | learning_rate: 3e-5
28 | train_batch_size: 1
29 | max_train_steps: 500
30 | checkpointing_steps: 1000
31 | validation_steps: 100
32 | trainable_modules:
33 | - "attn1.to_q"
34 | - "attn2.to_q"
35 | - "attn_temp"
36 |
37 | seed: 33
38 | mixed_precision: fp16
39 | use_8bit_adam: False
40 | gradient_checkpointing: True
41 | enable_xformers_memory_efficient_attention: True
42 |
--------------------------------------------------------------------------------
/data/car-turn.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/car-turn.mp4
--------------------------------------------------------------------------------
/data/man-skiing.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/man-skiing.mp4
--------------------------------------------------------------------------------
/data/man-surfing.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/man-surfing.mp4
--------------------------------------------------------------------------------
/data/rabbit-watermelon.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/rabbit-watermelon.mp4
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1
2 | torchvision==0.13.1
3 | diffusers[torch]==0.11.1
4 | transformers>=4.25.1
5 | bitsandbytes==0.35.4
6 | decord==0.6.0
7 | accelerate
8 | tensorboard
9 | modelcards
10 | omegaconf
11 | einops
12 | imageio
13 | ftfy
--------------------------------------------------------------------------------
/train_tuneavideo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import logging
4 | import inspect
5 | import math
6 | import os
7 | from typing import Dict, Optional, Tuple
8 | from omegaconf import OmegaConf
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | import torch.utils.checkpoint
13 |
14 | import diffusers
15 | import transformers
16 | from accelerate import Accelerator
17 | from accelerate.logging import get_logger
18 | from accelerate.utils import set_seed
19 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
20 | from diffusers.optimization import get_scheduler
21 | from diffusers.utils import check_min_version
22 | from diffusers.utils.import_utils import is_xformers_available
23 | from tqdm.auto import tqdm
24 | from transformers import CLIPTextModel, CLIPTokenizer
25 |
26 | from tuneavideo.models.unet import UNet3DConditionModel
27 | from tuneavideo.data.dataset import TuneAVideoDataset
28 | from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29 | from tuneavideo.util import save_videos_grid, ddim_inversion
30 | from einops import rearrange
31 |
32 |
33 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
34 | check_min_version("0.10.0.dev0")
35 |
36 | logger = get_logger(__name__, log_level="INFO")
37 |
38 |
39 | def main(
40 | pretrained_model_path: str,
41 | output_dir: str,
42 | train_data: Dict,
43 | validation_data: Dict,
44 | validation_steps: int = 100,
45 | trainable_modules: Tuple[str] = (
46 | "attn1.to_q",
47 | "attn2.to_q",
48 | "attn_temp",
49 | ),
50 | train_batch_size: int = 1,
51 | max_train_steps: int = 500,
52 | learning_rate: float = 3e-5,
53 | scale_lr: bool = False,
54 | lr_scheduler: str = "constant",
55 | lr_warmup_steps: int = 0,
56 | adam_beta1: float = 0.9,
57 | adam_beta2: float = 0.999,
58 | adam_weight_decay: float = 1e-2,
59 | adam_epsilon: float = 1e-08,
60 | max_grad_norm: float = 1.0,
61 | gradient_accumulation_steps: int = 1,
62 | gradient_checkpointing: bool = True,
63 | checkpointing_steps: int = 500,
64 | resume_from_checkpoint: Optional[str] = None,
65 | mixed_precision: Optional[str] = "fp16",
66 | use_8bit_adam: bool = False,
67 | enable_xformers_memory_efficient_attention: bool = True,
68 | seed: Optional[int] = None,
69 | enable_smooth_loss=False,
70 | smooth_weight=0.2,
71 | lambda_factor=1000,
72 | simple_manner=False
73 | ):
74 | *_, config = inspect.getargvalues(inspect.currentframe())
75 |
76 | accelerator = Accelerator(
77 | gradient_accumulation_steps=gradient_accumulation_steps,
78 | mixed_precision=mixed_precision,
79 | )
80 |
81 | # Make one log on every process with the configuration for debugging.
82 | logging.basicConfig(
83 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
84 | datefmt="%m/%d/%Y %H:%M:%S",
85 | level=logging.INFO,
86 | )
87 | logger.info(accelerator.state, main_process_only=False)
88 | if accelerator.is_local_main_process:
89 | transformers.utils.logging.set_verbosity_warning()
90 | diffusers.utils.logging.set_verbosity_info()
91 | else:
92 | transformers.utils.logging.set_verbosity_error()
93 | diffusers.utils.logging.set_verbosity_error()
94 |
95 | # If passed along, set the training seed now.
96 | if seed is not None:
97 | set_seed(seed)
98 |
99 | # Handle the output folder creation
100 | if accelerator.is_main_process:
101 | # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
102 | # output_dir = os.path.join(output_dir, now)
103 | os.makedirs(output_dir, exist_ok=True)
104 | os.makedirs(f"{output_dir}/samples", exist_ok=True)
105 | os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
106 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
107 |
108 | # Load scheduler, tokenizer and models.
109 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
110 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
111 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
112 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
113 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet")
114 |
115 | # Freeze vae and text_encoder
116 | vae.requires_grad_(False)
117 | text_encoder.requires_grad_(False)
118 |
119 | unet.requires_grad_(False)
120 | for name, module in unet.named_modules():
121 | if name.endswith(tuple(trainable_modules)):
122 | for params in module.parameters():
123 | params.requires_grad = True
124 |
125 | if enable_xformers_memory_efficient_attention:
126 | if is_xformers_available():
127 | unet.enable_xformers_memory_efficient_attention()
128 | else:
129 | raise ValueError("xformers is not available. Make sure it is installed correctly")
130 |
131 | if gradient_checkpointing:
132 | unet.enable_gradient_checkpointing()
133 |
134 | if scale_lr:
135 | learning_rate = (
136 | learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
137 | )
138 |
139 | # Initialize the optimizer
140 | if use_8bit_adam:
141 | try:
142 | import bitsandbytes as bnb
143 | except ImportError:
144 | raise ImportError(
145 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
146 | )
147 |
148 | optimizer_cls = bnb.optim.AdamW8bit
149 | else:
150 | optimizer_cls = torch.optim.AdamW
151 |
152 | optimizer = optimizer_cls(
153 | unet.parameters(),
154 | lr=learning_rate,
155 | betas=(adam_beta1, adam_beta2),
156 | weight_decay=adam_weight_decay,
157 | eps=adam_epsilon,
158 | )
159 |
160 | # Get the training dataset
161 | train_dataset = TuneAVideoDataset(**train_data)
162 |
163 | # Preprocessing the dataset
164 | train_dataset.prompt_ids = tokenizer(
165 | train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
166 | ).input_ids[0]
167 |
168 | # DataLoaders creation:
169 | train_dataloader = torch.utils.data.DataLoader(
170 | train_dataset, batch_size=train_batch_size
171 | )
172 |
173 | # Get the validation pipeline
174 | validation_pipeline = TuneAVideoPipeline(
175 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
176 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
177 | )
178 | validation_pipeline.enable_vae_slicing()
179 | ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
180 | ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
181 |
182 | # Scheduler
183 | lr_scheduler = get_scheduler(
184 | lr_scheduler,
185 | optimizer=optimizer,
186 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
187 | num_training_steps=max_train_steps * gradient_accumulation_steps,
188 | )
189 |
190 | # Prepare everything with our `accelerator`.
191 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
192 | unet, optimizer, train_dataloader, lr_scheduler
193 | )
194 |
195 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
196 | # as these models are only used for inference, keeping weights in full precision is not required.
197 | weight_dtype = torch.float32
198 | if accelerator.mixed_precision == "fp16":
199 | weight_dtype = torch.float16
200 | elif accelerator.mixed_precision == "bf16":
201 | weight_dtype = torch.bfloat16
202 |
203 | # Move text_encode and vae to gpu and cast to weight_dtype
204 | text_encoder.to(accelerator.device, dtype=weight_dtype)
205 | vae.to(accelerator.device, dtype=weight_dtype)
206 |
207 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
208 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
209 | # Afterwards we recalculate our number of training epochs
210 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
211 |
212 | # We need to initialize the trackers we use, and also store our configuration.
213 | # The trackers initializes automatically on the main process.
214 | if accelerator.is_main_process:
215 | accelerator.init_trackers("text2video-fine-tune")
216 |
217 | # Train!
218 | total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
219 |
220 | logger.info("***** Running training *****")
221 | logger.info(f" Num examples = {len(train_dataset)}")
222 | logger.info(f" Num Epochs = {num_train_epochs}")
223 | logger.info(f" Instantaneous batch size per device = {train_batch_size}")
224 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
225 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
226 | logger.info(f" Total optimization steps = {max_train_steps}")
227 | global_step = 0
228 | first_epoch = 0
229 |
230 | # Potentially load in the weights and states from a previous save
231 | if resume_from_checkpoint:
232 | if resume_from_checkpoint != "latest":
233 | path = os.path.basename(resume_from_checkpoint)
234 | else:
235 | # Get the most recent checkpoint
236 | dirs = os.listdir(output_dir)
237 | dirs = [d for d in dirs if d.startswith("checkpoint")]
238 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
239 | path = dirs[-1]
240 | accelerator.print(f"Resuming from checkpoint {path}")
241 | accelerator.load_state(os.path.join(output_dir, path))
242 | global_step = int(path.split("-")[1])
243 |
244 | first_epoch = global_step // num_update_steps_per_epoch
245 | resume_step = global_step % num_update_steps_per_epoch
246 |
247 | # Only show the progress bar once on each machine.
248 | progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
249 | progress_bar.set_description("Steps")
250 |
251 | for epoch in range(first_epoch, num_train_epochs):
252 | unet.train()
253 | train_loss = 0.0
254 | for step, batch in enumerate(train_dataloader):
255 | # Skip steps until we reach the resumed step
256 | if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
257 | if step % gradient_accumulation_steps == 0:
258 | progress_bar.update(1)
259 | continue
260 |
261 | with accelerator.accumulate(unet):
262 | # Convert videos to latent space
263 | pixel_values = batch["pixel_values"].to(weight_dtype)
264 | video_length = pixel_values.shape[1]
265 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
266 | latents = vae.encode(pixel_values).latent_dist.sample()
267 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
268 | latents = latents * 0.18215
269 |
270 | # Sample noise that we'll add to the latents
271 | noise = torch.randn_like(latents)
272 | bsz = latents.shape[0]
273 | # Sample a random timestep for each video
274 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
275 | timesteps = timesteps.long()
276 |
277 | # Add noise to the latents according to the noise magnitude at each timestep
278 | # (this is the forward diffusion process)
279 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
280 |
281 | # Get the text embedding for conditioning
282 | encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
283 |
284 | # Get the target for loss depending on the prediction type
285 | if noise_scheduler.prediction_type == "epsilon":
286 | target = noise
287 | elif noise_scheduler.prediction_type == "v_prediction":
288 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
289 | else:
290 | raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
291 |
292 | # Predict the noise residual and compute loss
293 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
294 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
295 |
296 |
297 | if enable_smooth_loss:
298 | noisy_pred_delta = (model_pred[:, :, 1:-1, ...] - model_pred[:, :, :-2, ...]) + \
299 | (model_pred[:, :, 1:-1, ...] - model_pred[:, :, 2:, ...])
300 | if simple_manner:
301 | smooth_loss = torch.abs(noisy_pred_delta).mean()
302 | else:
303 | noisy_latents_delta = (noisy_latents[:, :, 1:-1, ...] - noisy_latents[:, :, :-2, ...]) + \
304 | (noisy_latents[:, :, 1:-1, ...] - noisy_latents[:, :, 2:, ...])
305 |
306 | alphas_cumprod = noise_scheduler.alphas_cumprod
307 | t = timesteps
308 | C = torch.sqrt(alphas_cumprod[t]) / (torch.sqrt(alphas_cumprod[t])*torch.sqrt(1-alphas_cumprod[t-1]) -
309 | torch.sqrt(alphas_cumprod[t-1])*torch.sqrt(1-alphas_cumprod[t]))
310 | smooth_loss = torch.abs(noisy_pred_delta - (noisy_latents_delta * C/lambda_factor)).mean()
311 | loss += smooth_weight * smooth_loss
312 |
313 |
314 | # Gather the losses across all processes for logging (if we use distributed training).
315 | avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
316 | train_loss += avg_loss.item() / gradient_accumulation_steps
317 |
318 | # Backpropagate
319 | accelerator.backward(loss)
320 | if accelerator.sync_gradients:
321 | accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
322 | optimizer.step()
323 | lr_scheduler.step()
324 | optimizer.zero_grad()
325 |
326 | # Checks if the accelerator has performed an optimization step behind the scenes
327 | if accelerator.sync_gradients:
328 | progress_bar.update(1)
329 | global_step += 1
330 | accelerator.log({"train_loss": train_loss}, step=global_step)
331 | train_loss = 0.0
332 |
333 | if global_step % checkpointing_steps == 0:
334 | if accelerator.is_main_process:
335 | save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
336 | accelerator.save_state(save_path)
337 | logger.info(f"Saved state to {save_path}")
338 |
339 | if global_step % validation_steps == 0:
340 | if accelerator.is_main_process:
341 | samples = []
342 | generator = torch.Generator(device=latents.device)
343 | generator.manual_seed(seed)
344 |
345 | ddim_inv_latent = None
346 | if validation_data.use_inv_latent:
347 | inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
348 | ddim_inv_latent = ddim_inversion(
349 | validation_pipeline, ddim_inv_scheduler, video_latent=latents,
350 | num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype)
351 | torch.save(ddim_inv_latent, inv_latents_path)
352 |
353 | for idx, prompt in enumerate(validation_data.prompts):
354 | sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent,
355 | **validation_data).videos
356 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{prompt}.gif")
357 | samples.append(sample)
358 | samples = torch.concat(samples)
359 | save_path = f"{output_dir}/samples/sample-{global_step}.gif"
360 | save_videos_grid(samples, save_path)
361 | logger.info(f"Saved samples to {save_path}")
362 |
363 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
364 | progress_bar.set_postfix(**logs)
365 |
366 | if global_step >= max_train_steps:
367 | break
368 |
369 | # Create the pipeline using the trained modules and save it.
370 | accelerator.wait_for_everyone()
371 | if accelerator.is_main_process:
372 | unet = accelerator.unwrap_model(unet)
373 | pipeline = TuneAVideoPipeline.from_pretrained(
374 | pretrained_model_path,
375 | text_encoder=text_encoder,
376 | vae=vae,
377 | unet=unet,
378 | )
379 | pipeline.save_pretrained(output_dir)
380 |
381 | accelerator.end_training()
382 |
383 |
384 | if __name__ == "__main__":
385 | parser = argparse.ArgumentParser()
386 | parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
387 | parser.add_argument("--smooth_loss", action="store_true")
388 | parser.add_argument("--smooth_weight", type=float, default=0.2)
389 | parser.add_argument("--lambda_factor", type=float, default=1000)
390 | parser.add_argument("--simple_manner", action="store_true")
391 | args = parser.parse_args()
392 |
393 | main(**OmegaConf.load(args.config),
394 | enable_smooth_loss=args.smooth_loss,
395 | smooth_weight=args.smooth_weight,
396 | lambda_factor=args.lambda_factor,
397 | simple_manner=args.simple_manner
398 | )
399 |
--------------------------------------------------------------------------------
/tuneavideo/data/dataset.py:
--------------------------------------------------------------------------------
1 | import decord
2 | decord.bridge.set_bridge('torch')
3 |
4 | from torch.utils.data import Dataset
5 | from einops import rearrange
6 |
7 |
8 | class TuneAVideoDataset(Dataset):
9 | def __init__(
10 | self,
11 | video_path: str,
12 | prompt: str,
13 | width: int = 512,
14 | height: int = 512,
15 | n_sample_frames: int = 8,
16 | sample_start_idx: int = 0,
17 | sample_frame_rate: int = 1,
18 | ):
19 | self.video_path = video_path
20 | self.prompt = prompt
21 | self.prompt_ids = None
22 |
23 | self.width = width
24 | self.height = height
25 | self.n_sample_frames = n_sample_frames
26 | self.sample_start_idx = sample_start_idx
27 | self.sample_frame_rate = sample_frame_rate
28 |
29 | def __len__(self):
30 | return 1
31 |
32 | def __getitem__(self, index):
33 | # load and sample video frames
34 | vr = decord.VideoReader(self.video_path, width=self.width, height=self.height)
35 | sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
36 | video = vr.get_batch(sample_index)
37 | video = rearrange(video, "f h w c -> f c h w")
38 |
39 | example = {
40 | "pixel_values": (video / 127.5 - 1.0),
41 | "prompt_ids": self.prompt_ids
42 | }
43 |
44 | return example
45 |
--------------------------------------------------------------------------------
/tuneavideo/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config
11 | from diffusers.modeling_utils import ModelMixin
12 | from diffusers.utils import BaseOutput
13 | from diffusers.utils.import_utils import is_xformers_available
14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15 |
16 | from einops import rearrange, repeat
17 |
18 |
19 | @dataclass
20 | class Transformer3DModelOutput(BaseOutput):
21 | sample: torch.FloatTensor
22 |
23 |
24 | if is_xformers_available():
25 | import xformers
26 | import xformers.ops
27 | else:
28 | xformers = None
29 |
30 |
31 | class Transformer3DModel(ModelMixin, ConfigMixin):
32 | @register_to_config
33 | def __init__(
34 | self,
35 | num_attention_heads: int = 16,
36 | attention_head_dim: int = 88,
37 | in_channels: Optional[int] = None,
38 | num_layers: int = 1,
39 | dropout: float = 0.0,
40 | norm_num_groups: int = 32,
41 | cross_attention_dim: Optional[int] = None,
42 | attention_bias: bool = False,
43 | activation_fn: str = "geglu",
44 | num_embeds_ada_norm: Optional[int] = None,
45 | use_linear_projection: bool = False,
46 | only_cross_attention: bool = False,
47 | upcast_attention: bool = False,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59 | if use_linear_projection:
60 | self.proj_in = nn.Linear(in_channels, inner_dim)
61 | else:
62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63 |
64 | # Define transformers blocks
65 | self.transformer_blocks = nn.ModuleList(
66 | [
67 | BasicTransformerBlock(
68 | inner_dim,
69 | num_attention_heads,
70 | attention_head_dim,
71 | dropout=dropout,
72 | cross_attention_dim=cross_attention_dim,
73 | activation_fn=activation_fn,
74 | num_embeds_ada_norm=num_embeds_ada_norm,
75 | attention_bias=attention_bias,
76 | only_cross_attention=only_cross_attention,
77 | upcast_attention=upcast_attention,
78 | )
79 | for d in range(num_layers)
80 | ]
81 | )
82 |
83 | # 4. Define output layers
84 | if use_linear_projection:
85 | self.proj_out = nn.Linear(in_channels, inner_dim)
86 | else:
87 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
88 |
89 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
90 | # Input
91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
92 | video_length = hidden_states.shape[2]
93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
94 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
95 |
96 | batch, channel, height, weight = hidden_states.shape
97 | residual = hidden_states
98 |
99 | hidden_states = self.norm(hidden_states)
100 | if not self.use_linear_projection:
101 | hidden_states = self.proj_in(hidden_states)
102 | inner_dim = hidden_states.shape[1]
103 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
104 | else:
105 | inner_dim = hidden_states.shape[1]
106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107 | hidden_states = self.proj_in(hidden_states)
108 |
109 | # Blocks
110 | for block in self.transformer_blocks:
111 | hidden_states = block(
112 | hidden_states,
113 | encoder_hidden_states=encoder_hidden_states,
114 | timestep=timestep,
115 | video_length=video_length
116 | )
117 |
118 | # Output
119 | if not self.use_linear_projection:
120 | hidden_states = (
121 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
122 | )
123 | hidden_states = self.proj_out(hidden_states)
124 | else:
125 | hidden_states = self.proj_out(hidden_states)
126 | hidden_states = (
127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128 | )
129 |
130 | output = hidden_states + residual
131 |
132 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
133 | if not return_dict:
134 | return (output,)
135 |
136 | return Transformer3DModelOutput(sample=output)
137 |
138 |
139 | class BasicTransformerBlock(nn.Module):
140 | def __init__(
141 | self,
142 | dim: int,
143 | num_attention_heads: int,
144 | attention_head_dim: int,
145 | dropout=0.0,
146 | cross_attention_dim: Optional[int] = None,
147 | activation_fn: str = "geglu",
148 | num_embeds_ada_norm: Optional[int] = None,
149 | attention_bias: bool = False,
150 | only_cross_attention: bool = False,
151 | upcast_attention: bool = False,
152 | ):
153 | super().__init__()
154 | self.only_cross_attention = only_cross_attention
155 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
156 |
157 | # SC-Attn
158 | self.attn1 = SparseCausalAttention(
159 | query_dim=dim,
160 | heads=num_attention_heads,
161 | dim_head=attention_head_dim,
162 | dropout=dropout,
163 | bias=attention_bias,
164 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
165 | upcast_attention=upcast_attention,
166 | )
167 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
168 |
169 | # Cross-Attn
170 | if cross_attention_dim is not None:
171 | self.attn2 = CrossAttention(
172 | query_dim=dim,
173 | cross_attention_dim=cross_attention_dim,
174 | heads=num_attention_heads,
175 | dim_head=attention_head_dim,
176 | dropout=dropout,
177 | bias=attention_bias,
178 | upcast_attention=upcast_attention,
179 | )
180 | else:
181 | self.attn2 = None
182 |
183 | if cross_attention_dim is not None:
184 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
185 | else:
186 | self.norm2 = None
187 |
188 | # Feed-forward
189 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
190 | self.norm3 = nn.LayerNorm(dim)
191 |
192 | # Temp-Attn
193 | self.attn_temp = CrossAttention(
194 | query_dim=dim,
195 | heads=num_attention_heads,
196 | dim_head=attention_head_dim,
197 | dropout=dropout,
198 | bias=attention_bias,
199 | upcast_attention=upcast_attention,
200 | )
201 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
202 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
203 |
204 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
205 | if not is_xformers_available():
206 | print("Here is how to install it")
207 | raise ModuleNotFoundError(
208 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
209 | " xformers",
210 | name="xformers",
211 | )
212 | elif not torch.cuda.is_available():
213 | raise ValueError(
214 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
215 | " available for GPU "
216 | )
217 | else:
218 | try:
219 | # Make sure we can run the memory efficient attention
220 | _ = xformers.ops.memory_efficient_attention(
221 | torch.randn((1, 2, 40), device="cuda"),
222 | torch.randn((1, 2, 40), device="cuda"),
223 | torch.randn((1, 2, 40), device="cuda"),
224 | )
225 | except Exception as e:
226 | raise e
227 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
228 | if self.attn2 is not None:
229 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
230 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
231 |
232 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
233 | # SparseCausal-Attention
234 | norm_hidden_states = (
235 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
236 | )
237 |
238 | if self.only_cross_attention:
239 | hidden_states = (
240 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
241 | )
242 | else:
243 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
244 |
245 | if self.attn2 is not None:
246 | # Cross-Attention
247 | norm_hidden_states = (
248 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
249 | )
250 | hidden_states = (
251 | self.attn2(
252 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
253 | )
254 | + hidden_states
255 | )
256 |
257 | # Feed-forward
258 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
259 |
260 | # Temporal-Attention
261 | d = hidden_states.shape[1]
262 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
263 | norm_hidden_states = (
264 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
265 | )
266 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
267 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
268 |
269 | return hidden_states
270 |
271 |
272 | class SparseCausalAttention(CrossAttention):
273 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
274 | batch_size, sequence_length, _ = hidden_states.shape
275 |
276 | encoder_hidden_states = encoder_hidden_states
277 |
278 | if self.group_norm is not None:
279 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
280 |
281 | query = self.to_q(hidden_states)
282 | dim = query.shape[-1]
283 | query = self.reshape_heads_to_batch_dim(query)
284 |
285 | if self.added_kv_proj_dim is not None:
286 | raise NotImplementedError
287 |
288 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
289 | key = self.to_k(encoder_hidden_states)
290 | value = self.to_v(encoder_hidden_states)
291 |
292 | former_frame_index = torch.arange(video_length) - 1
293 | former_frame_index[0] = 0
294 |
295 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
296 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
297 | key = rearrange(key, "b f d c -> (b f) d c")
298 |
299 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
300 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
301 | value = rearrange(value, "b f d c -> (b f) d c")
302 |
303 | key = self.reshape_heads_to_batch_dim(key)
304 | value = self.reshape_heads_to_batch_dim(value)
305 |
306 | if attention_mask is not None:
307 | if attention_mask.shape[-1] != query.shape[1]:
308 | target_length = query.shape[1]
309 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
310 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
311 |
312 | # attention, what we cannot get enough of
313 | if self._use_memory_efficient_attention_xformers:
314 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
315 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
316 | hidden_states = hidden_states.to(query.dtype)
317 | else:
318 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
319 | hidden_states = self._attention(query, key, value, attention_mask)
320 | else:
321 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
322 |
323 | # linear proj
324 | hidden_states = self.to_out[0](hidden_states)
325 |
326 | # dropout
327 | hidden_states = self.to_out[1](hidden_states)
328 | return hidden_states
329 |
--------------------------------------------------------------------------------
/tuneavideo/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 |
18 | return x
19 |
20 |
21 | class Upsample3D(nn.Module):
22 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23 | super().__init__()
24 | self.channels = channels
25 | self.out_channels = out_channels or channels
26 | self.use_conv = use_conv
27 | self.use_conv_transpose = use_conv_transpose
28 | self.name = name
29 |
30 | conv = None
31 | if use_conv_transpose:
32 | raise NotImplementedError
33 | elif use_conv:
34 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35 |
36 | if name == "conv":
37 | self.conv = conv
38 | else:
39 | self.Conv2d_0 = conv
40 |
41 | def forward(self, hidden_states, output_size=None):
42 | assert hidden_states.shape[1] == self.channels
43 |
44 | if self.use_conv_transpose:
45 | raise NotImplementedError
46 |
47 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48 | dtype = hidden_states.dtype
49 | if dtype == torch.bfloat16:
50 | hidden_states = hidden_states.to(torch.float32)
51 |
52 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53 | if hidden_states.shape[0] >= 64:
54 | hidden_states = hidden_states.contiguous()
55 |
56 | # if `output_size` is passed we force the interpolation output
57 | # size and do not make use of `scale_factor=2`
58 | if output_size is None:
59 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60 | else:
61 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62 |
63 | # If the input is bfloat16, we cast back to bfloat16
64 | if dtype == torch.bfloat16:
65 | hidden_states = hidden_states.to(dtype)
66 |
67 | if self.use_conv:
68 | if self.name == "conv":
69 | hidden_states = self.conv(hidden_states)
70 | else:
71 | hidden_states = self.Conv2d_0(hidden_states)
72 |
73 | return hidden_states
74 |
75 |
76 | class Downsample3D(nn.Module):
77 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78 | super().__init__()
79 | self.channels = channels
80 | self.out_channels = out_channels or channels
81 | self.use_conv = use_conv
82 | self.padding = padding
83 | stride = 2
84 | self.name = name
85 |
86 | if use_conv:
87 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88 | else:
89 | raise NotImplementedError
90 |
91 | if name == "conv":
92 | self.Conv2d_0 = conv
93 | self.conv = conv
94 | elif name == "Conv2d_0":
95 | self.conv = conv
96 | else:
97 | self.conv = conv
98 |
99 | def forward(self, hidden_states):
100 | assert hidden_states.shape[1] == self.channels
101 | if self.use_conv and self.padding == 0:
102 | raise NotImplementedError
103 |
104 | assert hidden_states.shape[1] == self.channels
105 | hidden_states = self.conv(hidden_states)
106 |
107 | return hidden_states
108 |
109 |
110 | class ResnetBlock3D(nn.Module):
111 | def __init__(
112 | self,
113 | *,
114 | in_channels,
115 | out_channels=None,
116 | conv_shortcut=False,
117 | dropout=0.0,
118 | temb_channels=512,
119 | groups=32,
120 | groups_out=None,
121 | pre_norm=True,
122 | eps=1e-6,
123 | non_linearity="swish",
124 | time_embedding_norm="default",
125 | output_scale_factor=1.0,
126 | use_in_shortcut=None,
127 | ):
128 | super().__init__()
129 | self.pre_norm = pre_norm
130 | self.pre_norm = True
131 | self.in_channels = in_channels
132 | out_channels = in_channels if out_channels is None else out_channels
133 | self.out_channels = out_channels
134 | self.use_conv_shortcut = conv_shortcut
135 | self.time_embedding_norm = time_embedding_norm
136 | self.output_scale_factor = output_scale_factor
137 |
138 | if groups_out is None:
139 | groups_out = groups
140 |
141 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142 |
143 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144 |
145 | if temb_channels is not None:
146 | if self.time_embedding_norm == "default":
147 | time_emb_proj_out_channels = out_channels
148 | elif self.time_embedding_norm == "scale_shift":
149 | time_emb_proj_out_channels = out_channels * 2
150 | else:
151 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152 |
153 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154 | else:
155 | self.time_emb_proj = None
156 |
157 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158 | self.dropout = torch.nn.Dropout(dropout)
159 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160 |
161 | if non_linearity == "swish":
162 | self.nonlinearity = lambda x: F.silu(x)
163 | elif non_linearity == "mish":
164 | self.nonlinearity = Mish()
165 | elif non_linearity == "silu":
166 | self.nonlinearity = nn.SiLU()
167 |
168 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169 |
170 | self.conv_shortcut = None
171 | if self.use_in_shortcut:
172 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173 |
174 | def forward(self, input_tensor, temb):
175 | hidden_states = input_tensor
176 |
177 | hidden_states = self.norm1(hidden_states)
178 | hidden_states = self.nonlinearity(hidden_states)
179 |
180 | hidden_states = self.conv1(hidden_states)
181 |
182 | if temb is not None:
183 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
184 |
185 | if temb is not None and self.time_embedding_norm == "default":
186 | hidden_states = hidden_states + temb
187 |
188 | hidden_states = self.norm2(hidden_states)
189 |
190 | if temb is not None and self.time_embedding_norm == "scale_shift":
191 | scale, shift = torch.chunk(temb, 2, dim=1)
192 | hidden_states = hidden_states * (1 + scale) + shift
193 |
194 | hidden_states = self.nonlinearity(hidden_states)
195 |
196 | hidden_states = self.dropout(hidden_states)
197 | hidden_states = self.conv2(hidden_states)
198 |
199 | if self.conv_shortcut is not None:
200 | input_tensor = self.conv_shortcut(input_tensor)
201 |
202 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
203 |
204 | return output_tensor
205 |
206 |
207 | class Mish(torch.nn.Module):
208 | def forward(self, hidden_states):
209 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
--------------------------------------------------------------------------------
/tuneavideo/models/unet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | from diffusers.configuration_utils import ConfigMixin, register_to_config
14 | from diffusers.modeling_utils import ModelMixin
15 | from diffusers.utils import BaseOutput, logging
16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17 | from .unet_blocks import (
18 | CrossAttnDownBlock3D,
19 | CrossAttnUpBlock3D,
20 | DownBlock3D,
21 | UNetMidBlock3DCrossAttn,
22 | UpBlock3D,
23 | get_down_block,
24 | get_up_block,
25 | )
26 | from .resnet import InflatedConv3d
27 |
28 |
29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30 |
31 |
32 | @dataclass
33 | class UNet3DConditionOutput(BaseOutput):
34 | sample: torch.FloatTensor
35 |
36 |
37 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
38 | _supports_gradient_checkpointing = True
39 |
40 | @register_to_config
41 | def __init__(
42 | self,
43 | sample_size: Optional[int] = None,
44 | in_channels: int = 4,
45 | out_channels: int = 4,
46 | center_input_sample: bool = False,
47 | flip_sin_to_cos: bool = True,
48 | freq_shift: int = 0,
49 | down_block_types: Tuple[str] = (
50 | "CrossAttnDownBlock3D",
51 | "CrossAttnDownBlock3D",
52 | "CrossAttnDownBlock3D",
53 | "DownBlock3D",
54 | ),
55 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
56 | up_block_types: Tuple[str] = (
57 | "UpBlock3D",
58 | "CrossAttnUpBlock3D",
59 | "CrossAttnUpBlock3D",
60 | "CrossAttnUpBlock3D"
61 | ),
62 | only_cross_attention: Union[bool, Tuple[bool]] = False,
63 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64 | layers_per_block: int = 2,
65 | downsample_padding: int = 1,
66 | mid_block_scale_factor: float = 1,
67 | act_fn: str = "silu",
68 | norm_num_groups: int = 32,
69 | norm_eps: float = 1e-5,
70 | cross_attention_dim: int = 1280,
71 | attention_head_dim: Union[int, Tuple[int]] = 8,
72 | dual_cross_attention: bool = False,
73 | use_linear_projection: bool = False,
74 | class_embed_type: Optional[str] = None,
75 | num_class_embeds: Optional[int] = None,
76 | upcast_attention: bool = False,
77 | resnet_time_scale_shift: str = "default",
78 | ):
79 | super().__init__()
80 |
81 | self.sample_size = sample_size
82 | time_embed_dim = block_out_channels[0] * 4
83 |
84 | # input
85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86 |
87 | # time
88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89 | timestep_input_dim = block_out_channels[0]
90 |
91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92 |
93 | # class embedding
94 | if class_embed_type is None and num_class_embeds is not None:
95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96 | elif class_embed_type == "timestep":
97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98 | elif class_embed_type == "identity":
99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100 | else:
101 | self.class_embedding = None
102 |
103 | self.down_blocks = nn.ModuleList([])
104 | self.mid_block = None
105 | self.up_blocks = nn.ModuleList([])
106 |
107 | if isinstance(only_cross_attention, bool):
108 | only_cross_attention = [only_cross_attention] * len(down_block_types)
109 |
110 | if isinstance(attention_head_dim, int):
111 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
112 |
113 | # down
114 | output_channel = block_out_channels[0]
115 | for i, down_block_type in enumerate(down_block_types):
116 | input_channel = output_channel
117 | output_channel = block_out_channels[i]
118 | is_final_block = i == len(block_out_channels) - 1
119 |
120 | down_block = get_down_block(
121 | down_block_type,
122 | num_layers=layers_per_block,
123 | in_channels=input_channel,
124 | out_channels=output_channel,
125 | temb_channels=time_embed_dim,
126 | add_downsample=not is_final_block,
127 | resnet_eps=norm_eps,
128 | resnet_act_fn=act_fn,
129 | resnet_groups=norm_num_groups,
130 | cross_attention_dim=cross_attention_dim,
131 | attn_num_head_channels=attention_head_dim[i],
132 | downsample_padding=downsample_padding,
133 | dual_cross_attention=dual_cross_attention,
134 | use_linear_projection=use_linear_projection,
135 | only_cross_attention=only_cross_attention[i],
136 | upcast_attention=upcast_attention,
137 | resnet_time_scale_shift=resnet_time_scale_shift,
138 | )
139 | self.down_blocks.append(down_block)
140 |
141 | # mid
142 | if mid_block_type == "UNetMidBlock3DCrossAttn":
143 | self.mid_block = UNetMidBlock3DCrossAttn(
144 | in_channels=block_out_channels[-1],
145 | temb_channels=time_embed_dim,
146 | resnet_eps=norm_eps,
147 | resnet_act_fn=act_fn,
148 | output_scale_factor=mid_block_scale_factor,
149 | resnet_time_scale_shift=resnet_time_scale_shift,
150 | cross_attention_dim=cross_attention_dim,
151 | attn_num_head_channels=attention_head_dim[-1],
152 | resnet_groups=norm_num_groups,
153 | dual_cross_attention=dual_cross_attention,
154 | use_linear_projection=use_linear_projection,
155 | upcast_attention=upcast_attention,
156 | )
157 | else:
158 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
159 |
160 | # count how many layers upsample the videos
161 | self.num_upsamplers = 0
162 |
163 | # up
164 | reversed_block_out_channels = list(reversed(block_out_channels))
165 | reversed_attention_head_dim = list(reversed(attention_head_dim))
166 | only_cross_attention = list(reversed(only_cross_attention))
167 | output_channel = reversed_block_out_channels[0]
168 | for i, up_block_type in enumerate(up_block_types):
169 | is_final_block = i == len(block_out_channels) - 1
170 |
171 | prev_output_channel = output_channel
172 | output_channel = reversed_block_out_channels[i]
173 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
174 |
175 | # add upsample block for all BUT final layer
176 | if not is_final_block:
177 | add_upsample = True
178 | self.num_upsamplers += 1
179 | else:
180 | add_upsample = False
181 |
182 | up_block = get_up_block(
183 | up_block_type,
184 | num_layers=layers_per_block + 1,
185 | in_channels=input_channel,
186 | out_channels=output_channel,
187 | prev_output_channel=prev_output_channel,
188 | temb_channels=time_embed_dim,
189 | add_upsample=add_upsample,
190 | resnet_eps=norm_eps,
191 | resnet_act_fn=act_fn,
192 | resnet_groups=norm_num_groups,
193 | cross_attention_dim=cross_attention_dim,
194 | attn_num_head_channels=reversed_attention_head_dim[i],
195 | dual_cross_attention=dual_cross_attention,
196 | use_linear_projection=use_linear_projection,
197 | only_cross_attention=only_cross_attention[i],
198 | upcast_attention=upcast_attention,
199 | resnet_time_scale_shift=resnet_time_scale_shift,
200 | )
201 | self.up_blocks.append(up_block)
202 | prev_output_channel = output_channel
203 |
204 | # out
205 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
206 | self.conv_act = nn.SiLU()
207 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
208 |
209 | def set_attention_slice(self, slice_size):
210 | r"""
211 | Enable sliced attention computation.
212 |
213 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
214 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
215 |
216 | Args:
217 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
218 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
219 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
220 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
221 | must be a multiple of `slice_size`.
222 | """
223 | sliceable_head_dims = []
224 |
225 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
226 | if hasattr(module, "set_attention_slice"):
227 | sliceable_head_dims.append(module.sliceable_head_dim)
228 |
229 | for child in module.children():
230 | fn_recursive_retrieve_slicable_dims(child)
231 |
232 | # retrieve number of attention layers
233 | for module in self.children():
234 | fn_recursive_retrieve_slicable_dims(module)
235 |
236 | num_slicable_layers = len(sliceable_head_dims)
237 |
238 | if slice_size == "auto":
239 | # half the attention head size is usually a good trade-off between
240 | # speed and memory
241 | slice_size = [dim // 2 for dim in sliceable_head_dims]
242 | elif slice_size == "max":
243 | # make smallest slice possible
244 | slice_size = num_slicable_layers * [1]
245 |
246 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
247 |
248 | if len(slice_size) != len(sliceable_head_dims):
249 | raise ValueError(
250 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
251 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
252 | )
253 |
254 | for i in range(len(slice_size)):
255 | size = slice_size[i]
256 | dim = sliceable_head_dims[i]
257 | if size is not None and size > dim:
258 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
259 |
260 | # Recursively walk through all the children.
261 | # Any children which exposes the set_attention_slice method
262 | # gets the message
263 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
264 | if hasattr(module, "set_attention_slice"):
265 | module.set_attention_slice(slice_size.pop())
266 |
267 | for child in module.children():
268 | fn_recursive_set_attention_slice(child, slice_size)
269 |
270 | reversed_slice_size = list(reversed(slice_size))
271 | for module in self.children():
272 | fn_recursive_set_attention_slice(module, reversed_slice_size)
273 |
274 | def _set_gradient_checkpointing(self, module, value=False):
275 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
276 | module.gradient_checkpointing = value
277 |
278 | def forward(
279 | self,
280 | sample: torch.FloatTensor,
281 | timestep: Union[torch.Tensor, float, int],
282 | encoder_hidden_states: torch.Tensor,
283 | class_labels: Optional[torch.Tensor] = None,
284 | attention_mask: Optional[torch.Tensor] = None,
285 | return_dict: bool = True,
286 | ) -> Union[UNet3DConditionOutput, Tuple]:
287 | r"""
288 | Args:
289 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
290 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
291 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
292 | return_dict (`bool`, *optional*, defaults to `True`):
293 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
294 |
295 | Returns:
296 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
297 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
298 | returning a tuple, the first element is the sample tensor.
299 | """
300 | # By default samples have to be AT least a multiple of the overall upsampling factor.
301 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
302 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
303 | # on the fly if necessary.
304 | default_overall_up_factor = 2**self.num_upsamplers
305 |
306 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
307 | forward_upsample_size = False
308 | upsample_size = None
309 |
310 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
311 | logger.info("Forward upsample size to force interpolation output size.")
312 | forward_upsample_size = True
313 |
314 | # prepare attention_mask
315 | if attention_mask is not None:
316 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
317 | attention_mask = attention_mask.unsqueeze(1)
318 |
319 | # center input if necessary
320 | if self.config.center_input_sample:
321 | sample = 2 * sample - 1.0
322 |
323 | # time
324 | timesteps = timestep
325 | if not torch.is_tensor(timesteps):
326 | # This would be a good case for the `match` statement (Python 3.10+)
327 | is_mps = sample.device.type == "mps"
328 | if isinstance(timestep, float):
329 | dtype = torch.float32 if is_mps else torch.float64
330 | else:
331 | dtype = torch.int32 if is_mps else torch.int64
332 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
333 | elif len(timesteps.shape) == 0:
334 | timesteps = timesteps[None].to(sample.device)
335 |
336 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
337 | timesteps = timesteps.expand(sample.shape[0])
338 |
339 | t_emb = self.time_proj(timesteps)
340 |
341 | # timesteps does not contain any weights and will always return f32 tensors
342 | # but time_embedding might actually be running in fp16. so we need to cast here.
343 | # there might be better ways to encapsulate this.
344 | t_emb = t_emb.to(dtype=self.dtype)
345 | emb = self.time_embedding(t_emb)
346 |
347 | if self.class_embedding is not None:
348 | if class_labels is None:
349 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
350 |
351 | if self.config.class_embed_type == "timestep":
352 | class_labels = self.time_proj(class_labels)
353 |
354 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
355 | emb = emb + class_emb
356 |
357 | # pre-process
358 | sample = self.conv_in(sample)
359 |
360 | # down
361 | down_block_res_samples = (sample,)
362 | for downsample_block in self.down_blocks:
363 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
364 | sample, res_samples = downsample_block(
365 | hidden_states=sample,
366 | temb=emb,
367 | encoder_hidden_states=encoder_hidden_states,
368 | attention_mask=attention_mask,
369 | )
370 | else:
371 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
372 |
373 | down_block_res_samples += res_samples
374 |
375 | # mid
376 | sample = self.mid_block(
377 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
378 | )
379 |
380 | # up
381 | for i, upsample_block in enumerate(self.up_blocks):
382 | is_final_block = i == len(self.up_blocks) - 1
383 |
384 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
385 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
386 |
387 | # if we have not reached the final block and need to forward the
388 | # upsample size, we do it here
389 | if not is_final_block and forward_upsample_size:
390 | upsample_size = down_block_res_samples[-1].shape[2:]
391 |
392 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
393 | sample = upsample_block(
394 | hidden_states=sample,
395 | temb=emb,
396 | res_hidden_states_tuple=res_samples,
397 | encoder_hidden_states=encoder_hidden_states,
398 | upsample_size=upsample_size,
399 | attention_mask=attention_mask,
400 | )
401 | else:
402 | sample = upsample_block(
403 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
404 | )
405 | # post-process
406 | sample = self.conv_norm_out(sample)
407 | sample = self.conv_act(sample)
408 | sample = self.conv_out(sample)
409 |
410 | if not return_dict:
411 | return (sample,)
412 |
413 | return UNet3DConditionOutput(sample=sample)
414 |
415 | @classmethod
416 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
417 | if subfolder is not None:
418 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
419 |
420 | config_file = os.path.join(pretrained_model_path, 'config.json')
421 | if not os.path.isfile(config_file):
422 | raise RuntimeError(f"{config_file} does not exist")
423 | with open(config_file, "r") as f:
424 | config = json.load(f)
425 | config["_class_name"] = cls.__name__
426 | config["down_block_types"] = [
427 | "CrossAttnDownBlock3D",
428 | "CrossAttnDownBlock3D",
429 | "CrossAttnDownBlock3D",
430 | "DownBlock3D"
431 | ]
432 | config["up_block_types"] = [
433 | "UpBlock3D",
434 | "CrossAttnUpBlock3D",
435 | "CrossAttnUpBlock3D",
436 | "CrossAttnUpBlock3D"
437 | ]
438 |
439 | from diffusers.utils import WEIGHTS_NAME
440 | model = cls.from_config(config)
441 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
442 | if not os.path.isfile(model_file):
443 | raise RuntimeError(f"{model_file} does not exist")
444 | state_dict = torch.load(model_file, map_location="cpu")
445 | for k, v in model.state_dict().items():
446 | if '_temp.' in k:
447 | state_dict.update({k: v})
448 | model.load_state_dict(state_dict)
449 |
450 | return model
--------------------------------------------------------------------------------
/tuneavideo/models/unet_blocks.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .attention import Transformer3DModel
7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8 |
9 |
10 | def get_down_block(
11 | down_block_type,
12 | num_layers,
13 | in_channels,
14 | out_channels,
15 | temb_channels,
16 | add_downsample,
17 | resnet_eps,
18 | resnet_act_fn,
19 | attn_num_head_channels,
20 | resnet_groups=None,
21 | cross_attention_dim=None,
22 | downsample_padding=None,
23 | dual_cross_attention=False,
24 | use_linear_projection=False,
25 | only_cross_attention=False,
26 | upcast_attention=False,
27 | resnet_time_scale_shift="default",
28 | ):
29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30 | if down_block_type == "DownBlock3D":
31 | return DownBlock3D(
32 | num_layers=num_layers,
33 | in_channels=in_channels,
34 | out_channels=out_channels,
35 | temb_channels=temb_channels,
36 | add_downsample=add_downsample,
37 | resnet_eps=resnet_eps,
38 | resnet_act_fn=resnet_act_fn,
39 | resnet_groups=resnet_groups,
40 | downsample_padding=downsample_padding,
41 | resnet_time_scale_shift=resnet_time_scale_shift,
42 | )
43 | elif down_block_type == "CrossAttnDownBlock3D":
44 | if cross_attention_dim is None:
45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46 | return CrossAttnDownBlock3D(
47 | num_layers=num_layers,
48 | in_channels=in_channels,
49 | out_channels=out_channels,
50 | temb_channels=temb_channels,
51 | add_downsample=add_downsample,
52 | resnet_eps=resnet_eps,
53 | resnet_act_fn=resnet_act_fn,
54 | resnet_groups=resnet_groups,
55 | downsample_padding=downsample_padding,
56 | cross_attention_dim=cross_attention_dim,
57 | attn_num_head_channels=attn_num_head_channels,
58 | dual_cross_attention=dual_cross_attention,
59 | use_linear_projection=use_linear_projection,
60 | only_cross_attention=only_cross_attention,
61 | upcast_attention=upcast_attention,
62 | resnet_time_scale_shift=resnet_time_scale_shift,
63 | )
64 | raise ValueError(f"{down_block_type} does not exist.")
65 |
66 |
67 | def get_up_block(
68 | up_block_type,
69 | num_layers,
70 | in_channels,
71 | out_channels,
72 | prev_output_channel,
73 | temb_channels,
74 | add_upsample,
75 | resnet_eps,
76 | resnet_act_fn,
77 | attn_num_head_channels,
78 | resnet_groups=None,
79 | cross_attention_dim=None,
80 | dual_cross_attention=False,
81 | use_linear_projection=False,
82 | only_cross_attention=False,
83 | upcast_attention=False,
84 | resnet_time_scale_shift="default",
85 | ):
86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87 | if up_block_type == "UpBlock3D":
88 | return UpBlock3D(
89 | num_layers=num_layers,
90 | in_channels=in_channels,
91 | out_channels=out_channels,
92 | prev_output_channel=prev_output_channel,
93 | temb_channels=temb_channels,
94 | add_upsample=add_upsample,
95 | resnet_eps=resnet_eps,
96 | resnet_act_fn=resnet_act_fn,
97 | resnet_groups=resnet_groups,
98 | resnet_time_scale_shift=resnet_time_scale_shift,
99 | )
100 | elif up_block_type == "CrossAttnUpBlock3D":
101 | if cross_attention_dim is None:
102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103 | return CrossAttnUpBlock3D(
104 | num_layers=num_layers,
105 | in_channels=in_channels,
106 | out_channels=out_channels,
107 | prev_output_channel=prev_output_channel,
108 | temb_channels=temb_channels,
109 | add_upsample=add_upsample,
110 | resnet_eps=resnet_eps,
111 | resnet_act_fn=resnet_act_fn,
112 | resnet_groups=resnet_groups,
113 | cross_attention_dim=cross_attention_dim,
114 | attn_num_head_channels=attn_num_head_channels,
115 | dual_cross_attention=dual_cross_attention,
116 | use_linear_projection=use_linear_projection,
117 | only_cross_attention=only_cross_attention,
118 | upcast_attention=upcast_attention,
119 | resnet_time_scale_shift=resnet_time_scale_shift,
120 | )
121 | raise ValueError(f"{up_block_type} does not exist.")
122 |
123 |
124 | class UNetMidBlock3DCrossAttn(nn.Module):
125 | def __init__(
126 | self,
127 | in_channels: int,
128 | temb_channels: int,
129 | dropout: float = 0.0,
130 | num_layers: int = 1,
131 | resnet_eps: float = 1e-6,
132 | resnet_time_scale_shift: str = "default",
133 | resnet_act_fn: str = "swish",
134 | resnet_groups: int = 32,
135 | resnet_pre_norm: bool = True,
136 | attn_num_head_channels=1,
137 | output_scale_factor=1.0,
138 | cross_attention_dim=1280,
139 | dual_cross_attention=False,
140 | use_linear_projection=False,
141 | upcast_attention=False,
142 | ):
143 | super().__init__()
144 |
145 | self.has_cross_attention = True
146 | self.attn_num_head_channels = attn_num_head_channels
147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148 |
149 | # there is always at least one resnet
150 | resnets = [
151 | ResnetBlock3D(
152 | in_channels=in_channels,
153 | out_channels=in_channels,
154 | temb_channels=temb_channels,
155 | eps=resnet_eps,
156 | groups=resnet_groups,
157 | dropout=dropout,
158 | time_embedding_norm=resnet_time_scale_shift,
159 | non_linearity=resnet_act_fn,
160 | output_scale_factor=output_scale_factor,
161 | pre_norm=resnet_pre_norm,
162 | )
163 | ]
164 | attentions = []
165 |
166 | for _ in range(num_layers):
167 | if dual_cross_attention:
168 | raise NotImplementedError
169 | attentions.append(
170 | Transformer3DModel(
171 | attn_num_head_channels,
172 | in_channels // attn_num_head_channels,
173 | in_channels=in_channels,
174 | num_layers=1,
175 | cross_attention_dim=cross_attention_dim,
176 | norm_num_groups=resnet_groups,
177 | use_linear_projection=use_linear_projection,
178 | upcast_attention=upcast_attention,
179 | )
180 | )
181 | resnets.append(
182 | ResnetBlock3D(
183 | in_channels=in_channels,
184 | out_channels=in_channels,
185 | temb_channels=temb_channels,
186 | eps=resnet_eps,
187 | groups=resnet_groups,
188 | dropout=dropout,
189 | time_embedding_norm=resnet_time_scale_shift,
190 | non_linearity=resnet_act_fn,
191 | output_scale_factor=output_scale_factor,
192 | pre_norm=resnet_pre_norm,
193 | )
194 | )
195 |
196 | self.attentions = nn.ModuleList(attentions)
197 | self.resnets = nn.ModuleList(resnets)
198 |
199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
200 | hidden_states = self.resnets[0](hidden_states, temb)
201 | for attn, resnet in zip(self.attentions, self.resnets[1:]):
202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
203 | hidden_states = resnet(hidden_states, temb)
204 |
205 | return hidden_states
206 |
207 |
208 | class CrossAttnDownBlock3D(nn.Module):
209 | def __init__(
210 | self,
211 | in_channels: int,
212 | out_channels: int,
213 | temb_channels: int,
214 | dropout: float = 0.0,
215 | num_layers: int = 1,
216 | resnet_eps: float = 1e-6,
217 | resnet_time_scale_shift: str = "default",
218 | resnet_act_fn: str = "swish",
219 | resnet_groups: int = 32,
220 | resnet_pre_norm: bool = True,
221 | attn_num_head_channels=1,
222 | cross_attention_dim=1280,
223 | output_scale_factor=1.0,
224 | downsample_padding=1,
225 | add_downsample=True,
226 | dual_cross_attention=False,
227 | use_linear_projection=False,
228 | only_cross_attention=False,
229 | upcast_attention=False,
230 | ):
231 | super().__init__()
232 | resnets = []
233 | attentions = []
234 |
235 | self.has_cross_attention = True
236 | self.attn_num_head_channels = attn_num_head_channels
237 |
238 | for i in range(num_layers):
239 | in_channels = in_channels if i == 0 else out_channels
240 | resnets.append(
241 | ResnetBlock3D(
242 | in_channels=in_channels,
243 | out_channels=out_channels,
244 | temb_channels=temb_channels,
245 | eps=resnet_eps,
246 | groups=resnet_groups,
247 | dropout=dropout,
248 | time_embedding_norm=resnet_time_scale_shift,
249 | non_linearity=resnet_act_fn,
250 | output_scale_factor=output_scale_factor,
251 | pre_norm=resnet_pre_norm,
252 | )
253 | )
254 | if dual_cross_attention:
255 | raise NotImplementedError
256 | attentions.append(
257 | Transformer3DModel(
258 | attn_num_head_channels,
259 | out_channels // attn_num_head_channels,
260 | in_channels=out_channels,
261 | num_layers=1,
262 | cross_attention_dim=cross_attention_dim,
263 | norm_num_groups=resnet_groups,
264 | use_linear_projection=use_linear_projection,
265 | only_cross_attention=only_cross_attention,
266 | upcast_attention=upcast_attention,
267 | )
268 | )
269 | self.attentions = nn.ModuleList(attentions)
270 | self.resnets = nn.ModuleList(resnets)
271 |
272 | if add_downsample:
273 | self.downsamplers = nn.ModuleList(
274 | [
275 | Downsample3D(
276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277 | )
278 | ]
279 | )
280 | else:
281 | self.downsamplers = None
282 |
283 | self.gradient_checkpointing = False
284 |
285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
286 | output_states = ()
287 |
288 | for resnet, attn in zip(self.resnets, self.attentions):
289 | if self.training and self.gradient_checkpointing:
290 |
291 | def create_custom_forward(module, return_dict=None):
292 | def custom_forward(*inputs):
293 | if return_dict is not None:
294 | return module(*inputs, return_dict=return_dict)
295 | else:
296 | return module(*inputs)
297 |
298 | return custom_forward
299 |
300 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
301 | hidden_states = torch.utils.checkpoint.checkpoint(
302 | create_custom_forward(attn, return_dict=False),
303 | hidden_states,
304 | encoder_hidden_states,
305 | )[0]
306 | else:
307 | hidden_states = resnet(hidden_states, temb)
308 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
309 |
310 | output_states += (hidden_states,)
311 |
312 | if self.downsamplers is not None:
313 | for downsampler in self.downsamplers:
314 | hidden_states = downsampler(hidden_states)
315 |
316 | output_states += (hidden_states,)
317 |
318 | return hidden_states, output_states
319 |
320 |
321 | class DownBlock3D(nn.Module):
322 | def __init__(
323 | self,
324 | in_channels: int,
325 | out_channels: int,
326 | temb_channels: int,
327 | dropout: float = 0.0,
328 | num_layers: int = 1,
329 | resnet_eps: float = 1e-6,
330 | resnet_time_scale_shift: str = "default",
331 | resnet_act_fn: str = "swish",
332 | resnet_groups: int = 32,
333 | resnet_pre_norm: bool = True,
334 | output_scale_factor=1.0,
335 | add_downsample=True,
336 | downsample_padding=1,
337 | ):
338 | super().__init__()
339 | resnets = []
340 |
341 | for i in range(num_layers):
342 | in_channels = in_channels if i == 0 else out_channels
343 | resnets.append(
344 | ResnetBlock3D(
345 | in_channels=in_channels,
346 | out_channels=out_channels,
347 | temb_channels=temb_channels,
348 | eps=resnet_eps,
349 | groups=resnet_groups,
350 | dropout=dropout,
351 | time_embedding_norm=resnet_time_scale_shift,
352 | non_linearity=resnet_act_fn,
353 | output_scale_factor=output_scale_factor,
354 | pre_norm=resnet_pre_norm,
355 | )
356 | )
357 |
358 | self.resnets = nn.ModuleList(resnets)
359 |
360 | if add_downsample:
361 | self.downsamplers = nn.ModuleList(
362 | [
363 | Downsample3D(
364 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
365 | )
366 | ]
367 | )
368 | else:
369 | self.downsamplers = None
370 |
371 | self.gradient_checkpointing = False
372 |
373 | def forward(self, hidden_states, temb=None):
374 | output_states = ()
375 |
376 | for resnet in self.resnets:
377 | if self.training and self.gradient_checkpointing:
378 |
379 | def create_custom_forward(module):
380 | def custom_forward(*inputs):
381 | return module(*inputs)
382 |
383 | return custom_forward
384 |
385 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
386 | else:
387 | hidden_states = resnet(hidden_states, temb)
388 |
389 | output_states += (hidden_states,)
390 |
391 | if self.downsamplers is not None:
392 | for downsampler in self.downsamplers:
393 | hidden_states = downsampler(hidden_states)
394 |
395 | output_states += (hidden_states,)
396 |
397 | return hidden_states, output_states
398 |
399 |
400 | class CrossAttnUpBlock3D(nn.Module):
401 | def __init__(
402 | self,
403 | in_channels: int,
404 | out_channels: int,
405 | prev_output_channel: int,
406 | temb_channels: int,
407 | dropout: float = 0.0,
408 | num_layers: int = 1,
409 | resnet_eps: float = 1e-6,
410 | resnet_time_scale_shift: str = "default",
411 | resnet_act_fn: str = "swish",
412 | resnet_groups: int = 32,
413 | resnet_pre_norm: bool = True,
414 | attn_num_head_channels=1,
415 | cross_attention_dim=1280,
416 | output_scale_factor=1.0,
417 | add_upsample=True,
418 | dual_cross_attention=False,
419 | use_linear_projection=False,
420 | only_cross_attention=False,
421 | upcast_attention=False,
422 | ):
423 | super().__init__()
424 | resnets = []
425 | attentions = []
426 |
427 | self.has_cross_attention = True
428 | self.attn_num_head_channels = attn_num_head_channels
429 |
430 | for i in range(num_layers):
431 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
432 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
433 |
434 | resnets.append(
435 | ResnetBlock3D(
436 | in_channels=resnet_in_channels + res_skip_channels,
437 | out_channels=out_channels,
438 | temb_channels=temb_channels,
439 | eps=resnet_eps,
440 | groups=resnet_groups,
441 | dropout=dropout,
442 | time_embedding_norm=resnet_time_scale_shift,
443 | non_linearity=resnet_act_fn,
444 | output_scale_factor=output_scale_factor,
445 | pre_norm=resnet_pre_norm,
446 | )
447 | )
448 | if dual_cross_attention:
449 | raise NotImplementedError
450 | attentions.append(
451 | Transformer3DModel(
452 | attn_num_head_channels,
453 | out_channels // attn_num_head_channels,
454 | in_channels=out_channels,
455 | num_layers=1,
456 | cross_attention_dim=cross_attention_dim,
457 | norm_num_groups=resnet_groups,
458 | use_linear_projection=use_linear_projection,
459 | only_cross_attention=only_cross_attention,
460 | upcast_attention=upcast_attention,
461 | )
462 | )
463 |
464 | self.attentions = nn.ModuleList(attentions)
465 | self.resnets = nn.ModuleList(resnets)
466 |
467 | if add_upsample:
468 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
469 | else:
470 | self.upsamplers = None
471 |
472 | self.gradient_checkpointing = False
473 |
474 | def forward(
475 | self,
476 | hidden_states,
477 | res_hidden_states_tuple,
478 | temb=None,
479 | encoder_hidden_states=None,
480 | upsample_size=None,
481 | attention_mask=None,
482 | ):
483 | for resnet, attn in zip(self.resnets, self.attentions):
484 | # pop res hidden states
485 | res_hidden_states = res_hidden_states_tuple[-1]
486 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
487 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
488 |
489 | if self.training and self.gradient_checkpointing:
490 |
491 | def create_custom_forward(module, return_dict=None):
492 | def custom_forward(*inputs):
493 | if return_dict is not None:
494 | return module(*inputs, return_dict=return_dict)
495 | else:
496 | return module(*inputs)
497 |
498 | return custom_forward
499 |
500 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
501 | hidden_states = torch.utils.checkpoint.checkpoint(
502 | create_custom_forward(attn, return_dict=False),
503 | hidden_states,
504 | encoder_hidden_states,
505 | )[0]
506 | else:
507 | hidden_states = resnet(hidden_states, temb)
508 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
509 |
510 | if self.upsamplers is not None:
511 | for upsampler in self.upsamplers:
512 | hidden_states = upsampler(hidden_states, upsample_size)
513 |
514 | return hidden_states
515 |
516 |
517 | class UpBlock3D(nn.Module):
518 | def __init__(
519 | self,
520 | in_channels: int,
521 | prev_output_channel: int,
522 | out_channels: int,
523 | temb_channels: int,
524 | dropout: float = 0.0,
525 | num_layers: int = 1,
526 | resnet_eps: float = 1e-6,
527 | resnet_time_scale_shift: str = "default",
528 | resnet_act_fn: str = "swish",
529 | resnet_groups: int = 32,
530 | resnet_pre_norm: bool = True,
531 | output_scale_factor=1.0,
532 | add_upsample=True,
533 | ):
534 | super().__init__()
535 | resnets = []
536 |
537 | for i in range(num_layers):
538 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
539 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
540 |
541 | resnets.append(
542 | ResnetBlock3D(
543 | in_channels=resnet_in_channels + res_skip_channels,
544 | out_channels=out_channels,
545 | temb_channels=temb_channels,
546 | eps=resnet_eps,
547 | groups=resnet_groups,
548 | dropout=dropout,
549 | time_embedding_norm=resnet_time_scale_shift,
550 | non_linearity=resnet_act_fn,
551 | output_scale_factor=output_scale_factor,
552 | pre_norm=resnet_pre_norm,
553 | )
554 | )
555 |
556 | self.resnets = nn.ModuleList(resnets)
557 |
558 | if add_upsample:
559 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
560 | else:
561 | self.upsamplers = None
562 |
563 | self.gradient_checkpointing = False
564 |
565 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
566 | for resnet in self.resnets:
567 | # pop res hidden states
568 | res_hidden_states = res_hidden_states_tuple[-1]
569 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
570 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
571 |
572 | if self.training and self.gradient_checkpointing:
573 |
574 | def create_custom_forward(module):
575 | def custom_forward(*inputs):
576 | return module(*inputs)
577 |
578 | return custom_forward
579 |
580 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
581 | else:
582 | hidden_states = resnet(hidden_states, temb)
583 |
584 | if self.upsamplers is not None:
585 | for upsampler in self.upsamplers:
586 | hidden_states = upsampler(hidden_states, upsample_size)
587 |
588 | return hidden_states
589 |
--------------------------------------------------------------------------------
/tuneavideo/pipelines/pipeline_tuneavideo.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2 |
3 | import inspect
4 | from typing import Callable, List, Optional, Union
5 | from dataclasses import dataclass
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from diffusers.utils import is_accelerate_available
11 | from packaging import version
12 | from transformers import CLIPTextModel, CLIPTokenizer
13 |
14 | from diffusers.configuration_utils import FrozenDict
15 | from diffusers.models import AutoencoderKL
16 | from diffusers.pipeline_utils import DiffusionPipeline
17 | from diffusers.schedulers import (
18 | DDIMScheduler,
19 | DPMSolverMultistepScheduler,
20 | EulerAncestralDiscreteScheduler,
21 | EulerDiscreteScheduler,
22 | LMSDiscreteScheduler,
23 | PNDMScheduler,
24 | )
25 | from diffusers.utils import deprecate, logging, BaseOutput
26 |
27 | from einops import rearrange
28 |
29 | from ..models.unet import UNet3DConditionModel
30 |
31 |
32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33 |
34 |
35 | @dataclass
36 | class TuneAVideoPipelineOutput(BaseOutput):
37 | videos: Union[torch.Tensor, np.ndarray]
38 |
39 |
40 | class TuneAVideoPipeline(DiffusionPipeline):
41 | _optional_components = []
42 |
43 | def __init__(
44 | self,
45 | vae: AutoencoderKL,
46 | text_encoder: CLIPTextModel,
47 | tokenizer: CLIPTokenizer,
48 | unet: UNet3DConditionModel,
49 | scheduler: Union[
50 | DDIMScheduler,
51 | PNDMScheduler,
52 | LMSDiscreteScheduler,
53 | EulerDiscreteScheduler,
54 | EulerAncestralDiscreteScheduler,
55 | DPMSolverMultistepScheduler,
56 | ],
57 | ):
58 | super().__init__()
59 |
60 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
61 | deprecation_message = (
62 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
63 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
64 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
65 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
66 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
67 | " file"
68 | )
69 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
70 | new_config = dict(scheduler.config)
71 | new_config["steps_offset"] = 1
72 | scheduler._internal_dict = FrozenDict(new_config)
73 |
74 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
75 | deprecation_message = (
76 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
77 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
78 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
79 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
80 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
81 | )
82 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
83 | new_config = dict(scheduler.config)
84 | new_config["clip_sample"] = False
85 | scheduler._internal_dict = FrozenDict(new_config)
86 |
87 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
88 | version.parse(unet.config._diffusers_version).base_version
89 | ) < version.parse("0.9.0.dev0")
90 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
91 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
92 | deprecation_message = (
93 | "The configuration file of the unet has set the default `sample_size` to smaller than"
94 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
95 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
96 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
97 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
98 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
99 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
100 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
101 | " the `unet/config.json` file"
102 | )
103 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
104 | new_config = dict(unet.config)
105 | new_config["sample_size"] = 64
106 | unet._internal_dict = FrozenDict(new_config)
107 |
108 | self.register_modules(
109 | vae=vae,
110 | text_encoder=text_encoder,
111 | tokenizer=tokenizer,
112 | unet=unet,
113 | scheduler=scheduler,
114 | )
115 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
116 |
117 | def enable_vae_slicing(self):
118 | self.vae.enable_slicing()
119 |
120 | def disable_vae_slicing(self):
121 | self.vae.disable_slicing()
122 |
123 | def enable_sequential_cpu_offload(self, gpu_id=0):
124 | if is_accelerate_available():
125 | from accelerate import cpu_offload
126 | else:
127 | raise ImportError("Please install accelerate via `pip install accelerate`")
128 |
129 | device = torch.device(f"cuda:{gpu_id}")
130 |
131 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
132 | if cpu_offloaded_model is not None:
133 | cpu_offload(cpu_offloaded_model, device)
134 |
135 |
136 | @property
137 | def _execution_device(self):
138 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
139 | return self.device
140 | for module in self.unet.modules():
141 | if (
142 | hasattr(module, "_hf_hook")
143 | and hasattr(module._hf_hook, "execution_device")
144 | and module._hf_hook.execution_device is not None
145 | ):
146 | return torch.device(module._hf_hook.execution_device)
147 | return self.device
148 |
149 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
150 | batch_size = len(prompt) if isinstance(prompt, list) else 1
151 |
152 | text_inputs = self.tokenizer(
153 | prompt,
154 | padding="max_length",
155 | max_length=self.tokenizer.model_max_length,
156 | truncation=True,
157 | return_tensors="pt",
158 | )
159 | text_input_ids = text_inputs.input_ids
160 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161 |
162 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
164 | logger.warning(
165 | "The following part of your input was truncated because CLIP can only handle sequences up to"
166 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
167 | )
168 |
169 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
170 | attention_mask = text_inputs.attention_mask.to(device)
171 | else:
172 | attention_mask = None
173 |
174 | text_embeddings = self.text_encoder(
175 | text_input_ids.to(device),
176 | attention_mask=attention_mask,
177 | )
178 | text_embeddings = text_embeddings[0]
179 |
180 | # duplicate text embeddings for each generation per prompt, using mps friendly method
181 | bs_embed, seq_len, _ = text_embeddings.shape
182 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
183 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
184 |
185 | # get unconditional embeddings for classifier free guidance
186 | if do_classifier_free_guidance:
187 | uncond_tokens: List[str]
188 | if negative_prompt is None:
189 | uncond_tokens = [""] * batch_size
190 | elif type(prompt) is not type(negative_prompt):
191 | raise TypeError(
192 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
193 | f" {type(prompt)}."
194 | )
195 | elif isinstance(negative_prompt, str):
196 | uncond_tokens = [negative_prompt]
197 | elif batch_size != len(negative_prompt):
198 | raise ValueError(
199 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
200 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
201 | " the batch size of `prompt`."
202 | )
203 | else:
204 | uncond_tokens = negative_prompt
205 |
206 | max_length = text_input_ids.shape[-1]
207 | uncond_input = self.tokenizer(
208 | uncond_tokens,
209 | padding="max_length",
210 | max_length=max_length,
211 | truncation=True,
212 | return_tensors="pt",
213 | )
214 |
215 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
216 | attention_mask = uncond_input.attention_mask.to(device)
217 | else:
218 | attention_mask = None
219 |
220 | uncond_embeddings = self.text_encoder(
221 | uncond_input.input_ids.to(device),
222 | attention_mask=attention_mask,
223 | )
224 | uncond_embeddings = uncond_embeddings[0]
225 |
226 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
227 | seq_len = uncond_embeddings.shape[1]
228 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
229 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
230 |
231 | # For classifier free guidance, we need to do two forward passes.
232 | # Here we concatenate the unconditional and text embeddings into a single batch
233 | # to avoid doing two forward passes
234 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
235 |
236 | return text_embeddings
237 |
238 | def decode_latents(self, latents):
239 | video_length = latents.shape[2]
240 | latents = 1 / 0.18215 * latents
241 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
242 | video = self.vae.decode(latents).sample
243 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
244 | video = (video / 2 + 0.5).clamp(0, 1)
245 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
246 | video = video.cpu().float().numpy()
247 | return video
248 |
249 | def prepare_extra_step_kwargs(self, generator, eta):
250 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
251 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
252 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
253 | # and should be between [0, 1]
254 |
255 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
256 | extra_step_kwargs = {}
257 | if accepts_eta:
258 | extra_step_kwargs["eta"] = eta
259 |
260 | # check if the scheduler accepts generator
261 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
262 | if accepts_generator:
263 | extra_step_kwargs["generator"] = generator
264 | return extra_step_kwargs
265 |
266 | def check_inputs(self, prompt, height, width, callback_steps):
267 | if not isinstance(prompt, str) and not isinstance(prompt, list):
268 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
269 |
270 | if height % 8 != 0 or width % 8 != 0:
271 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
272 |
273 | if (callback_steps is None) or (
274 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
275 | ):
276 | raise ValueError(
277 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
278 | f" {type(callback_steps)}."
279 | )
280 |
281 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
282 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
283 | if isinstance(generator, list) and len(generator) != batch_size:
284 | raise ValueError(
285 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
286 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
287 | )
288 |
289 | if latents is None:
290 | rand_device = "cpu" if device.type == "mps" else device
291 |
292 | if isinstance(generator, list):
293 | shape = (1,) + shape[1:]
294 | latents = [
295 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
296 | for i in range(batch_size)
297 | ]
298 | latents = torch.cat(latents, dim=0).to(device)
299 | else:
300 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
301 | else:
302 | if latents.shape != shape:
303 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
304 | latents = latents.to(device)
305 |
306 | # scale the initial noise by the standard deviation required by the scheduler
307 | latents = latents * self.scheduler.init_noise_sigma
308 | return latents
309 |
310 | @torch.no_grad()
311 | def __call__(
312 | self,
313 | prompt: Union[str, List[str]],
314 | video_length: Optional[int],
315 | height: Optional[int] = None,
316 | width: Optional[int] = None,
317 | num_inference_steps: int = 50,
318 | guidance_scale: float = 7.5,
319 | negative_prompt: Optional[Union[str, List[str]]] = None,
320 | num_videos_per_prompt: Optional[int] = 1,
321 | eta: float = 0.0,
322 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
323 | latents: Optional[torch.FloatTensor] = None,
324 | output_type: Optional[str] = "tensor",
325 | return_dict: bool = True,
326 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
327 | callback_steps: Optional[int] = 1,
328 | **kwargs,
329 | ):
330 | # Default height and width to unet
331 | height = height or self.unet.config.sample_size * self.vae_scale_factor
332 | width = width or self.unet.config.sample_size * self.vae_scale_factor
333 |
334 | # Check inputs. Raise error if not correct
335 | self.check_inputs(prompt, height, width, callback_steps)
336 |
337 | # Define call parameters
338 | batch_size = 1 if isinstance(prompt, str) else len(prompt)
339 | device = self._execution_device
340 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
341 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
342 | # corresponds to doing no classifier free guidance.
343 | do_classifier_free_guidance = guidance_scale > 1.0
344 |
345 | # Encode input prompt
346 | text_embeddings = self._encode_prompt(
347 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
348 | )
349 |
350 | # Prepare timesteps
351 | self.scheduler.set_timesteps(num_inference_steps, device=device)
352 | timesteps = self.scheduler.timesteps
353 |
354 | # Prepare latent variables
355 | num_channels_latents = self.unet.in_channels
356 | latents = self.prepare_latents(
357 | batch_size * num_videos_per_prompt,
358 | num_channels_latents,
359 | video_length,
360 | height,
361 | width,
362 | text_embeddings.dtype,
363 | device,
364 | generator,
365 | latents,
366 | )
367 | latents_dtype = latents.dtype
368 |
369 | # Prepare extra step kwargs.
370 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
371 |
372 | # Denoising loop
373 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
374 | with self.progress_bar(total=num_inference_steps) as progress_bar:
375 | for i, t in enumerate(timesteps):
376 | # expand the latents if we are doing classifier free guidance
377 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
378 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
379 |
380 | # predict the noise residual
381 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
382 |
383 | # perform guidance
384 | if do_classifier_free_guidance:
385 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
386 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
387 |
388 | # compute the previous noisy sample x_t -> x_t-1
389 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
390 |
391 | # call the callback, if provided
392 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
393 | progress_bar.update()
394 | if callback is not None and i % callback_steps == 0:
395 | callback(i, t, latents)
396 |
397 | # Post-processing
398 | video = self.decode_latents(latents)
399 |
400 | # Convert to tensor
401 | if output_type == "tensor":
402 | video = torch.from_numpy(video)
403 |
404 | if not return_dict:
405 | return video
406 |
407 | return TuneAVideoPipelineOutput(videos=video)
--------------------------------------------------------------------------------
/tuneavideo/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 |
6 | import torch
7 | import torchvision
8 |
9 | from tqdm import tqdm
10 | from einops import rearrange
11 |
12 |
13 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
14 | videos = rearrange(videos, "b c t h w -> t b c h w")
15 | outputs = []
16 | for x in videos:
17 | x = torchvision.utils.make_grid(x, nrow=n_rows)
18 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
19 | if rescale:
20 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
21 | x = (x * 255).numpy().astype(np.uint8)
22 | outputs.append(x)
23 |
24 | os.makedirs(os.path.dirname(path), exist_ok=True)
25 | imageio.mimsave(path, outputs, fps=fps)
26 |
27 |
28 | # DDIM Inversion
29 | @torch.no_grad()
30 | def init_prompt(prompt, pipeline):
31 | uncond_input = pipeline.tokenizer(
32 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
33 | return_tensors="pt"
34 | )
35 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
36 | text_input = pipeline.tokenizer(
37 | [prompt],
38 | padding="max_length",
39 | max_length=pipeline.tokenizer.model_max_length,
40 | truncation=True,
41 | return_tensors="pt",
42 | )
43 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
44 | context = torch.cat([uncond_embeddings, text_embeddings])
45 |
46 | return context
47 |
48 |
49 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
50 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
51 | timestep, next_timestep = min(
52 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
53 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
54 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
55 | beta_prod_t = 1 - alpha_prod_t
56 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
57 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
58 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
59 | return next_sample
60 |
61 |
62 | def get_noise_pred_single(latents, t, context, unet):
63 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
64 | return noise_pred
65 |
66 |
67 | @torch.no_grad()
68 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
69 | context = init_prompt(prompt, pipeline)
70 | uncond_embeddings, cond_embeddings = context.chunk(2)
71 | all_latent = [latent]
72 | latent = latent.clone().detach()
73 | for i in tqdm(range(num_inv_steps)):
74 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
75 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
76 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
77 | all_latent.append(latent)
78 | return all_latent
79 |
80 |
81 | @torch.no_grad()
82 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
83 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
84 | return ddim_latents
85 |
--------------------------------------------------------------------------------