├── .gitignore ├── LICENSE ├── README.md ├── assets ├── controlvideo │ ├── building1-a wooden building, at night-80_merge.gif │ ├── building1-a wooden building, at night-80_merge_0.gif │ ├── building1-a wooden building, at night-80_merge_1.gif │ ├── building1-a wooden building, at night-80_merge_2.gif │ ├── building1-a wooden building, at night-80_merge_3.gif │ ├── dance26-Michael Jackson is dancing-1500_merge.gif │ ├── dance26-Michael Jackson is dancing-1500_merge_0.gif │ ├── dance26-Michael Jackson is dancing-1500_merge_1.gif │ ├── dance26-Michael Jackson is dancing-1500_merge_2.gif │ ├── dance26-Michael Jackson is dancing-1500_merge_3.gif │ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif │ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif │ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif │ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif │ ├── dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif │ ├── girl8-a girl, Krenz Cushart style-300_merge.gif │ ├── girl8-a girl, Krenz Cushart style-300_merge_0.gif │ ├── girl8-a girl, Krenz Cushart style-300_merge_1.gif │ ├── girl8-a girl, Krenz Cushart style-300_merge_2.gif │ ├── girl8-a girl, Krenz Cushart style-300_merge_3.gif │ ├── girlface9_6-a girl with rich makeup-300_merge.gif │ ├── girlface9_6-a girl with rich makeup-300_merge_0.gif │ ├── girlface9_6-a girl with rich makeup-300_merge_1.gif │ ├── girlface9_6-a girl with rich makeup-300_merge_2.gif │ ├── girlface9_6-a girl with rich makeup-300_merge_3.gif │ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif │ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif │ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif │ ├── ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif │ └── ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif ├── make-a-protagonist │ ├── 0-A man is playing a basketball on the beach, anime style-org.gif │ ├── 0-A man is playing a basketball on the beach, anime style.gif │ ├── 0-a jeep driving down a mountain road in the rain-org.gif │ ├── 0-a jeep driving down a mountain road in the rain.gif │ ├── 0-a panda walking down the snowy street-org.gif │ ├── 0-a panda walking down the snowy street.gif │ ├── 0-elon musk walking down the street-org.gif │ ├── 0-elon musk walking down the street.gif │ ├── car-turn.gif │ ├── huaqiang.gif │ ├── ikun.gif │ └── yanzi.gif ├── tune-a-video │ ├── car-turn.gif │ ├── car-turn_a jeep car is moving on the beach.gif │ ├── car-turn_a jeep car is moving on the road, cartoon style.gif │ ├── car-turn_a jeep car is moving on the snow.gif │ ├── car-turn_a sports car is moving on the road.gif │ ├── car-turn_smooth_a jeep car is moving on the beach.gif │ ├── car-turn_smooth_a jeep car is moving on the road, cartoon style.gif │ ├── car-turn_smooth_a jeep car is moving on the snow.gif │ ├── car-turn_smooth_a sports car is moving on the road.gif │ ├── man-skiing.gif │ ├── man-skiing_a man, wearing pink clothes, is skiing at sunset.gif │ ├── man-skiing_mickey mouse is skiing on the snow.gif │ ├── man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif │ ├── man-skiing_smooth_mickey mouse is skiing on the snow.gif │ ├── man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif │ ├── man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif │ ├── man-skiing_spider man is skiing on the beach, cartoon style.gif │ ├── man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif │ ├── rabbit-watermelon.gif │ ├── rabbit-watermelon_a puppy is eating an orange.gif │ ├── rabbit-watermelon_a rabbit is eating a pizza.gif │ ├── rabbit-watermelon_a rabbit is eating an orange.gif │ ├── rabbit-watermelon_a tiger is eating a watermelon.gif │ ├── rabbit-watermelon_smooth_a puppy is eating an orange.gif │ ├── rabbit-watermelon_smooth_a rabbit is eating a pizza.gif │ ├── rabbit-watermelon_smooth_a rabbit is eating an orange.gif │ └── rabbit-watermelon_smooth_a tiger is eating a watermelon.gif └── video2video-zero │ ├── mini-cooper.gif │ ├── mini-cooper_make it animation_InstructVideo2Video-zero.gif │ ├── mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif │ ├── mini-cooper_make it animation_Video-InstructPix2Pix.gif │ └── mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif ├── configs ├── car-turn.yaml ├── man-skiing.yaml ├── man-surfing.yaml └── rabbit-watermelon.yaml ├── data ├── car-turn.mp4 ├── man-skiing.mp4 ├── man-surfing.mp4 └── rabbit-watermelon.mp4 ├── requirements.txt ├── train_tuneavideo.py └── tuneavideo ├── data └── dataset.py ├── models ├── attention.py ├── resnet.py ├── unet.py └── unet_blocks.py ├── pipelines └── pipeline_tuneavideo.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | env/ 9 | build/ 10 | dist/ 11 | *.log 12 | 13 | # pyenv 14 | .python-version 15 | 16 | # dotenv 17 | .env 18 | 19 | # virtualenv 20 | .venv/ 21 | venv/ 22 | ENV/ 23 | 24 | # VSCode settings 25 | .vscode 26 | 27 | # IDEA files 28 | .idea 29 | 30 | # OSX dir files 31 | .DS_Store 32 | 33 | # Sublime Text settings 34 | *.sublime-workspace 35 | *.sublime-project 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SmoothVideo 2 | 3 | This repository is the official implementation of **Smooth Video Synthesis with Noise Constraints on Diffusion Models for One-shot Video Tuning**. 4 | 5 | 6 | 7 | 8 | ## Setup 9 | This implementation is based on [Tune-A-Video](https://github.com/showlab/Tune-A-Video). 10 | 11 | ### Requirements 12 | 13 | ```shell 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for more efficiency and speed on GPUs. 18 | To enable xformers, set `enable_xformers_memory_efficient_attention=True` (default). 19 | 20 | ### Weights 21 | 22 | **[Stable Diffusion]** [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The pre-trained Stable Diffusion models can be downloaded from Hugging Face (e.g., [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5))). 23 | 24 | 25 | ## Usage 26 | 27 | ### Training 28 | 29 | To fine-tune the text-to-image diffusion models for text-to-video generation, run this command for the baseline model: 30 | ```bash 31 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml" 32 | ``` 33 | 34 | Run this command for the baseline model with the proposed smooth loss: 35 | ```bash 36 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml" --smooth_loss 37 | ``` 38 | 39 | Run this command for the baseline model with the proposed simple smooth loss: 40 | ```bash 41 | accelerate launch train_tuneavideo.py --config="configs/man-skiing.yaml" --smooth_loss --simple_manner 42 | ``` 43 | 44 | 45 | 46 | Note: Tuning a 24-frame video usually takes `300~500` steps, about `10~15` minutes using one A100 GPU. 47 | Reduce `n_sample_frames` if your GPU memory is limited. 48 | 49 | ### Inference 50 | 51 | Once the training is done, run inference: 52 | 53 | ```python 54 | from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline 55 | from tuneavideo.models.unet import UNet3DConditionModel 56 | from tuneavideo.util import save_videos_grid 57 | import torch 58 | 59 | pretrained_model_path = "./checkpoints/stable-diffusion-v1-5" 60 | my_model_path = "./outputs/man-skiing" 61 | unet = UNet3DConditionModel.from_pretrained(my_model_path, subfolder='unet', torch_dtype=torch.float16).to('cuda') 62 | pipe = TuneAVideoPipeline.from_pretrained(pretrained_model_path, unet=unet, torch_dtype=torch.float16).to("cuda") 63 | pipe.enable_xformers_memory_efficient_attention() 64 | pipe.enable_vae_slicing() 65 | 66 | prompt = "spider man is skiing" 67 | ddim_inv_latent = torch.load(f"{my_model_path}/inv_latents/ddim_latent-500.pt").to(torch.float16) 68 | video = pipe(prompt, latents=ddim_inv_latent, video_length=24, height=512, width=512, num_inference_steps=50, guidance_scale=7.5).videos 69 | 70 | save_videos_grid(video, f"./{prompt}.gif") 71 | ``` 72 | 73 | 74 | 75 | **We provide comparisons with different baselines, as follows:** 76 | 77 | 78 | 79 | ## Results 80 | 81 | 82 | 83 | ### Tune-A-Video 84 | 85 | Comparisons to [Tune-A-Video](https://github.com/showlab/Tune-A-Video). 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 |
Input video Tune-A-Video
Input video Tune-A-Video + smooth loss
A jeep car is moving on the road A jeep car is moving on the beach A jeep car is moving on the snow A jeep car is moving on the road, cartoon style A sports car is moving on the road
Input video Tune-A-Video
Input video Tune-A-Video + smooth loss
A rabbit is eating a watermelon A tiger is eating a watermelon A rabbit is eating an orange A rabbit is eating a pizza A puppy is eating an orange
Input video Tune-A-Video
Input video Tune-A-Video + smooth loss
A man is skiing Mickey mouse is skiing on the snow Spider man is skiing on the beach, cartoon style Wonder woman, wearing a cowboy hat, is skiing A man, wearing pink clothes, is skiing at sunset
176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | ### Make-A-Protagonist 184 | 185 | Comparisons to [Make-A-Protagonist](https://github.com/HeliosZhao/Make-A-Protagonist). 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 |
Input video Make-A-Protagonist Make-A-Protagonist + smooth loss
A jeep driving down a mountain road A jeep driving down a mountain road in the rain
A man is playing basketball A man is playing a basketball on the beach, anime style
A man walking down the street at night A panda walking down the snowy street
A man waling down the street Elon musk walking down the street
230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | ### ControlVideo 238 | 239 | Comparisons to [ControlVideo](https://github.com/thu-ml/controlvideo). 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 |
Input video Condition ControlVideo ControlVideo + smooth loss
A person is dancing Pose condition Michael Jackson is dancing
A person is dancing Pose condition A person is dancing, Makoto Shinkai style
A building Canny edge condition A wooden building, at night
A girl Hed edge condition A girl, Krenz Cushart style
A girl Hed edge condition A girl with rich makeup
Ink diffuses in water Depth condition Gentle green ink diffuses in water, beautiful light
315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | ### Video2Video-zero 323 | 324 | Comparisons to [Training-free methods](https://github.com/Picsart-AI-Research/Text2Video-Zero). 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 |
Input video Instruct Video2Video-zero Instruct Video2Video-zero + noise constraint Video InstructPix2Pix Video InstructPix2Pix + noise constraint
Instruct: Make it animation
346 | 347 | 348 | -------------------------------------------------------------------------------- /assets/controlvideo/building1-a wooden building, at night-80_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/building1-a wooden building, at night-80_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/building1-a wooden building, at night-80_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/building1-a wooden building, at night-80_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/building1-a wooden building, at night-80_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/building1-a wooden building, at night-80_merge_3.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance26-Michael Jackson is dancing-1500_merge_3.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/dance5-a person is dancing, Makoto Shinkai style-1000_merge_3.gif -------------------------------------------------------------------------------- /assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girl8-a girl, Krenz Cushart style-300_merge_3.gif -------------------------------------------------------------------------------- /assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/girlface9_6-a girl with rich makeup-300_merge_3.gif -------------------------------------------------------------------------------- /assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge.gif -------------------------------------------------------------------------------- /assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_0.gif -------------------------------------------------------------------------------- /assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_1.gif -------------------------------------------------------------------------------- /assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_2.gif -------------------------------------------------------------------------------- /assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/controlvideo/ink1-gentle green ink diffuses in water, beautiful light-200_merge_3.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style-org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style-org.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-A man is playing a basketball on the beach, anime style.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain-org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain-org.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a jeep driving down a mountain road in the rain.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-a panda walking down the snowy street-org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a panda walking down the snowy street-org.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-a panda walking down the snowy street.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-a panda walking down the snowy street.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-elon musk walking down the street-org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-elon musk walking down the street-org.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/0-elon musk walking down the street.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/0-elon musk walking down the street.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/car-turn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/car-turn.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/huaqiang.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/huaqiang.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/ikun.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/ikun.gif -------------------------------------------------------------------------------- /assets/make-a-protagonist/yanzi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/make-a-protagonist/yanzi.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_a jeep car is moving on the beach.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the beach.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_a jeep car is moving on the road, cartoon style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the road, cartoon style.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_a jeep car is moving on the snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a jeep car is moving on the snow.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_a sports car is moving on the road.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_a sports car is moving on the road.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_smooth_a jeep car is moving on the beach.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the beach.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_smooth_a jeep car is moving on the road, cartoon style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the road, cartoon style.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_smooth_a jeep car is moving on the snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a jeep car is moving on the snow.gif -------------------------------------------------------------------------------- /assets/tune-a-video/car-turn_smooth_a sports car is moving on the road.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/car-turn_smooth_a sports car is moving on the road.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_a man, wearing pink clothes, is skiing at sunset.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_a man, wearing pink clothes, is skiing at sunset.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_mickey mouse is skiing on the snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_mickey mouse is skiing on the snow.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_a man, wearing pink clothes, is skiing at sunset.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_smooth_mickey mouse is skiing on the snow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_mickey mouse is skiing on the snow.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_spider man is skiing on the beach, cartoon style.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_smooth_wonder woman, wearing a cowboy hat, is skiing.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_spider man is skiing on the beach, cartoon style.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_spider man is skiing on the beach, cartoon style.gif -------------------------------------------------------------------------------- /assets/tune-a-video/man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/man-skiing_wonder woman, wearing a cowboy hat, is skiing.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_a puppy is eating an orange.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a puppy is eating an orange.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_a rabbit is eating a pizza.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a rabbit is eating a pizza.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_a rabbit is eating an orange.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a rabbit is eating an orange.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_a tiger is eating a watermelon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_a tiger is eating a watermelon.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_smooth_a puppy is eating an orange.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a puppy is eating an orange.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating a pizza.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating a pizza.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating an orange.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a rabbit is eating an orange.gif -------------------------------------------------------------------------------- /assets/tune-a-video/rabbit-watermelon_smooth_a tiger is eating a watermelon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/tune-a-video/rabbit-watermelon_smooth_a tiger is eating a watermelon.gif -------------------------------------------------------------------------------- /assets/video2video-zero/mini-cooper.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper.gif -------------------------------------------------------------------------------- /assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero.gif -------------------------------------------------------------------------------- /assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_InstructVideo2Video-zero_noise_cons.gif -------------------------------------------------------------------------------- /assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix.gif -------------------------------------------------------------------------------- /assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/assets/video2video-zero/mini-cooper_make it animation_Video-InstructPix2Pix_noise_cons.gif -------------------------------------------------------------------------------- /configs/car-turn.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5" 2 | output_dir: "./outputs/car-turn" 3 | 4 | train_data: 5 | video_path: "data/car-turn.mp4" 6 | prompt: "a jeep car is moving on the road" 7 | n_sample_frames: 24 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 2 12 | 13 | validation_data: 14 | prompts: 15 | - "a jeep car is moving on the beach" 16 | - "a jeep car is moving on the snow" 17 | - "a jeep car is moving on the road, cartoon style" 18 | - "a sports car is moving on the road" 19 | video_length: 24 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | use_inv_latent: True 25 | num_inv_steps: 50 26 | 27 | learning_rate: 3e-5 28 | train_batch_size: 1 29 | max_train_steps: 500 30 | checkpointing_steps: 1000 31 | validation_steps: 100 32 | trainable_modules: 33 | - "attn1.to_q" 34 | - "attn2.to_q" 35 | - "attn_temp" 36 | 37 | seed: 33 38 | mixed_precision: fp16 39 | use_8bit_adam: False 40 | gradient_checkpointing: True 41 | enable_xformers_memory_efficient_attention: True 42 | -------------------------------------------------------------------------------- /configs/man-skiing.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5" 2 | output_dir: "./outputs/man-skiing" 3 | 4 | train_data: 5 | video_path: "data/man-skiing.mp4" 6 | prompt: "a man is skiing" 7 | n_sample_frames: 24 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 2 12 | 13 | validation_data: 14 | prompts: 15 | - "mickey mouse is skiing on the snow" 16 | - "spider man is skiing on the beach, cartoon style" 17 | - "wonder woman, wearing a cowboy hat, is skiing" 18 | - "a man, wearing pink clothes, is skiing at sunset" 19 | video_length: 24 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | use_inv_latent: True 25 | num_inv_steps: 50 26 | 27 | learning_rate: 3e-5 28 | train_batch_size: 1 29 | max_train_steps: 500 30 | checkpointing_steps: 1000 31 | validation_steps: 100 32 | trainable_modules: 33 | - "attn1.to_q" 34 | - "attn2.to_q" 35 | - "attn_temp" 36 | 37 | seed: 33 38 | mixed_precision: fp16 39 | use_8bit_adam: False 40 | gradient_checkpointing: True 41 | enable_xformers_memory_efficient_attention: True 42 | -------------------------------------------------------------------------------- /configs/man-surfing.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5" 2 | output_dir: "./outputs/man-surfing" 3 | 4 | train_data: 5 | video_path: "data/man-surfing.mp4" 6 | prompt: "a man is surfing" 7 | n_sample_frames: 24 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 1 12 | 13 | validation_data: 14 | prompts: 15 | - "a panda is surfing" 16 | - "a boy, wearing a birthday hat, is surfing" 17 | - "a raccoon is surfing, cartoon style" 18 | - "Iron Man is surfing in the desert" 19 | video_length: 24 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | use_inv_latent: True 25 | num_inv_steps: 50 26 | 27 | learning_rate: 3e-5 28 | train_batch_size: 1 29 | max_train_steps: 500 30 | checkpointing_steps: 1000 31 | validation_steps: 100 32 | trainable_modules: 33 | - "attn1.to_q" 34 | - "attn2.to_q" 35 | - "attn_temp" 36 | 37 | seed: 33 38 | mixed_precision: fp16 39 | use_8bit_adam: False 40 | gradient_checkpointing: True 41 | enable_xformers_memory_efficient_attention: True 42 | -------------------------------------------------------------------------------- /configs/rabbit-watermelon.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "./checkpoints/stable-diffusion-v1-5" 2 | output_dir: "./outputs/rabbit-watermelon" 3 | 4 | train_data: 5 | video_path: "data/rabbit-watermelon.mp4" 6 | prompt: "a rabbit is eating a watermelon" 7 | n_sample_frames: 24 8 | width: 512 9 | height: 512 10 | sample_start_idx: 0 11 | sample_frame_rate: 2 12 | 13 | validation_data: 14 | prompts: 15 | - "a tiger is eating a watermelon" 16 | - "a rabbit is eating an orange" 17 | - "a rabbit is eating a pizza" 18 | - "a puppy is eating an orange" 19 | video_length: 24 20 | width: 512 21 | height: 512 22 | num_inference_steps: 50 23 | guidance_scale: 7.5 24 | use_inv_latent: True 25 | num_inv_steps: 50 26 | 27 | learning_rate: 3e-5 28 | train_batch_size: 1 29 | max_train_steps: 500 30 | checkpointing_steps: 1000 31 | validation_steps: 100 32 | trainable_modules: 33 | - "attn1.to_q" 34 | - "attn2.to_q" 35 | - "attn_temp" 36 | 37 | seed: 33 38 | mixed_precision: fp16 39 | use_8bit_adam: False 40 | gradient_checkpointing: True 41 | enable_xformers_memory_efficient_attention: True 42 | -------------------------------------------------------------------------------- /data/car-turn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/car-turn.mp4 -------------------------------------------------------------------------------- /data/man-skiing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/man-skiing.mp4 -------------------------------------------------------------------------------- /data/man-surfing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/man-surfing.mp4 -------------------------------------------------------------------------------- /data/rabbit-watermelon.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPengLiang/SmoothVideo/363d28ec4604246004f168c589e3c4b0fcd707e7/data/rabbit-watermelon.mp4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | diffusers[torch]==0.11.1 4 | transformers>=4.25.1 5 | bitsandbytes==0.35.4 6 | decord==0.6.0 7 | accelerate 8 | tensorboard 9 | modelcards 10 | omegaconf 11 | einops 12 | imageio 13 | ftfy -------------------------------------------------------------------------------- /train_tuneavideo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import inspect 5 | import math 6 | import os 7 | from typing import Dict, Optional, Tuple 8 | from omegaconf import OmegaConf 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.utils.checkpoint 13 | 14 | import diffusers 15 | import transformers 16 | from accelerate import Accelerator 17 | from accelerate.logging import get_logger 18 | from accelerate.utils import set_seed 19 | from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler 20 | from diffusers.optimization import get_scheduler 21 | from diffusers.utils import check_min_version 22 | from diffusers.utils.import_utils import is_xformers_available 23 | from tqdm.auto import tqdm 24 | from transformers import CLIPTextModel, CLIPTokenizer 25 | 26 | from tuneavideo.models.unet import UNet3DConditionModel 27 | from tuneavideo.data.dataset import TuneAVideoDataset 28 | from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline 29 | from tuneavideo.util import save_videos_grid, ddim_inversion 30 | from einops import rearrange 31 | 32 | 33 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 34 | check_min_version("0.10.0.dev0") 35 | 36 | logger = get_logger(__name__, log_level="INFO") 37 | 38 | 39 | def main( 40 | pretrained_model_path: str, 41 | output_dir: str, 42 | train_data: Dict, 43 | validation_data: Dict, 44 | validation_steps: int = 100, 45 | trainable_modules: Tuple[str] = ( 46 | "attn1.to_q", 47 | "attn2.to_q", 48 | "attn_temp", 49 | ), 50 | train_batch_size: int = 1, 51 | max_train_steps: int = 500, 52 | learning_rate: float = 3e-5, 53 | scale_lr: bool = False, 54 | lr_scheduler: str = "constant", 55 | lr_warmup_steps: int = 0, 56 | adam_beta1: float = 0.9, 57 | adam_beta2: float = 0.999, 58 | adam_weight_decay: float = 1e-2, 59 | adam_epsilon: float = 1e-08, 60 | max_grad_norm: float = 1.0, 61 | gradient_accumulation_steps: int = 1, 62 | gradient_checkpointing: bool = True, 63 | checkpointing_steps: int = 500, 64 | resume_from_checkpoint: Optional[str] = None, 65 | mixed_precision: Optional[str] = "fp16", 66 | use_8bit_adam: bool = False, 67 | enable_xformers_memory_efficient_attention: bool = True, 68 | seed: Optional[int] = None, 69 | enable_smooth_loss=False, 70 | smooth_weight=0.2, 71 | lambda_factor=1000, 72 | simple_manner=False 73 | ): 74 | *_, config = inspect.getargvalues(inspect.currentframe()) 75 | 76 | accelerator = Accelerator( 77 | gradient_accumulation_steps=gradient_accumulation_steps, 78 | mixed_precision=mixed_precision, 79 | ) 80 | 81 | # Make one log on every process with the configuration for debugging. 82 | logging.basicConfig( 83 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 84 | datefmt="%m/%d/%Y %H:%M:%S", 85 | level=logging.INFO, 86 | ) 87 | logger.info(accelerator.state, main_process_only=False) 88 | if accelerator.is_local_main_process: 89 | transformers.utils.logging.set_verbosity_warning() 90 | diffusers.utils.logging.set_verbosity_info() 91 | else: 92 | transformers.utils.logging.set_verbosity_error() 93 | diffusers.utils.logging.set_verbosity_error() 94 | 95 | # If passed along, set the training seed now. 96 | if seed is not None: 97 | set_seed(seed) 98 | 99 | # Handle the output folder creation 100 | if accelerator.is_main_process: 101 | # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 102 | # output_dir = os.path.join(output_dir, now) 103 | os.makedirs(output_dir, exist_ok=True) 104 | os.makedirs(f"{output_dir}/samples", exist_ok=True) 105 | os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) 106 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 107 | 108 | # Load scheduler, tokenizer and models. 109 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") 110 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 111 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 112 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 113 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") 114 | 115 | # Freeze vae and text_encoder 116 | vae.requires_grad_(False) 117 | text_encoder.requires_grad_(False) 118 | 119 | unet.requires_grad_(False) 120 | for name, module in unet.named_modules(): 121 | if name.endswith(tuple(trainable_modules)): 122 | for params in module.parameters(): 123 | params.requires_grad = True 124 | 125 | if enable_xformers_memory_efficient_attention: 126 | if is_xformers_available(): 127 | unet.enable_xformers_memory_efficient_attention() 128 | else: 129 | raise ValueError("xformers is not available. Make sure it is installed correctly") 130 | 131 | if gradient_checkpointing: 132 | unet.enable_gradient_checkpointing() 133 | 134 | if scale_lr: 135 | learning_rate = ( 136 | learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes 137 | ) 138 | 139 | # Initialize the optimizer 140 | if use_8bit_adam: 141 | try: 142 | import bitsandbytes as bnb 143 | except ImportError: 144 | raise ImportError( 145 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 146 | ) 147 | 148 | optimizer_cls = bnb.optim.AdamW8bit 149 | else: 150 | optimizer_cls = torch.optim.AdamW 151 | 152 | optimizer = optimizer_cls( 153 | unet.parameters(), 154 | lr=learning_rate, 155 | betas=(adam_beta1, adam_beta2), 156 | weight_decay=adam_weight_decay, 157 | eps=adam_epsilon, 158 | ) 159 | 160 | # Get the training dataset 161 | train_dataset = TuneAVideoDataset(**train_data) 162 | 163 | # Preprocessing the dataset 164 | train_dataset.prompt_ids = tokenizer( 165 | train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 166 | ).input_ids[0] 167 | 168 | # DataLoaders creation: 169 | train_dataloader = torch.utils.data.DataLoader( 170 | train_dataset, batch_size=train_batch_size 171 | ) 172 | 173 | # Get the validation pipeline 174 | validation_pipeline = TuneAVideoPipeline( 175 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 176 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") 177 | ) 178 | validation_pipeline.enable_vae_slicing() 179 | ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') 180 | ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps) 181 | 182 | # Scheduler 183 | lr_scheduler = get_scheduler( 184 | lr_scheduler, 185 | optimizer=optimizer, 186 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, 187 | num_training_steps=max_train_steps * gradient_accumulation_steps, 188 | ) 189 | 190 | # Prepare everything with our `accelerator`. 191 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 192 | unet, optimizer, train_dataloader, lr_scheduler 193 | ) 194 | 195 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 196 | # as these models are only used for inference, keeping weights in full precision is not required. 197 | weight_dtype = torch.float32 198 | if accelerator.mixed_precision == "fp16": 199 | weight_dtype = torch.float16 200 | elif accelerator.mixed_precision == "bf16": 201 | weight_dtype = torch.bfloat16 202 | 203 | # Move text_encode and vae to gpu and cast to weight_dtype 204 | text_encoder.to(accelerator.device, dtype=weight_dtype) 205 | vae.to(accelerator.device, dtype=weight_dtype) 206 | 207 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 208 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 209 | # Afterwards we recalculate our number of training epochs 210 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 211 | 212 | # We need to initialize the trackers we use, and also store our configuration. 213 | # The trackers initializes automatically on the main process. 214 | if accelerator.is_main_process: 215 | accelerator.init_trackers("text2video-fine-tune") 216 | 217 | # Train! 218 | total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps 219 | 220 | logger.info("***** Running training *****") 221 | logger.info(f" Num examples = {len(train_dataset)}") 222 | logger.info(f" Num Epochs = {num_train_epochs}") 223 | logger.info(f" Instantaneous batch size per device = {train_batch_size}") 224 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 225 | logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") 226 | logger.info(f" Total optimization steps = {max_train_steps}") 227 | global_step = 0 228 | first_epoch = 0 229 | 230 | # Potentially load in the weights and states from a previous save 231 | if resume_from_checkpoint: 232 | if resume_from_checkpoint != "latest": 233 | path = os.path.basename(resume_from_checkpoint) 234 | else: 235 | # Get the most recent checkpoint 236 | dirs = os.listdir(output_dir) 237 | dirs = [d for d in dirs if d.startswith("checkpoint")] 238 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 239 | path = dirs[-1] 240 | accelerator.print(f"Resuming from checkpoint {path}") 241 | accelerator.load_state(os.path.join(output_dir, path)) 242 | global_step = int(path.split("-")[1]) 243 | 244 | first_epoch = global_step // num_update_steps_per_epoch 245 | resume_step = global_step % num_update_steps_per_epoch 246 | 247 | # Only show the progress bar once on each machine. 248 | progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) 249 | progress_bar.set_description("Steps") 250 | 251 | for epoch in range(first_epoch, num_train_epochs): 252 | unet.train() 253 | train_loss = 0.0 254 | for step, batch in enumerate(train_dataloader): 255 | # Skip steps until we reach the resumed step 256 | if resume_from_checkpoint and epoch == first_epoch and step < resume_step: 257 | if step % gradient_accumulation_steps == 0: 258 | progress_bar.update(1) 259 | continue 260 | 261 | with accelerator.accumulate(unet): 262 | # Convert videos to latent space 263 | pixel_values = batch["pixel_values"].to(weight_dtype) 264 | video_length = pixel_values.shape[1] 265 | pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") 266 | latents = vae.encode(pixel_values).latent_dist.sample() 267 | latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) 268 | latents = latents * 0.18215 269 | 270 | # Sample noise that we'll add to the latents 271 | noise = torch.randn_like(latents) 272 | bsz = latents.shape[0] 273 | # Sample a random timestep for each video 274 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 275 | timesteps = timesteps.long() 276 | 277 | # Add noise to the latents according to the noise magnitude at each timestep 278 | # (this is the forward diffusion process) 279 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 280 | 281 | # Get the text embedding for conditioning 282 | encoder_hidden_states = text_encoder(batch["prompt_ids"])[0] 283 | 284 | # Get the target for loss depending on the prediction type 285 | if noise_scheduler.prediction_type == "epsilon": 286 | target = noise 287 | elif noise_scheduler.prediction_type == "v_prediction": 288 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 289 | else: 290 | raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") 291 | 292 | # Predict the noise residual and compute loss 293 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 294 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 295 | 296 | 297 | if enable_smooth_loss: 298 | noisy_pred_delta = (model_pred[:, :, 1:-1, ...] - model_pred[:, :, :-2, ...]) + \ 299 | (model_pred[:, :, 1:-1, ...] - model_pred[:, :, 2:, ...]) 300 | if simple_manner: 301 | smooth_loss = torch.abs(noisy_pred_delta).mean() 302 | else: 303 | noisy_latents_delta = (noisy_latents[:, :, 1:-1, ...] - noisy_latents[:, :, :-2, ...]) + \ 304 | (noisy_latents[:, :, 1:-1, ...] - noisy_latents[:, :, 2:, ...]) 305 | 306 | alphas_cumprod = noise_scheduler.alphas_cumprod 307 | t = timesteps 308 | C = torch.sqrt(alphas_cumprod[t]) / (torch.sqrt(alphas_cumprod[t])*torch.sqrt(1-alphas_cumprod[t-1]) - 309 | torch.sqrt(alphas_cumprod[t-1])*torch.sqrt(1-alphas_cumprod[t])) 310 | smooth_loss = torch.abs(noisy_pred_delta - (noisy_latents_delta * C/lambda_factor)).mean() 311 | loss += smooth_weight * smooth_loss 312 | 313 | 314 | # Gather the losses across all processes for logging (if we use distributed training). 315 | avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() 316 | train_loss += avg_loss.item() / gradient_accumulation_steps 317 | 318 | # Backpropagate 319 | accelerator.backward(loss) 320 | if accelerator.sync_gradients: 321 | accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) 322 | optimizer.step() 323 | lr_scheduler.step() 324 | optimizer.zero_grad() 325 | 326 | # Checks if the accelerator has performed an optimization step behind the scenes 327 | if accelerator.sync_gradients: 328 | progress_bar.update(1) 329 | global_step += 1 330 | accelerator.log({"train_loss": train_loss}, step=global_step) 331 | train_loss = 0.0 332 | 333 | if global_step % checkpointing_steps == 0: 334 | if accelerator.is_main_process: 335 | save_path = os.path.join(output_dir, f"checkpoint-{global_step}") 336 | accelerator.save_state(save_path) 337 | logger.info(f"Saved state to {save_path}") 338 | 339 | if global_step % validation_steps == 0: 340 | if accelerator.is_main_process: 341 | samples = [] 342 | generator = torch.Generator(device=latents.device) 343 | generator.manual_seed(seed) 344 | 345 | ddim_inv_latent = None 346 | if validation_data.use_inv_latent: 347 | inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt") 348 | ddim_inv_latent = ddim_inversion( 349 | validation_pipeline, ddim_inv_scheduler, video_latent=latents, 350 | num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype) 351 | torch.save(ddim_inv_latent, inv_latents_path) 352 | 353 | for idx, prompt in enumerate(validation_data.prompts): 354 | sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent, 355 | **validation_data).videos 356 | save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{prompt}.gif") 357 | samples.append(sample) 358 | samples = torch.concat(samples) 359 | save_path = f"{output_dir}/samples/sample-{global_step}.gif" 360 | save_videos_grid(samples, save_path) 361 | logger.info(f"Saved samples to {save_path}") 362 | 363 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 364 | progress_bar.set_postfix(**logs) 365 | 366 | if global_step >= max_train_steps: 367 | break 368 | 369 | # Create the pipeline using the trained modules and save it. 370 | accelerator.wait_for_everyone() 371 | if accelerator.is_main_process: 372 | unet = accelerator.unwrap_model(unet) 373 | pipeline = TuneAVideoPipeline.from_pretrained( 374 | pretrained_model_path, 375 | text_encoder=text_encoder, 376 | vae=vae, 377 | unet=unet, 378 | ) 379 | pipeline.save_pretrained(output_dir) 380 | 381 | accelerator.end_training() 382 | 383 | 384 | if __name__ == "__main__": 385 | parser = argparse.ArgumentParser() 386 | parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml") 387 | parser.add_argument("--smooth_loss", action="store_true") 388 | parser.add_argument("--smooth_weight", type=float, default=0.2) 389 | parser.add_argument("--lambda_factor", type=float, default=1000) 390 | parser.add_argument("--simple_manner", action="store_true") 391 | args = parser.parse_args() 392 | 393 | main(**OmegaConf.load(args.config), 394 | enable_smooth_loss=args.smooth_loss, 395 | smooth_weight=args.smooth_weight, 396 | lambda_factor=args.lambda_factor, 397 | simple_manner=args.simple_manner 398 | ) 399 | -------------------------------------------------------------------------------- /tuneavideo/data/dataset.py: -------------------------------------------------------------------------------- 1 | import decord 2 | decord.bridge.set_bridge('torch') 3 | 4 | from torch.utils.data import Dataset 5 | from einops import rearrange 6 | 7 | 8 | class TuneAVideoDataset(Dataset): 9 | def __init__( 10 | self, 11 | video_path: str, 12 | prompt: str, 13 | width: int = 512, 14 | height: int = 512, 15 | n_sample_frames: int = 8, 16 | sample_start_idx: int = 0, 17 | sample_frame_rate: int = 1, 18 | ): 19 | self.video_path = video_path 20 | self.prompt = prompt 21 | self.prompt_ids = None 22 | 23 | self.width = width 24 | self.height = height 25 | self.n_sample_frames = n_sample_frames 26 | self.sample_start_idx = sample_start_idx 27 | self.sample_frame_rate = sample_frame_rate 28 | 29 | def __len__(self): 30 | return 1 31 | 32 | def __getitem__(self, index): 33 | # load and sample video frames 34 | vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) 35 | sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] 36 | video = vr.get_batch(sample_index) 37 | video = rearrange(video, "f h w c -> f c h w") 38 | 39 | example = { 40 | "pixel_values": (video / 127.5 - 1.0), 41 | "prompt_ids": self.prompt_ids 42 | } 43 | 44 | return example 45 | -------------------------------------------------------------------------------- /tuneavideo/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | @dataclass 20 | class Transformer3DModelOutput(BaseOutput): 21 | sample: torch.FloatTensor 22 | 23 | 24 | if is_xformers_available(): 25 | import xformers 26 | import xformers.ops 27 | else: 28 | xformers = None 29 | 30 | 31 | class Transformer3DModel(ModelMixin, ConfigMixin): 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | activation_fn: str = "geglu", 44 | num_embeds_ada_norm: Optional[int] = None, 45 | use_linear_projection: bool = False, 46 | only_cross_attention: bool = False, 47 | upcast_attention: bool = False, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 59 | if use_linear_projection: 60 | self.proj_in = nn.Linear(in_channels, inner_dim) 61 | else: 62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 63 | 64 | # Define transformers blocks 65 | self.transformer_blocks = nn.ModuleList( 66 | [ 67 | BasicTransformerBlock( 68 | inner_dim, 69 | num_attention_heads, 70 | attention_head_dim, 71 | dropout=dropout, 72 | cross_attention_dim=cross_attention_dim, 73 | activation_fn=activation_fn, 74 | num_embeds_ada_norm=num_embeds_ada_norm, 75 | attention_bias=attention_bias, 76 | only_cross_attention=only_cross_attention, 77 | upcast_attention=upcast_attention, 78 | ) 79 | for d in range(num_layers) 80 | ] 81 | ) 82 | 83 | # 4. Define output layers 84 | if use_linear_projection: 85 | self.proj_out = nn.Linear(in_channels, inner_dim) 86 | else: 87 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 88 | 89 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 90 | # Input 91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 92 | video_length = hidden_states.shape[2] 93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 94 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 95 | 96 | batch, channel, height, weight = hidden_states.shape 97 | residual = hidden_states 98 | 99 | hidden_states = self.norm(hidden_states) 100 | if not self.use_linear_projection: 101 | hidden_states = self.proj_in(hidden_states) 102 | inner_dim = hidden_states.shape[1] 103 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 104 | else: 105 | inner_dim = hidden_states.shape[1] 106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 107 | hidden_states = self.proj_in(hidden_states) 108 | 109 | # Blocks 110 | for block in self.transformer_blocks: 111 | hidden_states = block( 112 | hidden_states, 113 | encoder_hidden_states=encoder_hidden_states, 114 | timestep=timestep, 115 | video_length=video_length 116 | ) 117 | 118 | # Output 119 | if not self.use_linear_projection: 120 | hidden_states = ( 121 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 122 | ) 123 | hidden_states = self.proj_out(hidden_states) 124 | else: 125 | hidden_states = self.proj_out(hidden_states) 126 | hidden_states = ( 127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | ) 129 | 130 | output = hidden_states + residual 131 | 132 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 133 | if not return_dict: 134 | return (output,) 135 | 136 | return Transformer3DModelOutput(sample=output) 137 | 138 | 139 | class BasicTransformerBlock(nn.Module): 140 | def __init__( 141 | self, 142 | dim: int, 143 | num_attention_heads: int, 144 | attention_head_dim: int, 145 | dropout=0.0, 146 | cross_attention_dim: Optional[int] = None, 147 | activation_fn: str = "geglu", 148 | num_embeds_ada_norm: Optional[int] = None, 149 | attention_bias: bool = False, 150 | only_cross_attention: bool = False, 151 | upcast_attention: bool = False, 152 | ): 153 | super().__init__() 154 | self.only_cross_attention = only_cross_attention 155 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 156 | 157 | # SC-Attn 158 | self.attn1 = SparseCausalAttention( 159 | query_dim=dim, 160 | heads=num_attention_heads, 161 | dim_head=attention_head_dim, 162 | dropout=dropout, 163 | bias=attention_bias, 164 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 165 | upcast_attention=upcast_attention, 166 | ) 167 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 168 | 169 | # Cross-Attn 170 | if cross_attention_dim is not None: 171 | self.attn2 = CrossAttention( 172 | query_dim=dim, 173 | cross_attention_dim=cross_attention_dim, 174 | heads=num_attention_heads, 175 | dim_head=attention_head_dim, 176 | dropout=dropout, 177 | bias=attention_bias, 178 | upcast_attention=upcast_attention, 179 | ) 180 | else: 181 | self.attn2 = None 182 | 183 | if cross_attention_dim is not None: 184 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 185 | else: 186 | self.norm2 = None 187 | 188 | # Feed-forward 189 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 190 | self.norm3 = nn.LayerNorm(dim) 191 | 192 | # Temp-Attn 193 | self.attn_temp = CrossAttention( 194 | query_dim=dim, 195 | heads=num_attention_heads, 196 | dim_head=attention_head_dim, 197 | dropout=dropout, 198 | bias=attention_bias, 199 | upcast_attention=upcast_attention, 200 | ) 201 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 202 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 203 | 204 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 205 | if not is_xformers_available(): 206 | print("Here is how to install it") 207 | raise ModuleNotFoundError( 208 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 209 | " xformers", 210 | name="xformers", 211 | ) 212 | elif not torch.cuda.is_available(): 213 | raise ValueError( 214 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 215 | " available for GPU " 216 | ) 217 | else: 218 | try: 219 | # Make sure we can run the memory efficient attention 220 | _ = xformers.ops.memory_efficient_attention( 221 | torch.randn((1, 2, 40), device="cuda"), 222 | torch.randn((1, 2, 40), device="cuda"), 223 | torch.randn((1, 2, 40), device="cuda"), 224 | ) 225 | except Exception as e: 226 | raise e 227 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 228 | if self.attn2 is not None: 229 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 230 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 231 | 232 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 233 | # SparseCausal-Attention 234 | norm_hidden_states = ( 235 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 236 | ) 237 | 238 | if self.only_cross_attention: 239 | hidden_states = ( 240 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 241 | ) 242 | else: 243 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 244 | 245 | if self.attn2 is not None: 246 | # Cross-Attention 247 | norm_hidden_states = ( 248 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 249 | ) 250 | hidden_states = ( 251 | self.attn2( 252 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 253 | ) 254 | + hidden_states 255 | ) 256 | 257 | # Feed-forward 258 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 259 | 260 | # Temporal-Attention 261 | d = hidden_states.shape[1] 262 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 263 | norm_hidden_states = ( 264 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 265 | ) 266 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 267 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 268 | 269 | return hidden_states 270 | 271 | 272 | class SparseCausalAttention(CrossAttention): 273 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 274 | batch_size, sequence_length, _ = hidden_states.shape 275 | 276 | encoder_hidden_states = encoder_hidden_states 277 | 278 | if self.group_norm is not None: 279 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 280 | 281 | query = self.to_q(hidden_states) 282 | dim = query.shape[-1] 283 | query = self.reshape_heads_to_batch_dim(query) 284 | 285 | if self.added_kv_proj_dim is not None: 286 | raise NotImplementedError 287 | 288 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 289 | key = self.to_k(encoder_hidden_states) 290 | value = self.to_v(encoder_hidden_states) 291 | 292 | former_frame_index = torch.arange(video_length) - 1 293 | former_frame_index[0] = 0 294 | 295 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 296 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) 297 | key = rearrange(key, "b f d c -> (b f) d c") 298 | 299 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 300 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) 301 | value = rearrange(value, "b f d c -> (b f) d c") 302 | 303 | key = self.reshape_heads_to_batch_dim(key) 304 | value = self.reshape_heads_to_batch_dim(value) 305 | 306 | if attention_mask is not None: 307 | if attention_mask.shape[-1] != query.shape[1]: 308 | target_length = query.shape[1] 309 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 310 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 311 | 312 | # attention, what we cannot get enough of 313 | if self._use_memory_efficient_attention_xformers: 314 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 315 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 316 | hidden_states = hidden_states.to(query.dtype) 317 | else: 318 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 319 | hidden_states = self._attention(query, key, value, attention_mask) 320 | else: 321 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 322 | 323 | # linear proj 324 | hidden_states = self.to_out[0](hidden_states) 325 | 326 | # dropout 327 | hidden_states = self.to_out[1](hidden_states) 328 | return hidden_states 329 | -------------------------------------------------------------------------------- /tuneavideo/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class Upsample3D(nn.Module): 22 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 23 | super().__init__() 24 | self.channels = channels 25 | self.out_channels = out_channels or channels 26 | self.use_conv = use_conv 27 | self.use_conv_transpose = use_conv_transpose 28 | self.name = name 29 | 30 | conv = None 31 | if use_conv_transpose: 32 | raise NotImplementedError 33 | elif use_conv: 34 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 35 | 36 | if name == "conv": 37 | self.conv = conv 38 | else: 39 | self.Conv2d_0 = conv 40 | 41 | def forward(self, hidden_states, output_size=None): 42 | assert hidden_states.shape[1] == self.channels 43 | 44 | if self.use_conv_transpose: 45 | raise NotImplementedError 46 | 47 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 48 | dtype = hidden_states.dtype 49 | if dtype == torch.bfloat16: 50 | hidden_states = hidden_states.to(torch.float32) 51 | 52 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 53 | if hidden_states.shape[0] >= 64: 54 | hidden_states = hidden_states.contiguous() 55 | 56 | # if `output_size` is passed we force the interpolation output 57 | # size and do not make use of `scale_factor=2` 58 | if output_size is None: 59 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 60 | else: 61 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 62 | 63 | # If the input is bfloat16, we cast back to bfloat16 64 | if dtype == torch.bfloat16: 65 | hidden_states = hidden_states.to(dtype) 66 | 67 | if self.use_conv: 68 | if self.name == "conv": 69 | hidden_states = self.conv(hidden_states) 70 | else: 71 | hidden_states = self.Conv2d_0(hidden_states) 72 | 73 | return hidden_states 74 | 75 | 76 | class Downsample3D(nn.Module): 77 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 78 | super().__init__() 79 | self.channels = channels 80 | self.out_channels = out_channels or channels 81 | self.use_conv = use_conv 82 | self.padding = padding 83 | stride = 2 84 | self.name = name 85 | 86 | if use_conv: 87 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 88 | else: 89 | raise NotImplementedError 90 | 91 | if name == "conv": 92 | self.Conv2d_0 = conv 93 | self.conv = conv 94 | elif name == "Conv2d_0": 95 | self.conv = conv 96 | else: 97 | self.conv = conv 98 | 99 | def forward(self, hidden_states): 100 | assert hidden_states.shape[1] == self.channels 101 | if self.use_conv and self.padding == 0: 102 | raise NotImplementedError 103 | 104 | assert hidden_states.shape[1] == self.channels 105 | hidden_states = self.conv(hidden_states) 106 | 107 | return hidden_states 108 | 109 | 110 | class ResnetBlock3D(nn.Module): 111 | def __init__( 112 | self, 113 | *, 114 | in_channels, 115 | out_channels=None, 116 | conv_shortcut=False, 117 | dropout=0.0, 118 | temb_channels=512, 119 | groups=32, 120 | groups_out=None, 121 | pre_norm=True, 122 | eps=1e-6, 123 | non_linearity="swish", 124 | time_embedding_norm="default", 125 | output_scale_factor=1.0, 126 | use_in_shortcut=None, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 142 | 143 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 144 | 145 | if temb_channels is not None: 146 | if self.time_embedding_norm == "default": 147 | time_emb_proj_out_channels = out_channels 148 | elif self.time_embedding_norm == "scale_shift": 149 | time_emb_proj_out_channels = out_channels * 2 150 | else: 151 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 152 | 153 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 154 | else: 155 | self.time_emb_proj = None 156 | 157 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 158 | self.dropout = torch.nn.Dropout(dropout) 159 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 160 | 161 | if non_linearity == "swish": 162 | self.nonlinearity = lambda x: F.silu(x) 163 | elif non_linearity == "mish": 164 | self.nonlinearity = Mish() 165 | elif non_linearity == "silu": 166 | self.nonlinearity = nn.SiLU() 167 | 168 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 169 | 170 | self.conv_shortcut = None 171 | if self.use_in_shortcut: 172 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 173 | 174 | def forward(self, input_tensor, temb): 175 | hidden_states = input_tensor 176 | 177 | hidden_states = self.norm1(hidden_states) 178 | hidden_states = self.nonlinearity(hidden_states) 179 | 180 | hidden_states = self.conv1(hidden_states) 181 | 182 | if temb is not None: 183 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 184 | 185 | if temb is not None and self.time_embedding_norm == "default": 186 | hidden_states = hidden_states + temb 187 | 188 | hidden_states = self.norm2(hidden_states) 189 | 190 | if temb is not None and self.time_embedding_norm == "scale_shift": 191 | scale, shift = torch.chunk(temb, 2, dim=1) 192 | hidden_states = hidden_states * (1 + scale) + shift 193 | 194 | hidden_states = self.nonlinearity(hidden_states) 195 | 196 | hidden_states = self.dropout(hidden_states) 197 | hidden_states = self.conv2(hidden_states) 198 | 199 | if self.conv_shortcut is not None: 200 | input_tensor = self.conv_shortcut(input_tensor) 201 | 202 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 203 | 204 | return output_tensor 205 | 206 | 207 | class Mish(torch.nn.Module): 208 | def forward(self, hidden_states): 209 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /tuneavideo/models/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | from diffusers.configuration_utils import ConfigMixin, register_to_config 14 | from diffusers.modeling_utils import ModelMixin 15 | from diffusers.utils import BaseOutput, logging 16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 17 | from .unet_blocks import ( 18 | CrossAttnDownBlock3D, 19 | CrossAttnUpBlock3D, 20 | DownBlock3D, 21 | UNetMidBlock3DCrossAttn, 22 | UpBlock3D, 23 | get_down_block, 24 | get_up_block, 25 | ) 26 | from .resnet import InflatedConv3d 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | class UNet3DConditionOutput(BaseOutput): 34 | sample: torch.FloatTensor 35 | 36 | 37 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 38 | _supports_gradient_checkpointing = True 39 | 40 | @register_to_config 41 | def __init__( 42 | self, 43 | sample_size: Optional[int] = None, 44 | in_channels: int = 4, 45 | out_channels: int = 4, 46 | center_input_sample: bool = False, 47 | flip_sin_to_cos: bool = True, 48 | freq_shift: int = 0, 49 | down_block_types: Tuple[str] = ( 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "CrossAttnDownBlock3D", 53 | "DownBlock3D", 54 | ), 55 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 56 | up_block_types: Tuple[str] = ( 57 | "UpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D", 60 | "CrossAttnUpBlock3D" 61 | ), 62 | only_cross_attention: Union[bool, Tuple[bool]] = False, 63 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 64 | layers_per_block: int = 2, 65 | downsample_padding: int = 1, 66 | mid_block_scale_factor: float = 1, 67 | act_fn: str = "silu", 68 | norm_num_groups: int = 32, 69 | norm_eps: float = 1e-5, 70 | cross_attention_dim: int = 1280, 71 | attention_head_dim: Union[int, Tuple[int]] = 8, 72 | dual_cross_attention: bool = False, 73 | use_linear_projection: bool = False, 74 | class_embed_type: Optional[str] = None, 75 | num_class_embeds: Optional[int] = None, 76 | upcast_attention: bool = False, 77 | resnet_time_scale_shift: str = "default", 78 | ): 79 | super().__init__() 80 | 81 | self.sample_size = sample_size 82 | time_embed_dim = block_out_channels[0] * 4 83 | 84 | # input 85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 86 | 87 | # time 88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 89 | timestep_input_dim = block_out_channels[0] 90 | 91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 92 | 93 | # class embedding 94 | if class_embed_type is None and num_class_embeds is not None: 95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 96 | elif class_embed_type == "timestep": 97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 98 | elif class_embed_type == "identity": 99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 100 | else: 101 | self.class_embedding = None 102 | 103 | self.down_blocks = nn.ModuleList([]) 104 | self.mid_block = None 105 | self.up_blocks = nn.ModuleList([]) 106 | 107 | if isinstance(only_cross_attention, bool): 108 | only_cross_attention = [only_cross_attention] * len(down_block_types) 109 | 110 | if isinstance(attention_head_dim, int): 111 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 112 | 113 | # down 114 | output_channel = block_out_channels[0] 115 | for i, down_block_type in enumerate(down_block_types): 116 | input_channel = output_channel 117 | output_channel = block_out_channels[i] 118 | is_final_block = i == len(block_out_channels) - 1 119 | 120 | down_block = get_down_block( 121 | down_block_type, 122 | num_layers=layers_per_block, 123 | in_channels=input_channel, 124 | out_channels=output_channel, 125 | temb_channels=time_embed_dim, 126 | add_downsample=not is_final_block, 127 | resnet_eps=norm_eps, 128 | resnet_act_fn=act_fn, 129 | resnet_groups=norm_num_groups, 130 | cross_attention_dim=cross_attention_dim, 131 | attn_num_head_channels=attention_head_dim[i], 132 | downsample_padding=downsample_padding, 133 | dual_cross_attention=dual_cross_attention, 134 | use_linear_projection=use_linear_projection, 135 | only_cross_attention=only_cross_attention[i], 136 | upcast_attention=upcast_attention, 137 | resnet_time_scale_shift=resnet_time_scale_shift, 138 | ) 139 | self.down_blocks.append(down_block) 140 | 141 | # mid 142 | if mid_block_type == "UNetMidBlock3DCrossAttn": 143 | self.mid_block = UNetMidBlock3DCrossAttn( 144 | in_channels=block_out_channels[-1], 145 | temb_channels=time_embed_dim, 146 | resnet_eps=norm_eps, 147 | resnet_act_fn=act_fn, 148 | output_scale_factor=mid_block_scale_factor, 149 | resnet_time_scale_shift=resnet_time_scale_shift, 150 | cross_attention_dim=cross_attention_dim, 151 | attn_num_head_channels=attention_head_dim[-1], 152 | resnet_groups=norm_num_groups, 153 | dual_cross_attention=dual_cross_attention, 154 | use_linear_projection=use_linear_projection, 155 | upcast_attention=upcast_attention, 156 | ) 157 | else: 158 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 159 | 160 | # count how many layers upsample the videos 161 | self.num_upsamplers = 0 162 | 163 | # up 164 | reversed_block_out_channels = list(reversed(block_out_channels)) 165 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 166 | only_cross_attention = list(reversed(only_cross_attention)) 167 | output_channel = reversed_block_out_channels[0] 168 | for i, up_block_type in enumerate(up_block_types): 169 | is_final_block = i == len(block_out_channels) - 1 170 | 171 | prev_output_channel = output_channel 172 | output_channel = reversed_block_out_channels[i] 173 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 174 | 175 | # add upsample block for all BUT final layer 176 | if not is_final_block: 177 | add_upsample = True 178 | self.num_upsamplers += 1 179 | else: 180 | add_upsample = False 181 | 182 | up_block = get_up_block( 183 | up_block_type, 184 | num_layers=layers_per_block + 1, 185 | in_channels=input_channel, 186 | out_channels=output_channel, 187 | prev_output_channel=prev_output_channel, 188 | temb_channels=time_embed_dim, 189 | add_upsample=add_upsample, 190 | resnet_eps=norm_eps, 191 | resnet_act_fn=act_fn, 192 | resnet_groups=norm_num_groups, 193 | cross_attention_dim=cross_attention_dim, 194 | attn_num_head_channels=reversed_attention_head_dim[i], 195 | dual_cross_attention=dual_cross_attention, 196 | use_linear_projection=use_linear_projection, 197 | only_cross_attention=only_cross_attention[i], 198 | upcast_attention=upcast_attention, 199 | resnet_time_scale_shift=resnet_time_scale_shift, 200 | ) 201 | self.up_blocks.append(up_block) 202 | prev_output_channel = output_channel 203 | 204 | # out 205 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 206 | self.conv_act = nn.SiLU() 207 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 208 | 209 | def set_attention_slice(self, slice_size): 210 | r""" 211 | Enable sliced attention computation. 212 | 213 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 214 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 215 | 216 | Args: 217 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 218 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 219 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 220 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 221 | must be a multiple of `slice_size`. 222 | """ 223 | sliceable_head_dims = [] 224 | 225 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 226 | if hasattr(module, "set_attention_slice"): 227 | sliceable_head_dims.append(module.sliceable_head_dim) 228 | 229 | for child in module.children(): 230 | fn_recursive_retrieve_slicable_dims(child) 231 | 232 | # retrieve number of attention layers 233 | for module in self.children(): 234 | fn_recursive_retrieve_slicable_dims(module) 235 | 236 | num_slicable_layers = len(sliceable_head_dims) 237 | 238 | if slice_size == "auto": 239 | # half the attention head size is usually a good trade-off between 240 | # speed and memory 241 | slice_size = [dim // 2 for dim in sliceable_head_dims] 242 | elif slice_size == "max": 243 | # make smallest slice possible 244 | slice_size = num_slicable_layers * [1] 245 | 246 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 247 | 248 | if len(slice_size) != len(sliceable_head_dims): 249 | raise ValueError( 250 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 251 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 252 | ) 253 | 254 | for i in range(len(slice_size)): 255 | size = slice_size[i] 256 | dim = sliceable_head_dims[i] 257 | if size is not None and size > dim: 258 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 259 | 260 | # Recursively walk through all the children. 261 | # Any children which exposes the set_attention_slice method 262 | # gets the message 263 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 264 | if hasattr(module, "set_attention_slice"): 265 | module.set_attention_slice(slice_size.pop()) 266 | 267 | for child in module.children(): 268 | fn_recursive_set_attention_slice(child, slice_size) 269 | 270 | reversed_slice_size = list(reversed(slice_size)) 271 | for module in self.children(): 272 | fn_recursive_set_attention_slice(module, reversed_slice_size) 273 | 274 | def _set_gradient_checkpointing(self, module, value=False): 275 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 276 | module.gradient_checkpointing = value 277 | 278 | def forward( 279 | self, 280 | sample: torch.FloatTensor, 281 | timestep: Union[torch.Tensor, float, int], 282 | encoder_hidden_states: torch.Tensor, 283 | class_labels: Optional[torch.Tensor] = None, 284 | attention_mask: Optional[torch.Tensor] = None, 285 | return_dict: bool = True, 286 | ) -> Union[UNet3DConditionOutput, Tuple]: 287 | r""" 288 | Args: 289 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 290 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 291 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 292 | return_dict (`bool`, *optional*, defaults to `True`): 293 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 294 | 295 | Returns: 296 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 297 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 298 | returning a tuple, the first element is the sample tensor. 299 | """ 300 | # By default samples have to be AT least a multiple of the overall upsampling factor. 301 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 302 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 303 | # on the fly if necessary. 304 | default_overall_up_factor = 2**self.num_upsamplers 305 | 306 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 307 | forward_upsample_size = False 308 | upsample_size = None 309 | 310 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 311 | logger.info("Forward upsample size to force interpolation output size.") 312 | forward_upsample_size = True 313 | 314 | # prepare attention_mask 315 | if attention_mask is not None: 316 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 317 | attention_mask = attention_mask.unsqueeze(1) 318 | 319 | # center input if necessary 320 | if self.config.center_input_sample: 321 | sample = 2 * sample - 1.0 322 | 323 | # time 324 | timesteps = timestep 325 | if not torch.is_tensor(timesteps): 326 | # This would be a good case for the `match` statement (Python 3.10+) 327 | is_mps = sample.device.type == "mps" 328 | if isinstance(timestep, float): 329 | dtype = torch.float32 if is_mps else torch.float64 330 | else: 331 | dtype = torch.int32 if is_mps else torch.int64 332 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 333 | elif len(timesteps.shape) == 0: 334 | timesteps = timesteps[None].to(sample.device) 335 | 336 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 337 | timesteps = timesteps.expand(sample.shape[0]) 338 | 339 | t_emb = self.time_proj(timesteps) 340 | 341 | # timesteps does not contain any weights and will always return f32 tensors 342 | # but time_embedding might actually be running in fp16. so we need to cast here. 343 | # there might be better ways to encapsulate this. 344 | t_emb = t_emb.to(dtype=self.dtype) 345 | emb = self.time_embedding(t_emb) 346 | 347 | if self.class_embedding is not None: 348 | if class_labels is None: 349 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 350 | 351 | if self.config.class_embed_type == "timestep": 352 | class_labels = self.time_proj(class_labels) 353 | 354 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 355 | emb = emb + class_emb 356 | 357 | # pre-process 358 | sample = self.conv_in(sample) 359 | 360 | # down 361 | down_block_res_samples = (sample,) 362 | for downsample_block in self.down_blocks: 363 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 364 | sample, res_samples = downsample_block( 365 | hidden_states=sample, 366 | temb=emb, 367 | encoder_hidden_states=encoder_hidden_states, 368 | attention_mask=attention_mask, 369 | ) 370 | else: 371 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 372 | 373 | down_block_res_samples += res_samples 374 | 375 | # mid 376 | sample = self.mid_block( 377 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 378 | ) 379 | 380 | # up 381 | for i, upsample_block in enumerate(self.up_blocks): 382 | is_final_block = i == len(self.up_blocks) - 1 383 | 384 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 385 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 386 | 387 | # if we have not reached the final block and need to forward the 388 | # upsample size, we do it here 389 | if not is_final_block and forward_upsample_size: 390 | upsample_size = down_block_res_samples[-1].shape[2:] 391 | 392 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 393 | sample = upsample_block( 394 | hidden_states=sample, 395 | temb=emb, 396 | res_hidden_states_tuple=res_samples, 397 | encoder_hidden_states=encoder_hidden_states, 398 | upsample_size=upsample_size, 399 | attention_mask=attention_mask, 400 | ) 401 | else: 402 | sample = upsample_block( 403 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 404 | ) 405 | # post-process 406 | sample = self.conv_norm_out(sample) 407 | sample = self.conv_act(sample) 408 | sample = self.conv_out(sample) 409 | 410 | if not return_dict: 411 | return (sample,) 412 | 413 | return UNet3DConditionOutput(sample=sample) 414 | 415 | @classmethod 416 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): 417 | if subfolder is not None: 418 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 419 | 420 | config_file = os.path.join(pretrained_model_path, 'config.json') 421 | if not os.path.isfile(config_file): 422 | raise RuntimeError(f"{config_file} does not exist") 423 | with open(config_file, "r") as f: 424 | config = json.load(f) 425 | config["_class_name"] = cls.__name__ 426 | config["down_block_types"] = [ 427 | "CrossAttnDownBlock3D", 428 | "CrossAttnDownBlock3D", 429 | "CrossAttnDownBlock3D", 430 | "DownBlock3D" 431 | ] 432 | config["up_block_types"] = [ 433 | "UpBlock3D", 434 | "CrossAttnUpBlock3D", 435 | "CrossAttnUpBlock3D", 436 | "CrossAttnUpBlock3D" 437 | ] 438 | 439 | from diffusers.utils import WEIGHTS_NAME 440 | model = cls.from_config(config) 441 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 442 | if not os.path.isfile(model_file): 443 | raise RuntimeError(f"{model_file} does not exist") 444 | state_dict = torch.load(model_file, map_location="cpu") 445 | for k, v in model.state_dict().items(): 446 | if '_temp.' in k: 447 | state_dict.update({k: v}) 448 | model.load_state_dict(state_dict) 449 | 450 | return model -------------------------------------------------------------------------------- /tuneavideo/models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | 9 | 10 | def get_down_block( 11 | down_block_type, 12 | num_layers, 13 | in_channels, 14 | out_channels, 15 | temb_channels, 16 | add_downsample, 17 | resnet_eps, 18 | resnet_act_fn, 19 | attn_num_head_channels, 20 | resnet_groups=None, 21 | cross_attention_dim=None, 22 | downsample_padding=None, 23 | dual_cross_attention=False, 24 | use_linear_projection=False, 25 | only_cross_attention=False, 26 | upcast_attention=False, 27 | resnet_time_scale_shift="default", 28 | ): 29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 30 | if down_block_type == "DownBlock3D": 31 | return DownBlock3D( 32 | num_layers=num_layers, 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | temb_channels=temb_channels, 36 | add_downsample=add_downsample, 37 | resnet_eps=resnet_eps, 38 | resnet_act_fn=resnet_act_fn, 39 | resnet_groups=resnet_groups, 40 | downsample_padding=downsample_padding, 41 | resnet_time_scale_shift=resnet_time_scale_shift, 42 | ) 43 | elif down_block_type == "CrossAttnDownBlock3D": 44 | if cross_attention_dim is None: 45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 46 | return CrossAttnDownBlock3D( 47 | num_layers=num_layers, 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | temb_channels=temb_channels, 51 | add_downsample=add_downsample, 52 | resnet_eps=resnet_eps, 53 | resnet_act_fn=resnet_act_fn, 54 | resnet_groups=resnet_groups, 55 | downsample_padding=downsample_padding, 56 | cross_attention_dim=cross_attention_dim, 57 | attn_num_head_channels=attn_num_head_channels, 58 | dual_cross_attention=dual_cross_attention, 59 | use_linear_projection=use_linear_projection, 60 | only_cross_attention=only_cross_attention, 61 | upcast_attention=upcast_attention, 62 | resnet_time_scale_shift=resnet_time_scale_shift, 63 | ) 64 | raise ValueError(f"{down_block_type} does not exist.") 65 | 66 | 67 | def get_up_block( 68 | up_block_type, 69 | num_layers, 70 | in_channels, 71 | out_channels, 72 | prev_output_channel, 73 | temb_channels, 74 | add_upsample, 75 | resnet_eps, 76 | resnet_act_fn, 77 | attn_num_head_channels, 78 | resnet_groups=None, 79 | cross_attention_dim=None, 80 | dual_cross_attention=False, 81 | use_linear_projection=False, 82 | only_cross_attention=False, 83 | upcast_attention=False, 84 | resnet_time_scale_shift="default", 85 | ): 86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 87 | if up_block_type == "UpBlock3D": 88 | return UpBlock3D( 89 | num_layers=num_layers, 90 | in_channels=in_channels, 91 | out_channels=out_channels, 92 | prev_output_channel=prev_output_channel, 93 | temb_channels=temb_channels, 94 | add_upsample=add_upsample, 95 | resnet_eps=resnet_eps, 96 | resnet_act_fn=resnet_act_fn, 97 | resnet_groups=resnet_groups, 98 | resnet_time_scale_shift=resnet_time_scale_shift, 99 | ) 100 | elif up_block_type == "CrossAttnUpBlock3D": 101 | if cross_attention_dim is None: 102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 103 | return CrossAttnUpBlock3D( 104 | num_layers=num_layers, 105 | in_channels=in_channels, 106 | out_channels=out_channels, 107 | prev_output_channel=prev_output_channel, 108 | temb_channels=temb_channels, 109 | add_upsample=add_upsample, 110 | resnet_eps=resnet_eps, 111 | resnet_act_fn=resnet_act_fn, 112 | resnet_groups=resnet_groups, 113 | cross_attention_dim=cross_attention_dim, 114 | attn_num_head_channels=attn_num_head_channels, 115 | dual_cross_attention=dual_cross_attention, 116 | use_linear_projection=use_linear_projection, 117 | only_cross_attention=only_cross_attention, 118 | upcast_attention=upcast_attention, 119 | resnet_time_scale_shift=resnet_time_scale_shift, 120 | ) 121 | raise ValueError(f"{up_block_type} does not exist.") 122 | 123 | 124 | class UNetMidBlock3DCrossAttn(nn.Module): 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | temb_channels: int, 129 | dropout: float = 0.0, 130 | num_layers: int = 1, 131 | resnet_eps: float = 1e-6, 132 | resnet_time_scale_shift: str = "default", 133 | resnet_act_fn: str = "swish", 134 | resnet_groups: int = 32, 135 | resnet_pre_norm: bool = True, 136 | attn_num_head_channels=1, 137 | output_scale_factor=1.0, 138 | cross_attention_dim=1280, 139 | dual_cross_attention=False, 140 | use_linear_projection=False, 141 | upcast_attention=False, 142 | ): 143 | super().__init__() 144 | 145 | self.has_cross_attention = True 146 | self.attn_num_head_channels = attn_num_head_channels 147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 148 | 149 | # there is always at least one resnet 150 | resnets = [ 151 | ResnetBlock3D( 152 | in_channels=in_channels, 153 | out_channels=in_channels, 154 | temb_channels=temb_channels, 155 | eps=resnet_eps, 156 | groups=resnet_groups, 157 | dropout=dropout, 158 | time_embedding_norm=resnet_time_scale_shift, 159 | non_linearity=resnet_act_fn, 160 | output_scale_factor=output_scale_factor, 161 | pre_norm=resnet_pre_norm, 162 | ) 163 | ] 164 | attentions = [] 165 | 166 | for _ in range(num_layers): 167 | if dual_cross_attention: 168 | raise NotImplementedError 169 | attentions.append( 170 | Transformer3DModel( 171 | attn_num_head_channels, 172 | in_channels // attn_num_head_channels, 173 | in_channels=in_channels, 174 | num_layers=1, 175 | cross_attention_dim=cross_attention_dim, 176 | norm_num_groups=resnet_groups, 177 | use_linear_projection=use_linear_projection, 178 | upcast_attention=upcast_attention, 179 | ) 180 | ) 181 | resnets.append( 182 | ResnetBlock3D( 183 | in_channels=in_channels, 184 | out_channels=in_channels, 185 | temb_channels=temb_channels, 186 | eps=resnet_eps, 187 | groups=resnet_groups, 188 | dropout=dropout, 189 | time_embedding_norm=resnet_time_scale_shift, 190 | non_linearity=resnet_act_fn, 191 | output_scale_factor=output_scale_factor, 192 | pre_norm=resnet_pre_norm, 193 | ) 194 | ) 195 | 196 | self.attentions = nn.ModuleList(attentions) 197 | self.resnets = nn.ModuleList(resnets) 198 | 199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 200 | hidden_states = self.resnets[0](hidden_states, temb) 201 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 203 | hidden_states = resnet(hidden_states, temb) 204 | 205 | return hidden_states 206 | 207 | 208 | class CrossAttnDownBlock3D(nn.Module): 209 | def __init__( 210 | self, 211 | in_channels: int, 212 | out_channels: int, 213 | temb_channels: int, 214 | dropout: float = 0.0, 215 | num_layers: int = 1, 216 | resnet_eps: float = 1e-6, 217 | resnet_time_scale_shift: str = "default", 218 | resnet_act_fn: str = "swish", 219 | resnet_groups: int = 32, 220 | resnet_pre_norm: bool = True, 221 | attn_num_head_channels=1, 222 | cross_attention_dim=1280, 223 | output_scale_factor=1.0, 224 | downsample_padding=1, 225 | add_downsample=True, 226 | dual_cross_attention=False, 227 | use_linear_projection=False, 228 | only_cross_attention=False, 229 | upcast_attention=False, 230 | ): 231 | super().__init__() 232 | resnets = [] 233 | attentions = [] 234 | 235 | self.has_cross_attention = True 236 | self.attn_num_head_channels = attn_num_head_channels 237 | 238 | for i in range(num_layers): 239 | in_channels = in_channels if i == 0 else out_channels 240 | resnets.append( 241 | ResnetBlock3D( 242 | in_channels=in_channels, 243 | out_channels=out_channels, 244 | temb_channels=temb_channels, 245 | eps=resnet_eps, 246 | groups=resnet_groups, 247 | dropout=dropout, 248 | time_embedding_norm=resnet_time_scale_shift, 249 | non_linearity=resnet_act_fn, 250 | output_scale_factor=output_scale_factor, 251 | pre_norm=resnet_pre_norm, 252 | ) 253 | ) 254 | if dual_cross_attention: 255 | raise NotImplementedError 256 | attentions.append( 257 | Transformer3DModel( 258 | attn_num_head_channels, 259 | out_channels // attn_num_head_channels, 260 | in_channels=out_channels, 261 | num_layers=1, 262 | cross_attention_dim=cross_attention_dim, 263 | norm_num_groups=resnet_groups, 264 | use_linear_projection=use_linear_projection, 265 | only_cross_attention=only_cross_attention, 266 | upcast_attention=upcast_attention, 267 | ) 268 | ) 269 | self.attentions = nn.ModuleList(attentions) 270 | self.resnets = nn.ModuleList(resnets) 271 | 272 | if add_downsample: 273 | self.downsamplers = nn.ModuleList( 274 | [ 275 | Downsample3D( 276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 277 | ) 278 | ] 279 | ) 280 | else: 281 | self.downsamplers = None 282 | 283 | self.gradient_checkpointing = False 284 | 285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 286 | output_states = () 287 | 288 | for resnet, attn in zip(self.resnets, self.attentions): 289 | if self.training and self.gradient_checkpointing: 290 | 291 | def create_custom_forward(module, return_dict=None): 292 | def custom_forward(*inputs): 293 | if return_dict is not None: 294 | return module(*inputs, return_dict=return_dict) 295 | else: 296 | return module(*inputs) 297 | 298 | return custom_forward 299 | 300 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 301 | hidden_states = torch.utils.checkpoint.checkpoint( 302 | create_custom_forward(attn, return_dict=False), 303 | hidden_states, 304 | encoder_hidden_states, 305 | )[0] 306 | else: 307 | hidden_states = resnet(hidden_states, temb) 308 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 309 | 310 | output_states += (hidden_states,) 311 | 312 | if self.downsamplers is not None: 313 | for downsampler in self.downsamplers: 314 | hidden_states = downsampler(hidden_states) 315 | 316 | output_states += (hidden_states,) 317 | 318 | return hidden_states, output_states 319 | 320 | 321 | class DownBlock3D(nn.Module): 322 | def __init__( 323 | self, 324 | in_channels: int, 325 | out_channels: int, 326 | temb_channels: int, 327 | dropout: float = 0.0, 328 | num_layers: int = 1, 329 | resnet_eps: float = 1e-6, 330 | resnet_time_scale_shift: str = "default", 331 | resnet_act_fn: str = "swish", 332 | resnet_groups: int = 32, 333 | resnet_pre_norm: bool = True, 334 | output_scale_factor=1.0, 335 | add_downsample=True, 336 | downsample_padding=1, 337 | ): 338 | super().__init__() 339 | resnets = [] 340 | 341 | for i in range(num_layers): 342 | in_channels = in_channels if i == 0 else out_channels 343 | resnets.append( 344 | ResnetBlock3D( 345 | in_channels=in_channels, 346 | out_channels=out_channels, 347 | temb_channels=temb_channels, 348 | eps=resnet_eps, 349 | groups=resnet_groups, 350 | dropout=dropout, 351 | time_embedding_norm=resnet_time_scale_shift, 352 | non_linearity=resnet_act_fn, 353 | output_scale_factor=output_scale_factor, 354 | pre_norm=resnet_pre_norm, 355 | ) 356 | ) 357 | 358 | self.resnets = nn.ModuleList(resnets) 359 | 360 | if add_downsample: 361 | self.downsamplers = nn.ModuleList( 362 | [ 363 | Downsample3D( 364 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 365 | ) 366 | ] 367 | ) 368 | else: 369 | self.downsamplers = None 370 | 371 | self.gradient_checkpointing = False 372 | 373 | def forward(self, hidden_states, temb=None): 374 | output_states = () 375 | 376 | for resnet in self.resnets: 377 | if self.training and self.gradient_checkpointing: 378 | 379 | def create_custom_forward(module): 380 | def custom_forward(*inputs): 381 | return module(*inputs) 382 | 383 | return custom_forward 384 | 385 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 386 | else: 387 | hidden_states = resnet(hidden_states, temb) 388 | 389 | output_states += (hidden_states,) 390 | 391 | if self.downsamplers is not None: 392 | for downsampler in self.downsamplers: 393 | hidden_states = downsampler(hidden_states) 394 | 395 | output_states += (hidden_states,) 396 | 397 | return hidden_states, output_states 398 | 399 | 400 | class CrossAttnUpBlock3D(nn.Module): 401 | def __init__( 402 | self, 403 | in_channels: int, 404 | out_channels: int, 405 | prev_output_channel: int, 406 | temb_channels: int, 407 | dropout: float = 0.0, 408 | num_layers: int = 1, 409 | resnet_eps: float = 1e-6, 410 | resnet_time_scale_shift: str = "default", 411 | resnet_act_fn: str = "swish", 412 | resnet_groups: int = 32, 413 | resnet_pre_norm: bool = True, 414 | attn_num_head_channels=1, 415 | cross_attention_dim=1280, 416 | output_scale_factor=1.0, 417 | add_upsample=True, 418 | dual_cross_attention=False, 419 | use_linear_projection=False, 420 | only_cross_attention=False, 421 | upcast_attention=False, 422 | ): 423 | super().__init__() 424 | resnets = [] 425 | attentions = [] 426 | 427 | self.has_cross_attention = True 428 | self.attn_num_head_channels = attn_num_head_channels 429 | 430 | for i in range(num_layers): 431 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 432 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 433 | 434 | resnets.append( 435 | ResnetBlock3D( 436 | in_channels=resnet_in_channels + res_skip_channels, 437 | out_channels=out_channels, 438 | temb_channels=temb_channels, 439 | eps=resnet_eps, 440 | groups=resnet_groups, 441 | dropout=dropout, 442 | time_embedding_norm=resnet_time_scale_shift, 443 | non_linearity=resnet_act_fn, 444 | output_scale_factor=output_scale_factor, 445 | pre_norm=resnet_pre_norm, 446 | ) 447 | ) 448 | if dual_cross_attention: 449 | raise NotImplementedError 450 | attentions.append( 451 | Transformer3DModel( 452 | attn_num_head_channels, 453 | out_channels // attn_num_head_channels, 454 | in_channels=out_channels, 455 | num_layers=1, 456 | cross_attention_dim=cross_attention_dim, 457 | norm_num_groups=resnet_groups, 458 | use_linear_projection=use_linear_projection, 459 | only_cross_attention=only_cross_attention, 460 | upcast_attention=upcast_attention, 461 | ) 462 | ) 463 | 464 | self.attentions = nn.ModuleList(attentions) 465 | self.resnets = nn.ModuleList(resnets) 466 | 467 | if add_upsample: 468 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 469 | else: 470 | self.upsamplers = None 471 | 472 | self.gradient_checkpointing = False 473 | 474 | def forward( 475 | self, 476 | hidden_states, 477 | res_hidden_states_tuple, 478 | temb=None, 479 | encoder_hidden_states=None, 480 | upsample_size=None, 481 | attention_mask=None, 482 | ): 483 | for resnet, attn in zip(self.resnets, self.attentions): 484 | # pop res hidden states 485 | res_hidden_states = res_hidden_states_tuple[-1] 486 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 487 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 488 | 489 | if self.training and self.gradient_checkpointing: 490 | 491 | def create_custom_forward(module, return_dict=None): 492 | def custom_forward(*inputs): 493 | if return_dict is not None: 494 | return module(*inputs, return_dict=return_dict) 495 | else: 496 | return module(*inputs) 497 | 498 | return custom_forward 499 | 500 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 501 | hidden_states = torch.utils.checkpoint.checkpoint( 502 | create_custom_forward(attn, return_dict=False), 503 | hidden_states, 504 | encoder_hidden_states, 505 | )[0] 506 | else: 507 | hidden_states = resnet(hidden_states, temb) 508 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 509 | 510 | if self.upsamplers is not None: 511 | for upsampler in self.upsamplers: 512 | hidden_states = upsampler(hidden_states, upsample_size) 513 | 514 | return hidden_states 515 | 516 | 517 | class UpBlock3D(nn.Module): 518 | def __init__( 519 | self, 520 | in_channels: int, 521 | prev_output_channel: int, 522 | out_channels: int, 523 | temb_channels: int, 524 | dropout: float = 0.0, 525 | num_layers: int = 1, 526 | resnet_eps: float = 1e-6, 527 | resnet_time_scale_shift: str = "default", 528 | resnet_act_fn: str = "swish", 529 | resnet_groups: int = 32, 530 | resnet_pre_norm: bool = True, 531 | output_scale_factor=1.0, 532 | add_upsample=True, 533 | ): 534 | super().__init__() 535 | resnets = [] 536 | 537 | for i in range(num_layers): 538 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 539 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 540 | 541 | resnets.append( 542 | ResnetBlock3D( 543 | in_channels=resnet_in_channels + res_skip_channels, 544 | out_channels=out_channels, 545 | temb_channels=temb_channels, 546 | eps=resnet_eps, 547 | groups=resnet_groups, 548 | dropout=dropout, 549 | time_embedding_norm=resnet_time_scale_shift, 550 | non_linearity=resnet_act_fn, 551 | output_scale_factor=output_scale_factor, 552 | pre_norm=resnet_pre_norm, 553 | ) 554 | ) 555 | 556 | self.resnets = nn.ModuleList(resnets) 557 | 558 | if add_upsample: 559 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 560 | else: 561 | self.upsamplers = None 562 | 563 | self.gradient_checkpointing = False 564 | 565 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 566 | for resnet in self.resnets: 567 | # pop res hidden states 568 | res_hidden_states = res_hidden_states_tuple[-1] 569 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 570 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 571 | 572 | if self.training and self.gradient_checkpointing: 573 | 574 | def create_custom_forward(module): 575 | def custom_forward(*inputs): 576 | return module(*inputs) 577 | 578 | return custom_forward 579 | 580 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 581 | else: 582 | hidden_states = resnet(hidden_states, temb) 583 | 584 | if self.upsamplers is not None: 585 | for upsampler in self.upsamplers: 586 | hidden_states = upsampler(hidden_states, upsample_size) 587 | 588 | return hidden_states 589 | -------------------------------------------------------------------------------- /tuneavideo/pipelines/pipeline_tuneavideo.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 2 | 3 | import inspect 4 | from typing import Callable, List, Optional, Union 5 | from dataclasses import dataclass 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from diffusers.utils import is_accelerate_available 11 | from packaging import version 12 | from transformers import CLIPTextModel, CLIPTokenizer 13 | 14 | from diffusers.configuration_utils import FrozenDict 15 | from diffusers.models import AutoencoderKL 16 | from diffusers.pipeline_utils import DiffusionPipeline 17 | from diffusers.schedulers import ( 18 | DDIMScheduler, 19 | DPMSolverMultistepScheduler, 20 | EulerAncestralDiscreteScheduler, 21 | EulerDiscreteScheduler, 22 | LMSDiscreteScheduler, 23 | PNDMScheduler, 24 | ) 25 | from diffusers.utils import deprecate, logging, BaseOutput 26 | 27 | from einops import rearrange 28 | 29 | from ..models.unet import UNet3DConditionModel 30 | 31 | 32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 33 | 34 | 35 | @dataclass 36 | class TuneAVideoPipelineOutput(BaseOutput): 37 | videos: Union[torch.Tensor, np.ndarray] 38 | 39 | 40 | class TuneAVideoPipeline(DiffusionPipeline): 41 | _optional_components = [] 42 | 43 | def __init__( 44 | self, 45 | vae: AutoencoderKL, 46 | text_encoder: CLIPTextModel, 47 | tokenizer: CLIPTokenizer, 48 | unet: UNet3DConditionModel, 49 | scheduler: Union[ 50 | DDIMScheduler, 51 | PNDMScheduler, 52 | LMSDiscreteScheduler, 53 | EulerDiscreteScheduler, 54 | EulerAncestralDiscreteScheduler, 55 | DPMSolverMultistepScheduler, 56 | ], 57 | ): 58 | super().__init__() 59 | 60 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 61 | deprecation_message = ( 62 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 63 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 64 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 65 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 66 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 67 | " file" 68 | ) 69 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 70 | new_config = dict(scheduler.config) 71 | new_config["steps_offset"] = 1 72 | scheduler._internal_dict = FrozenDict(new_config) 73 | 74 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 75 | deprecation_message = ( 76 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 77 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 78 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 79 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 80 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 81 | ) 82 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 83 | new_config = dict(scheduler.config) 84 | new_config["clip_sample"] = False 85 | scheduler._internal_dict = FrozenDict(new_config) 86 | 87 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 88 | version.parse(unet.config._diffusers_version).base_version 89 | ) < version.parse("0.9.0.dev0") 90 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 91 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 92 | deprecation_message = ( 93 | "The configuration file of the unet has set the default `sample_size` to smaller than" 94 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 95 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 96 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 97 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 98 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 99 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 100 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 101 | " the `unet/config.json` file" 102 | ) 103 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 104 | new_config = dict(unet.config) 105 | new_config["sample_size"] = 64 106 | unet._internal_dict = FrozenDict(new_config) 107 | 108 | self.register_modules( 109 | vae=vae, 110 | text_encoder=text_encoder, 111 | tokenizer=tokenizer, 112 | unet=unet, 113 | scheduler=scheduler, 114 | ) 115 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 116 | 117 | def enable_vae_slicing(self): 118 | self.vae.enable_slicing() 119 | 120 | def disable_vae_slicing(self): 121 | self.vae.disable_slicing() 122 | 123 | def enable_sequential_cpu_offload(self, gpu_id=0): 124 | if is_accelerate_available(): 125 | from accelerate import cpu_offload 126 | else: 127 | raise ImportError("Please install accelerate via `pip install accelerate`") 128 | 129 | device = torch.device(f"cuda:{gpu_id}") 130 | 131 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 132 | if cpu_offloaded_model is not None: 133 | cpu_offload(cpu_offloaded_model, device) 134 | 135 | 136 | @property 137 | def _execution_device(self): 138 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 139 | return self.device 140 | for module in self.unet.modules(): 141 | if ( 142 | hasattr(module, "_hf_hook") 143 | and hasattr(module._hf_hook, "execution_device") 144 | and module._hf_hook.execution_device is not None 145 | ): 146 | return torch.device(module._hf_hook.execution_device) 147 | return self.device 148 | 149 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 150 | batch_size = len(prompt) if isinstance(prompt, list) else 1 151 | 152 | text_inputs = self.tokenizer( 153 | prompt, 154 | padding="max_length", 155 | max_length=self.tokenizer.model_max_length, 156 | truncation=True, 157 | return_tensors="pt", 158 | ) 159 | text_input_ids = text_inputs.input_ids 160 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 161 | 162 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 163 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 164 | logger.warning( 165 | "The following part of your input was truncated because CLIP can only handle sequences up to" 166 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 167 | ) 168 | 169 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 170 | attention_mask = text_inputs.attention_mask.to(device) 171 | else: 172 | attention_mask = None 173 | 174 | text_embeddings = self.text_encoder( 175 | text_input_ids.to(device), 176 | attention_mask=attention_mask, 177 | ) 178 | text_embeddings = text_embeddings[0] 179 | 180 | # duplicate text embeddings for each generation per prompt, using mps friendly method 181 | bs_embed, seq_len, _ = text_embeddings.shape 182 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 183 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 184 | 185 | # get unconditional embeddings for classifier free guidance 186 | if do_classifier_free_guidance: 187 | uncond_tokens: List[str] 188 | if negative_prompt is None: 189 | uncond_tokens = [""] * batch_size 190 | elif type(prompt) is not type(negative_prompt): 191 | raise TypeError( 192 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 193 | f" {type(prompt)}." 194 | ) 195 | elif isinstance(negative_prompt, str): 196 | uncond_tokens = [negative_prompt] 197 | elif batch_size != len(negative_prompt): 198 | raise ValueError( 199 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 200 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 201 | " the batch size of `prompt`." 202 | ) 203 | else: 204 | uncond_tokens = negative_prompt 205 | 206 | max_length = text_input_ids.shape[-1] 207 | uncond_input = self.tokenizer( 208 | uncond_tokens, 209 | padding="max_length", 210 | max_length=max_length, 211 | truncation=True, 212 | return_tensors="pt", 213 | ) 214 | 215 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 216 | attention_mask = uncond_input.attention_mask.to(device) 217 | else: 218 | attention_mask = None 219 | 220 | uncond_embeddings = self.text_encoder( 221 | uncond_input.input_ids.to(device), 222 | attention_mask=attention_mask, 223 | ) 224 | uncond_embeddings = uncond_embeddings[0] 225 | 226 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 227 | seq_len = uncond_embeddings.shape[1] 228 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 229 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 230 | 231 | # For classifier free guidance, we need to do two forward passes. 232 | # Here we concatenate the unconditional and text embeddings into a single batch 233 | # to avoid doing two forward passes 234 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 235 | 236 | return text_embeddings 237 | 238 | def decode_latents(self, latents): 239 | video_length = latents.shape[2] 240 | latents = 1 / 0.18215 * latents 241 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 242 | video = self.vae.decode(latents).sample 243 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 244 | video = (video / 2 + 0.5).clamp(0, 1) 245 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 246 | video = video.cpu().float().numpy() 247 | return video 248 | 249 | def prepare_extra_step_kwargs(self, generator, eta): 250 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 251 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 252 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 253 | # and should be between [0, 1] 254 | 255 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 256 | extra_step_kwargs = {} 257 | if accepts_eta: 258 | extra_step_kwargs["eta"] = eta 259 | 260 | # check if the scheduler accepts generator 261 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 262 | if accepts_generator: 263 | extra_step_kwargs["generator"] = generator 264 | return extra_step_kwargs 265 | 266 | def check_inputs(self, prompt, height, width, callback_steps): 267 | if not isinstance(prompt, str) and not isinstance(prompt, list): 268 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 269 | 270 | if height % 8 != 0 or width % 8 != 0: 271 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 272 | 273 | if (callback_steps is None) or ( 274 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 275 | ): 276 | raise ValueError( 277 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 278 | f" {type(callback_steps)}." 279 | ) 280 | 281 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): 282 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 283 | if isinstance(generator, list) and len(generator) != batch_size: 284 | raise ValueError( 285 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 286 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 287 | ) 288 | 289 | if latents is None: 290 | rand_device = "cpu" if device.type == "mps" else device 291 | 292 | if isinstance(generator, list): 293 | shape = (1,) + shape[1:] 294 | latents = [ 295 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 296 | for i in range(batch_size) 297 | ] 298 | latents = torch.cat(latents, dim=0).to(device) 299 | else: 300 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 301 | else: 302 | if latents.shape != shape: 303 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 304 | latents = latents.to(device) 305 | 306 | # scale the initial noise by the standard deviation required by the scheduler 307 | latents = latents * self.scheduler.init_noise_sigma 308 | return latents 309 | 310 | @torch.no_grad() 311 | def __call__( 312 | self, 313 | prompt: Union[str, List[str]], 314 | video_length: Optional[int], 315 | height: Optional[int] = None, 316 | width: Optional[int] = None, 317 | num_inference_steps: int = 50, 318 | guidance_scale: float = 7.5, 319 | negative_prompt: Optional[Union[str, List[str]]] = None, 320 | num_videos_per_prompt: Optional[int] = 1, 321 | eta: float = 0.0, 322 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 323 | latents: Optional[torch.FloatTensor] = None, 324 | output_type: Optional[str] = "tensor", 325 | return_dict: bool = True, 326 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 327 | callback_steps: Optional[int] = 1, 328 | **kwargs, 329 | ): 330 | # Default height and width to unet 331 | height = height or self.unet.config.sample_size * self.vae_scale_factor 332 | width = width or self.unet.config.sample_size * self.vae_scale_factor 333 | 334 | # Check inputs. Raise error if not correct 335 | self.check_inputs(prompt, height, width, callback_steps) 336 | 337 | # Define call parameters 338 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 339 | device = self._execution_device 340 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 341 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 342 | # corresponds to doing no classifier free guidance. 343 | do_classifier_free_guidance = guidance_scale > 1.0 344 | 345 | # Encode input prompt 346 | text_embeddings = self._encode_prompt( 347 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 348 | ) 349 | 350 | # Prepare timesteps 351 | self.scheduler.set_timesteps(num_inference_steps, device=device) 352 | timesteps = self.scheduler.timesteps 353 | 354 | # Prepare latent variables 355 | num_channels_latents = self.unet.in_channels 356 | latents = self.prepare_latents( 357 | batch_size * num_videos_per_prompt, 358 | num_channels_latents, 359 | video_length, 360 | height, 361 | width, 362 | text_embeddings.dtype, 363 | device, 364 | generator, 365 | latents, 366 | ) 367 | latents_dtype = latents.dtype 368 | 369 | # Prepare extra step kwargs. 370 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 371 | 372 | # Denoising loop 373 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 374 | with self.progress_bar(total=num_inference_steps) as progress_bar: 375 | for i, t in enumerate(timesteps): 376 | # expand the latents if we are doing classifier free guidance 377 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 378 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 379 | 380 | # predict the noise residual 381 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) 382 | 383 | # perform guidance 384 | if do_classifier_free_guidance: 385 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 386 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 387 | 388 | # compute the previous noisy sample x_t -> x_t-1 389 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 390 | 391 | # call the callback, if provided 392 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 393 | progress_bar.update() 394 | if callback is not None and i % callback_steps == 0: 395 | callback(i, t, latents) 396 | 397 | # Post-processing 398 | video = self.decode_latents(latents) 399 | 400 | # Convert to tensor 401 | if output_type == "tensor": 402 | video = torch.from_numpy(video) 403 | 404 | if not return_dict: 405 | return video 406 | 407 | return TuneAVideoPipelineOutput(videos=video) -------------------------------------------------------------------------------- /tuneavideo/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | 9 | from tqdm import tqdm 10 | from einops import rearrange 11 | 12 | 13 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 14 | videos = rearrange(videos, "b c t h w -> t b c h w") 15 | outputs = [] 16 | for x in videos: 17 | x = torchvision.utils.make_grid(x, nrow=n_rows) 18 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 19 | if rescale: 20 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 21 | x = (x * 255).numpy().astype(np.uint8) 22 | outputs.append(x) 23 | 24 | os.makedirs(os.path.dirname(path), exist_ok=True) 25 | imageio.mimsave(path, outputs, fps=fps) 26 | 27 | 28 | # DDIM Inversion 29 | @torch.no_grad() 30 | def init_prompt(prompt, pipeline): 31 | uncond_input = pipeline.tokenizer( 32 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 33 | return_tensors="pt" 34 | ) 35 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 36 | text_input = pipeline.tokenizer( 37 | [prompt], 38 | padding="max_length", 39 | max_length=pipeline.tokenizer.model_max_length, 40 | truncation=True, 41 | return_tensors="pt", 42 | ) 43 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 44 | context = torch.cat([uncond_embeddings, text_embeddings]) 45 | 46 | return context 47 | 48 | 49 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 50 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 51 | timestep, next_timestep = min( 52 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 53 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 54 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 55 | beta_prod_t = 1 - alpha_prod_t 56 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 57 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 58 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 59 | return next_sample 60 | 61 | 62 | def get_noise_pred_single(latents, t, context, unet): 63 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 64 | return noise_pred 65 | 66 | 67 | @torch.no_grad() 68 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 69 | context = init_prompt(prompt, pipeline) 70 | uncond_embeddings, cond_embeddings = context.chunk(2) 71 | all_latent = [latent] 72 | latent = latent.clone().detach() 73 | for i in tqdm(range(num_inv_steps)): 74 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 75 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 76 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 77 | all_latent.append(latent) 78 | return all_latent 79 | 80 | 81 | @torch.no_grad() 82 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 83 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 84 | return ddim_latents 85 | --------------------------------------------------------------------------------