├── .gitignore ├── LICENSE ├── README.md ├── background.sh ├── cat.sh ├── data ├── background.mp4 ├── guidance10.gif ├── guidance17.5.gif ├── guidance20.gif ├── puff.mp4 ├── source.gif ├── tiger_empty.gif ├── tiger_neg.gif └── trucks-race.mp4 ├── inference.py ├── models ├── __init__.py ├── attention.py ├── pipeline_flatten.py ├── resnet.py ├── unet.py ├── unet_blocks.py └── util.py └── truck.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing 2 | [![arXiv](https://img.shields.io/badge/arXiv-2310.05922-b31b1b.svg)](https://arxiv.org/abs/2310.05922) 3 | [![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://flatten-video-editing.github.io/) 4 | [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fyrcong%2Fflatten%2F&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=visitors&edge_flat=false)](https://hits.seeyoufarm.com) 5 | 6 | **Pytorch Implementation of "FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing".** 7 | 8 | 🎊🎊🎊 We are proud to announce that our paper has been accepted at **ICLR 2024**! If you are interested in FLATTEN, please give us a star😬 9 | ![teaser-ezgif com-resize](https://github.com/yrcong/flatten/assets/47991543/4f92f2bd-e4e9-4710-82b3-6efd36c27f46) 10 | 11 | Thanks to @[**logtd**](https://github.com/logtd) for integrating FLATTEN into ComfyUI and the great sampled videos! **Here is the [Link](https://github.com/logtd/ComfyUI-FLATTEN?tab=readme-ov-file)!** 12 | 13 | https://github.com/yrcong/flatten/assets/47991543/1ad49092-9133-42d0-984f-38c6427bde34 14 | 15 | 16 | ## 📖Abstract 17 | 🚩**Text-to-Video** 🚩**Training-free** 🚩**Plug-and-Play**
18 | 19 | Text-to-video editing aims to edit the visual appearance of a source video conditional on textual prompts. A major challenge in this task is to ensure that all frames in the edited video are visually consistent. In this work, for the first time, we introduce optical flow into the attention module in the diffusion model's U-Net to address the inconsistency issue for text-to-video editing. Our method, FLATTEN, enforces the patches on the same flow path across different frames to attend to each other in the attention module, thus improving the visual consistency in the edited videos. Additionally, our method is training-free and can be seamlessly integrated into any diffusion-based text-to-video editing methods and improve their visual consistency. 20 | 21 | ## Requirements 22 | First you can download Stable Diffusion 2.1 **(base)** [here](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). 23 | 24 | Install the following packages: 25 | - PyTorch == 2.1 26 | - accelerate == 0.24.1 27 | - diffusers == 0.19.0 28 | - transformers == 4.35.0 29 | - xformers == 0.0.23 30 | 31 | ## Usage 32 | For text-to-video edting, a source video and a textual prompt should be given. You can run the script to get the teaser video easily: 33 | ``` 34 | sh cat.sh 35 | ``` 36 | or with the command: 37 | ``` 38 | python inference.py \ 39 | --prompt "A Tiger, high quality" \ 40 | --neg_prompt "a cat with big eyes, deformed" \ 41 | --guidance_scale 20 \ 42 | --video_path "data/puff.mp4" \ 43 | --output_path "outputs/" \ 44 | --video_length 32 \ 45 | --width 512 \ 46 | --height 512 \ 47 | --old_qk 0 \ 48 | --frame_rate 2 \ 49 | ``` 50 | 51 | ## Editing tricks 52 | - You can use a negative prompt (NP) when there is a big gap between the edit target and the source (1st row). 53 | - You can increase the scale of classifier-free guidance to enhance the semantic alignment (2nd row). 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 |
Source videoNP: " "NP: "A cat with big eyes, deformed."
Classifier-free guidance: 10Classifier-free guidance: 17.5Classifier-free guidance: 25
77 | 78 | 79 | ## BibTex 80 | ``` 81 | @article{cong2023flatten, 82 | title={FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing}, 83 | author={Cong, Yuren and Xu, Mengmeng and Simon, Christian and Chen, Shoufa and Ren, Jiawei and Xie, Yanping and Perez-Rua, Juan-Manuel and Rosenhahn, Bodo and Xiang, Tao and He, Sen}, 84 | journal={arXiv preprint arXiv:2310.05922}, 85 | year={2023} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /background.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --prompt "pointillism painting, detailed" \ 3 | --neg_prompt " " \ 4 | --guidance 25 \ 5 | --video_path "data/background.mp4" \ 6 | --output_path "outputs/" \ 7 | --video_length 32 \ 8 | --width 512 \ 9 | --height 512 \ 10 | --old_qk 1 \ 11 | --frame_rate 1 \ 12 | -------------------------------------------------------------------------------- /cat.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --prompt "A Tiger, high quality" \ 3 | --neg_prompt "a cat with big eyes, deformed" \ 4 | --guidance_scale 20 \ 5 | --video_path "data/puff.mp4" \ 6 | --output_path "outputs/" \ 7 | --video_length 32 \ 8 | --width 512 \ 9 | --height 512 \ 10 | --old_qk 0 \ 11 | --frame_rate 2 \ 12 | -------------------------------------------------------------------------------- /data/background.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/background.mp4 -------------------------------------------------------------------------------- /data/guidance10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance10.gif -------------------------------------------------------------------------------- /data/guidance17.5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance17.5.gif -------------------------------------------------------------------------------- /data/guidance20.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance20.gif -------------------------------------------------------------------------------- /data/puff.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/puff.mp4 -------------------------------------------------------------------------------- /data/source.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/source.gif -------------------------------------------------------------------------------- /data/tiger_empty.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/tiger_empty.gif -------------------------------------------------------------------------------- /data/tiger_neg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/tiger_neg.gif -------------------------------------------------------------------------------- /data/trucks-race.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/trucks-race.mp4 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torchvision 5 | from einops import rearrange 6 | from diffusers import DDIMScheduler, AutoencoderKL, DDIMInverseScheduler 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | 9 | from models.pipeline_flatten import FlattenPipeline 10 | from models.util import save_videos_grid, read_video, sample_trajectories 11 | from models.unet import UNet3DConditionModel 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--prompt", type=str, required=True, help="Textual prompt for video editing") 16 | parser.add_argument("--neg_prompt", type=str, required=True, help="Negative prompt for guidance") 17 | parser.add_argument("--guidance_scale", default=10.0, type=float, help="Guidance scale") 18 | parser.add_argument("--video_path", type=str, required=True, help="Path to a source video") 19 | parser.add_argument("--sd_path", type=str, default="checkpoints/stable-diffusion-2-1-base", help="Path of Stable Diffusion") 20 | parser.add_argument("--output_path", type=str, default="./outputs", help="Directory of output") 21 | parser.add_argument("--video_length", type=int, default=15, help="Length of output video") 22 | parser.add_argument("--old_qk", type=int, default=0, help="Whether to use old queries and keys for flow-guided attention") 23 | parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32") 24 | parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32") 25 | parser.add_argument("--sample_steps", type=int, default=50, help="Steps for feature injection") 26 | parser.add_argument("--inject_step", type=int, default=40, help="Steps for feature injection") 27 | parser.add_argument("--seed", type=int, default=66, help="Random seed of generator") 28 | parser.add_argument("--frame_rate", type=int, default=None, help="The frame rate of loading input video. Default rate is computed according to video length.") 29 | parser.add_argument("--fps", type=int, default=15, help="FPS of the output video") 30 | args = parser.parse_args() 31 | return args 32 | 33 | if __name__ == "__main__": 34 | 35 | args = get_args() 36 | os.makedirs(args.output_path, exist_ok=True) 37 | device = "cuda" 38 | # Height and width should be 512 39 | args.height = (args.height // 32) * 32 40 | args.width = (args.width // 32) * 32 41 | 42 | tokenizer = CLIPTokenizer.from_pretrained(args.sd_path, subfolder="tokenizer") 43 | text_encoder = CLIPTextModel.from_pretrained(args.sd_path, subfolder="text_encoder").to(dtype=torch.float16) 44 | vae = AutoencoderKL.from_pretrained(args.sd_path, subfolder="vae").to(dtype=torch.float16) 45 | unet = UNet3DConditionModel.from_pretrained_2d(args.sd_path, subfolder="unet").to(dtype=torch.float16) 46 | scheduler=DDIMScheduler.from_pretrained(args.sd_path, subfolder="scheduler") 47 | inverse=DDIMInverseScheduler.from_pretrained(args.sd_path, subfolder="scheduler") 48 | 49 | pipe = FlattenPipeline( 50 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 51 | scheduler=scheduler, inverse_scheduler=inverse) 52 | pipe.enable_vae_slicing() 53 | pipe.enable_xformers_memory_efficient_attention() 54 | pipe.to(device) 55 | 56 | generator = torch.Generator(device=device) 57 | generator.manual_seed(args.seed) 58 | 59 | # read the source video 60 | video = read_video(video_path=args.video_path, video_length=args.video_length, 61 | width=args.width, height=args.height, frame_rate=args.frame_rate) 62 | original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1) 63 | save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True) 64 | 65 | t2i_transform = torchvision.transforms.ToPILImage() 66 | real_frames = [] 67 | for i, frame in enumerate(video): 68 | real_frames.append(t2i_transform(((frame+1)/2*255).to(torch.uint8))) 69 | 70 | # compute optical flows and sample trajectories 71 | trajectories = sample_trajectories(os.path.join(args.output_path, "source_video.mp4"), device) 72 | torch.cuda.empty_cache() 73 | 74 | for k in trajectories.keys(): 75 | trajectories[k] = trajectories[k].to(device) 76 | sample = pipe(args.prompt, video_length=args.video_length, frames=real_frames, 77 | num_inference_steps=args.sample_steps, generator=generator, guidance_scale=args.guidance_scale, 78 | negative_prompt=args.neg_prompt, width=args.width, height=args.height, 79 | trajs=trajectories, output_dir="tmp/", inject_step=args.inject_step, old_qk=args.old_qk).videos 80 | temp_video_name = args.prompt+"_"+args.neg_prompt+"_"+str(args.guidance_scale) 81 | save_videos_grid(sample, f"{args.output_path}/{temp_video_name}.mp4", fps=args.fps) 82 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/models/__init__.py -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional, Callable 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from diffusers.configuration_utils import ConfigMixin, register_to_config 10 | from diffusers import ModelMixin 11 | from diffusers.utils import BaseOutput 12 | from diffusers.utils.import_utils import is_xformers_available 13 | from diffusers.models.attention import FeedForward, AdaLayerNorm 14 | from diffusers.models.cross_attention import CrossAttention 15 | from einops import rearrange, repeat 16 | 17 | @dataclass 18 | class Transformer3DModelOutput(BaseOutput): 19 | sample: torch.FloatTensor 20 | 21 | 22 | if is_xformers_available(): 23 | import xformers 24 | import xformers.ops 25 | else: 26 | xformers = None 27 | 28 | 29 | class Transformer3DModel(ModelMixin, ConfigMixin): 30 | @register_to_config 31 | def __init__( 32 | self, 33 | num_attention_heads: int = 16, 34 | attention_head_dim: int = 88, 35 | in_channels: Optional[int] = None, 36 | num_layers: int = 1, 37 | dropout: float = 0.0, 38 | norm_num_groups: int = 32, 39 | cross_attention_dim: Optional[int] = None, 40 | attention_bias: bool = False, 41 | activation_fn: str = "geglu", 42 | num_embeds_ada_norm: Optional[int] = None, 43 | use_linear_projection: bool = False, 44 | only_cross_attention: bool = False, 45 | upcast_attention: bool = False, 46 | ): 47 | super().__init__() 48 | self.use_linear_projection = use_linear_projection 49 | self.num_attention_heads = num_attention_heads 50 | self.attention_head_dim = attention_head_dim 51 | inner_dim = num_attention_heads * attention_head_dim 52 | 53 | # Define input layers 54 | self.in_channels = in_channels 55 | 56 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 57 | if use_linear_projection: 58 | self.proj_in = nn.Linear(in_channels, inner_dim) 59 | else: 60 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 61 | 62 | # Define transformers blocks 63 | self.transformer_blocks = nn.ModuleList( 64 | [ 65 | BasicTransformerBlock( 66 | inner_dim, 67 | num_attention_heads, 68 | attention_head_dim, 69 | dropout=dropout, 70 | cross_attention_dim=cross_attention_dim, 71 | activation_fn=activation_fn, 72 | num_embeds_ada_norm=num_embeds_ada_norm, 73 | attention_bias=attention_bias, 74 | only_cross_attention=only_cross_attention, 75 | upcast_attention=upcast_attention, 76 | ) 77 | for d in range(num_layers) 78 | ] 79 | ) 80 | 81 | # 4. Define output layers 82 | if use_linear_projection: 83 | self.proj_out = nn.Linear(in_channels, inner_dim) 84 | else: 85 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 86 | 87 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, \ 88 | inter_frame=False, **kwargs): 89 | # Input 90 | 91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 92 | video_length = hidden_states.shape[2] 93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 94 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 95 | 96 | batch, channel, height, weight = hidden_states.shape 97 | residual = hidden_states 98 | 99 | # check resolution 100 | resolu = hidden_states.shape[-1] 101 | trajs = {} 102 | trajs["traj"] = kwargs["trajs"]["traj{}".format(resolu)] 103 | trajs["mask"] = kwargs["trajs"]["mask{}".format(resolu)] 104 | trajs["t"] = kwargs["t"] 105 | trajs["old_qk"] = kwargs["old_qk"] 106 | 107 | hidden_states = self.norm(hidden_states) 108 | if not self.use_linear_projection: 109 | hidden_states = self.proj_in(hidden_states) 110 | inner_dim = hidden_states.shape[1] 111 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 112 | else: 113 | inner_dim = hidden_states.shape[1] 114 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 115 | hidden_states = self.proj_in(hidden_states) 116 | 117 | # Blocks 118 | for block in self.transformer_blocks: 119 | hidden_states = block( 120 | hidden_states, 121 | encoder_hidden_states=encoder_hidden_states, 122 | timestep=timestep, 123 | video_length=video_length, 124 | inter_frame=inter_frame, 125 | **trajs 126 | ) 127 | 128 | # Output 129 | if not self.use_linear_projection: 130 | hidden_states = ( 131 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 132 | ) 133 | hidden_states = self.proj_out(hidden_states) 134 | else: 135 | hidden_states = self.proj_out(hidden_states) 136 | hidden_states = ( 137 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 138 | ) 139 | 140 | output = hidden_states + residual 141 | 142 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 143 | if not return_dict: 144 | return (output,) 145 | 146 | return Transformer3DModelOutput(sample=output) 147 | 148 | 149 | class BasicTransformerBlock(nn.Module): 150 | def __init__( 151 | self, 152 | dim: int, 153 | num_attention_heads: int, 154 | attention_head_dim: int, 155 | dropout=0.0, 156 | cross_attention_dim: Optional[int] = None, 157 | activation_fn: str = "geglu", 158 | num_embeds_ada_norm: Optional[int] = None, 159 | attention_bias: bool = False, 160 | only_cross_attention: bool = False, 161 | upcast_attention: bool = False, 162 | ): 163 | super().__init__() 164 | self.only_cross_attention = only_cross_attention 165 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 166 | 167 | # Fully 168 | self.attn1 = FullyFrameAttention( 169 | query_dim=dim, 170 | heads=num_attention_heads, 171 | dim_head=attention_head_dim, 172 | dropout=dropout, 173 | bias=attention_bias, 174 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 175 | upcast_attention=upcast_attention, 176 | ) 177 | 178 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 179 | 180 | # Cross-Attn 181 | if cross_attention_dim is not None: 182 | self.attn2 = CrossAttention( 183 | query_dim=dim, 184 | cross_attention_dim=cross_attention_dim, 185 | heads=num_attention_heads, 186 | dim_head=attention_head_dim, 187 | dropout=dropout, 188 | bias=attention_bias, 189 | upcast_attention=upcast_attention, 190 | ) 191 | else: 192 | self.attn2 = None 193 | 194 | if cross_attention_dim is not None: 195 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 196 | else: 197 | self.norm2 = None 198 | 199 | # Feed-forward 200 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 201 | self.norm3 = nn.LayerNorm(dim) 202 | 203 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None): 204 | if not is_xformers_available(): 205 | print("Here is how to install it") 206 | raise ModuleNotFoundError( 207 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 208 | " xformers", 209 | name="xformers", 210 | ) 211 | elif not torch.cuda.is_available(): 212 | raise ValueError( 213 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 214 | " available for GPU " 215 | ) 216 | else: 217 | try: 218 | # Make sure we can run the memory efficient attention 219 | _ = xformers.ops.memory_efficient_attention( 220 | torch.randn((1, 2, 40), device="cuda"), 221 | torch.randn((1, 2, 40), device="cuda"), 222 | torch.randn((1, 2, 40), device="cuda"), 223 | ) 224 | except Exception as e: 225 | raise e 226 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 227 | if self.attn2 is not None: 228 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 229 | 230 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, \ 231 | inter_frame=False, **kwargs): 232 | # SparseCausal-Attention 233 | norm_hidden_states = ( 234 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 235 | ) 236 | 237 | if self.only_cross_attention: 238 | hidden_states = ( 239 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask, inter_frame=inter_frame, **kwargs) + hidden_states 240 | ) 241 | else: 242 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length, inter_frame=inter_frame, **kwargs) + hidden_states 243 | 244 | if self.attn2 is not None: 245 | # Cross-Attention 246 | norm_hidden_states = ( 247 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 248 | ) 249 | hidden_states = ( 250 | self.attn2( 251 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 252 | ) 253 | + hidden_states 254 | ) 255 | 256 | # Feed-forward 257 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 258 | 259 | return hidden_states 260 | 261 | class FullyFrameAttention(nn.Module): 262 | r""" 263 | A cross attention layer. 264 | 265 | Parameters: 266 | query_dim (`int`): The number of channels in the query. 267 | cross_attention_dim (`int`, *optional*): 268 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. 269 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 270 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 271 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 272 | bias (`bool`, *optional*, defaults to False): 273 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 274 | """ 275 | 276 | def __init__( 277 | self, 278 | query_dim: int, 279 | cross_attention_dim: Optional[int] = None, 280 | heads: int = 8, 281 | dim_head: int = 64, 282 | dropout: float = 0.0, 283 | bias=False, 284 | upcast_attention: bool = False, 285 | upcast_softmax: bool = False, 286 | added_kv_proj_dim: Optional[int] = None, 287 | norm_num_groups: Optional[int] = None, 288 | ): 289 | super().__init__() 290 | inner_dim = dim_head * heads 291 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 292 | self.upcast_attention = upcast_attention 293 | self.upcast_softmax = upcast_softmax 294 | 295 | self.scale = dim_head**-0.5 296 | 297 | self.heads = heads 298 | # for slice_size > 0 the attention score computation 299 | # is split across the batch axis to save memory 300 | # You can set slice_size with `set_attention_slice` 301 | self.sliceable_head_dim = heads 302 | self._slice_size = None 303 | self._use_memory_efficient_attention_xformers = False 304 | self.added_kv_proj_dim = added_kv_proj_dim 305 | 306 | if norm_num_groups is not None: 307 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) 308 | else: 309 | self.group_norm = None 310 | 311 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) 312 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 313 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) 314 | 315 | if self.added_kv_proj_dim is not None: 316 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 317 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) 318 | 319 | self.to_out = nn.ModuleList([]) 320 | self.to_out.append(nn.Linear(inner_dim, query_dim)) 321 | self.to_out.append(nn.Dropout(dropout)) 322 | 323 | self.q = None 324 | self.inject_q = None 325 | self.k = None 326 | self.inject_k = None 327 | 328 | 329 | def reshape_heads_to_batch_dim(self, tensor): 330 | batch_size, seq_len, dim = tensor.shape 331 | head_size = self.heads 332 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 333 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 334 | return tensor 335 | 336 | def reshape_heads_to_batch_dim2(self, tensor): 337 | batch_size, seq_len, dim = tensor.shape 338 | head_size = self.heads 339 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 340 | tensor = tensor.permute(0, 2, 1, 3) 341 | return tensor 342 | 343 | def reshape_heads_to_batch_dim3(self, tensor): 344 | batch_size1, batch_size2, seq_len, dim = tensor.shape 345 | head_size = self.heads 346 | tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size) 347 | tensor = tensor.permute(0, 3, 1, 2, 4) 348 | return tensor 349 | 350 | def reshape_batch_dim_to_heads(self, tensor): 351 | batch_size, seq_len, dim = tensor.shape 352 | head_size = self.heads 353 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 354 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 355 | return tensor 356 | 357 | def set_attention_slice(self, slice_size): 358 | if slice_size is not None and slice_size > self.sliceable_head_dim: 359 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") 360 | 361 | self._slice_size = slice_size 362 | 363 | def _attention(self, query, key, value, attention_mask=None): 364 | if self.upcast_attention: 365 | query = query.float() 366 | key = key.float() 367 | 368 | attention_scores = torch.baddbmm( 369 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 370 | query, 371 | key.transpose(-1, -2), 372 | beta=0, 373 | alpha=self.scale, 374 | ) 375 | if attention_mask is not None: 376 | attention_scores = attention_scores + attention_mask 377 | 378 | if self.upcast_softmax: 379 | attention_scores = attention_scores.float() 380 | 381 | attention_probs = attention_scores.softmax(dim=-1) 382 | 383 | # cast back to the original dtype 384 | attention_probs = attention_probs.to(value.dtype) 385 | 386 | # compute attention output 387 | hidden_states = torch.bmm(attention_probs, value) 388 | 389 | # reshape hidden_states 390 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 391 | return hidden_states 392 | 393 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): 394 | batch_size_attention = query.shape[0] 395 | hidden_states = torch.zeros( 396 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype 397 | ) 398 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] 399 | for i in range(hidden_states.shape[0] // slice_size): 400 | start_idx = i * slice_size 401 | end_idx = (i + 1) * slice_size 402 | 403 | query_slice = query[start_idx:end_idx] 404 | key_slice = key[start_idx:end_idx] 405 | 406 | if self.upcast_attention: 407 | query_slice = query_slice.float() 408 | key_slice = key_slice.float() 409 | 410 | attn_slice = torch.baddbmm( 411 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), 412 | query_slice, 413 | key_slice.transpose(-1, -2), 414 | beta=0, 415 | alpha=self.scale, 416 | ) 417 | 418 | if attention_mask is not None: 419 | attn_slice = attn_slice + attention_mask[start_idx:end_idx] 420 | 421 | if self.upcast_softmax: 422 | attn_slice = attn_slice.float() 423 | 424 | attn_slice = attn_slice.softmax(dim=-1) 425 | 426 | # cast back to the original dtype 427 | attn_slice = attn_slice.to(value.dtype) 428 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 429 | 430 | hidden_states[start_idx:end_idx] = attn_slice 431 | 432 | # reshape hidden_states 433 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 434 | return hidden_states 435 | 436 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 437 | # TODO attention_mask 438 | query = query.contiguous() 439 | key = key.contiguous() 440 | value = value.contiguous() 441 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 442 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 443 | return hidden_states 444 | 445 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, inter_frame=False, **kwargs): 446 | batch_size, sequence_length, _ = hidden_states.shape 447 | 448 | encoder_hidden_states = encoder_hidden_states 449 | 450 | h = w = int(math.sqrt(sequence_length)) 451 | if self.group_norm is not None: 452 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 453 | 454 | query = self.to_q(hidden_states) # (bf) x d(hw) x c 455 | self.q = query 456 | if self.inject_q is not None: 457 | query = self.inject_q 458 | dim = query.shape[-1] 459 | query_old = query.clone() 460 | 461 | # All frames 462 | query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length) 463 | 464 | query = self.reshape_heads_to_batch_dim(query) 465 | 466 | if self.added_kv_proj_dim is not None: 467 | raise NotImplementedError 468 | 469 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 470 | key = self.to_k(encoder_hidden_states) 471 | self.k = key 472 | if self.inject_k is not None: 473 | key = self.inject_k 474 | key_old = key.clone() 475 | value = self.to_v(encoder_hidden_states) 476 | 477 | if inter_frame: 478 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]] 479 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]] 480 | key = rearrange(key, "b f d c -> b (f d) c",) 481 | value = rearrange(value, "b f d c -> b (f d) c") 482 | else: 483 | # All frames 484 | key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length) 485 | value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length) 486 | 487 | key = self.reshape_heads_to_batch_dim(key) 488 | value = self.reshape_heads_to_batch_dim(value) 489 | 490 | if attention_mask is not None: 491 | if attention_mask.shape[-1] != query.shape[1]: 492 | target_length = query.shape[1] 493 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 494 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 495 | 496 | # attention, what we cannot get enough of 497 | if self._use_memory_efficient_attention_xformers: 498 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 499 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 500 | hidden_states = hidden_states.to(query.dtype) 501 | else: 502 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 503 | hidden_states = self._attention(query, key, value, attention_mask) 504 | else: 505 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 506 | 507 | if h in [64]: 508 | hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length) 509 | if self.group_norm is not None: 510 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 511 | 512 | if kwargs["old_qk"] == 1: 513 | query = query_old 514 | key = key_old 515 | else: 516 | query = hidden_states 517 | key = hidden_states 518 | value = hidden_states 519 | 520 | traj = kwargs["traj"] 521 | traj = rearrange(traj, '(f n) l d -> f n l d', f=video_length, n=sequence_length) 522 | mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=video_length, n=sequence_length) 523 | mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -video_length+1:]], dim=-1) 524 | 525 | traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -video_length+1:, :]], dim=-2) 526 | t_inds = traj_key_sequence_inds[:, :, :, 0] 527 | x_inds = traj_key_sequence_inds[:, :, :, 1] 528 | y_inds = traj_key_sequence_inds[:, :, :, 2] 529 | 530 | query_tempo = query.unsqueeze(-2) 531 | _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w) 532 | _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w) 533 | key_tempo = _key[:, t_inds, x_inds, y_inds] 534 | value_tempo = _value[:, t_inds, x_inds, y_inds] 535 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') 536 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') 537 | 538 | mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l') 539 | mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2) 540 | attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like 541 | attn_bias[~mask] = -torch.inf 542 | 543 | # flow attention 544 | query_tempo = self.reshape_heads_to_batch_dim3(query_tempo) 545 | key_tempo = self.reshape_heads_to_batch_dim3(key_tempo) 546 | value_tempo = self.reshape_heads_to_batch_dim3(value_tempo) 547 | 548 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias 549 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1) 550 | out = (attn_matrix2@value_tempo).squeeze(-2) 551 | 552 | hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/video_length), f=video_length, h=h, w=w) 553 | 554 | # linear proj 555 | hidden_states = self.to_out[0](hidden_states) 556 | 557 | # dropout 558 | hidden_states = self.to_out[1](hidden_states) 559 | 560 | # All frames 561 | hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length) 562 | 563 | return hidden_states 564 | -------------------------------------------------------------------------------- /models/pipeline_flatten.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import inspect 17 | import os 18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 19 | from dataclasses import dataclass 20 | 21 | import numpy as np 22 | import PIL.Image 23 | import torch 24 | from transformers import CLIPTextModel, CLIPTokenizer 25 | 26 | from diffusers.models import AutoencoderKL 27 | from diffusers import ModelMixin 28 | from diffusers.schedulers import DDIMScheduler, DDIMInverseScheduler 29 | from diffusers.utils import ( 30 | PIL_INTERPOLATION, 31 | is_accelerate_available, 32 | is_accelerate_version, 33 | logging, 34 | randn_tensor, 35 | BaseOutput 36 | ) 37 | from diffusers.pipeline_utils import DiffusionPipeline 38 | from einops import rearrange 39 | from .unet import UNet3DConditionModel 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | @dataclass 45 | class FlattenPipelineOutput(BaseOutput): 46 | videos: Union[torch.Tensor, np.ndarray] 47 | 48 | class FlattenPipeline(DiffusionPipeline): 49 | r""" 50 | pipeline for FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing. 51 | 52 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 53 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 54 | 55 | Args: 56 | vae ([`AutoencoderKL`]): 57 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 58 | text_encoder ([`CLIPTextModel`]): 59 | Frozen text-encoder. Stable Diffusion uses the text portion of 60 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 61 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 62 | tokenizer (`CLIPTokenizer`): 63 | Tokenizer of class 64 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 65 | unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. 66 | scheduler ([`SchedulerMixin`]): 67 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 68 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 69 | inverse_scheduler ([`SchedulerMixin`]): 70 | DDIM inversion scheduler . 71 | """ 72 | _optional_components = ["safety_checker", "feature_extractor"] 73 | 74 | def __init__( 75 | self, 76 | vae: AutoencoderKL, 77 | text_encoder: CLIPTextModel, 78 | tokenizer: CLIPTokenizer, 79 | unet: UNet3DConditionModel, 80 | scheduler: DDIMScheduler, 81 | inverse_scheduler: DDIMInverseScheduler 82 | ): 83 | super().__init__() 84 | 85 | self.register_modules( 86 | vae=vae, 87 | text_encoder=text_encoder, 88 | tokenizer=tokenizer, 89 | unet=unet, 90 | scheduler=scheduler, 91 | inverse_scheduler=inverse_scheduler 92 | ) 93 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 94 | 95 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing 96 | def enable_vae_slicing(self): 97 | r""" 98 | Enable sliced VAE decoding. 99 | 100 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several 101 | steps. This is useful to save some memory and allow larger batch sizes. 102 | """ 103 | self.vae.enable_slicing() 104 | 105 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing 106 | def disable_vae_slicing(self): 107 | r""" 108 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to 109 | computing decoding in one step. 110 | """ 111 | self.vae.disable_slicing() 112 | 113 | def enable_sequential_cpu_offload(self, gpu_id=0): 114 | r""" 115 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 116 | text_encoder, vae, and safety checker have their state dicts saved to CPU and then are moved to a 117 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 118 | Note that offloading happens on a submodule basis. Memory savings are higher than with 119 | `enable_model_cpu_offload`, but performance is lower. 120 | """ 121 | if is_accelerate_available(): 122 | from accelerate import cpu_offload 123 | else: 124 | raise ImportError("Please install accelerate via `pip install accelerate`") 125 | 126 | device = torch.device(f"cuda:{gpu_id}") 127 | 128 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 129 | cpu_offload(cpu_offloaded_model, device) 130 | 131 | if self.safety_checker is not None: 132 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) 133 | 134 | def enable_model_cpu_offload(self, gpu_id=0): 135 | r""" 136 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 137 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 138 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 139 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 140 | """ 141 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 142 | from accelerate import cpu_offload_with_hook 143 | else: 144 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") 145 | 146 | device = torch.device(f"cuda:{gpu_id}") 147 | 148 | hook = None 149 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: 150 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) 151 | 152 | if self.safety_checker is not None: 153 | # the safety checker can offload the vae again 154 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) 155 | 156 | # We'll offload the last model manually. 157 | self.final_offload_hook = hook 158 | 159 | @property 160 | def _execution_device(self): 161 | r""" 162 | Returns the device on which the pipeline's models will be executed. After calling 163 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 164 | hooks. 165 | """ 166 | if not hasattr(self.unet, "_hf_hook"): 167 | return self.device 168 | for module in self.unet.modules(): 169 | if ( 170 | hasattr(module, "_hf_hook") 171 | and hasattr(module._hf_hook, "execution_device") 172 | and module._hf_hook.execution_device is not None 173 | ): 174 | return torch.device(module._hf_hook.execution_device) 175 | return self.device 176 | 177 | def _encode_prompt( 178 | self, 179 | prompt, 180 | device, 181 | num_videos_per_prompt, 182 | do_classifier_free_guidance, 183 | negative_prompt=None, 184 | prompt_embeds: Optional[torch.FloatTensor] = None, 185 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 186 | ): 187 | r""" 188 | Encodes the prompt into text encoder hidden states. 189 | 190 | Args: 191 | prompt (`str` or `List[str]`, *optional*): 192 | prompt to be encoded 193 | device: (`torch.device`): 194 | torch device 195 | num_videos_per_prompt (`int`): 196 | number of images that should be generated per prompt 197 | do_classifier_free_guidance (`bool`): 198 | whether to use classifier free guidance or not 199 | negative_prompt (`str` or `List[str]`, *optional*): 200 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 201 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 202 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 203 | prompt_embeds (`torch.FloatTensor`, *optional*): 204 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 205 | provided, text embeddings will be generated from `prompt` input argument. 206 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 207 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 208 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 209 | argument. 210 | """ 211 | if prompt is not None and isinstance(prompt, str): 212 | batch_size = 1 213 | elif prompt is not None and isinstance(prompt, list): 214 | batch_size = len(prompt) 215 | else: 216 | batch_size = prompt_embeds.shape[0] 217 | 218 | if prompt_embeds is None: 219 | text_inputs = self.tokenizer( 220 | prompt, 221 | padding="max_length", 222 | max_length=self.tokenizer.model_max_length, 223 | truncation=True, 224 | return_tensors="pt", 225 | ) 226 | text_input_ids = text_inputs.input_ids 227 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 228 | 229 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 230 | text_input_ids, untruncated_ids 231 | ): 232 | removed_text = self.tokenizer.batch_decode( 233 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 234 | ) 235 | logger.warning( 236 | "The following part of your input was truncated because CLIP can only handle sequences up to" 237 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 238 | ) 239 | 240 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 241 | attention_mask = text_inputs.attention_mask.to(device) 242 | else: 243 | attention_mask = None 244 | 245 | prompt_embeds = self.text_encoder( 246 | text_input_ids.to(device), 247 | attention_mask=attention_mask, 248 | ) 249 | prompt_embeds = prompt_embeds[0] 250 | 251 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 252 | 253 | bs_embed, seq_len, _ = prompt_embeds.shape 254 | # duplicate text embeddings for each generation per prompt, using mps friendly method 255 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 256 | prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) 257 | 258 | # get unconditional embeddings for classifier free guidance 259 | if do_classifier_free_guidance and negative_prompt_embeds is None: 260 | uncond_tokens: List[str] 261 | if negative_prompt is None: 262 | uncond_tokens = [""] * batch_size 263 | elif type(prompt) is not type(negative_prompt): 264 | raise TypeError( 265 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 266 | f" {type(prompt)}." 267 | ) 268 | elif isinstance(negative_prompt, str): 269 | uncond_tokens = [negative_prompt] 270 | elif batch_size != len(negative_prompt): 271 | raise ValueError( 272 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 273 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 274 | " the batch size of `prompt`." 275 | ) 276 | else: 277 | uncond_tokens = negative_prompt 278 | 279 | max_length = prompt_embeds.shape[1] 280 | uncond_input = self.tokenizer( 281 | uncond_tokens, 282 | padding="max_length", 283 | max_length=max_length, 284 | truncation=True, 285 | return_tensors="pt", 286 | ) 287 | 288 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 289 | attention_mask = uncond_input.attention_mask.to(device) 290 | else: 291 | attention_mask = None 292 | 293 | negative_prompt_embeds = self.text_encoder( 294 | uncond_input.input_ids.to(device), 295 | attention_mask=attention_mask, 296 | ) 297 | negative_prompt_embeds = negative_prompt_embeds[0] 298 | 299 | if do_classifier_free_guidance: 300 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 301 | seq_len = negative_prompt_embeds.shape[1] 302 | 303 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 304 | 305 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) 306 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 307 | 308 | # For classifier free guidance, we need to do two forward passes. 309 | # Here we concatenate the unconditional and text embeddings into a single batch 310 | # to avoid doing two forward passes 311 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 312 | 313 | return prompt_embeds 314 | 315 | def decode_latents(self, latents, return_tensor=False): 316 | video_length = latents.shape[2] 317 | latents = 1 / 0.18215 * latents 318 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 319 | video = self.vae.decode(latents).sample 320 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 321 | video = (video / 2 + 0.5).clamp(0, 1) 322 | if return_tensor: 323 | return video 324 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 325 | video = video.cpu().float().numpy() 326 | return video 327 | 328 | def prepare_extra_step_kwargs(self, generator, eta): 329 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 330 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 331 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 332 | # and should be between [0, 1] 333 | 334 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 335 | extra_step_kwargs = {} 336 | if accepts_eta: 337 | extra_step_kwargs["eta"] = eta 338 | 339 | # check if the scheduler accepts generator 340 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 341 | if accepts_generator: 342 | extra_step_kwargs["generator"] = generator 343 | return extra_step_kwargs 344 | 345 | def check_inputs( 346 | self, 347 | prompt, 348 | # image, 349 | height, 350 | width, 351 | callback_steps, 352 | negative_prompt=None, 353 | prompt_embeds=None, 354 | negative_prompt_embeds=None, 355 | ): 356 | if height % 8 != 0 or width % 8 != 0: 357 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 358 | 359 | if (callback_steps is None) or ( 360 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 361 | ): 362 | raise ValueError( 363 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 364 | f" {type(callback_steps)}." 365 | ) 366 | 367 | if prompt is not None and prompt_embeds is not None: 368 | raise ValueError( 369 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 370 | " only forward one of the two." 371 | ) 372 | elif prompt is None and prompt_embeds is None: 373 | raise ValueError( 374 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 375 | ) 376 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 377 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 378 | 379 | if negative_prompt is not None and negative_prompt_embeds is not None: 380 | raise ValueError( 381 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 382 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 383 | ) 384 | 385 | if prompt_embeds is not None and negative_prompt_embeds is not None: 386 | if prompt_embeds.shape != negative_prompt_embeds.shape: 387 | raise ValueError( 388 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 389 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 390 | f" {negative_prompt_embeds.shape}." 391 | ) 392 | 393 | 394 | def check_image(self, image, prompt, prompt_embeds): 395 | image_is_pil = isinstance(image, PIL.Image.Image) 396 | image_is_tensor = isinstance(image, torch.Tensor) 397 | image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) 398 | image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) 399 | 400 | if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: 401 | raise TypeError( 402 | "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" 403 | ) 404 | 405 | if image_is_pil: 406 | image_batch_size = 1 407 | elif image_is_tensor: 408 | image_batch_size = image.shape[0] 409 | elif image_is_pil_list: 410 | image_batch_size = len(image) 411 | elif image_is_tensor_list: 412 | image_batch_size = len(image) 413 | 414 | if prompt is not None and isinstance(prompt, str): 415 | prompt_batch_size = 1 416 | elif prompt is not None and isinstance(prompt, list): 417 | prompt_batch_size = len(prompt) 418 | elif prompt_embeds is not None: 419 | prompt_batch_size = prompt_embeds.shape[0] 420 | 421 | if image_batch_size != 1 and image_batch_size != prompt_batch_size: 422 | raise ValueError( 423 | f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" 424 | ) 425 | 426 | def prepare_image( 427 | self, image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance 428 | ): 429 | if not isinstance(image, torch.Tensor): 430 | if isinstance(image, PIL.Image.Image): 431 | image = [image] 432 | 433 | if isinstance(image[0], PIL.Image.Image): 434 | images = [] 435 | 436 | for image_ in image: 437 | image_ = image_.convert("RGB") 438 | image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) 439 | image_ = np.array(image_) 440 | image_ = image_[None, :] 441 | images.append(image_) 442 | 443 | image = images 444 | 445 | image = np.concatenate(image, axis=0) 446 | image = np.array(image).astype(np.float32) / 255.0 447 | image = image.transpose(0, 3, 1, 2) 448 | image = 2.0 * image - 1.0 449 | image = torch.from_numpy(image) 450 | elif isinstance(image[0], torch.Tensor): 451 | image = torch.cat(image, dim=0) 452 | 453 | image_batch_size = image.shape[0] 454 | 455 | if image_batch_size == 1: 456 | repeat_by = batch_size 457 | else: 458 | # image batch size is the same as prompt batch size 459 | repeat_by = num_videos_per_prompt 460 | 461 | image = image.repeat_interleave(repeat_by, dim=0) 462 | 463 | image = image.to(device=device, dtype=dtype) 464 | 465 | return image 466 | 467 | def prepare_video_latents(self, frames, batch_size, dtype, device, generator=None): 468 | if not isinstance(frames, (torch.Tensor, PIL.Image.Image, list)): 469 | raise ValueError( 470 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 471 | ) 472 | 473 | frames = frames[0].to(device=device, dtype=dtype) 474 | frames = rearrange(frames, "c f h w -> f c h w" ) 475 | 476 | if isinstance(generator, list) and len(generator) != batch_size: 477 | raise ValueError( 478 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 479 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 480 | ) 481 | 482 | if isinstance(generator, list): 483 | latents = [ 484 | self.vae.encode(frames[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) 485 | ] 486 | latents = torch.cat(latents, dim=0) 487 | else: 488 | latents = self.vae.encode(frames).latent_dist.sample(generator) 489 | 490 | latents = self.vae.config.scaling_factor * latents 491 | 492 | latents = rearrange(latents, "f c h w ->c f h w" ) 493 | 494 | return latents[None] 495 | 496 | def _default_height_width(self, height, width, image): 497 | # NOTE: It is possible that a list of images have different 498 | # dimensions for each image, so just checking the first image 499 | # is not _exactly_ correct, but it is simple. 500 | while isinstance(image, list): 501 | image = image[0] 502 | 503 | if height is None: 504 | if isinstance(image, PIL.Image.Image): 505 | height = image.height 506 | elif isinstance(image, torch.Tensor): 507 | height = image.shape[3] 508 | 509 | height = (height // 8) * 8 # round down to nearest multiple of 8 510 | 511 | if width is None: 512 | if isinstance(image, PIL.Image.Image): 513 | width = image.width 514 | elif isinstance(image, torch.Tensor): 515 | width = image.shape[2] 516 | 517 | width = (width // 8) * 8 # round down to nearest multiple of 8 518 | 519 | return height, width 520 | 521 | def get_alpha_prev(self, timestep): 522 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 523 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 524 | return alpha_prod_t_prev 525 | 526 | def get_slide_window_indices(self, video_length, window_size): 527 | assert window_size >=3 528 | key_frame_indices = np.arange(0, video_length, window_size-1).tolist() 529 | 530 | # Append last index 531 | if key_frame_indices[-1] != (video_length-1): 532 | key_frame_indices.append(video_length-1) 533 | 534 | slices = np.split(np.arange(video_length), key_frame_indices) 535 | inter_frame_list = [] 536 | for s in slices: 537 | if len(s) < 2: 538 | continue 539 | inter_frame_list.append(s[1:].tolist()) 540 | return key_frame_indices, inter_frame_list 541 | 542 | def get_inverse_timesteps(self, num_inference_steps, strength, device): 543 | # get the original timestep using init_timestep 544 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 545 | 546 | t_start = max(num_inference_steps - init_timestep, 0) 547 | 548 | # safety for t_start overflow to prevent empty timsteps slice 549 | if t_start == 0: 550 | return self.inverse_scheduler.timesteps, num_inference_steps 551 | timesteps = self.inverse_scheduler.timesteps[:-t_start] 552 | 553 | return timesteps, num_inference_steps - t_start 554 | 555 | def clean_features(self): 556 | self.unet.up_blocks[1].resnets[0].out_layers_inject_features = None 557 | self.unet.up_blocks[1].resnets[1].out_layers_inject_features = None 558 | self.unet.up_blocks[2].resnets[0].out_layers_inject_features = None 559 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = None 560 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = None 561 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = None 562 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = None 563 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = None 564 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = None 565 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = None 566 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = None 567 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = None 568 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = None 569 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = None 570 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = None 571 | 572 | @torch.no_grad() 573 | def __call__( 574 | self, 575 | prompt: Union[str, List[str]] = None, 576 | video_length: Optional[int] = 1, 577 | frames: Union[List[torch.FloatTensor], List[PIL.Image.Image], List[List[torch.FloatTensor]], List[List[PIL.Image.Image]]] = None, 578 | height: Optional[int] = None, 579 | width: Optional[int] = None, 580 | num_inference_steps: int = 50, 581 | guidance_scale: float = 7.5, 582 | negative_prompt: Optional[Union[str, List[str]]] = None, 583 | num_videos_per_prompt: Optional[int] = 1, 584 | eta: float = 0.0, 585 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 586 | latents: Optional[torch.FloatTensor] = None, 587 | prompt_embeds: Optional[torch.FloatTensor] = None, 588 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 589 | output_type: Optional[str] = "tensor", 590 | return_dict: bool = True, 591 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 592 | callback_steps: int = 1, 593 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 594 | **kwargs, 595 | ): 596 | r""" 597 | Function invoked when calling the pipeline for generation. 598 | 599 | Args: 600 | prompt (`str` or `List[str]`, *optional*): 601 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 602 | instead. 603 | frames (`List[torch.FloatTensor]`, `List[PIL.Image.Image]`, 604 | `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): 605 | The original video frames to be edited. 606 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 607 | The height in pixels of the generated image. 608 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 609 | The width in pixels of the generated image. 610 | num_inference_steps (`int`, *optional*, defaults to 50): 611 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 612 | expense of slower inference. 613 | guidance_scale (`float`, *optional*, defaults to 7.5): 614 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 615 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 616 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 617 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 618 | usually at the expense of lower image quality. 619 | negative_prompt (`str` or `List[str]`, *optional*): 620 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 621 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 622 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 623 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 624 | The number of images to generate per prompt. 625 | eta (`float`, *optional*, defaults to 0.0): 626 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 627 | [`schedulers.DDIMScheduler`], will be ignored for others. 628 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 629 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 630 | to make generation deterministic. 631 | latents (`torch.FloatTensor`, *optional*): 632 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 633 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 634 | tensor will ge generated by sampling using the supplied random `generator`. 635 | prompt_embeds (`torch.FloatTensor`, *optional*): 636 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 637 | provided, text embeddings will be generated from `prompt` input argument. 638 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 639 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 640 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 641 | argument. 642 | output_type (`str`, *optional*, defaults to `"pil"`): 643 | The output format of the generate image. Choose between 644 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 645 | return_dict (`bool`, *optional*, defaults to `True`): 646 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 647 | plain tuple. 648 | callback (`Callable`, *optional*): 649 | A function that will be called every `callback_steps` steps during inference. The function will be 650 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 651 | callback_steps (`int`, *optional*, defaults to 1): 652 | The frequency at which the `callback` function will be called. If not specified, the callback will be 653 | called at every step. 654 | cross_attention_kwargs (`dict`, *optional*): 655 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 656 | `self.processor` in 657 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 658 | """ 659 | height, width = self._default_height_width(height, width, frames) 660 | 661 | self.check_inputs( 662 | prompt, 663 | height, 664 | width, 665 | callback_steps, 666 | negative_prompt, 667 | prompt_embeds, 668 | negative_prompt_embeds, 669 | ) 670 | 671 | if prompt is not None and isinstance(prompt, str): 672 | batch_size = 1 673 | elif prompt is not None and isinstance(prompt, list): 674 | batch_size = len(prompt) 675 | else: 676 | batch_size = prompt_embeds.shape[0] 677 | 678 | device = self._execution_device 679 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 680 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 681 | # corresponds to doing no classifier free guidance. 682 | do_classifier_free_guidance = guidance_scale > 1.0 683 | 684 | # encode empty prompt 685 | prompt_embeds = self._encode_prompt( 686 | "", 687 | device, 688 | num_videos_per_prompt, 689 | do_classifier_free_guidance=do_classifier_free_guidance, 690 | negative_prompt=None, 691 | prompt_embeds=prompt_embeds, 692 | negative_prompt_embeds=negative_prompt_embeds, 693 | ) 694 | 695 | images = [] 696 | for i_img in frames: 697 | i_img = self.prepare_image( 698 | image=i_img, 699 | width=width, 700 | height=height, 701 | batch_size=batch_size * num_videos_per_prompt, 702 | num_videos_per_prompt=num_videos_per_prompt, 703 | device=device, 704 | dtype=self.unet.dtype, 705 | do_classifier_free_guidance=do_classifier_free_guidance, 706 | ) 707 | images.append(i_img) 708 | frames = torch.stack(images, dim=2) # b x c x f x h x w 709 | 710 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 711 | 712 | latents = self.prepare_video_latents(frames, batch_size, self.unet.dtype, device, generator=generator) 713 | 714 | saved_features0 = [] 715 | saved_features1 = [] 716 | saved_features2 = [] 717 | saved_q4 = [] 718 | saved_k4 = [] 719 | saved_q5 = [] 720 | saved_k5 = [] 721 | saved_q6 = [] 722 | saved_k6 = [] 723 | saved_q7 = [] 724 | saved_k7 = [] 725 | saved_q8 = [] 726 | saved_k8 = [] 727 | saved_q9 = [] 728 | saved_k9 = [] 729 | 730 | # ddim inverse 731 | self.scheduler.set_timesteps(num_inference_steps, device=device) 732 | timesteps = self.scheduler.timesteps 733 | 734 | num_inverse_steps = 100 735 | self.inverse_scheduler.set_timesteps(num_inverse_steps, device=device) 736 | inverse_timesteps, num_inverse_steps = self.get_inverse_timesteps(num_inverse_steps, 1, device) 737 | num_warmup_steps = len(inverse_timesteps) - num_inverse_steps * self.inverse_scheduler.order 738 | 739 | with self.progress_bar(total=num_inverse_steps-1) as progress_bar: 740 | for i, t in enumerate(inverse_timesteps[1:]): 741 | # expand the latents if we are doing classifier free guidance 742 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 743 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) 744 | 745 | noise_pred = self.unet( 746 | latent_model_input, 747 | t, 748 | encoder_hidden_states=prompt_embeds, 749 | cross_attention_kwargs=cross_attention_kwargs, 750 | **kwargs, 751 | ).sample 752 | 753 | if t in timesteps: 754 | saved_features0.append(self.unet.up_blocks[1].resnets[0].out_layers_features.cpu()) 755 | saved_features1.append(self.unet.up_blocks[1].resnets[1].out_layers_features.cpu()) 756 | saved_features2.append(self.unet.up_blocks[2].resnets[0].out_layers_features.cpu()) 757 | saved_q4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.q.cpu()) 758 | saved_k4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.k.cpu()) 759 | saved_q5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.q.cpu()) 760 | saved_k5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.k.cpu()) 761 | saved_q6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.q.cpu()) 762 | saved_k6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.k.cpu()) 763 | saved_q7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.q.cpu()) 764 | saved_k7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.k.cpu()) 765 | saved_q8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.q.cpu()) 766 | saved_k8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.k.cpu()) 767 | saved_q9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.q.cpu()) 768 | saved_k9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.k.cpu()) 769 | 770 | 771 | if do_classifier_free_guidance: 772 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 773 | noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_uncond) 774 | 775 | # compute the previous noisy sample x_t -> x_t-1 776 | latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample 777 | if i == len(inverse_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0): 778 | progress_bar.update() 779 | 780 | saved_features0.reverse() 781 | saved_features1.reverse() 782 | saved_features2.reverse() 783 | saved_q4.reverse() 784 | saved_k4.reverse() 785 | saved_q5.reverse() 786 | saved_k5.reverse() 787 | saved_q6.reverse() 788 | saved_k6.reverse() 789 | saved_q7.reverse() 790 | saved_k7.reverse() 791 | saved_q8.reverse() 792 | saved_k8.reverse() 793 | saved_q9.reverse() 794 | saved_k9.reverse() 795 | 796 | # video sampling 797 | prompt_embeds = self._encode_prompt( 798 | prompt, 799 | device, 800 | num_videos_per_prompt, 801 | do_classifier_free_guidance, 802 | negative_prompt, 803 | prompt_embeds=None, 804 | negative_prompt_embeds=negative_prompt_embeds, 805 | ) 806 | 807 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 808 | with self.progress_bar(total=num_inference_steps) as progress_bar: 809 | for i, t in enumerate(timesteps): 810 | torch.cuda.empty_cache() 811 | 812 | # expand the latents if we are doing classifier free guidance 813 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 814 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 815 | 816 | # inject features 817 | if i < kwargs["inject_step"]: 818 | self.unet.up_blocks[1].resnets[0].out_layers_inject_features = saved_features0[i].to(device) 819 | self.unet.up_blocks[1].resnets[1].out_layers_inject_features = saved_features1[i].to(device) 820 | self.unet.up_blocks[2].resnets[0].out_layers_inject_features = saved_features2[i].to(device) 821 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q4[i].to(device) 822 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k4[i].to(device) 823 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q5[i].to(device) 824 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k5[i].to(device) 825 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q6[i].to(device) 826 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k6[i].to(device) 827 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q7[i].to(device) 828 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k7[i].to(device) 829 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q8[i].to(device) 830 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k8[i].to(device) 831 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q9[i].to(device) 832 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k9[i].to(device) 833 | else: 834 | self.clean_features() 835 | 836 | noise_pred = self.unet( 837 | latent_model_input, 838 | t, 839 | encoder_hidden_states=prompt_embeds, 840 | cross_attention_kwargs=cross_attention_kwargs, 841 | **kwargs, 842 | ).sample 843 | 844 | self.clean_features() 845 | 846 | # perform guidance 847 | if do_classifier_free_guidance: 848 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 849 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 850 | 851 | # compute the previous noisy sample x_t -> x_t-1 852 | step_dict = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs) 853 | latents = step_dict.prev_sample 854 | 855 | # call the callback, if provided 856 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 857 | progress_bar.update() 858 | if callback is not None and i % callback_steps == 0: 859 | callback(i, t, latents) 860 | 861 | # If we do sequential model offloading, let's offload unet 862 | # manually for max memory savings 863 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 864 | self.unet.to("cpu") 865 | torch.cuda.empty_cache() 866 | # Post-processing 867 | video = self.decode_latents(latents) 868 | 869 | # Convert to tensor 870 | if output_type == "tensor": 871 | video = torch.from_numpy(video) 872 | 873 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 874 | self.final_offload_hook.offload() 875 | 876 | if not return_dict: 877 | return video 878 | 879 | return FlattenPipelineOutput(videos=video) 880 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | class TemporalConv1d(nn.Conv1d): 21 | def forward(self, x): 22 | b, c, f, h, w = x.shape 23 | y = rearrange(x.clone(), "b c f h w -> (b h w) c f") 24 | y = super().forward(y) 25 | y = rearrange(y, "(b h w) c f -> b c f h w", b=b, h=h, w=w) 26 | return y 27 | 28 | 29 | class Upsample3D(nn.Module): 30 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 31 | super().__init__() 32 | self.channels = channels 33 | self.out_channels = out_channels or channels 34 | self.use_conv = use_conv 35 | self.use_conv_transpose = use_conv_transpose 36 | self.name = name 37 | 38 | conv = None 39 | if use_conv_transpose: 40 | raise NotImplementedError 41 | elif use_conv: 42 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 43 | 44 | if name == "conv": 45 | self.conv = conv 46 | else: 47 | self.Conv2d_0 = conv 48 | 49 | def forward(self, hidden_states, output_size=None): 50 | assert hidden_states.shape[1] == self.channels 51 | 52 | if self.use_conv_transpose: 53 | raise NotImplementedError 54 | 55 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 56 | dtype = hidden_states.dtype 57 | if dtype == torch.bfloat16: 58 | hidden_states = hidden_states.to(torch.float32) 59 | 60 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 61 | if hidden_states.shape[0] >= 64: 62 | hidden_states = hidden_states.contiguous() 63 | 64 | # if `output_size` is passed we force the interpolation output 65 | # size and do not make use of `scale_factor=2` 66 | if output_size is None: 67 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 68 | else: 69 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 70 | 71 | # If the input is bfloat16, we cast back to bfloat16 72 | if dtype == torch.bfloat16: 73 | hidden_states = hidden_states.to(dtype) 74 | 75 | if self.use_conv: 76 | if self.name == "conv": 77 | hidden_states = self.conv(hidden_states) 78 | else: 79 | hidden_states = self.Conv2d_0(hidden_states) 80 | 81 | return hidden_states 82 | 83 | 84 | class Downsample3D(nn.Module): 85 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 86 | super().__init__() 87 | self.channels = channels 88 | self.out_channels = out_channels or channels 89 | self.use_conv = use_conv 90 | self.padding = padding 91 | stride = 2 92 | self.name = name 93 | 94 | if use_conv: 95 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 96 | else: 97 | raise NotImplementedError 98 | 99 | if name == "conv": 100 | self.Conv2d_0 = conv 101 | self.conv = conv 102 | elif name == "Conv2d_0": 103 | self.conv = conv 104 | else: 105 | self.conv = conv 106 | 107 | def forward(self, hidden_states): 108 | assert hidden_states.shape[1] == self.channels 109 | if self.use_conv and self.padding == 0: 110 | raise NotImplementedError 111 | 112 | assert hidden_states.shape[1] == self.channels 113 | hidden_states = self.conv(hidden_states) 114 | 115 | return hidden_states 116 | 117 | 118 | class ResnetBlock3D(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | in_channels, 123 | out_channels=None, 124 | conv_shortcut=False, 125 | dropout=0.0, 126 | temb_channels=512, 127 | groups=32, 128 | groups_out=None, 129 | pre_norm=True, 130 | eps=1e-6, 131 | non_linearity="swish", 132 | time_embedding_norm="default", 133 | output_scale_factor=1.0, 134 | use_in_shortcut=None, 135 | ): 136 | super().__init__() 137 | self.pre_norm = pre_norm 138 | self.pre_norm = True 139 | self.in_channels = in_channels 140 | out_channels = in_channels if out_channels is None else out_channels 141 | self.out_channels = out_channels 142 | self.use_conv_shortcut = conv_shortcut 143 | self.time_embedding_norm = time_embedding_norm 144 | self.output_scale_factor = output_scale_factor 145 | 146 | if groups_out is None: 147 | groups_out = groups 148 | 149 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 150 | 151 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 152 | 153 | if temb_channels is not None: 154 | if self.time_embedding_norm == "default": 155 | time_emb_proj_out_channels = out_channels 156 | elif self.time_embedding_norm == "scale_shift": 157 | time_emb_proj_out_channels = out_channels * 2 158 | else: 159 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 160 | 161 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 162 | else: 163 | self.time_emb_proj = None 164 | 165 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | # save features 183 | self.out_layers_features = None 184 | self.out_layers_inject_features = None 185 | 186 | def forward(self, input_tensor, temb): 187 | hidden_states = input_tensor 188 | 189 | hidden_states = self.norm1(hidden_states) 190 | hidden_states = self.nonlinearity(hidden_states) 191 | 192 | hidden_states = self.conv1(hidden_states) 193 | 194 | if temb is not None: 195 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 196 | 197 | if temb is not None and self.time_embedding_norm == "default": 198 | hidden_states = hidden_states + temb 199 | 200 | hidden_states = self.norm2(hidden_states) 201 | 202 | if temb is not None and self.time_embedding_norm == "scale_shift": 203 | scale, shift = torch.chunk(temb, 2, dim=1) 204 | hidden_states = hidden_states * (1 + scale) + shift 205 | 206 | hidden_states = self.nonlinearity(hidden_states) 207 | 208 | hidden_states = self.dropout(hidden_states) 209 | hidden_states = self.conv2(hidden_states) 210 | 211 | if self.conv_shortcut is not None: 212 | input_tensor = self.conv_shortcut(input_tensor) 213 | 214 | # save features 215 | self.out_layers_features = hidden_states 216 | if self.out_layers_inject_features is not None: 217 | hidden_states = self.out_layers_inject_features 218 | 219 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 220 | 221 | return output_tensor 222 | 223 | 224 | class Mish(torch.nn.Module): 225 | def forward(self, hidden_states): 226 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 227 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | from diffusers.configuration_utils import ConfigMixin, register_to_config 14 | from diffusers import ModelMixin 15 | from diffusers.utils import BaseOutput, logging 16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 17 | from .unet_blocks import ( 18 | CrossAttnDownBlock3D, 19 | CrossAttnUpBlock3D, 20 | DownBlock3D, 21 | UNetMidBlock3DCrossAttn, 22 | UpBlock3D, 23 | get_down_block, 24 | get_up_block, 25 | ) 26 | from .resnet import InflatedConv3d 27 | 28 | 29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 30 | 31 | 32 | @dataclass 33 | class UNet3DConditionOutput(BaseOutput): 34 | sample: torch.FloatTensor 35 | 36 | 37 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 38 | _supports_gradient_checkpointing = True 39 | 40 | @register_to_config 41 | def __init__( 42 | self, 43 | sample_size: Optional[int] = None, 44 | in_channels: int = 4, 45 | out_channels: int = 4, 46 | center_input_sample: bool = False, 47 | flip_sin_to_cos: bool = True, 48 | freq_shift: int = 0, 49 | down_block_types: Tuple[str] = ( 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "CrossAttnDownBlock3D", 53 | "DownBlock3D", 54 | ), 55 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 56 | up_block_types: Tuple[str] = ( 57 | "UpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D", 60 | "CrossAttnUpBlock3D" 61 | ), 62 | only_cross_attention: Union[bool, Tuple[bool]] = False, 63 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 64 | layers_per_block: int = 2, 65 | downsample_padding: int = 1, 66 | mid_block_scale_factor: float = 1, 67 | act_fn: str = "silu", 68 | norm_num_groups: int = 32, 69 | norm_eps: float = 1e-5, 70 | cross_attention_dim: int = 1280, 71 | attention_head_dim: Union[int, Tuple[int]] = 8, 72 | dual_cross_attention: bool = False, 73 | use_linear_projection: bool = False, 74 | class_embed_type: Optional[str] = None, 75 | num_class_embeds: Optional[int] = None, 76 | upcast_attention: bool = False, 77 | resnet_time_scale_shift: str = "default", 78 | ): 79 | super().__init__() 80 | 81 | self.sample_size = sample_size 82 | time_embed_dim = block_out_channels[0] * 4 83 | 84 | # input 85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 86 | 87 | # time 88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 89 | timestep_input_dim = block_out_channels[0] 90 | 91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 92 | 93 | # class embedding 94 | if class_embed_type is None and num_class_embeds is not None: 95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 96 | elif class_embed_type == "timestep": 97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 98 | elif class_embed_type == "identity": 99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 100 | else: 101 | self.class_embedding = None 102 | 103 | self.down_blocks = nn.ModuleList([]) 104 | self.mid_block = None 105 | self.up_blocks = nn.ModuleList([]) 106 | 107 | if isinstance(only_cross_attention, bool): 108 | only_cross_attention = [only_cross_attention] * len(down_block_types) 109 | 110 | if isinstance(attention_head_dim, int): 111 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 112 | 113 | # down 114 | output_channel = block_out_channels[0] 115 | for i, down_block_type in enumerate(down_block_types): 116 | input_channel = output_channel 117 | output_channel = block_out_channels[i] 118 | is_final_block = i == len(block_out_channels) - 1 119 | 120 | down_block = get_down_block( 121 | down_block_type, 122 | num_layers=layers_per_block, 123 | in_channels=input_channel, 124 | out_channels=output_channel, 125 | temb_channels=time_embed_dim, 126 | add_downsample=not is_final_block, 127 | resnet_eps=norm_eps, 128 | resnet_act_fn=act_fn, 129 | resnet_groups=norm_num_groups, 130 | cross_attention_dim=cross_attention_dim, 131 | attn_num_head_channels=attention_head_dim[i], 132 | downsample_padding=downsample_padding, 133 | dual_cross_attention=dual_cross_attention, 134 | use_linear_projection=use_linear_projection, 135 | only_cross_attention=only_cross_attention[i], 136 | upcast_attention=upcast_attention, 137 | resnet_time_scale_shift=resnet_time_scale_shift, 138 | ) 139 | self.down_blocks.append(down_block) 140 | 141 | # mid 142 | if mid_block_type == "UNetMidBlock3DCrossAttn": 143 | self.mid_block = UNetMidBlock3DCrossAttn( 144 | in_channels=block_out_channels[-1], 145 | temb_channels=time_embed_dim, 146 | resnet_eps=norm_eps, 147 | resnet_act_fn=act_fn, 148 | output_scale_factor=mid_block_scale_factor, 149 | resnet_time_scale_shift=resnet_time_scale_shift, 150 | cross_attention_dim=cross_attention_dim, 151 | attn_num_head_channels=attention_head_dim[-1], 152 | resnet_groups=norm_num_groups, 153 | dual_cross_attention=dual_cross_attention, 154 | use_linear_projection=use_linear_projection, 155 | upcast_attention=upcast_attention, 156 | ) 157 | else: 158 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 159 | 160 | # count how many layers upsample the videos 161 | self.num_upsamplers = 0 162 | 163 | # up 164 | reversed_block_out_channels = list(reversed(block_out_channels)) 165 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 166 | only_cross_attention = list(reversed(only_cross_attention)) 167 | output_channel = reversed_block_out_channels[0] 168 | for i, up_block_type in enumerate(up_block_types): 169 | is_final_block = i == len(block_out_channels) - 1 170 | 171 | prev_output_channel = output_channel 172 | output_channel = reversed_block_out_channels[i] 173 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 174 | 175 | # add upsample block for all BUT final layer 176 | if not is_final_block: 177 | add_upsample = True 178 | self.num_upsamplers += 1 179 | else: 180 | add_upsample = False 181 | 182 | up_block = get_up_block( 183 | up_block_type, 184 | num_layers=layers_per_block + 1, 185 | in_channels=input_channel, 186 | out_channels=output_channel, 187 | prev_output_channel=prev_output_channel, 188 | temb_channels=time_embed_dim, 189 | add_upsample=add_upsample, 190 | resnet_eps=norm_eps, 191 | resnet_act_fn=act_fn, 192 | resnet_groups=norm_num_groups, 193 | cross_attention_dim=cross_attention_dim, 194 | attn_num_head_channels=reversed_attention_head_dim[i], 195 | dual_cross_attention=dual_cross_attention, 196 | use_linear_projection=use_linear_projection, 197 | only_cross_attention=only_cross_attention[i], 198 | upcast_attention=upcast_attention, 199 | resnet_time_scale_shift=resnet_time_scale_shift, 200 | ) 201 | self.up_blocks.append(up_block) 202 | prev_output_channel = output_channel 203 | 204 | # out 205 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 206 | self.conv_act = nn.SiLU() 207 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 208 | 209 | def set_attention_slice(self, slice_size): 210 | r""" 211 | Enable sliced attention computation. 212 | 213 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 214 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 215 | 216 | Args: 217 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 218 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 219 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 220 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 221 | must be a multiple of `slice_size`. 222 | """ 223 | sliceable_head_dims = [] 224 | 225 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 226 | if hasattr(module, "set_attention_slice"): 227 | sliceable_head_dims.append(module.sliceable_head_dim) 228 | 229 | for child in module.children(): 230 | fn_recursive_retrieve_slicable_dims(child) 231 | 232 | # retrieve number of attention layers 233 | for module in self.children(): 234 | fn_recursive_retrieve_slicable_dims(module) 235 | 236 | num_slicable_layers = len(sliceable_head_dims) 237 | 238 | if slice_size == "auto": 239 | # half the attention head size is usually a good trade-off between 240 | # speed and memory 241 | slice_size = [dim // 2 for dim in sliceable_head_dims] 242 | elif slice_size == "max": 243 | # make smallest slice possible 244 | slice_size = num_slicable_layers * [1] 245 | 246 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 247 | 248 | if len(slice_size) != len(sliceable_head_dims): 249 | raise ValueError( 250 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 251 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 252 | ) 253 | 254 | for i in range(len(slice_size)): 255 | size = slice_size[i] 256 | dim = sliceable_head_dims[i] 257 | if size is not None and size > dim: 258 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 259 | 260 | # Recursively walk through all the children. 261 | # Any children which exposes the set_attention_slice method 262 | # gets the message 263 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 264 | if hasattr(module, "set_attention_slice"): 265 | module.set_attention_slice(slice_size.pop()) 266 | 267 | for child in module.children(): 268 | fn_recursive_set_attention_slice(child, slice_size) 269 | 270 | reversed_slice_size = list(reversed(slice_size)) 271 | for module in self.children(): 272 | fn_recursive_set_attention_slice(module, reversed_slice_size) 273 | 274 | def _set_gradient_checkpointing(self, module, value=False): 275 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 276 | module.gradient_checkpointing = value 277 | 278 | def forward( 279 | self, 280 | sample: torch.FloatTensor, 281 | timestep: Union[torch.Tensor, float, int], 282 | encoder_hidden_states: torch.Tensor, 283 | class_labels: Optional[torch.Tensor] = None, 284 | attention_mask: Optional[torch.Tensor] = None, 285 | return_dict: bool = True, 286 | cross_attention_kwargs = None, 287 | inter_frame = False, 288 | **kwargs, 289 | ) -> Union[UNet3DConditionOutput, Tuple]: 290 | r""" 291 | Args: 292 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 293 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 294 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 295 | return_dict (`bool`, *optional*, defaults to `True`): 296 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 297 | 298 | Returns: 299 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 300 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 301 | returning a tuple, the first element is the sample tensor. 302 | """ 303 | # By default samples have to be AT least a multiple of the overall upsampling factor. 304 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 305 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 306 | # on the fly if necessary. 307 | default_overall_up_factor = 2**self.num_upsamplers 308 | kwargs["t"] = timestep 309 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 310 | forward_upsample_size = False 311 | upsample_size = None 312 | 313 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 314 | logger.info("Forward upsample size to force interpolation output size.") 315 | forward_upsample_size = True 316 | 317 | # prepare attention_mask 318 | if attention_mask is not None: 319 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 320 | attention_mask = attention_mask.unsqueeze(1) 321 | 322 | # center input if necessary 323 | if self.config.center_input_sample: 324 | sample = 2 * sample - 1.0 325 | 326 | # time 327 | timesteps = timestep 328 | if not torch.is_tensor(timesteps): 329 | # This would be a good case for the `match` statement (Python 3.10+) 330 | is_mps = sample.device.type == "mps" 331 | if isinstance(timestep, float): 332 | dtype = torch.float32 if is_mps else torch.float64 333 | else: 334 | dtype = torch.int32 if is_mps else torch.int64 335 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 336 | elif len(timesteps.shape) == 0: 337 | timesteps = timesteps[None].to(sample.device) 338 | 339 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 340 | timesteps = timesteps.expand(sample.shape[0]) 341 | 342 | t_emb = self.time_proj(timesteps) 343 | 344 | # timesteps does not contain any weights and will always return f32 tensors 345 | # but time_embedding might actually be running in fp16. so we need to cast here. 346 | # there might be better ways to encapsulate this. 347 | t_emb = t_emb.to(dtype=self.dtype) 348 | emb = self.time_embedding(t_emb) 349 | 350 | if self.class_embedding is not None: 351 | if class_labels is None: 352 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 353 | 354 | if self.config.class_embed_type == "timestep": 355 | class_labels = self.time_proj(class_labels) 356 | 357 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 358 | emb = emb + class_emb 359 | 360 | # pre-process 361 | sample = self.conv_in(sample) 362 | 363 | # down 364 | down_block_res_samples = (sample,) 365 | for downsample_block in self.down_blocks: 366 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 367 | sample, res_samples = downsample_block( 368 | hidden_states=sample, 369 | temb=emb, 370 | encoder_hidden_states=encoder_hidden_states, 371 | attention_mask=attention_mask, 372 | inter_frame=inter_frame, 373 | **kwargs, 374 | ) 375 | else: 376 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 377 | 378 | down_block_res_samples += res_samples 379 | 380 | # mid 381 | sample = self.mid_block( 382 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, 383 | inter_frame=inter_frame, 384 | **kwargs, 385 | ) 386 | 387 | # up 388 | for i, upsample_block in enumerate(self.up_blocks): 389 | is_final_block = i == len(self.up_blocks) - 1 390 | 391 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 392 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 393 | 394 | # if we have not reached the final block and need to forward the 395 | # upsample size, we do it here 396 | if not is_final_block and forward_upsample_size: 397 | upsample_size = down_block_res_samples[-1].shape[2:] 398 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 399 | sample = upsample_block( 400 | hidden_states=sample, 401 | temb=emb, 402 | res_hidden_states_tuple=res_samples, 403 | encoder_hidden_states=encoder_hidden_states, 404 | upsample_size=upsample_size, 405 | attention_mask=attention_mask, 406 | inter_frame=inter_frame, 407 | **kwargs, 408 | ) 409 | else: 410 | sample = upsample_block( 411 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 412 | ) 413 | # post-process 414 | sample = self.conv_norm_out(sample) 415 | sample = self.conv_act(sample) 416 | sample = self.conv_out(sample) 417 | 418 | if not return_dict: 419 | return (sample,) 420 | 421 | return UNet3DConditionOutput(sample=sample) 422 | 423 | @classmethod 424 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): 425 | if subfolder is not None: 426 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 427 | 428 | config_file = os.path.join(pretrained_model_path, 'config.json') 429 | if not os.path.isfile(config_file): 430 | raise RuntimeError(f"{config_file} does not exist") 431 | with open(config_file, "r") as f: 432 | config = json.load(f) 433 | config["_class_name"] = cls.__name__ 434 | config["down_block_types"] = [ 435 | "CrossAttnDownBlock3D", 436 | "CrossAttnDownBlock3D", 437 | "CrossAttnDownBlock3D", 438 | "DownBlock3D" 439 | ] 440 | config["up_block_types"] = [ 441 | "UpBlock3D", 442 | "CrossAttnUpBlock3D", 443 | "CrossAttnUpBlock3D", 444 | "CrossAttnUpBlock3D" 445 | ] 446 | 447 | from diffusers.utils import WEIGHTS_NAME 448 | model = cls.from_config(config) 449 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 450 | if not os.path.isfile(model_file): 451 | raise RuntimeError(f"{model_file} does not exist") 452 | state_dict = torch.load(model_file, map_location="cpu") 453 | # for k, v in model.state_dict().items(): 454 | # if '_temp.' in k: 455 | # state_dict.update({k: v}) 456 | model.load_state_dict(state_dict, strict=False) 457 | 458 | return model 459 | -------------------------------------------------------------------------------- /models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | 9 | 10 | def get_down_block( 11 | down_block_type, 12 | num_layers, 13 | in_channels, 14 | out_channels, 15 | temb_channels, 16 | add_downsample, 17 | resnet_eps, 18 | resnet_act_fn, 19 | attn_num_head_channels, 20 | resnet_groups=None, 21 | cross_attention_dim=None, 22 | downsample_padding=None, 23 | dual_cross_attention=False, 24 | use_linear_projection=False, 25 | only_cross_attention=False, 26 | upcast_attention=False, 27 | resnet_time_scale_shift="default", 28 | ): 29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 30 | if down_block_type == "DownBlock3D": 31 | return DownBlock3D( 32 | num_layers=num_layers, 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | temb_channels=temb_channels, 36 | add_downsample=add_downsample, 37 | resnet_eps=resnet_eps, 38 | resnet_act_fn=resnet_act_fn, 39 | resnet_groups=resnet_groups, 40 | downsample_padding=downsample_padding, 41 | resnet_time_scale_shift=resnet_time_scale_shift, 42 | ) 43 | elif down_block_type == "CrossAttnDownBlock3D": 44 | if cross_attention_dim is None: 45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 46 | return CrossAttnDownBlock3D( 47 | num_layers=num_layers, 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | temb_channels=temb_channels, 51 | add_downsample=add_downsample, 52 | resnet_eps=resnet_eps, 53 | resnet_act_fn=resnet_act_fn, 54 | resnet_groups=resnet_groups, 55 | downsample_padding=downsample_padding, 56 | cross_attention_dim=cross_attention_dim, 57 | attn_num_head_channels=attn_num_head_channels, 58 | dual_cross_attention=dual_cross_attention, 59 | use_linear_projection=use_linear_projection, 60 | only_cross_attention=only_cross_attention, 61 | upcast_attention=upcast_attention, 62 | resnet_time_scale_shift=resnet_time_scale_shift, 63 | ) 64 | raise ValueError(f"{down_block_type} does not exist.") 65 | 66 | 67 | def get_up_block( 68 | up_block_type, 69 | num_layers, 70 | in_channels, 71 | out_channels, 72 | prev_output_channel, 73 | temb_channels, 74 | add_upsample, 75 | resnet_eps, 76 | resnet_act_fn, 77 | attn_num_head_channels, 78 | resnet_groups=None, 79 | cross_attention_dim=None, 80 | dual_cross_attention=False, 81 | use_linear_projection=False, 82 | only_cross_attention=False, 83 | upcast_attention=False, 84 | resnet_time_scale_shift="default", 85 | ): 86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 87 | if up_block_type == "UpBlock3D": 88 | return UpBlock3D( 89 | num_layers=num_layers, 90 | in_channels=in_channels, 91 | out_channels=out_channels, 92 | prev_output_channel=prev_output_channel, 93 | temb_channels=temb_channels, 94 | add_upsample=add_upsample, 95 | resnet_eps=resnet_eps, 96 | resnet_act_fn=resnet_act_fn, 97 | resnet_groups=resnet_groups, 98 | resnet_time_scale_shift=resnet_time_scale_shift, 99 | ) 100 | elif up_block_type == "CrossAttnUpBlock3D": 101 | if cross_attention_dim is None: 102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 103 | return CrossAttnUpBlock3D( 104 | num_layers=num_layers, 105 | in_channels=in_channels, 106 | out_channels=out_channels, 107 | prev_output_channel=prev_output_channel, 108 | temb_channels=temb_channels, 109 | add_upsample=add_upsample, 110 | resnet_eps=resnet_eps, 111 | resnet_act_fn=resnet_act_fn, 112 | resnet_groups=resnet_groups, 113 | cross_attention_dim=cross_attention_dim, 114 | attn_num_head_channels=attn_num_head_channels, 115 | dual_cross_attention=dual_cross_attention, 116 | use_linear_projection=use_linear_projection, 117 | only_cross_attention=only_cross_attention, 118 | upcast_attention=upcast_attention, 119 | resnet_time_scale_shift=resnet_time_scale_shift, 120 | ) 121 | raise ValueError(f"{up_block_type} does not exist.") 122 | 123 | 124 | class UNetMidBlock3DCrossAttn(nn.Module): 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | temb_channels: int, 129 | dropout: float = 0.0, 130 | num_layers: int = 1, 131 | resnet_eps: float = 1e-6, 132 | resnet_time_scale_shift: str = "default", 133 | resnet_act_fn: str = "swish", 134 | resnet_groups: int = 32, 135 | resnet_pre_norm: bool = True, 136 | attn_num_head_channels=1, 137 | output_scale_factor=1.0, 138 | cross_attention_dim=1280, 139 | dual_cross_attention=False, 140 | use_linear_projection=False, 141 | upcast_attention=False, 142 | ): 143 | super().__init__() 144 | 145 | self.has_cross_attention = True 146 | self.attn_num_head_channels = attn_num_head_channels 147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 148 | 149 | # there is always at least one resnet 150 | resnets = [ 151 | ResnetBlock3D( 152 | in_channels=in_channels, 153 | out_channels=in_channels, 154 | temb_channels=temb_channels, 155 | eps=resnet_eps, 156 | groups=resnet_groups, 157 | dropout=dropout, 158 | time_embedding_norm=resnet_time_scale_shift, 159 | non_linearity=resnet_act_fn, 160 | output_scale_factor=output_scale_factor, 161 | pre_norm=resnet_pre_norm, 162 | ) 163 | ] 164 | attentions = [] 165 | 166 | for _ in range(num_layers): 167 | if dual_cross_attention: 168 | raise NotImplementedError 169 | attentions.append( 170 | Transformer3DModel( 171 | attn_num_head_channels, 172 | in_channels // attn_num_head_channels, 173 | in_channels=in_channels, 174 | num_layers=1, 175 | cross_attention_dim=cross_attention_dim, 176 | norm_num_groups=resnet_groups, 177 | use_linear_projection=use_linear_projection, 178 | upcast_attention=upcast_attention, 179 | ) 180 | ) 181 | resnets.append( 182 | ResnetBlock3D( 183 | in_channels=in_channels, 184 | out_channels=in_channels, 185 | temb_channels=temb_channels, 186 | eps=resnet_eps, 187 | groups=resnet_groups, 188 | dropout=dropout, 189 | time_embedding_norm=resnet_time_scale_shift, 190 | non_linearity=resnet_act_fn, 191 | output_scale_factor=output_scale_factor, 192 | pre_norm=resnet_pre_norm, 193 | ) 194 | ) 195 | 196 | self.attentions = nn.ModuleList(attentions) 197 | self.resnets = nn.ModuleList(resnets) 198 | 199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs): 200 | hidden_states = self.resnets[0](hidden_states, temb) 201 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample 203 | hidden_states = resnet(hidden_states, temb) 204 | 205 | return hidden_states 206 | 207 | 208 | class CrossAttnDownBlock3D(nn.Module): 209 | def __init__( 210 | self, 211 | in_channels: int, 212 | out_channels: int, 213 | temb_channels: int, 214 | dropout: float = 0.0, 215 | num_layers: int = 1, 216 | resnet_eps: float = 1e-6, 217 | resnet_time_scale_shift: str = "default", 218 | resnet_act_fn: str = "swish", 219 | resnet_groups: int = 32, 220 | resnet_pre_norm: bool = True, 221 | attn_num_head_channels=1, 222 | cross_attention_dim=1280, 223 | output_scale_factor=1.0, 224 | downsample_padding=1, 225 | add_downsample=True, 226 | dual_cross_attention=False, 227 | use_linear_projection=False, 228 | only_cross_attention=False, 229 | upcast_attention=False, 230 | ): 231 | super().__init__() 232 | resnets = [] 233 | attentions = [] 234 | 235 | self.has_cross_attention = True 236 | self.attn_num_head_channels = attn_num_head_channels 237 | 238 | for i in range(num_layers): 239 | in_channels = in_channels if i == 0 else out_channels 240 | resnets.append( 241 | ResnetBlock3D( 242 | in_channels=in_channels, 243 | out_channels=out_channels, 244 | temb_channels=temb_channels, 245 | eps=resnet_eps, 246 | groups=resnet_groups, 247 | dropout=dropout, 248 | time_embedding_norm=resnet_time_scale_shift, 249 | non_linearity=resnet_act_fn, 250 | output_scale_factor=output_scale_factor, 251 | pre_norm=resnet_pre_norm, 252 | ) 253 | ) 254 | if dual_cross_attention: 255 | raise NotImplementedError 256 | attentions.append( 257 | Transformer3DModel( 258 | attn_num_head_channels, 259 | out_channels // attn_num_head_channels, 260 | in_channels=out_channels, 261 | num_layers=1, 262 | cross_attention_dim=cross_attention_dim, 263 | norm_num_groups=resnet_groups, 264 | use_linear_projection=use_linear_projection, 265 | only_cross_attention=only_cross_attention, 266 | upcast_attention=upcast_attention, 267 | ) 268 | ) 269 | self.attentions = nn.ModuleList(attentions) 270 | self.resnets = nn.ModuleList(resnets) 271 | 272 | if add_downsample: 273 | self.downsamplers = nn.ModuleList( 274 | [ 275 | Downsample3D( 276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 277 | ) 278 | ] 279 | ) 280 | else: 281 | self.downsamplers = None 282 | 283 | self.gradient_checkpointing = False 284 | 285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs): 286 | output_states = () 287 | 288 | for resnet, attn in zip(self.resnets, self.attentions): 289 | if self.training and self.gradient_checkpointing: 290 | 291 | def create_custom_forward(module, return_dict=None, inter_frame=None): 292 | def custom_forward(*inputs): 293 | if return_dict is not None: 294 | return module(*inputs, return_dict=return_dict, inter_frame=inter_frame) 295 | else: 296 | return module(*inputs) 297 | 298 | return custom_forward 299 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 300 | hidden_states = torch.utils.checkpoint.checkpoint( 301 | create_custom_forward(attn, return_dict=False, inter_frame=inter_frame), 302 | hidden_states, 303 | encoder_hidden_states, 304 | )[0] 305 | else: 306 | hidden_states = resnet(hidden_states, temb) 307 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample 308 | 309 | output_states += (hidden_states,) 310 | 311 | if self.downsamplers is not None: 312 | for downsampler in self.downsamplers: 313 | hidden_states = downsampler(hidden_states) 314 | 315 | output_states += (hidden_states,) 316 | 317 | return hidden_states, output_states 318 | 319 | 320 | class DownBlock3D(nn.Module): 321 | def __init__( 322 | self, 323 | in_channels: int, 324 | out_channels: int, 325 | temb_channels: int, 326 | dropout: float = 0.0, 327 | num_layers: int = 1, 328 | resnet_eps: float = 1e-6, 329 | resnet_time_scale_shift: str = "default", 330 | resnet_act_fn: str = "swish", 331 | resnet_groups: int = 32, 332 | resnet_pre_norm: bool = True, 333 | output_scale_factor=1.0, 334 | add_downsample=True, 335 | downsample_padding=1, 336 | ): 337 | super().__init__() 338 | resnets = [] 339 | 340 | for i in range(num_layers): 341 | in_channels = in_channels if i == 0 else out_channels 342 | resnets.append( 343 | ResnetBlock3D( 344 | in_channels=in_channels, 345 | out_channels=out_channels, 346 | temb_channels=temb_channels, 347 | eps=resnet_eps, 348 | groups=resnet_groups, 349 | dropout=dropout, 350 | time_embedding_norm=resnet_time_scale_shift, 351 | non_linearity=resnet_act_fn, 352 | output_scale_factor=output_scale_factor, 353 | pre_norm=resnet_pre_norm, 354 | ) 355 | ) 356 | 357 | self.resnets = nn.ModuleList(resnets) 358 | 359 | if add_downsample: 360 | self.downsamplers = nn.ModuleList( 361 | [ 362 | Downsample3D( 363 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 364 | ) 365 | ] 366 | ) 367 | else: 368 | self.downsamplers = None 369 | 370 | self.gradient_checkpointing = False 371 | 372 | def forward(self, hidden_states, temb=None): 373 | output_states = () 374 | 375 | for resnet in self.resnets: 376 | if self.training and self.gradient_checkpointing: 377 | 378 | def create_custom_forward(module): 379 | def custom_forward(*inputs): 380 | return module(*inputs) 381 | 382 | return custom_forward 383 | 384 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 385 | else: 386 | hidden_states = resnet(hidden_states, temb) 387 | 388 | output_states += (hidden_states,) 389 | 390 | if self.downsamplers is not None: 391 | for downsampler in self.downsamplers: 392 | hidden_states = downsampler(hidden_states) 393 | 394 | output_states += (hidden_states,) 395 | 396 | return hidden_states, output_states 397 | 398 | 399 | class CrossAttnUpBlock3D(nn.Module): 400 | def __init__( 401 | self, 402 | in_channels: int, 403 | out_channels: int, 404 | prev_output_channel: int, 405 | temb_channels: int, 406 | dropout: float = 0.0, 407 | num_layers: int = 1, 408 | resnet_eps: float = 1e-6, 409 | resnet_time_scale_shift: str = "default", 410 | resnet_act_fn: str = "swish", 411 | resnet_groups: int = 32, 412 | resnet_pre_norm: bool = True, 413 | attn_num_head_channels=1, 414 | cross_attention_dim=1280, 415 | output_scale_factor=1.0, 416 | add_upsample=True, 417 | dual_cross_attention=False, 418 | use_linear_projection=False, 419 | only_cross_attention=False, 420 | upcast_attention=False, 421 | ): 422 | super().__init__() 423 | resnets = [] 424 | attentions = [] 425 | 426 | self.has_cross_attention = True 427 | self.attn_num_head_channels = attn_num_head_channels 428 | 429 | for i in range(num_layers): 430 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 431 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 432 | 433 | resnets.append( 434 | ResnetBlock3D( 435 | in_channels=resnet_in_channels + res_skip_channels, 436 | out_channels=out_channels, 437 | temb_channels=temb_channels, 438 | eps=resnet_eps, 439 | groups=resnet_groups, 440 | dropout=dropout, 441 | time_embedding_norm=resnet_time_scale_shift, 442 | non_linearity=resnet_act_fn, 443 | output_scale_factor=output_scale_factor, 444 | pre_norm=resnet_pre_norm, 445 | ) 446 | ) 447 | if dual_cross_attention: 448 | raise NotImplementedError 449 | attentions.append( 450 | Transformer3DModel( 451 | attn_num_head_channels, 452 | out_channels // attn_num_head_channels, 453 | in_channels=out_channels, 454 | num_layers=1, 455 | cross_attention_dim=cross_attention_dim, 456 | norm_num_groups=resnet_groups, 457 | use_linear_projection=use_linear_projection, 458 | only_cross_attention=only_cross_attention, 459 | upcast_attention=upcast_attention, 460 | ) 461 | ) 462 | 463 | self.attentions = nn.ModuleList(attentions) 464 | self.resnets = nn.ModuleList(resnets) 465 | 466 | if add_upsample: 467 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 468 | else: 469 | self.upsamplers = None 470 | 471 | self.gradient_checkpointing = False 472 | 473 | def forward( 474 | self, 475 | hidden_states, 476 | res_hidden_states_tuple, 477 | temb=None, 478 | encoder_hidden_states=None, 479 | upsample_size=None, 480 | attention_mask=None, 481 | inter_frame=False, 482 | **kwargs, 483 | ): 484 | for resnet, attn in zip(self.resnets, self.attentions): 485 | # pop res hidden states 486 | res_hidden_states = res_hidden_states_tuple[-1] 487 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 488 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 489 | 490 | if self.training and self.gradient_checkpointing: 491 | 492 | def create_custom_forward(module, return_dict=None, inter_frame=None): 493 | def custom_forward(*inputs): 494 | if return_dict is not None: 495 | return module(*inputs, return_dict=return_dict, inter_frame=inter_frame) 496 | else: 497 | return module(*inputs) 498 | 499 | return custom_forward 500 | 501 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 502 | hidden_states = torch.utils.checkpoint.checkpoint( 503 | create_custom_forward(attn, return_dict=False, inter_frame=inter_frame), 504 | hidden_states, 505 | encoder_hidden_states, 506 | )[0] 507 | else: 508 | hidden_states = resnet(hidden_states, temb) 509 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample 510 | 511 | if self.upsamplers is not None: 512 | for upsampler in self.upsamplers: 513 | hidden_states = upsampler(hidden_states, upsample_size) 514 | 515 | return hidden_states 516 | 517 | 518 | class UpBlock3D(nn.Module): 519 | def __init__( 520 | self, 521 | in_channels: int, 522 | prev_output_channel: int, 523 | out_channels: int, 524 | temb_channels: int, 525 | dropout: float = 0.0, 526 | num_layers: int = 1, 527 | resnet_eps: float = 1e-6, 528 | resnet_time_scale_shift: str = "default", 529 | resnet_act_fn: str = "swish", 530 | resnet_groups: int = 32, 531 | resnet_pre_norm: bool = True, 532 | output_scale_factor=1.0, 533 | add_upsample=True, 534 | ): 535 | super().__init__() 536 | resnets = [] 537 | 538 | for i in range(num_layers): 539 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 540 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 541 | 542 | resnets.append( 543 | ResnetBlock3D( 544 | in_channels=resnet_in_channels + res_skip_channels, 545 | out_channels=out_channels, 546 | temb_channels=temb_channels, 547 | eps=resnet_eps, 548 | groups=resnet_groups, 549 | dropout=dropout, 550 | time_embedding_norm=resnet_time_scale_shift, 551 | non_linearity=resnet_act_fn, 552 | output_scale_factor=output_scale_factor, 553 | pre_norm=resnet_pre_norm, 554 | ) 555 | ) 556 | 557 | self.resnets = nn.ModuleList(resnets) 558 | 559 | if add_upsample: 560 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 561 | else: 562 | self.upsamplers = None 563 | 564 | self.gradient_checkpointing = False 565 | 566 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 567 | for resnet in self.resnets: 568 | # pop res hidden states 569 | res_hidden_states = res_hidden_states_tuple[-1] 570 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 571 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 572 | 573 | if self.training and self.gradient_checkpointing: 574 | 575 | def create_custom_forward(module): 576 | def custom_forward(*inputs): 577 | return module(*inputs) 578 | 579 | return custom_forward 580 | 581 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 582 | else: 583 | hidden_states = resnet(hidden_states, temb) 584 | 585 | if self.upsamplers is not None: 586 | for upsampler in self.upsamplers: 587 | hidden_states = upsampler(hidden_states, upsample_size) 588 | 589 | return hidden_states 590 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | import decord 6 | decord.bridge.set_bridge('torch') 7 | import torch 8 | import torchvision 9 | import PIL 10 | from typing import List 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | import torchvision.transforms.functional as F 14 | import random 15 | 16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 17 | videos = rearrange(videos, "b c t h w -> t b c h w") 18 | outputs = [] 19 | for x in videos: 20 | x = torchvision.utils.make_grid(x, nrow=n_rows) 21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 22 | if rescale: 23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 24 | x = (x * 255).numpy().astype(np.uint8) 25 | outputs.append(x) 26 | 27 | os.makedirs(os.path.dirname(path), exist_ok=True) 28 | imageio.mimsave(path, outputs, fps=fps) 29 | 30 | def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8): 31 | videos = rearrange(videos, "b c t h w -> t b c h w") 32 | outputs = [] 33 | for x in videos: 34 | x = torchvision.utils.make_grid(x, nrow=n_rows) 35 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 36 | if rescale: 37 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 38 | x = (x * 255).numpy().astype(np.uint8) 39 | outputs.append(x) 40 | 41 | os.makedirs(os.path.dirname(path), exist_ok=True) 42 | imageio.mimsave(path, outputs, fps=fps) 43 | 44 | def read_video(video_path, video_length, width=512, height=512, frame_rate=None): 45 | vr = decord.VideoReader(video_path, width=width, height=height) 46 | if frame_rate is None: 47 | frame_rate = max(1, len(vr) // video_length) 48 | sample_index = list(range(0, len(vr), frame_rate))[:video_length] 49 | video = vr.get_batch(sample_index) 50 | video = rearrange(video, "f h w c -> f c h w") 51 | video = (video / 127.5 - 1.0) 52 | return video 53 | 54 | 55 | # DDIM Inversion 56 | @torch.no_grad() 57 | def init_prompt(prompt, pipeline): 58 | uncond_input = pipeline.tokenizer( 59 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 60 | return_tensors="pt" 61 | ) 62 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 63 | text_input = pipeline.tokenizer( 64 | [prompt], 65 | padding="max_length", 66 | max_length=pipeline.tokenizer.model_max_length, 67 | truncation=True, 68 | return_tensors="pt", 69 | ) 70 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 71 | context = torch.cat([uncond_embeddings, text_embeddings]) 72 | 73 | return context 74 | 75 | 76 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 77 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 78 | timestep, next_timestep = min( 79 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 80 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 81 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 82 | beta_prod_t = 1 - alpha_prod_t 83 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 84 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 85 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 86 | return next_sample 87 | 88 | 89 | def get_noise_pred_single(latents, t, context, unet): 90 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 91 | return noise_pred 92 | 93 | 94 | @torch.no_grad() 95 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 96 | context = init_prompt(prompt, pipeline) 97 | uncond_embeddings, cond_embeddings = context.chunk(2) 98 | all_latent = [latent] 99 | latent = latent.clone().detach() 100 | for i in tqdm(range(num_inv_steps)): 101 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 102 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 103 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 104 | all_latent.append(latent) 105 | return all_latent 106 | 107 | 108 | @torch.no_grad() 109 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 110 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 111 | return ddim_latents 112 | 113 | 114 | """optical flow and trajectories sampling""" 115 | def preprocess(img1_batch, img2_batch, transforms): 116 | img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False) 117 | img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False) 118 | return transforms(img1_batch, img2_batch) 119 | 120 | def keys_with_same_value(dictionary): 121 | result = {} 122 | for key, value in dictionary.items(): 123 | if value not in result: 124 | result[value] = [key] 125 | else: 126 | result[value].append(key) 127 | 128 | conflict_points = {} 129 | for k in result.keys(): 130 | if len(result[k]) > 1: 131 | conflict_points[k] = result[k] 132 | return conflict_points 133 | 134 | def find_duplicates(input_list): 135 | seen = set() 136 | duplicates = set() 137 | 138 | for item in input_list: 139 | if item in seen: 140 | duplicates.add(item) 141 | else: 142 | seen.add(item) 143 | 144 | return list(duplicates) 145 | 146 | def neighbors_index(point, window_size, H, W): 147 | """return the spatial neighbor indices""" 148 | t, x, y = point 149 | neighbors = [] 150 | for i in range(-window_size, window_size + 1): 151 | for j in range(-window_size, window_size + 1): 152 | if i == 0 and j == 0: 153 | continue 154 | if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W: 155 | continue 156 | neighbors.append((t, x + i, y + j)) 157 | return neighbors 158 | 159 | 160 | @torch.no_grad() 161 | def sample_trajectories(video_path, device): 162 | from torchvision.models.optical_flow import Raft_Large_Weights 163 | from torchvision.models.optical_flow import raft_large 164 | 165 | weights = Raft_Large_Weights.DEFAULT 166 | transforms = weights.transforms() 167 | 168 | frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW") 169 | 170 | clips = list(range(len(frames))) 171 | 172 | model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) 173 | model = model.eval() 174 | 175 | finished_trajectories = [] 176 | 177 | current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms) 178 | list_of_flows = model(current_frames.to(device), next_frames.to(device)) 179 | predicted_flows = list_of_flows[-1] 180 | 181 | predicted_flows = predicted_flows/512 182 | 183 | resolutions = [64, 32, 16, 8] 184 | res = {} 185 | window_sizes = {64: 2, 186 | 32: 1, 187 | 16: 1, 188 | 8: 1} 189 | 190 | for resolution in resolutions: 191 | print("="*30) 192 | trajectories = {} 193 | predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512))) 194 | 195 | T = predicted_flow_resolu.shape[0]+1 196 | H = predicted_flow_resolu.shape[2] 197 | W = predicted_flow_resolu.shape[3] 198 | 199 | is_activated = torch.zeros([T, H, W], dtype=torch.bool) 200 | 201 | for t in range(T-1): 202 | flow = predicted_flow_resolu[t] 203 | for h in range(H): 204 | for w in range(W): 205 | 206 | if not is_activated[t, h, w]: 207 | is_activated[t, h, w] = True 208 | # this point has not been traversed, start new trajectory 209 | x = h + int(flow[1, h, w]) 210 | y = w + int(flow[0, h, w]) 211 | if x >= 0 and x < H and y >= 0 and y < W: 212 | # trajectories.append([(t, h, w), (t+1, x, y)]) 213 | trajectories[(t, h, w)]= (t+1, x, y) 214 | 215 | conflict_points = keys_with_same_value(trajectories) 216 | for k in conflict_points: 217 | index_to_pop = random.randint(0, len(conflict_points[k]) - 1) 218 | conflict_points[k].pop(index_to_pop) 219 | for point in conflict_points[k]: 220 | if point[0] != T-1: 221 | trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1) 222 | 223 | active_traj = [] 224 | all_traj = [] 225 | for t in range(T): 226 | pixel_set = {(t, x//H, x%H):0 for x in range(H*W)} 227 | new_active_traj = [] 228 | for traj in active_traj: 229 | if traj[-1] in trajectories: 230 | v = trajectories[traj[-1]] 231 | new_active_traj.append(traj + [v]) 232 | pixel_set[v] = 1 233 | else: 234 | all_traj.append(traj) 235 | active_traj = new_active_traj 236 | active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0] 237 | all_traj += active_traj 238 | 239 | useful_traj = [i for i in all_traj if len(i)>1] 240 | for idx in range(len(useful_traj)): 241 | if useful_traj[idx][-1] == (-1, -1, -1): 242 | useful_traj[idx] = useful_traj[idx][:-1] 243 | print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj])) 244 | print("how many points in the video for resolution{}?".format(resolution), T*H*W) 245 | 246 | # validate if there are no duplicates in the trajectories 247 | trajs = [] 248 | for traj in useful_traj: 249 | trajs = trajs + traj 250 | assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories." 251 | 252 | # check if non-appearing points + appearing points = all the points in the video 253 | all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)]) 254 | left_points = all_points- set(trajs) 255 | print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points)) 256 | for p in list(left_points): 257 | useful_traj.append([p]) 258 | print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj])) 259 | 260 | 261 | longest_length = max([len(i) for i in useful_traj]) 262 | sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1 263 | 264 | seqs = [] 265 | masks = [] 266 | 267 | # create a dictionary to facilitate checking the trajectories to which each point belongs. 268 | point_to_traj = {} 269 | for traj in useful_traj: 270 | for p in traj: 271 | point_to_traj[p] = traj 272 | 273 | for t in range(T): 274 | for x in range(H): 275 | for y in range(W): 276 | neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W) 277 | sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))] 278 | sequence_mask = torch.zeros(sequence_length, dtype=torch.bool) 279 | sequence_mask[:len(neighbours)+1] = True 280 | 281 | traj = point_to_traj[(t,x,y)].copy() 282 | traj.remove((t,x,y)) 283 | sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))] 284 | sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True 285 | 286 | seqs.append(sequence) 287 | masks.append(sequence_mask) 288 | 289 | seqs = torch.tensor(seqs) 290 | masks = torch.stack(masks) 291 | res["traj{}".format(resolution)] = seqs 292 | res["mask{}".format(resolution)] = masks 293 | return res 294 | 295 | -------------------------------------------------------------------------------- /truck.sh: -------------------------------------------------------------------------------- 1 | python inference.py \ 2 | --prompt "Wooden trucks drive on a racetrack." \ 3 | --neg_prompt " " \ 4 | --guidance_scale 15 \ 5 | --video_path "data/trucks-race.mp4" \ 6 | --output_path "truck/" \ 7 | --video_length 32 \ 8 | --width 512 \ 9 | --height 512 \ 10 | --frame_rate 1 \ 11 | --old_qk 1 12 | --------------------------------------------------------------------------------