├── .gitignore
├── LICENSE
├── README.md
├── background.sh
├── cat.sh
├── data
├── background.mp4
├── guidance10.gif
├── guidance17.5.gif
├── guidance20.gif
├── puff.mp4
├── source.gif
├── tiger_empty.gif
├── tiger_neg.gif
└── trucks-race.mp4
├── inference.py
├── models
├── __init__.py
├── attention.py
├── pipeline_flatten.py
├── resnet.py
├── unet.py
├── unet_blocks.py
└── util.py
└── truck.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing
2 | [](https://arxiv.org/abs/2310.05922)
3 | [](https://flatten-video-editing.github.io/)
4 | [](https://hits.seeyoufarm.com)
5 |
6 | **Pytorch Implementation of "FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing".**
7 |
8 | 🎊🎊🎊 We are proud to announce that our paper has been accepted at **ICLR 2024**! If you are interested in FLATTEN, please give us a star😬
9 | 
10 |
11 | Thanks to @[**logtd**](https://github.com/logtd) for integrating FLATTEN into ComfyUI and the great sampled videos! **Here is the [Link](https://github.com/logtd/ComfyUI-FLATTEN?tab=readme-ov-file)!**
12 |
13 | https://github.com/yrcong/flatten/assets/47991543/1ad49092-9133-42d0-984f-38c6427bde34
14 |
15 |
16 | ## 📖Abstract
17 | 🚩**Text-to-Video** 🚩**Training-free** 🚩**Plug-and-Play**
18 |
19 | Text-to-video editing aims to edit the visual appearance of a source video conditional on textual prompts. A major challenge in this task is to ensure that all frames in the edited video are visually consistent. In this work, for the first time, we introduce optical flow into the attention module in the diffusion model's U-Net to address the inconsistency issue for text-to-video editing. Our method, FLATTEN, enforces the patches on the same flow path across different frames to attend to each other in the attention module, thus improving the visual consistency in the edited videos. Additionally, our method is training-free and can be seamlessly integrated into any diffusion-based text-to-video editing methods and improve their visual consistency.
20 |
21 | ## Requirements
22 | First you can download Stable Diffusion 2.1 **(base)** [here](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
23 |
24 | Install the following packages:
25 | - PyTorch == 2.1
26 | - accelerate == 0.24.1
27 | - diffusers == 0.19.0
28 | - transformers == 4.35.0
29 | - xformers == 0.0.23
30 |
31 | ## Usage
32 | For text-to-video edting, a source video and a textual prompt should be given. You can run the script to get the teaser video easily:
33 | ```
34 | sh cat.sh
35 | ```
36 | or with the command:
37 | ```
38 | python inference.py \
39 | --prompt "A Tiger, high quality" \
40 | --neg_prompt "a cat with big eyes, deformed" \
41 | --guidance_scale 20 \
42 | --video_path "data/puff.mp4" \
43 | --output_path "outputs/" \
44 | --video_length 32 \
45 | --width 512 \
46 | --height 512 \
47 | --old_qk 0 \
48 | --frame_rate 2 \
49 | ```
50 |
51 | ## Editing tricks
52 | - You can use a negative prompt (NP) when there is a big gap between the edit target and the source (1st row).
53 | - You can increase the scale of classifier-free guidance to enhance the semantic alignment (2nd row).
54 |
55 |
56 |
57 |  |
58 |  |
59 |  |
60 |
61 |
62 | Source video |
63 | NP: " " |
64 | NP: "A cat with big eyes, deformed." |
65 |
66 |
67 |  |
68 |  |
69 |  |
70 |
71 |
72 | Classifier-free guidance: 10 |
73 | Classifier-free guidance: 17.5 |
74 | Classifier-free guidance: 25 |
75 |
76 |
77 |
78 |
79 | ## BibTex
80 | ```
81 | @article{cong2023flatten,
82 | title={FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing},
83 | author={Cong, Yuren and Xu, Mengmeng and Simon, Christian and Chen, Shoufa and Ren, Jiawei and Xie, Yanping and Perez-Rua, Juan-Manuel and Rosenhahn, Bodo and Xiang, Tao and He, Sen},
84 | journal={arXiv preprint arXiv:2310.05922},
85 | year={2023}
86 | }
87 | ```
88 |
--------------------------------------------------------------------------------
/background.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --prompt "pointillism painting, detailed" \
3 | --neg_prompt " " \
4 | --guidance 25 \
5 | --video_path "data/background.mp4" \
6 | --output_path "outputs/" \
7 | --video_length 32 \
8 | --width 512 \
9 | --height 512 \
10 | --old_qk 1 \
11 | --frame_rate 1 \
12 |
--------------------------------------------------------------------------------
/cat.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --prompt "A Tiger, high quality" \
3 | --neg_prompt "a cat with big eyes, deformed" \
4 | --guidance_scale 20 \
5 | --video_path "data/puff.mp4" \
6 | --output_path "outputs/" \
7 | --video_length 32 \
8 | --width 512 \
9 | --height 512 \
10 | --old_qk 0 \
11 | --frame_rate 2 \
12 |
--------------------------------------------------------------------------------
/data/background.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/background.mp4
--------------------------------------------------------------------------------
/data/guidance10.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance10.gif
--------------------------------------------------------------------------------
/data/guidance17.5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance17.5.gif
--------------------------------------------------------------------------------
/data/guidance20.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/guidance20.gif
--------------------------------------------------------------------------------
/data/puff.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/puff.mp4
--------------------------------------------------------------------------------
/data/source.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/source.gif
--------------------------------------------------------------------------------
/data/tiger_empty.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/tiger_empty.gif
--------------------------------------------------------------------------------
/data/tiger_neg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/tiger_neg.gif
--------------------------------------------------------------------------------
/data/trucks-race.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/data/trucks-race.mp4
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torchvision
5 | from einops import rearrange
6 | from diffusers import DDIMScheduler, AutoencoderKL, DDIMInverseScheduler
7 | from transformers import CLIPTextModel, CLIPTokenizer
8 |
9 | from models.pipeline_flatten import FlattenPipeline
10 | from models.util import save_videos_grid, read_video, sample_trajectories
11 | from models.unet import UNet3DConditionModel
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--prompt", type=str, required=True, help="Textual prompt for video editing")
16 | parser.add_argument("--neg_prompt", type=str, required=True, help="Negative prompt for guidance")
17 | parser.add_argument("--guidance_scale", default=10.0, type=float, help="Guidance scale")
18 | parser.add_argument("--video_path", type=str, required=True, help="Path to a source video")
19 | parser.add_argument("--sd_path", type=str, default="checkpoints/stable-diffusion-2-1-base", help="Path of Stable Diffusion")
20 | parser.add_argument("--output_path", type=str, default="./outputs", help="Directory of output")
21 | parser.add_argument("--video_length", type=int, default=15, help="Length of output video")
22 | parser.add_argument("--old_qk", type=int, default=0, help="Whether to use old queries and keys for flow-guided attention")
23 | parser.add_argument("--height", type=int, default=512, help="Height of synthesized video, and should be a multiple of 32")
24 | parser.add_argument("--width", type=int, default=512, help="Width of synthesized video, and should be a multiple of 32")
25 | parser.add_argument("--sample_steps", type=int, default=50, help="Steps for feature injection")
26 | parser.add_argument("--inject_step", type=int, default=40, help="Steps for feature injection")
27 | parser.add_argument("--seed", type=int, default=66, help="Random seed of generator")
28 | parser.add_argument("--frame_rate", type=int, default=None, help="The frame rate of loading input video. Default rate is computed according to video length.")
29 | parser.add_argument("--fps", type=int, default=15, help="FPS of the output video")
30 | args = parser.parse_args()
31 | return args
32 |
33 | if __name__ == "__main__":
34 |
35 | args = get_args()
36 | os.makedirs(args.output_path, exist_ok=True)
37 | device = "cuda"
38 | # Height and width should be 512
39 | args.height = (args.height // 32) * 32
40 | args.width = (args.width // 32) * 32
41 |
42 | tokenizer = CLIPTokenizer.from_pretrained(args.sd_path, subfolder="tokenizer")
43 | text_encoder = CLIPTextModel.from_pretrained(args.sd_path, subfolder="text_encoder").to(dtype=torch.float16)
44 | vae = AutoencoderKL.from_pretrained(args.sd_path, subfolder="vae").to(dtype=torch.float16)
45 | unet = UNet3DConditionModel.from_pretrained_2d(args.sd_path, subfolder="unet").to(dtype=torch.float16)
46 | scheduler=DDIMScheduler.from_pretrained(args.sd_path, subfolder="scheduler")
47 | inverse=DDIMInverseScheduler.from_pretrained(args.sd_path, subfolder="scheduler")
48 |
49 | pipe = FlattenPipeline(
50 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
51 | scheduler=scheduler, inverse_scheduler=inverse)
52 | pipe.enable_vae_slicing()
53 | pipe.enable_xformers_memory_efficient_attention()
54 | pipe.to(device)
55 |
56 | generator = torch.Generator(device=device)
57 | generator.manual_seed(args.seed)
58 |
59 | # read the source video
60 | video = read_video(video_path=args.video_path, video_length=args.video_length,
61 | width=args.width, height=args.height, frame_rate=args.frame_rate)
62 | original_pixels = rearrange(video, "(b f) c h w -> b c f h w", b=1)
63 | save_videos_grid(original_pixels, os.path.join(args.output_path, "source_video.mp4"), rescale=True)
64 |
65 | t2i_transform = torchvision.transforms.ToPILImage()
66 | real_frames = []
67 | for i, frame in enumerate(video):
68 | real_frames.append(t2i_transform(((frame+1)/2*255).to(torch.uint8)))
69 |
70 | # compute optical flows and sample trajectories
71 | trajectories = sample_trajectories(os.path.join(args.output_path, "source_video.mp4"), device)
72 | torch.cuda.empty_cache()
73 |
74 | for k in trajectories.keys():
75 | trajectories[k] = trajectories[k].to(device)
76 | sample = pipe(args.prompt, video_length=args.video_length, frames=real_frames,
77 | num_inference_steps=args.sample_steps, generator=generator, guidance_scale=args.guidance_scale,
78 | negative_prompt=args.neg_prompt, width=args.width, height=args.height,
79 | trajs=trajectories, output_dir="tmp/", inject_step=args.inject_step, old_qk=args.old_qk).videos
80 | temp_video_name = args.prompt+"_"+args.neg_prompt+"_"+str(args.guidance_scale)
81 | save_videos_grid(sample, f"{args.output_path}/{temp_video_name}.mp4", fps=args.fps)
82 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yrcong/flatten/ac2d83bf10363c670c911b6857614b9b41ef6eb4/models/__init__.py
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional, Callable
5 | import math
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from diffusers.configuration_utils import ConfigMixin, register_to_config
10 | from diffusers import ModelMixin
11 | from diffusers.utils import BaseOutput
12 | from diffusers.utils.import_utils import is_xformers_available
13 | from diffusers.models.attention import FeedForward, AdaLayerNorm
14 | from diffusers.models.cross_attention import CrossAttention
15 | from einops import rearrange, repeat
16 |
17 | @dataclass
18 | class Transformer3DModelOutput(BaseOutput):
19 | sample: torch.FloatTensor
20 |
21 |
22 | if is_xformers_available():
23 | import xformers
24 | import xformers.ops
25 | else:
26 | xformers = None
27 |
28 |
29 | class Transformer3DModel(ModelMixin, ConfigMixin):
30 | @register_to_config
31 | def __init__(
32 | self,
33 | num_attention_heads: int = 16,
34 | attention_head_dim: int = 88,
35 | in_channels: Optional[int] = None,
36 | num_layers: int = 1,
37 | dropout: float = 0.0,
38 | norm_num_groups: int = 32,
39 | cross_attention_dim: Optional[int] = None,
40 | attention_bias: bool = False,
41 | activation_fn: str = "geglu",
42 | num_embeds_ada_norm: Optional[int] = None,
43 | use_linear_projection: bool = False,
44 | only_cross_attention: bool = False,
45 | upcast_attention: bool = False,
46 | ):
47 | super().__init__()
48 | self.use_linear_projection = use_linear_projection
49 | self.num_attention_heads = num_attention_heads
50 | self.attention_head_dim = attention_head_dim
51 | inner_dim = num_attention_heads * attention_head_dim
52 |
53 | # Define input layers
54 | self.in_channels = in_channels
55 |
56 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
57 | if use_linear_projection:
58 | self.proj_in = nn.Linear(in_channels, inner_dim)
59 | else:
60 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
61 |
62 | # Define transformers blocks
63 | self.transformer_blocks = nn.ModuleList(
64 | [
65 | BasicTransformerBlock(
66 | inner_dim,
67 | num_attention_heads,
68 | attention_head_dim,
69 | dropout=dropout,
70 | cross_attention_dim=cross_attention_dim,
71 | activation_fn=activation_fn,
72 | num_embeds_ada_norm=num_embeds_ada_norm,
73 | attention_bias=attention_bias,
74 | only_cross_attention=only_cross_attention,
75 | upcast_attention=upcast_attention,
76 | )
77 | for d in range(num_layers)
78 | ]
79 | )
80 |
81 | # 4. Define output layers
82 | if use_linear_projection:
83 | self.proj_out = nn.Linear(in_channels, inner_dim)
84 | else:
85 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
86 |
87 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, \
88 | inter_frame=False, **kwargs):
89 | # Input
90 |
91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
92 | video_length = hidden_states.shape[2]
93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
94 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
95 |
96 | batch, channel, height, weight = hidden_states.shape
97 | residual = hidden_states
98 |
99 | # check resolution
100 | resolu = hidden_states.shape[-1]
101 | trajs = {}
102 | trajs["traj"] = kwargs["trajs"]["traj{}".format(resolu)]
103 | trajs["mask"] = kwargs["trajs"]["mask{}".format(resolu)]
104 | trajs["t"] = kwargs["t"]
105 | trajs["old_qk"] = kwargs["old_qk"]
106 |
107 | hidden_states = self.norm(hidden_states)
108 | if not self.use_linear_projection:
109 | hidden_states = self.proj_in(hidden_states)
110 | inner_dim = hidden_states.shape[1]
111 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
112 | else:
113 | inner_dim = hidden_states.shape[1]
114 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
115 | hidden_states = self.proj_in(hidden_states)
116 |
117 | # Blocks
118 | for block in self.transformer_blocks:
119 | hidden_states = block(
120 | hidden_states,
121 | encoder_hidden_states=encoder_hidden_states,
122 | timestep=timestep,
123 | video_length=video_length,
124 | inter_frame=inter_frame,
125 | **trajs
126 | )
127 |
128 | # Output
129 | if not self.use_linear_projection:
130 | hidden_states = (
131 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
132 | )
133 | hidden_states = self.proj_out(hidden_states)
134 | else:
135 | hidden_states = self.proj_out(hidden_states)
136 | hidden_states = (
137 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
138 | )
139 |
140 | output = hidden_states + residual
141 |
142 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
143 | if not return_dict:
144 | return (output,)
145 |
146 | return Transformer3DModelOutput(sample=output)
147 |
148 |
149 | class BasicTransformerBlock(nn.Module):
150 | def __init__(
151 | self,
152 | dim: int,
153 | num_attention_heads: int,
154 | attention_head_dim: int,
155 | dropout=0.0,
156 | cross_attention_dim: Optional[int] = None,
157 | activation_fn: str = "geglu",
158 | num_embeds_ada_norm: Optional[int] = None,
159 | attention_bias: bool = False,
160 | only_cross_attention: bool = False,
161 | upcast_attention: bool = False,
162 | ):
163 | super().__init__()
164 | self.only_cross_attention = only_cross_attention
165 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
166 |
167 | # Fully
168 | self.attn1 = FullyFrameAttention(
169 | query_dim=dim,
170 | heads=num_attention_heads,
171 | dim_head=attention_head_dim,
172 | dropout=dropout,
173 | bias=attention_bias,
174 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
175 | upcast_attention=upcast_attention,
176 | )
177 |
178 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
179 |
180 | # Cross-Attn
181 | if cross_attention_dim is not None:
182 | self.attn2 = CrossAttention(
183 | query_dim=dim,
184 | cross_attention_dim=cross_attention_dim,
185 | heads=num_attention_heads,
186 | dim_head=attention_head_dim,
187 | dropout=dropout,
188 | bias=attention_bias,
189 | upcast_attention=upcast_attention,
190 | )
191 | else:
192 | self.attn2 = None
193 |
194 | if cross_attention_dim is not None:
195 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
196 | else:
197 | self.norm2 = None
198 |
199 | # Feed-forward
200 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
201 | self.norm3 = nn.LayerNorm(dim)
202 |
203 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None):
204 | if not is_xformers_available():
205 | print("Here is how to install it")
206 | raise ModuleNotFoundError(
207 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
208 | " xformers",
209 | name="xformers",
210 | )
211 | elif not torch.cuda.is_available():
212 | raise ValueError(
213 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
214 | " available for GPU "
215 | )
216 | else:
217 | try:
218 | # Make sure we can run the memory efficient attention
219 | _ = xformers.ops.memory_efficient_attention(
220 | torch.randn((1, 2, 40), device="cuda"),
221 | torch.randn((1, 2, 40), device="cuda"),
222 | torch.randn((1, 2, 40), device="cuda"),
223 | )
224 | except Exception as e:
225 | raise e
226 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
227 | if self.attn2 is not None:
228 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
229 |
230 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, \
231 | inter_frame=False, **kwargs):
232 | # SparseCausal-Attention
233 | norm_hidden_states = (
234 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
235 | )
236 |
237 | if self.only_cross_attention:
238 | hidden_states = (
239 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask, inter_frame=inter_frame, **kwargs) + hidden_states
240 | )
241 | else:
242 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length, inter_frame=inter_frame, **kwargs) + hidden_states
243 |
244 | if self.attn2 is not None:
245 | # Cross-Attention
246 | norm_hidden_states = (
247 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
248 | )
249 | hidden_states = (
250 | self.attn2(
251 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
252 | )
253 | + hidden_states
254 | )
255 |
256 | # Feed-forward
257 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
258 |
259 | return hidden_states
260 |
261 | class FullyFrameAttention(nn.Module):
262 | r"""
263 | A cross attention layer.
264 |
265 | Parameters:
266 | query_dim (`int`): The number of channels in the query.
267 | cross_attention_dim (`int`, *optional*):
268 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
269 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
270 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
271 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
272 | bias (`bool`, *optional*, defaults to False):
273 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
274 | """
275 |
276 | def __init__(
277 | self,
278 | query_dim: int,
279 | cross_attention_dim: Optional[int] = None,
280 | heads: int = 8,
281 | dim_head: int = 64,
282 | dropout: float = 0.0,
283 | bias=False,
284 | upcast_attention: bool = False,
285 | upcast_softmax: bool = False,
286 | added_kv_proj_dim: Optional[int] = None,
287 | norm_num_groups: Optional[int] = None,
288 | ):
289 | super().__init__()
290 | inner_dim = dim_head * heads
291 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
292 | self.upcast_attention = upcast_attention
293 | self.upcast_softmax = upcast_softmax
294 |
295 | self.scale = dim_head**-0.5
296 |
297 | self.heads = heads
298 | # for slice_size > 0 the attention score computation
299 | # is split across the batch axis to save memory
300 | # You can set slice_size with `set_attention_slice`
301 | self.sliceable_head_dim = heads
302 | self._slice_size = None
303 | self._use_memory_efficient_attention_xformers = False
304 | self.added_kv_proj_dim = added_kv_proj_dim
305 |
306 | if norm_num_groups is not None:
307 | self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
308 | else:
309 | self.group_norm = None
310 |
311 | self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
312 | self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
313 | self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
314 |
315 | if self.added_kv_proj_dim is not None:
316 | self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
317 | self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
318 |
319 | self.to_out = nn.ModuleList([])
320 | self.to_out.append(nn.Linear(inner_dim, query_dim))
321 | self.to_out.append(nn.Dropout(dropout))
322 |
323 | self.q = None
324 | self.inject_q = None
325 | self.k = None
326 | self.inject_k = None
327 |
328 |
329 | def reshape_heads_to_batch_dim(self, tensor):
330 | batch_size, seq_len, dim = tensor.shape
331 | head_size = self.heads
332 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
333 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
334 | return tensor
335 |
336 | def reshape_heads_to_batch_dim2(self, tensor):
337 | batch_size, seq_len, dim = tensor.shape
338 | head_size = self.heads
339 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
340 | tensor = tensor.permute(0, 2, 1, 3)
341 | return tensor
342 |
343 | def reshape_heads_to_batch_dim3(self, tensor):
344 | batch_size1, batch_size2, seq_len, dim = tensor.shape
345 | head_size = self.heads
346 | tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size)
347 | tensor = tensor.permute(0, 3, 1, 2, 4)
348 | return tensor
349 |
350 | def reshape_batch_dim_to_heads(self, tensor):
351 | batch_size, seq_len, dim = tensor.shape
352 | head_size = self.heads
353 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
354 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
355 | return tensor
356 |
357 | def set_attention_slice(self, slice_size):
358 | if slice_size is not None and slice_size > self.sliceable_head_dim:
359 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
360 |
361 | self._slice_size = slice_size
362 |
363 | def _attention(self, query, key, value, attention_mask=None):
364 | if self.upcast_attention:
365 | query = query.float()
366 | key = key.float()
367 |
368 | attention_scores = torch.baddbmm(
369 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
370 | query,
371 | key.transpose(-1, -2),
372 | beta=0,
373 | alpha=self.scale,
374 | )
375 | if attention_mask is not None:
376 | attention_scores = attention_scores + attention_mask
377 |
378 | if self.upcast_softmax:
379 | attention_scores = attention_scores.float()
380 |
381 | attention_probs = attention_scores.softmax(dim=-1)
382 |
383 | # cast back to the original dtype
384 | attention_probs = attention_probs.to(value.dtype)
385 |
386 | # compute attention output
387 | hidden_states = torch.bmm(attention_probs, value)
388 |
389 | # reshape hidden_states
390 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
391 | return hidden_states
392 |
393 | def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
394 | batch_size_attention = query.shape[0]
395 | hidden_states = torch.zeros(
396 | (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
397 | )
398 | slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
399 | for i in range(hidden_states.shape[0] // slice_size):
400 | start_idx = i * slice_size
401 | end_idx = (i + 1) * slice_size
402 |
403 | query_slice = query[start_idx:end_idx]
404 | key_slice = key[start_idx:end_idx]
405 |
406 | if self.upcast_attention:
407 | query_slice = query_slice.float()
408 | key_slice = key_slice.float()
409 |
410 | attn_slice = torch.baddbmm(
411 | torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
412 | query_slice,
413 | key_slice.transpose(-1, -2),
414 | beta=0,
415 | alpha=self.scale,
416 | )
417 |
418 | if attention_mask is not None:
419 | attn_slice = attn_slice + attention_mask[start_idx:end_idx]
420 |
421 | if self.upcast_softmax:
422 | attn_slice = attn_slice.float()
423 |
424 | attn_slice = attn_slice.softmax(dim=-1)
425 |
426 | # cast back to the original dtype
427 | attn_slice = attn_slice.to(value.dtype)
428 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
429 |
430 | hidden_states[start_idx:end_idx] = attn_slice
431 |
432 | # reshape hidden_states
433 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
434 | return hidden_states
435 |
436 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
437 | # TODO attention_mask
438 | query = query.contiguous()
439 | key = key.contiguous()
440 | value = value.contiguous()
441 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
442 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
443 | return hidden_states
444 |
445 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, inter_frame=False, **kwargs):
446 | batch_size, sequence_length, _ = hidden_states.shape
447 |
448 | encoder_hidden_states = encoder_hidden_states
449 |
450 | h = w = int(math.sqrt(sequence_length))
451 | if self.group_norm is not None:
452 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
453 |
454 | query = self.to_q(hidden_states) # (bf) x d(hw) x c
455 | self.q = query
456 | if self.inject_q is not None:
457 | query = self.inject_q
458 | dim = query.shape[-1]
459 | query_old = query.clone()
460 |
461 | # All frames
462 | query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
463 |
464 | query = self.reshape_heads_to_batch_dim(query)
465 |
466 | if self.added_kv_proj_dim is not None:
467 | raise NotImplementedError
468 |
469 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
470 | key = self.to_k(encoder_hidden_states)
471 | self.k = key
472 | if self.inject_k is not None:
473 | key = self.inject_k
474 | key_old = key.clone()
475 | value = self.to_v(encoder_hidden_states)
476 |
477 | if inter_frame:
478 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
479 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)[:, [0, -1]]
480 | key = rearrange(key, "b f d c -> b (f d) c",)
481 | value = rearrange(value, "b f d c -> b (f d) c")
482 | else:
483 | # All frames
484 | key = rearrange(key, "(b f) d c -> b (f d) c", f=video_length)
485 | value = rearrange(value, "(b f) d c -> b (f d) c", f=video_length)
486 |
487 | key = self.reshape_heads_to_batch_dim(key)
488 | value = self.reshape_heads_to_batch_dim(value)
489 |
490 | if attention_mask is not None:
491 | if attention_mask.shape[-1] != query.shape[1]:
492 | target_length = query.shape[1]
493 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
494 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
495 |
496 | # attention, what we cannot get enough of
497 | if self._use_memory_efficient_attention_xformers:
498 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
499 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
500 | hidden_states = hidden_states.to(query.dtype)
501 | else:
502 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
503 | hidden_states = self._attention(query, key, value, attention_mask)
504 | else:
505 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
506 |
507 | if h in [64]:
508 | hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
509 | if self.group_norm is not None:
510 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
511 |
512 | if kwargs["old_qk"] == 1:
513 | query = query_old
514 | key = key_old
515 | else:
516 | query = hidden_states
517 | key = hidden_states
518 | value = hidden_states
519 |
520 | traj = kwargs["traj"]
521 | traj = rearrange(traj, '(f n) l d -> f n l d', f=video_length, n=sequence_length)
522 | mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=video_length, n=sequence_length)
523 | mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -video_length+1:]], dim=-1)
524 |
525 | traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -video_length+1:, :]], dim=-2)
526 | t_inds = traj_key_sequence_inds[:, :, :, 0]
527 | x_inds = traj_key_sequence_inds[:, :, :, 1]
528 | y_inds = traj_key_sequence_inds[:, :, :, 2]
529 |
530 | query_tempo = query.unsqueeze(-2)
531 | _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w)
532 | _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/video_length), f=video_length, h=h, w=w)
533 | key_tempo = _key[:, t_inds, x_inds, y_inds]
534 | value_tempo = _value[:, t_inds, x_inds, y_inds]
535 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d')
536 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d')
537 |
538 | mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l')
539 | mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2)
540 | attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like
541 | attn_bias[~mask] = -torch.inf
542 |
543 | # flow attention
544 | query_tempo = self.reshape_heads_to_batch_dim3(query_tempo)
545 | key_tempo = self.reshape_heads_to_batch_dim3(key_tempo)
546 | value_tempo = self.reshape_heads_to_batch_dim3(value_tempo)
547 |
548 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias
549 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1)
550 | out = (attn_matrix2@value_tempo).squeeze(-2)
551 |
552 | hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/video_length), f=video_length, h=h, w=w)
553 |
554 | # linear proj
555 | hidden_states = self.to_out[0](hidden_states)
556 |
557 | # dropout
558 | hidden_states = self.to_out[1](hidden_states)
559 |
560 | # All frames
561 | hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length)
562 |
563 | return hidden_states
564 |
--------------------------------------------------------------------------------
/models/pipeline_flatten.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import inspect
17 | import os
18 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19 | from dataclasses import dataclass
20 |
21 | import numpy as np
22 | import PIL.Image
23 | import torch
24 | from transformers import CLIPTextModel, CLIPTokenizer
25 |
26 | from diffusers.models import AutoencoderKL
27 | from diffusers import ModelMixin
28 | from diffusers.schedulers import DDIMScheduler, DDIMInverseScheduler
29 | from diffusers.utils import (
30 | PIL_INTERPOLATION,
31 | is_accelerate_available,
32 | is_accelerate_version,
33 | logging,
34 | randn_tensor,
35 | BaseOutput
36 | )
37 | from diffusers.pipeline_utils import DiffusionPipeline
38 | from einops import rearrange
39 | from .unet import UNet3DConditionModel
40 |
41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42 |
43 |
44 | @dataclass
45 | class FlattenPipelineOutput(BaseOutput):
46 | videos: Union[torch.Tensor, np.ndarray]
47 |
48 | class FlattenPipeline(DiffusionPipeline):
49 | r"""
50 | pipeline for FLATTEN: optical FLow-guided ATTENtion for consistent text-to-video editing.
51 |
52 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
53 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
54 |
55 | Args:
56 | vae ([`AutoencoderKL`]):
57 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
58 | text_encoder ([`CLIPTextModel`]):
59 | Frozen text-encoder. Stable Diffusion uses the text portion of
60 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
61 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
62 | tokenizer (`CLIPTokenizer`):
63 | Tokenizer of class
64 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
65 | unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
66 | scheduler ([`SchedulerMixin`]):
67 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
68 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
69 | inverse_scheduler ([`SchedulerMixin`]):
70 | DDIM inversion scheduler .
71 | """
72 | _optional_components = ["safety_checker", "feature_extractor"]
73 |
74 | def __init__(
75 | self,
76 | vae: AutoencoderKL,
77 | text_encoder: CLIPTextModel,
78 | tokenizer: CLIPTokenizer,
79 | unet: UNet3DConditionModel,
80 | scheduler: DDIMScheduler,
81 | inverse_scheduler: DDIMInverseScheduler
82 | ):
83 | super().__init__()
84 |
85 | self.register_modules(
86 | vae=vae,
87 | text_encoder=text_encoder,
88 | tokenizer=tokenizer,
89 | unet=unet,
90 | scheduler=scheduler,
91 | inverse_scheduler=inverse_scheduler
92 | )
93 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
94 |
95 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
96 | def enable_vae_slicing(self):
97 | r"""
98 | Enable sliced VAE decoding.
99 |
100 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
101 | steps. This is useful to save some memory and allow larger batch sizes.
102 | """
103 | self.vae.enable_slicing()
104 |
105 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
106 | def disable_vae_slicing(self):
107 | r"""
108 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
109 | computing decoding in one step.
110 | """
111 | self.vae.disable_slicing()
112 |
113 | def enable_sequential_cpu_offload(self, gpu_id=0):
114 | r"""
115 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
116 | text_encoder, vae, and safety checker have their state dicts saved to CPU and then are moved to a
117 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
118 | Note that offloading happens on a submodule basis. Memory savings are higher than with
119 | `enable_model_cpu_offload`, but performance is lower.
120 | """
121 | if is_accelerate_available():
122 | from accelerate import cpu_offload
123 | else:
124 | raise ImportError("Please install accelerate via `pip install accelerate`")
125 |
126 | device = torch.device(f"cuda:{gpu_id}")
127 |
128 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
129 | cpu_offload(cpu_offloaded_model, device)
130 |
131 | if self.safety_checker is not None:
132 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
133 |
134 | def enable_model_cpu_offload(self, gpu_id=0):
135 | r"""
136 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
137 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
138 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
139 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
140 | """
141 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
142 | from accelerate import cpu_offload_with_hook
143 | else:
144 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
145 |
146 | device = torch.device(f"cuda:{gpu_id}")
147 |
148 | hook = None
149 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
150 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
151 |
152 | if self.safety_checker is not None:
153 | # the safety checker can offload the vae again
154 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
155 |
156 | # We'll offload the last model manually.
157 | self.final_offload_hook = hook
158 |
159 | @property
160 | def _execution_device(self):
161 | r"""
162 | Returns the device on which the pipeline's models will be executed. After calling
163 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
164 | hooks.
165 | """
166 | if not hasattr(self.unet, "_hf_hook"):
167 | return self.device
168 | for module in self.unet.modules():
169 | if (
170 | hasattr(module, "_hf_hook")
171 | and hasattr(module._hf_hook, "execution_device")
172 | and module._hf_hook.execution_device is not None
173 | ):
174 | return torch.device(module._hf_hook.execution_device)
175 | return self.device
176 |
177 | def _encode_prompt(
178 | self,
179 | prompt,
180 | device,
181 | num_videos_per_prompt,
182 | do_classifier_free_guidance,
183 | negative_prompt=None,
184 | prompt_embeds: Optional[torch.FloatTensor] = None,
185 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
186 | ):
187 | r"""
188 | Encodes the prompt into text encoder hidden states.
189 |
190 | Args:
191 | prompt (`str` or `List[str]`, *optional*):
192 | prompt to be encoded
193 | device: (`torch.device`):
194 | torch device
195 | num_videos_per_prompt (`int`):
196 | number of images that should be generated per prompt
197 | do_classifier_free_guidance (`bool`):
198 | whether to use classifier free guidance or not
199 | negative_prompt (`str` or `List[str]`, *optional*):
200 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
201 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
202 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
203 | prompt_embeds (`torch.FloatTensor`, *optional*):
204 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
205 | provided, text embeddings will be generated from `prompt` input argument.
206 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
207 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
208 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
209 | argument.
210 | """
211 | if prompt is not None and isinstance(prompt, str):
212 | batch_size = 1
213 | elif prompt is not None and isinstance(prompt, list):
214 | batch_size = len(prompt)
215 | else:
216 | batch_size = prompt_embeds.shape[0]
217 |
218 | if prompt_embeds is None:
219 | text_inputs = self.tokenizer(
220 | prompt,
221 | padding="max_length",
222 | max_length=self.tokenizer.model_max_length,
223 | truncation=True,
224 | return_tensors="pt",
225 | )
226 | text_input_ids = text_inputs.input_ids
227 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
228 |
229 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
230 | text_input_ids, untruncated_ids
231 | ):
232 | removed_text = self.tokenizer.batch_decode(
233 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
234 | )
235 | logger.warning(
236 | "The following part of your input was truncated because CLIP can only handle sequences up to"
237 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
238 | )
239 |
240 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
241 | attention_mask = text_inputs.attention_mask.to(device)
242 | else:
243 | attention_mask = None
244 |
245 | prompt_embeds = self.text_encoder(
246 | text_input_ids.to(device),
247 | attention_mask=attention_mask,
248 | )
249 | prompt_embeds = prompt_embeds[0]
250 |
251 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
252 |
253 | bs_embed, seq_len, _ = prompt_embeds.shape
254 | # duplicate text embeddings for each generation per prompt, using mps friendly method
255 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
256 | prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
257 |
258 | # get unconditional embeddings for classifier free guidance
259 | if do_classifier_free_guidance and negative_prompt_embeds is None:
260 | uncond_tokens: List[str]
261 | if negative_prompt is None:
262 | uncond_tokens = [""] * batch_size
263 | elif type(prompt) is not type(negative_prompt):
264 | raise TypeError(
265 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
266 | f" {type(prompt)}."
267 | )
268 | elif isinstance(negative_prompt, str):
269 | uncond_tokens = [negative_prompt]
270 | elif batch_size != len(negative_prompt):
271 | raise ValueError(
272 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
273 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
274 | " the batch size of `prompt`."
275 | )
276 | else:
277 | uncond_tokens = negative_prompt
278 |
279 | max_length = prompt_embeds.shape[1]
280 | uncond_input = self.tokenizer(
281 | uncond_tokens,
282 | padding="max_length",
283 | max_length=max_length,
284 | truncation=True,
285 | return_tensors="pt",
286 | )
287 |
288 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
289 | attention_mask = uncond_input.attention_mask.to(device)
290 | else:
291 | attention_mask = None
292 |
293 | negative_prompt_embeds = self.text_encoder(
294 | uncond_input.input_ids.to(device),
295 | attention_mask=attention_mask,
296 | )
297 | negative_prompt_embeds = negative_prompt_embeds[0]
298 |
299 | if do_classifier_free_guidance:
300 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301 | seq_len = negative_prompt_embeds.shape[1]
302 |
303 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304 |
305 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
306 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
307 |
308 | # For classifier free guidance, we need to do two forward passes.
309 | # Here we concatenate the unconditional and text embeddings into a single batch
310 | # to avoid doing two forward passes
311 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
312 |
313 | return prompt_embeds
314 |
315 | def decode_latents(self, latents, return_tensor=False):
316 | video_length = latents.shape[2]
317 | latents = 1 / 0.18215 * latents
318 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
319 | video = self.vae.decode(latents).sample
320 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
321 | video = (video / 2 + 0.5).clamp(0, 1)
322 | if return_tensor:
323 | return video
324 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
325 | video = video.cpu().float().numpy()
326 | return video
327 |
328 | def prepare_extra_step_kwargs(self, generator, eta):
329 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
330 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
331 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
332 | # and should be between [0, 1]
333 |
334 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
335 | extra_step_kwargs = {}
336 | if accepts_eta:
337 | extra_step_kwargs["eta"] = eta
338 |
339 | # check if the scheduler accepts generator
340 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
341 | if accepts_generator:
342 | extra_step_kwargs["generator"] = generator
343 | return extra_step_kwargs
344 |
345 | def check_inputs(
346 | self,
347 | prompt,
348 | # image,
349 | height,
350 | width,
351 | callback_steps,
352 | negative_prompt=None,
353 | prompt_embeds=None,
354 | negative_prompt_embeds=None,
355 | ):
356 | if height % 8 != 0 or width % 8 != 0:
357 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
358 |
359 | if (callback_steps is None) or (
360 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
361 | ):
362 | raise ValueError(
363 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
364 | f" {type(callback_steps)}."
365 | )
366 |
367 | if prompt is not None and prompt_embeds is not None:
368 | raise ValueError(
369 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
370 | " only forward one of the two."
371 | )
372 | elif prompt is None and prompt_embeds is None:
373 | raise ValueError(
374 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
375 | )
376 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
377 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
378 |
379 | if negative_prompt is not None and negative_prompt_embeds is not None:
380 | raise ValueError(
381 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
382 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
383 | )
384 |
385 | if prompt_embeds is not None and negative_prompt_embeds is not None:
386 | if prompt_embeds.shape != negative_prompt_embeds.shape:
387 | raise ValueError(
388 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
389 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
390 | f" {negative_prompt_embeds.shape}."
391 | )
392 |
393 |
394 | def check_image(self, image, prompt, prompt_embeds):
395 | image_is_pil = isinstance(image, PIL.Image.Image)
396 | image_is_tensor = isinstance(image, torch.Tensor)
397 | image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
398 | image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
399 |
400 | if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
401 | raise TypeError(
402 | "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
403 | )
404 |
405 | if image_is_pil:
406 | image_batch_size = 1
407 | elif image_is_tensor:
408 | image_batch_size = image.shape[0]
409 | elif image_is_pil_list:
410 | image_batch_size = len(image)
411 | elif image_is_tensor_list:
412 | image_batch_size = len(image)
413 |
414 | if prompt is not None and isinstance(prompt, str):
415 | prompt_batch_size = 1
416 | elif prompt is not None and isinstance(prompt, list):
417 | prompt_batch_size = len(prompt)
418 | elif prompt_embeds is not None:
419 | prompt_batch_size = prompt_embeds.shape[0]
420 |
421 | if image_batch_size != 1 and image_batch_size != prompt_batch_size:
422 | raise ValueError(
423 | f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
424 | )
425 |
426 | def prepare_image(
427 | self, image, width, height, batch_size, num_videos_per_prompt, device, dtype, do_classifier_free_guidance
428 | ):
429 | if not isinstance(image, torch.Tensor):
430 | if isinstance(image, PIL.Image.Image):
431 | image = [image]
432 |
433 | if isinstance(image[0], PIL.Image.Image):
434 | images = []
435 |
436 | for image_ in image:
437 | image_ = image_.convert("RGB")
438 | image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
439 | image_ = np.array(image_)
440 | image_ = image_[None, :]
441 | images.append(image_)
442 |
443 | image = images
444 |
445 | image = np.concatenate(image, axis=0)
446 | image = np.array(image).astype(np.float32) / 255.0
447 | image = image.transpose(0, 3, 1, 2)
448 | image = 2.0 * image - 1.0
449 | image = torch.from_numpy(image)
450 | elif isinstance(image[0], torch.Tensor):
451 | image = torch.cat(image, dim=0)
452 |
453 | image_batch_size = image.shape[0]
454 |
455 | if image_batch_size == 1:
456 | repeat_by = batch_size
457 | else:
458 | # image batch size is the same as prompt batch size
459 | repeat_by = num_videos_per_prompt
460 |
461 | image = image.repeat_interleave(repeat_by, dim=0)
462 |
463 | image = image.to(device=device, dtype=dtype)
464 |
465 | return image
466 |
467 | def prepare_video_latents(self, frames, batch_size, dtype, device, generator=None):
468 | if not isinstance(frames, (torch.Tensor, PIL.Image.Image, list)):
469 | raise ValueError(
470 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
471 | )
472 |
473 | frames = frames[0].to(device=device, dtype=dtype)
474 | frames = rearrange(frames, "c f h w -> f c h w" )
475 |
476 | if isinstance(generator, list) and len(generator) != batch_size:
477 | raise ValueError(
478 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
479 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
480 | )
481 |
482 | if isinstance(generator, list):
483 | latents = [
484 | self.vae.encode(frames[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
485 | ]
486 | latents = torch.cat(latents, dim=0)
487 | else:
488 | latents = self.vae.encode(frames).latent_dist.sample(generator)
489 |
490 | latents = self.vae.config.scaling_factor * latents
491 |
492 | latents = rearrange(latents, "f c h w ->c f h w" )
493 |
494 | return latents[None]
495 |
496 | def _default_height_width(self, height, width, image):
497 | # NOTE: It is possible that a list of images have different
498 | # dimensions for each image, so just checking the first image
499 | # is not _exactly_ correct, but it is simple.
500 | while isinstance(image, list):
501 | image = image[0]
502 |
503 | if height is None:
504 | if isinstance(image, PIL.Image.Image):
505 | height = image.height
506 | elif isinstance(image, torch.Tensor):
507 | height = image.shape[3]
508 |
509 | height = (height // 8) * 8 # round down to nearest multiple of 8
510 |
511 | if width is None:
512 | if isinstance(image, PIL.Image.Image):
513 | width = image.width
514 | elif isinstance(image, torch.Tensor):
515 | width = image.shape[2]
516 |
517 | width = (width // 8) * 8 # round down to nearest multiple of 8
518 |
519 | return height, width
520 |
521 | def get_alpha_prev(self, timestep):
522 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
523 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
524 | return alpha_prod_t_prev
525 |
526 | def get_slide_window_indices(self, video_length, window_size):
527 | assert window_size >=3
528 | key_frame_indices = np.arange(0, video_length, window_size-1).tolist()
529 |
530 | # Append last index
531 | if key_frame_indices[-1] != (video_length-1):
532 | key_frame_indices.append(video_length-1)
533 |
534 | slices = np.split(np.arange(video_length), key_frame_indices)
535 | inter_frame_list = []
536 | for s in slices:
537 | if len(s) < 2:
538 | continue
539 | inter_frame_list.append(s[1:].tolist())
540 | return key_frame_indices, inter_frame_list
541 |
542 | def get_inverse_timesteps(self, num_inference_steps, strength, device):
543 | # get the original timestep using init_timestep
544 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
545 |
546 | t_start = max(num_inference_steps - init_timestep, 0)
547 |
548 | # safety for t_start overflow to prevent empty timsteps slice
549 | if t_start == 0:
550 | return self.inverse_scheduler.timesteps, num_inference_steps
551 | timesteps = self.inverse_scheduler.timesteps[:-t_start]
552 |
553 | return timesteps, num_inference_steps - t_start
554 |
555 | def clean_features(self):
556 | self.unet.up_blocks[1].resnets[0].out_layers_inject_features = None
557 | self.unet.up_blocks[1].resnets[1].out_layers_inject_features = None
558 | self.unet.up_blocks[2].resnets[0].out_layers_inject_features = None
559 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = None
560 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = None
561 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = None
562 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = None
563 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = None
564 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = None
565 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = None
566 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = None
567 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = None
568 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = None
569 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = None
570 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = None
571 |
572 | @torch.no_grad()
573 | def __call__(
574 | self,
575 | prompt: Union[str, List[str]] = None,
576 | video_length: Optional[int] = 1,
577 | frames: Union[List[torch.FloatTensor], List[PIL.Image.Image], List[List[torch.FloatTensor]], List[List[PIL.Image.Image]]] = None,
578 | height: Optional[int] = None,
579 | width: Optional[int] = None,
580 | num_inference_steps: int = 50,
581 | guidance_scale: float = 7.5,
582 | negative_prompt: Optional[Union[str, List[str]]] = None,
583 | num_videos_per_prompt: Optional[int] = 1,
584 | eta: float = 0.0,
585 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
586 | latents: Optional[torch.FloatTensor] = None,
587 | prompt_embeds: Optional[torch.FloatTensor] = None,
588 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
589 | output_type: Optional[str] = "tensor",
590 | return_dict: bool = True,
591 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
592 | callback_steps: int = 1,
593 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
594 | **kwargs,
595 | ):
596 | r"""
597 | Function invoked when calling the pipeline for generation.
598 |
599 | Args:
600 | prompt (`str` or `List[str]`, *optional*):
601 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
602 | instead.
603 | frames (`List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
604 | `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
605 | The original video frames to be edited.
606 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
607 | The height in pixels of the generated image.
608 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
609 | The width in pixels of the generated image.
610 | num_inference_steps (`int`, *optional*, defaults to 50):
611 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
612 | expense of slower inference.
613 | guidance_scale (`float`, *optional*, defaults to 7.5):
614 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
615 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
616 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
617 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
618 | usually at the expense of lower image quality.
619 | negative_prompt (`str` or `List[str]`, *optional*):
620 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
621 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
622 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
623 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
624 | The number of images to generate per prompt.
625 | eta (`float`, *optional*, defaults to 0.0):
626 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
627 | [`schedulers.DDIMScheduler`], will be ignored for others.
628 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
629 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
630 | to make generation deterministic.
631 | latents (`torch.FloatTensor`, *optional*):
632 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
633 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
634 | tensor will ge generated by sampling using the supplied random `generator`.
635 | prompt_embeds (`torch.FloatTensor`, *optional*):
636 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
637 | provided, text embeddings will be generated from `prompt` input argument.
638 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
639 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
640 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
641 | argument.
642 | output_type (`str`, *optional*, defaults to `"pil"`):
643 | The output format of the generate image. Choose between
644 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
645 | return_dict (`bool`, *optional*, defaults to `True`):
646 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
647 | plain tuple.
648 | callback (`Callable`, *optional*):
649 | A function that will be called every `callback_steps` steps during inference. The function will be
650 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
651 | callback_steps (`int`, *optional*, defaults to 1):
652 | The frequency at which the `callback` function will be called. If not specified, the callback will be
653 | called at every step.
654 | cross_attention_kwargs (`dict`, *optional*):
655 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
656 | `self.processor` in
657 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
658 | """
659 | height, width = self._default_height_width(height, width, frames)
660 |
661 | self.check_inputs(
662 | prompt,
663 | height,
664 | width,
665 | callback_steps,
666 | negative_prompt,
667 | prompt_embeds,
668 | negative_prompt_embeds,
669 | )
670 |
671 | if prompt is not None and isinstance(prompt, str):
672 | batch_size = 1
673 | elif prompt is not None and isinstance(prompt, list):
674 | batch_size = len(prompt)
675 | else:
676 | batch_size = prompt_embeds.shape[0]
677 |
678 | device = self._execution_device
679 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
680 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
681 | # corresponds to doing no classifier free guidance.
682 | do_classifier_free_guidance = guidance_scale > 1.0
683 |
684 | # encode empty prompt
685 | prompt_embeds = self._encode_prompt(
686 | "",
687 | device,
688 | num_videos_per_prompt,
689 | do_classifier_free_guidance=do_classifier_free_guidance,
690 | negative_prompt=None,
691 | prompt_embeds=prompt_embeds,
692 | negative_prompt_embeds=negative_prompt_embeds,
693 | )
694 |
695 | images = []
696 | for i_img in frames:
697 | i_img = self.prepare_image(
698 | image=i_img,
699 | width=width,
700 | height=height,
701 | batch_size=batch_size * num_videos_per_prompt,
702 | num_videos_per_prompt=num_videos_per_prompt,
703 | device=device,
704 | dtype=self.unet.dtype,
705 | do_classifier_free_guidance=do_classifier_free_guidance,
706 | )
707 | images.append(i_img)
708 | frames = torch.stack(images, dim=2) # b x c x f x h x w
709 |
710 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
711 |
712 | latents = self.prepare_video_latents(frames, batch_size, self.unet.dtype, device, generator=generator)
713 |
714 | saved_features0 = []
715 | saved_features1 = []
716 | saved_features2 = []
717 | saved_q4 = []
718 | saved_k4 = []
719 | saved_q5 = []
720 | saved_k5 = []
721 | saved_q6 = []
722 | saved_k6 = []
723 | saved_q7 = []
724 | saved_k7 = []
725 | saved_q8 = []
726 | saved_k8 = []
727 | saved_q9 = []
728 | saved_k9 = []
729 |
730 | # ddim inverse
731 | self.scheduler.set_timesteps(num_inference_steps, device=device)
732 | timesteps = self.scheduler.timesteps
733 |
734 | num_inverse_steps = 100
735 | self.inverse_scheduler.set_timesteps(num_inverse_steps, device=device)
736 | inverse_timesteps, num_inverse_steps = self.get_inverse_timesteps(num_inverse_steps, 1, device)
737 | num_warmup_steps = len(inverse_timesteps) - num_inverse_steps * self.inverse_scheduler.order
738 |
739 | with self.progress_bar(total=num_inverse_steps-1) as progress_bar:
740 | for i, t in enumerate(inverse_timesteps[1:]):
741 | # expand the latents if we are doing classifier free guidance
742 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
743 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
744 |
745 | noise_pred = self.unet(
746 | latent_model_input,
747 | t,
748 | encoder_hidden_states=prompt_embeds,
749 | cross_attention_kwargs=cross_attention_kwargs,
750 | **kwargs,
751 | ).sample
752 |
753 | if t in timesteps:
754 | saved_features0.append(self.unet.up_blocks[1].resnets[0].out_layers_features.cpu())
755 | saved_features1.append(self.unet.up_blocks[1].resnets[1].out_layers_features.cpu())
756 | saved_features2.append(self.unet.up_blocks[2].resnets[0].out_layers_features.cpu())
757 | saved_q4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.q.cpu())
758 | saved_k4.append(self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.k.cpu())
759 | saved_q5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.q.cpu())
760 | saved_k5.append(self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.k.cpu())
761 | saved_q6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.q.cpu())
762 | saved_k6.append(self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.k.cpu())
763 | saved_q7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.q.cpu())
764 | saved_k7.append(self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.k.cpu())
765 | saved_q8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.q.cpu())
766 | saved_k8.append(self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.k.cpu())
767 | saved_q9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.q.cpu())
768 | saved_k9.append(self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.k.cpu())
769 |
770 |
771 | if do_classifier_free_guidance:
772 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
773 | noise_pred = noise_pred_uncond + 1 * (noise_pred_text - noise_pred_uncond)
774 |
775 | # compute the previous noisy sample x_t -> x_t-1
776 | latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
777 | if i == len(inverse_timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0):
778 | progress_bar.update()
779 |
780 | saved_features0.reverse()
781 | saved_features1.reverse()
782 | saved_features2.reverse()
783 | saved_q4.reverse()
784 | saved_k4.reverse()
785 | saved_q5.reverse()
786 | saved_k5.reverse()
787 | saved_q6.reverse()
788 | saved_k6.reverse()
789 | saved_q7.reverse()
790 | saved_k7.reverse()
791 | saved_q8.reverse()
792 | saved_k8.reverse()
793 | saved_q9.reverse()
794 | saved_k9.reverse()
795 |
796 | # video sampling
797 | prompt_embeds = self._encode_prompt(
798 | prompt,
799 | device,
800 | num_videos_per_prompt,
801 | do_classifier_free_guidance,
802 | negative_prompt,
803 | prompt_embeds=None,
804 | negative_prompt_embeds=negative_prompt_embeds,
805 | )
806 |
807 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
808 | with self.progress_bar(total=num_inference_steps) as progress_bar:
809 | for i, t in enumerate(timesteps):
810 | torch.cuda.empty_cache()
811 |
812 | # expand the latents if we are doing classifier free guidance
813 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
814 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
815 |
816 | # inject features
817 | if i < kwargs["inject_step"]:
818 | self.unet.up_blocks[1].resnets[0].out_layers_inject_features = saved_features0[i].to(device)
819 | self.unet.up_blocks[1].resnets[1].out_layers_inject_features = saved_features1[i].to(device)
820 | self.unet.up_blocks[2].resnets[0].out_layers_inject_features = saved_features2[i].to(device)
821 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q4[i].to(device)
822 | self.unet.up_blocks[1].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k4[i].to(device)
823 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q5[i].to(device)
824 | self.unet.up_blocks[1].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k5[i].to(device)
825 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q6[i].to(device)
826 | self.unet.up_blocks[2].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k6[i].to(device)
827 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_q = saved_q7[i].to(device)
828 | self.unet.up_blocks[2].attentions[1].transformer_blocks[0].attn1.inject_k = saved_k7[i].to(device)
829 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_q = saved_q8[i].to(device)
830 | self.unet.up_blocks[2].attentions[2].transformer_blocks[0].attn1.inject_k = saved_k8[i].to(device)
831 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_q = saved_q9[i].to(device)
832 | self.unet.up_blocks[3].attentions[0].transformer_blocks[0].attn1.inject_k = saved_k9[i].to(device)
833 | else:
834 | self.clean_features()
835 |
836 | noise_pred = self.unet(
837 | latent_model_input,
838 | t,
839 | encoder_hidden_states=prompt_embeds,
840 | cross_attention_kwargs=cross_attention_kwargs,
841 | **kwargs,
842 | ).sample
843 |
844 | self.clean_features()
845 |
846 | # perform guidance
847 | if do_classifier_free_guidance:
848 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
849 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
850 |
851 | # compute the previous noisy sample x_t -> x_t-1
852 | step_dict = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
853 | latents = step_dict.prev_sample
854 |
855 | # call the callback, if provided
856 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
857 | progress_bar.update()
858 | if callback is not None and i % callback_steps == 0:
859 | callback(i, t, latents)
860 |
861 | # If we do sequential model offloading, let's offload unet
862 | # manually for max memory savings
863 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
864 | self.unet.to("cpu")
865 | torch.cuda.empty_cache()
866 | # Post-processing
867 | video = self.decode_latents(latents)
868 |
869 | # Convert to tensor
870 | if output_type == "tensor":
871 | video = torch.from_numpy(video)
872 |
873 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
874 | self.final_offload_hook.offload()
875 |
876 | if not return_dict:
877 | return video
878 |
879 | return FlattenPipelineOutput(videos=video)
880 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 |
18 | return x
19 |
20 | class TemporalConv1d(nn.Conv1d):
21 | def forward(self, x):
22 | b, c, f, h, w = x.shape
23 | y = rearrange(x.clone(), "b c f h w -> (b h w) c f")
24 | y = super().forward(y)
25 | y = rearrange(y, "(b h w) c f -> b c f h w", b=b, h=h, w=w)
26 | return y
27 |
28 |
29 | class Upsample3D(nn.Module):
30 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
31 | super().__init__()
32 | self.channels = channels
33 | self.out_channels = out_channels or channels
34 | self.use_conv = use_conv
35 | self.use_conv_transpose = use_conv_transpose
36 | self.name = name
37 |
38 | conv = None
39 | if use_conv_transpose:
40 | raise NotImplementedError
41 | elif use_conv:
42 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
43 |
44 | if name == "conv":
45 | self.conv = conv
46 | else:
47 | self.Conv2d_0 = conv
48 |
49 | def forward(self, hidden_states, output_size=None):
50 | assert hidden_states.shape[1] == self.channels
51 |
52 | if self.use_conv_transpose:
53 | raise NotImplementedError
54 |
55 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
56 | dtype = hidden_states.dtype
57 | if dtype == torch.bfloat16:
58 | hidden_states = hidden_states.to(torch.float32)
59 |
60 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
61 | if hidden_states.shape[0] >= 64:
62 | hidden_states = hidden_states.contiguous()
63 |
64 | # if `output_size` is passed we force the interpolation output
65 | # size and do not make use of `scale_factor=2`
66 | if output_size is None:
67 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
68 | else:
69 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
70 |
71 | # If the input is bfloat16, we cast back to bfloat16
72 | if dtype == torch.bfloat16:
73 | hidden_states = hidden_states.to(dtype)
74 |
75 | if self.use_conv:
76 | if self.name == "conv":
77 | hidden_states = self.conv(hidden_states)
78 | else:
79 | hidden_states = self.Conv2d_0(hidden_states)
80 |
81 | return hidden_states
82 |
83 |
84 | class Downsample3D(nn.Module):
85 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
86 | super().__init__()
87 | self.channels = channels
88 | self.out_channels = out_channels or channels
89 | self.use_conv = use_conv
90 | self.padding = padding
91 | stride = 2
92 | self.name = name
93 |
94 | if use_conv:
95 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
96 | else:
97 | raise NotImplementedError
98 |
99 | if name == "conv":
100 | self.Conv2d_0 = conv
101 | self.conv = conv
102 | elif name == "Conv2d_0":
103 | self.conv = conv
104 | else:
105 | self.conv = conv
106 |
107 | def forward(self, hidden_states):
108 | assert hidden_states.shape[1] == self.channels
109 | if self.use_conv and self.padding == 0:
110 | raise NotImplementedError
111 |
112 | assert hidden_states.shape[1] == self.channels
113 | hidden_states = self.conv(hidden_states)
114 |
115 | return hidden_states
116 |
117 |
118 | class ResnetBlock3D(nn.Module):
119 | def __init__(
120 | self,
121 | *,
122 | in_channels,
123 | out_channels=None,
124 | conv_shortcut=False,
125 | dropout=0.0,
126 | temb_channels=512,
127 | groups=32,
128 | groups_out=None,
129 | pre_norm=True,
130 | eps=1e-6,
131 | non_linearity="swish",
132 | time_embedding_norm="default",
133 | output_scale_factor=1.0,
134 | use_in_shortcut=None,
135 | ):
136 | super().__init__()
137 | self.pre_norm = pre_norm
138 | self.pre_norm = True
139 | self.in_channels = in_channels
140 | out_channels = in_channels if out_channels is None else out_channels
141 | self.out_channels = out_channels
142 | self.use_conv_shortcut = conv_shortcut
143 | self.time_embedding_norm = time_embedding_norm
144 | self.output_scale_factor = output_scale_factor
145 |
146 | if groups_out is None:
147 | groups_out = groups
148 |
149 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
150 |
151 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
152 |
153 | if temb_channels is not None:
154 | if self.time_embedding_norm == "default":
155 | time_emb_proj_out_channels = out_channels
156 | elif self.time_embedding_norm == "scale_shift":
157 | time_emb_proj_out_channels = out_channels * 2
158 | else:
159 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
160 |
161 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
162 | else:
163 | self.time_emb_proj = None
164 |
165 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
166 | self.dropout = torch.nn.Dropout(dropout)
167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168 |
169 | if non_linearity == "swish":
170 | self.nonlinearity = lambda x: F.silu(x)
171 | elif non_linearity == "mish":
172 | self.nonlinearity = Mish()
173 | elif non_linearity == "silu":
174 | self.nonlinearity = nn.SiLU()
175 |
176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
177 |
178 | self.conv_shortcut = None
179 | if self.use_in_shortcut:
180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
181 |
182 | # save features
183 | self.out_layers_features = None
184 | self.out_layers_inject_features = None
185 |
186 | def forward(self, input_tensor, temb):
187 | hidden_states = input_tensor
188 |
189 | hidden_states = self.norm1(hidden_states)
190 | hidden_states = self.nonlinearity(hidden_states)
191 |
192 | hidden_states = self.conv1(hidden_states)
193 |
194 | if temb is not None:
195 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
196 |
197 | if temb is not None and self.time_embedding_norm == "default":
198 | hidden_states = hidden_states + temb
199 |
200 | hidden_states = self.norm2(hidden_states)
201 |
202 | if temb is not None and self.time_embedding_norm == "scale_shift":
203 | scale, shift = torch.chunk(temb, 2, dim=1)
204 | hidden_states = hidden_states * (1 + scale) + shift
205 |
206 | hidden_states = self.nonlinearity(hidden_states)
207 |
208 | hidden_states = self.dropout(hidden_states)
209 | hidden_states = self.conv2(hidden_states)
210 |
211 | if self.conv_shortcut is not None:
212 | input_tensor = self.conv_shortcut(input_tensor)
213 |
214 | # save features
215 | self.out_layers_features = hidden_states
216 | if self.out_layers_inject_features is not None:
217 | hidden_states = self.out_layers_inject_features
218 |
219 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
220 |
221 | return output_tensor
222 |
223 |
224 | class Mish(torch.nn.Module):
225 | def forward(self, hidden_states):
226 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
227 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | from diffusers.configuration_utils import ConfigMixin, register_to_config
14 | from diffusers import ModelMixin
15 | from diffusers.utils import BaseOutput, logging
16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17 | from .unet_blocks import (
18 | CrossAttnDownBlock3D,
19 | CrossAttnUpBlock3D,
20 | DownBlock3D,
21 | UNetMidBlock3DCrossAttn,
22 | UpBlock3D,
23 | get_down_block,
24 | get_up_block,
25 | )
26 | from .resnet import InflatedConv3d
27 |
28 |
29 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30 |
31 |
32 | @dataclass
33 | class UNet3DConditionOutput(BaseOutput):
34 | sample: torch.FloatTensor
35 |
36 |
37 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
38 | _supports_gradient_checkpointing = True
39 |
40 | @register_to_config
41 | def __init__(
42 | self,
43 | sample_size: Optional[int] = None,
44 | in_channels: int = 4,
45 | out_channels: int = 4,
46 | center_input_sample: bool = False,
47 | flip_sin_to_cos: bool = True,
48 | freq_shift: int = 0,
49 | down_block_types: Tuple[str] = (
50 | "CrossAttnDownBlock3D",
51 | "CrossAttnDownBlock3D",
52 | "CrossAttnDownBlock3D",
53 | "DownBlock3D",
54 | ),
55 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
56 | up_block_types: Tuple[str] = (
57 | "UpBlock3D",
58 | "CrossAttnUpBlock3D",
59 | "CrossAttnUpBlock3D",
60 | "CrossAttnUpBlock3D"
61 | ),
62 | only_cross_attention: Union[bool, Tuple[bool]] = False,
63 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64 | layers_per_block: int = 2,
65 | downsample_padding: int = 1,
66 | mid_block_scale_factor: float = 1,
67 | act_fn: str = "silu",
68 | norm_num_groups: int = 32,
69 | norm_eps: float = 1e-5,
70 | cross_attention_dim: int = 1280,
71 | attention_head_dim: Union[int, Tuple[int]] = 8,
72 | dual_cross_attention: bool = False,
73 | use_linear_projection: bool = False,
74 | class_embed_type: Optional[str] = None,
75 | num_class_embeds: Optional[int] = None,
76 | upcast_attention: bool = False,
77 | resnet_time_scale_shift: str = "default",
78 | ):
79 | super().__init__()
80 |
81 | self.sample_size = sample_size
82 | time_embed_dim = block_out_channels[0] * 4
83 |
84 | # input
85 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86 |
87 | # time
88 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
89 | timestep_input_dim = block_out_channels[0]
90 |
91 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
92 |
93 | # class embedding
94 | if class_embed_type is None and num_class_embeds is not None:
95 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
96 | elif class_embed_type == "timestep":
97 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
98 | elif class_embed_type == "identity":
99 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
100 | else:
101 | self.class_embedding = None
102 |
103 | self.down_blocks = nn.ModuleList([])
104 | self.mid_block = None
105 | self.up_blocks = nn.ModuleList([])
106 |
107 | if isinstance(only_cross_attention, bool):
108 | only_cross_attention = [only_cross_attention] * len(down_block_types)
109 |
110 | if isinstance(attention_head_dim, int):
111 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
112 |
113 | # down
114 | output_channel = block_out_channels[0]
115 | for i, down_block_type in enumerate(down_block_types):
116 | input_channel = output_channel
117 | output_channel = block_out_channels[i]
118 | is_final_block = i == len(block_out_channels) - 1
119 |
120 | down_block = get_down_block(
121 | down_block_type,
122 | num_layers=layers_per_block,
123 | in_channels=input_channel,
124 | out_channels=output_channel,
125 | temb_channels=time_embed_dim,
126 | add_downsample=not is_final_block,
127 | resnet_eps=norm_eps,
128 | resnet_act_fn=act_fn,
129 | resnet_groups=norm_num_groups,
130 | cross_attention_dim=cross_attention_dim,
131 | attn_num_head_channels=attention_head_dim[i],
132 | downsample_padding=downsample_padding,
133 | dual_cross_attention=dual_cross_attention,
134 | use_linear_projection=use_linear_projection,
135 | only_cross_attention=only_cross_attention[i],
136 | upcast_attention=upcast_attention,
137 | resnet_time_scale_shift=resnet_time_scale_shift,
138 | )
139 | self.down_blocks.append(down_block)
140 |
141 | # mid
142 | if mid_block_type == "UNetMidBlock3DCrossAttn":
143 | self.mid_block = UNetMidBlock3DCrossAttn(
144 | in_channels=block_out_channels[-1],
145 | temb_channels=time_embed_dim,
146 | resnet_eps=norm_eps,
147 | resnet_act_fn=act_fn,
148 | output_scale_factor=mid_block_scale_factor,
149 | resnet_time_scale_shift=resnet_time_scale_shift,
150 | cross_attention_dim=cross_attention_dim,
151 | attn_num_head_channels=attention_head_dim[-1],
152 | resnet_groups=norm_num_groups,
153 | dual_cross_attention=dual_cross_attention,
154 | use_linear_projection=use_linear_projection,
155 | upcast_attention=upcast_attention,
156 | )
157 | else:
158 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
159 |
160 | # count how many layers upsample the videos
161 | self.num_upsamplers = 0
162 |
163 | # up
164 | reversed_block_out_channels = list(reversed(block_out_channels))
165 | reversed_attention_head_dim = list(reversed(attention_head_dim))
166 | only_cross_attention = list(reversed(only_cross_attention))
167 | output_channel = reversed_block_out_channels[0]
168 | for i, up_block_type in enumerate(up_block_types):
169 | is_final_block = i == len(block_out_channels) - 1
170 |
171 | prev_output_channel = output_channel
172 | output_channel = reversed_block_out_channels[i]
173 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
174 |
175 | # add upsample block for all BUT final layer
176 | if not is_final_block:
177 | add_upsample = True
178 | self.num_upsamplers += 1
179 | else:
180 | add_upsample = False
181 |
182 | up_block = get_up_block(
183 | up_block_type,
184 | num_layers=layers_per_block + 1,
185 | in_channels=input_channel,
186 | out_channels=output_channel,
187 | prev_output_channel=prev_output_channel,
188 | temb_channels=time_embed_dim,
189 | add_upsample=add_upsample,
190 | resnet_eps=norm_eps,
191 | resnet_act_fn=act_fn,
192 | resnet_groups=norm_num_groups,
193 | cross_attention_dim=cross_attention_dim,
194 | attn_num_head_channels=reversed_attention_head_dim[i],
195 | dual_cross_attention=dual_cross_attention,
196 | use_linear_projection=use_linear_projection,
197 | only_cross_attention=only_cross_attention[i],
198 | upcast_attention=upcast_attention,
199 | resnet_time_scale_shift=resnet_time_scale_shift,
200 | )
201 | self.up_blocks.append(up_block)
202 | prev_output_channel = output_channel
203 |
204 | # out
205 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
206 | self.conv_act = nn.SiLU()
207 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
208 |
209 | def set_attention_slice(self, slice_size):
210 | r"""
211 | Enable sliced attention computation.
212 |
213 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
214 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
215 |
216 | Args:
217 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
218 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
219 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
220 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
221 | must be a multiple of `slice_size`.
222 | """
223 | sliceable_head_dims = []
224 |
225 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
226 | if hasattr(module, "set_attention_slice"):
227 | sliceable_head_dims.append(module.sliceable_head_dim)
228 |
229 | for child in module.children():
230 | fn_recursive_retrieve_slicable_dims(child)
231 |
232 | # retrieve number of attention layers
233 | for module in self.children():
234 | fn_recursive_retrieve_slicable_dims(module)
235 |
236 | num_slicable_layers = len(sliceable_head_dims)
237 |
238 | if slice_size == "auto":
239 | # half the attention head size is usually a good trade-off between
240 | # speed and memory
241 | slice_size = [dim // 2 for dim in sliceable_head_dims]
242 | elif slice_size == "max":
243 | # make smallest slice possible
244 | slice_size = num_slicable_layers * [1]
245 |
246 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
247 |
248 | if len(slice_size) != len(sliceable_head_dims):
249 | raise ValueError(
250 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
251 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
252 | )
253 |
254 | for i in range(len(slice_size)):
255 | size = slice_size[i]
256 | dim = sliceable_head_dims[i]
257 | if size is not None and size > dim:
258 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
259 |
260 | # Recursively walk through all the children.
261 | # Any children which exposes the set_attention_slice method
262 | # gets the message
263 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
264 | if hasattr(module, "set_attention_slice"):
265 | module.set_attention_slice(slice_size.pop())
266 |
267 | for child in module.children():
268 | fn_recursive_set_attention_slice(child, slice_size)
269 |
270 | reversed_slice_size = list(reversed(slice_size))
271 | for module in self.children():
272 | fn_recursive_set_attention_slice(module, reversed_slice_size)
273 |
274 | def _set_gradient_checkpointing(self, module, value=False):
275 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
276 | module.gradient_checkpointing = value
277 |
278 | def forward(
279 | self,
280 | sample: torch.FloatTensor,
281 | timestep: Union[torch.Tensor, float, int],
282 | encoder_hidden_states: torch.Tensor,
283 | class_labels: Optional[torch.Tensor] = None,
284 | attention_mask: Optional[torch.Tensor] = None,
285 | return_dict: bool = True,
286 | cross_attention_kwargs = None,
287 | inter_frame = False,
288 | **kwargs,
289 | ) -> Union[UNet3DConditionOutput, Tuple]:
290 | r"""
291 | Args:
292 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
293 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
294 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
295 | return_dict (`bool`, *optional*, defaults to `True`):
296 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
297 |
298 | Returns:
299 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
300 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
301 | returning a tuple, the first element is the sample tensor.
302 | """
303 | # By default samples have to be AT least a multiple of the overall upsampling factor.
304 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
305 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
306 | # on the fly if necessary.
307 | default_overall_up_factor = 2**self.num_upsamplers
308 | kwargs["t"] = timestep
309 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
310 | forward_upsample_size = False
311 | upsample_size = None
312 |
313 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
314 | logger.info("Forward upsample size to force interpolation output size.")
315 | forward_upsample_size = True
316 |
317 | # prepare attention_mask
318 | if attention_mask is not None:
319 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
320 | attention_mask = attention_mask.unsqueeze(1)
321 |
322 | # center input if necessary
323 | if self.config.center_input_sample:
324 | sample = 2 * sample - 1.0
325 |
326 | # time
327 | timesteps = timestep
328 | if not torch.is_tensor(timesteps):
329 | # This would be a good case for the `match` statement (Python 3.10+)
330 | is_mps = sample.device.type == "mps"
331 | if isinstance(timestep, float):
332 | dtype = torch.float32 if is_mps else torch.float64
333 | else:
334 | dtype = torch.int32 if is_mps else torch.int64
335 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
336 | elif len(timesteps.shape) == 0:
337 | timesteps = timesteps[None].to(sample.device)
338 |
339 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
340 | timesteps = timesteps.expand(sample.shape[0])
341 |
342 | t_emb = self.time_proj(timesteps)
343 |
344 | # timesteps does not contain any weights and will always return f32 tensors
345 | # but time_embedding might actually be running in fp16. so we need to cast here.
346 | # there might be better ways to encapsulate this.
347 | t_emb = t_emb.to(dtype=self.dtype)
348 | emb = self.time_embedding(t_emb)
349 |
350 | if self.class_embedding is not None:
351 | if class_labels is None:
352 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
353 |
354 | if self.config.class_embed_type == "timestep":
355 | class_labels = self.time_proj(class_labels)
356 |
357 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
358 | emb = emb + class_emb
359 |
360 | # pre-process
361 | sample = self.conv_in(sample)
362 |
363 | # down
364 | down_block_res_samples = (sample,)
365 | for downsample_block in self.down_blocks:
366 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
367 | sample, res_samples = downsample_block(
368 | hidden_states=sample,
369 | temb=emb,
370 | encoder_hidden_states=encoder_hidden_states,
371 | attention_mask=attention_mask,
372 | inter_frame=inter_frame,
373 | **kwargs,
374 | )
375 | else:
376 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
377 |
378 | down_block_res_samples += res_samples
379 |
380 | # mid
381 | sample = self.mid_block(
382 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,
383 | inter_frame=inter_frame,
384 | **kwargs,
385 | )
386 |
387 | # up
388 | for i, upsample_block in enumerate(self.up_blocks):
389 | is_final_block = i == len(self.up_blocks) - 1
390 |
391 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
392 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
393 |
394 | # if we have not reached the final block and need to forward the
395 | # upsample size, we do it here
396 | if not is_final_block and forward_upsample_size:
397 | upsample_size = down_block_res_samples[-1].shape[2:]
398 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
399 | sample = upsample_block(
400 | hidden_states=sample,
401 | temb=emb,
402 | res_hidden_states_tuple=res_samples,
403 | encoder_hidden_states=encoder_hidden_states,
404 | upsample_size=upsample_size,
405 | attention_mask=attention_mask,
406 | inter_frame=inter_frame,
407 | **kwargs,
408 | )
409 | else:
410 | sample = upsample_block(
411 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
412 | )
413 | # post-process
414 | sample = self.conv_norm_out(sample)
415 | sample = self.conv_act(sample)
416 | sample = self.conv_out(sample)
417 |
418 | if not return_dict:
419 | return (sample,)
420 |
421 | return UNet3DConditionOutput(sample=sample)
422 |
423 | @classmethod
424 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
425 | if subfolder is not None:
426 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
427 |
428 | config_file = os.path.join(pretrained_model_path, 'config.json')
429 | if not os.path.isfile(config_file):
430 | raise RuntimeError(f"{config_file} does not exist")
431 | with open(config_file, "r") as f:
432 | config = json.load(f)
433 | config["_class_name"] = cls.__name__
434 | config["down_block_types"] = [
435 | "CrossAttnDownBlock3D",
436 | "CrossAttnDownBlock3D",
437 | "CrossAttnDownBlock3D",
438 | "DownBlock3D"
439 | ]
440 | config["up_block_types"] = [
441 | "UpBlock3D",
442 | "CrossAttnUpBlock3D",
443 | "CrossAttnUpBlock3D",
444 | "CrossAttnUpBlock3D"
445 | ]
446 |
447 | from diffusers.utils import WEIGHTS_NAME
448 | model = cls.from_config(config)
449 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
450 | if not os.path.isfile(model_file):
451 | raise RuntimeError(f"{model_file} does not exist")
452 | state_dict = torch.load(model_file, map_location="cpu")
453 | # for k, v in model.state_dict().items():
454 | # if '_temp.' in k:
455 | # state_dict.update({k: v})
456 | model.load_state_dict(state_dict, strict=False)
457 |
458 | return model
459 |
--------------------------------------------------------------------------------
/models/unet_blocks.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .attention import Transformer3DModel
7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8 |
9 |
10 | def get_down_block(
11 | down_block_type,
12 | num_layers,
13 | in_channels,
14 | out_channels,
15 | temb_channels,
16 | add_downsample,
17 | resnet_eps,
18 | resnet_act_fn,
19 | attn_num_head_channels,
20 | resnet_groups=None,
21 | cross_attention_dim=None,
22 | downsample_padding=None,
23 | dual_cross_attention=False,
24 | use_linear_projection=False,
25 | only_cross_attention=False,
26 | upcast_attention=False,
27 | resnet_time_scale_shift="default",
28 | ):
29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30 | if down_block_type == "DownBlock3D":
31 | return DownBlock3D(
32 | num_layers=num_layers,
33 | in_channels=in_channels,
34 | out_channels=out_channels,
35 | temb_channels=temb_channels,
36 | add_downsample=add_downsample,
37 | resnet_eps=resnet_eps,
38 | resnet_act_fn=resnet_act_fn,
39 | resnet_groups=resnet_groups,
40 | downsample_padding=downsample_padding,
41 | resnet_time_scale_shift=resnet_time_scale_shift,
42 | )
43 | elif down_block_type == "CrossAttnDownBlock3D":
44 | if cross_attention_dim is None:
45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46 | return CrossAttnDownBlock3D(
47 | num_layers=num_layers,
48 | in_channels=in_channels,
49 | out_channels=out_channels,
50 | temb_channels=temb_channels,
51 | add_downsample=add_downsample,
52 | resnet_eps=resnet_eps,
53 | resnet_act_fn=resnet_act_fn,
54 | resnet_groups=resnet_groups,
55 | downsample_padding=downsample_padding,
56 | cross_attention_dim=cross_attention_dim,
57 | attn_num_head_channels=attn_num_head_channels,
58 | dual_cross_attention=dual_cross_attention,
59 | use_linear_projection=use_linear_projection,
60 | only_cross_attention=only_cross_attention,
61 | upcast_attention=upcast_attention,
62 | resnet_time_scale_shift=resnet_time_scale_shift,
63 | )
64 | raise ValueError(f"{down_block_type} does not exist.")
65 |
66 |
67 | def get_up_block(
68 | up_block_type,
69 | num_layers,
70 | in_channels,
71 | out_channels,
72 | prev_output_channel,
73 | temb_channels,
74 | add_upsample,
75 | resnet_eps,
76 | resnet_act_fn,
77 | attn_num_head_channels,
78 | resnet_groups=None,
79 | cross_attention_dim=None,
80 | dual_cross_attention=False,
81 | use_linear_projection=False,
82 | only_cross_attention=False,
83 | upcast_attention=False,
84 | resnet_time_scale_shift="default",
85 | ):
86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87 | if up_block_type == "UpBlock3D":
88 | return UpBlock3D(
89 | num_layers=num_layers,
90 | in_channels=in_channels,
91 | out_channels=out_channels,
92 | prev_output_channel=prev_output_channel,
93 | temb_channels=temb_channels,
94 | add_upsample=add_upsample,
95 | resnet_eps=resnet_eps,
96 | resnet_act_fn=resnet_act_fn,
97 | resnet_groups=resnet_groups,
98 | resnet_time_scale_shift=resnet_time_scale_shift,
99 | )
100 | elif up_block_type == "CrossAttnUpBlock3D":
101 | if cross_attention_dim is None:
102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103 | return CrossAttnUpBlock3D(
104 | num_layers=num_layers,
105 | in_channels=in_channels,
106 | out_channels=out_channels,
107 | prev_output_channel=prev_output_channel,
108 | temb_channels=temb_channels,
109 | add_upsample=add_upsample,
110 | resnet_eps=resnet_eps,
111 | resnet_act_fn=resnet_act_fn,
112 | resnet_groups=resnet_groups,
113 | cross_attention_dim=cross_attention_dim,
114 | attn_num_head_channels=attn_num_head_channels,
115 | dual_cross_attention=dual_cross_attention,
116 | use_linear_projection=use_linear_projection,
117 | only_cross_attention=only_cross_attention,
118 | upcast_attention=upcast_attention,
119 | resnet_time_scale_shift=resnet_time_scale_shift,
120 | )
121 | raise ValueError(f"{up_block_type} does not exist.")
122 |
123 |
124 | class UNetMidBlock3DCrossAttn(nn.Module):
125 | def __init__(
126 | self,
127 | in_channels: int,
128 | temb_channels: int,
129 | dropout: float = 0.0,
130 | num_layers: int = 1,
131 | resnet_eps: float = 1e-6,
132 | resnet_time_scale_shift: str = "default",
133 | resnet_act_fn: str = "swish",
134 | resnet_groups: int = 32,
135 | resnet_pre_norm: bool = True,
136 | attn_num_head_channels=1,
137 | output_scale_factor=1.0,
138 | cross_attention_dim=1280,
139 | dual_cross_attention=False,
140 | use_linear_projection=False,
141 | upcast_attention=False,
142 | ):
143 | super().__init__()
144 |
145 | self.has_cross_attention = True
146 | self.attn_num_head_channels = attn_num_head_channels
147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148 |
149 | # there is always at least one resnet
150 | resnets = [
151 | ResnetBlock3D(
152 | in_channels=in_channels,
153 | out_channels=in_channels,
154 | temb_channels=temb_channels,
155 | eps=resnet_eps,
156 | groups=resnet_groups,
157 | dropout=dropout,
158 | time_embedding_norm=resnet_time_scale_shift,
159 | non_linearity=resnet_act_fn,
160 | output_scale_factor=output_scale_factor,
161 | pre_norm=resnet_pre_norm,
162 | )
163 | ]
164 | attentions = []
165 |
166 | for _ in range(num_layers):
167 | if dual_cross_attention:
168 | raise NotImplementedError
169 | attentions.append(
170 | Transformer3DModel(
171 | attn_num_head_channels,
172 | in_channels // attn_num_head_channels,
173 | in_channels=in_channels,
174 | num_layers=1,
175 | cross_attention_dim=cross_attention_dim,
176 | norm_num_groups=resnet_groups,
177 | use_linear_projection=use_linear_projection,
178 | upcast_attention=upcast_attention,
179 | )
180 | )
181 | resnets.append(
182 | ResnetBlock3D(
183 | in_channels=in_channels,
184 | out_channels=in_channels,
185 | temb_channels=temb_channels,
186 | eps=resnet_eps,
187 | groups=resnet_groups,
188 | dropout=dropout,
189 | time_embedding_norm=resnet_time_scale_shift,
190 | non_linearity=resnet_act_fn,
191 | output_scale_factor=output_scale_factor,
192 | pre_norm=resnet_pre_norm,
193 | )
194 | )
195 |
196 | self.attentions = nn.ModuleList(attentions)
197 | self.resnets = nn.ModuleList(resnets)
198 |
199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs):
200 | hidden_states = self.resnets[0](hidden_states, temb)
201 | for attn, resnet in zip(self.attentions, self.resnets[1:]):
202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
203 | hidden_states = resnet(hidden_states, temb)
204 |
205 | return hidden_states
206 |
207 |
208 | class CrossAttnDownBlock3D(nn.Module):
209 | def __init__(
210 | self,
211 | in_channels: int,
212 | out_channels: int,
213 | temb_channels: int,
214 | dropout: float = 0.0,
215 | num_layers: int = 1,
216 | resnet_eps: float = 1e-6,
217 | resnet_time_scale_shift: str = "default",
218 | resnet_act_fn: str = "swish",
219 | resnet_groups: int = 32,
220 | resnet_pre_norm: bool = True,
221 | attn_num_head_channels=1,
222 | cross_attention_dim=1280,
223 | output_scale_factor=1.0,
224 | downsample_padding=1,
225 | add_downsample=True,
226 | dual_cross_attention=False,
227 | use_linear_projection=False,
228 | only_cross_attention=False,
229 | upcast_attention=False,
230 | ):
231 | super().__init__()
232 | resnets = []
233 | attentions = []
234 |
235 | self.has_cross_attention = True
236 | self.attn_num_head_channels = attn_num_head_channels
237 |
238 | for i in range(num_layers):
239 | in_channels = in_channels if i == 0 else out_channels
240 | resnets.append(
241 | ResnetBlock3D(
242 | in_channels=in_channels,
243 | out_channels=out_channels,
244 | temb_channels=temb_channels,
245 | eps=resnet_eps,
246 | groups=resnet_groups,
247 | dropout=dropout,
248 | time_embedding_norm=resnet_time_scale_shift,
249 | non_linearity=resnet_act_fn,
250 | output_scale_factor=output_scale_factor,
251 | pre_norm=resnet_pre_norm,
252 | )
253 | )
254 | if dual_cross_attention:
255 | raise NotImplementedError
256 | attentions.append(
257 | Transformer3DModel(
258 | attn_num_head_channels,
259 | out_channels // attn_num_head_channels,
260 | in_channels=out_channels,
261 | num_layers=1,
262 | cross_attention_dim=cross_attention_dim,
263 | norm_num_groups=resnet_groups,
264 | use_linear_projection=use_linear_projection,
265 | only_cross_attention=only_cross_attention,
266 | upcast_attention=upcast_attention,
267 | )
268 | )
269 | self.attentions = nn.ModuleList(attentions)
270 | self.resnets = nn.ModuleList(resnets)
271 |
272 | if add_downsample:
273 | self.downsamplers = nn.ModuleList(
274 | [
275 | Downsample3D(
276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277 | )
278 | ]
279 | )
280 | else:
281 | self.downsamplers = None
282 |
283 | self.gradient_checkpointing = False
284 |
285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, inter_frame=False, **kwargs):
286 | output_states = ()
287 |
288 | for resnet, attn in zip(self.resnets, self.attentions):
289 | if self.training and self.gradient_checkpointing:
290 |
291 | def create_custom_forward(module, return_dict=None, inter_frame=None):
292 | def custom_forward(*inputs):
293 | if return_dict is not None:
294 | return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
295 | else:
296 | return module(*inputs)
297 |
298 | return custom_forward
299 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
300 | hidden_states = torch.utils.checkpoint.checkpoint(
301 | create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
302 | hidden_states,
303 | encoder_hidden_states,
304 | )[0]
305 | else:
306 | hidden_states = resnet(hidden_states, temb)
307 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
308 |
309 | output_states += (hidden_states,)
310 |
311 | if self.downsamplers is not None:
312 | for downsampler in self.downsamplers:
313 | hidden_states = downsampler(hidden_states)
314 |
315 | output_states += (hidden_states,)
316 |
317 | return hidden_states, output_states
318 |
319 |
320 | class DownBlock3D(nn.Module):
321 | def __init__(
322 | self,
323 | in_channels: int,
324 | out_channels: int,
325 | temb_channels: int,
326 | dropout: float = 0.0,
327 | num_layers: int = 1,
328 | resnet_eps: float = 1e-6,
329 | resnet_time_scale_shift: str = "default",
330 | resnet_act_fn: str = "swish",
331 | resnet_groups: int = 32,
332 | resnet_pre_norm: bool = True,
333 | output_scale_factor=1.0,
334 | add_downsample=True,
335 | downsample_padding=1,
336 | ):
337 | super().__init__()
338 | resnets = []
339 |
340 | for i in range(num_layers):
341 | in_channels = in_channels if i == 0 else out_channels
342 | resnets.append(
343 | ResnetBlock3D(
344 | in_channels=in_channels,
345 | out_channels=out_channels,
346 | temb_channels=temb_channels,
347 | eps=resnet_eps,
348 | groups=resnet_groups,
349 | dropout=dropout,
350 | time_embedding_norm=resnet_time_scale_shift,
351 | non_linearity=resnet_act_fn,
352 | output_scale_factor=output_scale_factor,
353 | pre_norm=resnet_pre_norm,
354 | )
355 | )
356 |
357 | self.resnets = nn.ModuleList(resnets)
358 |
359 | if add_downsample:
360 | self.downsamplers = nn.ModuleList(
361 | [
362 | Downsample3D(
363 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
364 | )
365 | ]
366 | )
367 | else:
368 | self.downsamplers = None
369 |
370 | self.gradient_checkpointing = False
371 |
372 | def forward(self, hidden_states, temb=None):
373 | output_states = ()
374 |
375 | for resnet in self.resnets:
376 | if self.training and self.gradient_checkpointing:
377 |
378 | def create_custom_forward(module):
379 | def custom_forward(*inputs):
380 | return module(*inputs)
381 |
382 | return custom_forward
383 |
384 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
385 | else:
386 | hidden_states = resnet(hidden_states, temb)
387 |
388 | output_states += (hidden_states,)
389 |
390 | if self.downsamplers is not None:
391 | for downsampler in self.downsamplers:
392 | hidden_states = downsampler(hidden_states)
393 |
394 | output_states += (hidden_states,)
395 |
396 | return hidden_states, output_states
397 |
398 |
399 | class CrossAttnUpBlock3D(nn.Module):
400 | def __init__(
401 | self,
402 | in_channels: int,
403 | out_channels: int,
404 | prev_output_channel: int,
405 | temb_channels: int,
406 | dropout: float = 0.0,
407 | num_layers: int = 1,
408 | resnet_eps: float = 1e-6,
409 | resnet_time_scale_shift: str = "default",
410 | resnet_act_fn: str = "swish",
411 | resnet_groups: int = 32,
412 | resnet_pre_norm: bool = True,
413 | attn_num_head_channels=1,
414 | cross_attention_dim=1280,
415 | output_scale_factor=1.0,
416 | add_upsample=True,
417 | dual_cross_attention=False,
418 | use_linear_projection=False,
419 | only_cross_attention=False,
420 | upcast_attention=False,
421 | ):
422 | super().__init__()
423 | resnets = []
424 | attentions = []
425 |
426 | self.has_cross_attention = True
427 | self.attn_num_head_channels = attn_num_head_channels
428 |
429 | for i in range(num_layers):
430 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
431 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
432 |
433 | resnets.append(
434 | ResnetBlock3D(
435 | in_channels=resnet_in_channels + res_skip_channels,
436 | out_channels=out_channels,
437 | temb_channels=temb_channels,
438 | eps=resnet_eps,
439 | groups=resnet_groups,
440 | dropout=dropout,
441 | time_embedding_norm=resnet_time_scale_shift,
442 | non_linearity=resnet_act_fn,
443 | output_scale_factor=output_scale_factor,
444 | pre_norm=resnet_pre_norm,
445 | )
446 | )
447 | if dual_cross_attention:
448 | raise NotImplementedError
449 | attentions.append(
450 | Transformer3DModel(
451 | attn_num_head_channels,
452 | out_channels // attn_num_head_channels,
453 | in_channels=out_channels,
454 | num_layers=1,
455 | cross_attention_dim=cross_attention_dim,
456 | norm_num_groups=resnet_groups,
457 | use_linear_projection=use_linear_projection,
458 | only_cross_attention=only_cross_attention,
459 | upcast_attention=upcast_attention,
460 | )
461 | )
462 |
463 | self.attentions = nn.ModuleList(attentions)
464 | self.resnets = nn.ModuleList(resnets)
465 |
466 | if add_upsample:
467 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
468 | else:
469 | self.upsamplers = None
470 |
471 | self.gradient_checkpointing = False
472 |
473 | def forward(
474 | self,
475 | hidden_states,
476 | res_hidden_states_tuple,
477 | temb=None,
478 | encoder_hidden_states=None,
479 | upsample_size=None,
480 | attention_mask=None,
481 | inter_frame=False,
482 | **kwargs,
483 | ):
484 | for resnet, attn in zip(self.resnets, self.attentions):
485 | # pop res hidden states
486 | res_hidden_states = res_hidden_states_tuple[-1]
487 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
488 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
489 |
490 | if self.training and self.gradient_checkpointing:
491 |
492 | def create_custom_forward(module, return_dict=None, inter_frame=None):
493 | def custom_forward(*inputs):
494 | if return_dict is not None:
495 | return module(*inputs, return_dict=return_dict, inter_frame=inter_frame)
496 | else:
497 | return module(*inputs)
498 |
499 | return custom_forward
500 |
501 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
502 | hidden_states = torch.utils.checkpoint.checkpoint(
503 | create_custom_forward(attn, return_dict=False, inter_frame=inter_frame),
504 | hidden_states,
505 | encoder_hidden_states,
506 | )[0]
507 | else:
508 | hidden_states = resnet(hidden_states, temb)
509 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, inter_frame=inter_frame, **kwargs).sample
510 |
511 | if self.upsamplers is not None:
512 | for upsampler in self.upsamplers:
513 | hidden_states = upsampler(hidden_states, upsample_size)
514 |
515 | return hidden_states
516 |
517 |
518 | class UpBlock3D(nn.Module):
519 | def __init__(
520 | self,
521 | in_channels: int,
522 | prev_output_channel: int,
523 | out_channels: int,
524 | temb_channels: int,
525 | dropout: float = 0.0,
526 | num_layers: int = 1,
527 | resnet_eps: float = 1e-6,
528 | resnet_time_scale_shift: str = "default",
529 | resnet_act_fn: str = "swish",
530 | resnet_groups: int = 32,
531 | resnet_pre_norm: bool = True,
532 | output_scale_factor=1.0,
533 | add_upsample=True,
534 | ):
535 | super().__init__()
536 | resnets = []
537 |
538 | for i in range(num_layers):
539 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
540 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
541 |
542 | resnets.append(
543 | ResnetBlock3D(
544 | in_channels=resnet_in_channels + res_skip_channels,
545 | out_channels=out_channels,
546 | temb_channels=temb_channels,
547 | eps=resnet_eps,
548 | groups=resnet_groups,
549 | dropout=dropout,
550 | time_embedding_norm=resnet_time_scale_shift,
551 | non_linearity=resnet_act_fn,
552 | output_scale_factor=output_scale_factor,
553 | pre_norm=resnet_pre_norm,
554 | )
555 | )
556 |
557 | self.resnets = nn.ModuleList(resnets)
558 |
559 | if add_upsample:
560 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
561 | else:
562 | self.upsamplers = None
563 |
564 | self.gradient_checkpointing = False
565 |
566 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
567 | for resnet in self.resnets:
568 | # pop res hidden states
569 | res_hidden_states = res_hidden_states_tuple[-1]
570 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
571 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
572 |
573 | if self.training and self.gradient_checkpointing:
574 |
575 | def create_custom_forward(module):
576 | def custom_forward(*inputs):
577 | return module(*inputs)
578 |
579 | return custom_forward
580 |
581 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
582 | else:
583 | hidden_states = resnet(hidden_states, temb)
584 |
585 | if self.upsamplers is not None:
586 | for upsampler in self.upsamplers:
587 | hidden_states = upsampler(hidden_states, upsample_size)
588 |
589 | return hidden_states
590 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from typing import Union
5 | import decord
6 | decord.bridge.set_bridge('torch')
7 | import torch
8 | import torchvision
9 | import PIL
10 | from typing import List
11 | from tqdm import tqdm
12 | from einops import rearrange
13 | import torchvision.transforms.functional as F
14 | import random
15 |
16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
17 | videos = rearrange(videos, "b c t h w -> t b c h w")
18 | outputs = []
19 | for x in videos:
20 | x = torchvision.utils.make_grid(x, nrow=n_rows)
21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
22 | if rescale:
23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
24 | x = (x * 255).numpy().astype(np.uint8)
25 | outputs.append(x)
26 |
27 | os.makedirs(os.path.dirname(path), exist_ok=True)
28 | imageio.mimsave(path, outputs, fps=fps)
29 |
30 | def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8):
31 | videos = rearrange(videos, "b c t h w -> t b c h w")
32 | outputs = []
33 | for x in videos:
34 | x = torchvision.utils.make_grid(x, nrow=n_rows)
35 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
36 | if rescale:
37 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
38 | x = (x * 255).numpy().astype(np.uint8)
39 | outputs.append(x)
40 |
41 | os.makedirs(os.path.dirname(path), exist_ok=True)
42 | imageio.mimsave(path, outputs, fps=fps)
43 |
44 | def read_video(video_path, video_length, width=512, height=512, frame_rate=None):
45 | vr = decord.VideoReader(video_path, width=width, height=height)
46 | if frame_rate is None:
47 | frame_rate = max(1, len(vr) // video_length)
48 | sample_index = list(range(0, len(vr), frame_rate))[:video_length]
49 | video = vr.get_batch(sample_index)
50 | video = rearrange(video, "f h w c -> f c h w")
51 | video = (video / 127.5 - 1.0)
52 | return video
53 |
54 |
55 | # DDIM Inversion
56 | @torch.no_grad()
57 | def init_prompt(prompt, pipeline):
58 | uncond_input = pipeline.tokenizer(
59 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
60 | return_tensors="pt"
61 | )
62 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
63 | text_input = pipeline.tokenizer(
64 | [prompt],
65 | padding="max_length",
66 | max_length=pipeline.tokenizer.model_max_length,
67 | truncation=True,
68 | return_tensors="pt",
69 | )
70 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
71 | context = torch.cat([uncond_embeddings, text_embeddings])
72 |
73 | return context
74 |
75 |
76 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
77 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
78 | timestep, next_timestep = min(
79 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
80 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
81 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
82 | beta_prod_t = 1 - alpha_prod_t
83 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
84 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
85 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
86 | return next_sample
87 |
88 |
89 | def get_noise_pred_single(latents, t, context, unet):
90 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
91 | return noise_pred
92 |
93 |
94 | @torch.no_grad()
95 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
96 | context = init_prompt(prompt, pipeline)
97 | uncond_embeddings, cond_embeddings = context.chunk(2)
98 | all_latent = [latent]
99 | latent = latent.clone().detach()
100 | for i in tqdm(range(num_inv_steps)):
101 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
102 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
103 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
104 | all_latent.append(latent)
105 | return all_latent
106 |
107 |
108 | @torch.no_grad()
109 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
110 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
111 | return ddim_latents
112 |
113 |
114 | """optical flow and trajectories sampling"""
115 | def preprocess(img1_batch, img2_batch, transforms):
116 | img1_batch = F.resize(img1_batch, size=[512, 512], antialias=False)
117 | img2_batch = F.resize(img2_batch, size=[512, 512], antialias=False)
118 | return transforms(img1_batch, img2_batch)
119 |
120 | def keys_with_same_value(dictionary):
121 | result = {}
122 | for key, value in dictionary.items():
123 | if value not in result:
124 | result[value] = [key]
125 | else:
126 | result[value].append(key)
127 |
128 | conflict_points = {}
129 | for k in result.keys():
130 | if len(result[k]) > 1:
131 | conflict_points[k] = result[k]
132 | return conflict_points
133 |
134 | def find_duplicates(input_list):
135 | seen = set()
136 | duplicates = set()
137 |
138 | for item in input_list:
139 | if item in seen:
140 | duplicates.add(item)
141 | else:
142 | seen.add(item)
143 |
144 | return list(duplicates)
145 |
146 | def neighbors_index(point, window_size, H, W):
147 | """return the spatial neighbor indices"""
148 | t, x, y = point
149 | neighbors = []
150 | for i in range(-window_size, window_size + 1):
151 | for j in range(-window_size, window_size + 1):
152 | if i == 0 and j == 0:
153 | continue
154 | if x + i < 0 or x + i >= H or y + j < 0 or y + j >= W:
155 | continue
156 | neighbors.append((t, x + i, y + j))
157 | return neighbors
158 |
159 |
160 | @torch.no_grad()
161 | def sample_trajectories(video_path, device):
162 | from torchvision.models.optical_flow import Raft_Large_Weights
163 | from torchvision.models.optical_flow import raft_large
164 |
165 | weights = Raft_Large_Weights.DEFAULT
166 | transforms = weights.transforms()
167 |
168 | frames, _, _ = torchvision.io.read_video(str(video_path), output_format="TCHW")
169 |
170 | clips = list(range(len(frames)))
171 |
172 | model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
173 | model = model.eval()
174 |
175 | finished_trajectories = []
176 |
177 | current_frames, next_frames = preprocess(frames[clips[:-1]], frames[clips[1:]], transforms)
178 | list_of_flows = model(current_frames.to(device), next_frames.to(device))
179 | predicted_flows = list_of_flows[-1]
180 |
181 | predicted_flows = predicted_flows/512
182 |
183 | resolutions = [64, 32, 16, 8]
184 | res = {}
185 | window_sizes = {64: 2,
186 | 32: 1,
187 | 16: 1,
188 | 8: 1}
189 |
190 | for resolution in resolutions:
191 | print("="*30)
192 | trajectories = {}
193 | predicted_flow_resolu = torch.round(resolution*torch.nn.functional.interpolate(predicted_flows, scale_factor=(resolution/512, resolution/512)))
194 |
195 | T = predicted_flow_resolu.shape[0]+1
196 | H = predicted_flow_resolu.shape[2]
197 | W = predicted_flow_resolu.shape[3]
198 |
199 | is_activated = torch.zeros([T, H, W], dtype=torch.bool)
200 |
201 | for t in range(T-1):
202 | flow = predicted_flow_resolu[t]
203 | for h in range(H):
204 | for w in range(W):
205 |
206 | if not is_activated[t, h, w]:
207 | is_activated[t, h, w] = True
208 | # this point has not been traversed, start new trajectory
209 | x = h + int(flow[1, h, w])
210 | y = w + int(flow[0, h, w])
211 | if x >= 0 and x < H and y >= 0 and y < W:
212 | # trajectories.append([(t, h, w), (t+1, x, y)])
213 | trajectories[(t, h, w)]= (t+1, x, y)
214 |
215 | conflict_points = keys_with_same_value(trajectories)
216 | for k in conflict_points:
217 | index_to_pop = random.randint(0, len(conflict_points[k]) - 1)
218 | conflict_points[k].pop(index_to_pop)
219 | for point in conflict_points[k]:
220 | if point[0] != T-1:
221 | trajectories[point]= (-1, -1, -1) # stupid padding with (-1, -1, -1)
222 |
223 | active_traj = []
224 | all_traj = []
225 | for t in range(T):
226 | pixel_set = {(t, x//H, x%H):0 for x in range(H*W)}
227 | new_active_traj = []
228 | for traj in active_traj:
229 | if traj[-1] in trajectories:
230 | v = trajectories[traj[-1]]
231 | new_active_traj.append(traj + [v])
232 | pixel_set[v] = 1
233 | else:
234 | all_traj.append(traj)
235 | active_traj = new_active_traj
236 | active_traj+=[[pixel] for pixel in pixel_set if pixel_set[pixel] == 0]
237 | all_traj += active_traj
238 |
239 | useful_traj = [i for i in all_traj if len(i)>1]
240 | for idx in range(len(useful_traj)):
241 | if useful_traj[idx][-1] == (-1, -1, -1):
242 | useful_traj[idx] = useful_traj[idx][:-1]
243 | print("how many points in all trajectories for resolution{}?".format(resolution), sum([len(i) for i in useful_traj]))
244 | print("how many points in the video for resolution{}?".format(resolution), T*H*W)
245 |
246 | # validate if there are no duplicates in the trajectories
247 | trajs = []
248 | for traj in useful_traj:
249 | trajs = trajs + traj
250 | assert len(find_duplicates(trajs)) == 0, "There should not be duplicates in the useful trajectories."
251 |
252 | # check if non-appearing points + appearing points = all the points in the video
253 | all_points = set([(t, x, y) for t in range(T) for x in range(H) for y in range(W)])
254 | left_points = all_points- set(trajs)
255 | print("How many points not in the trajectories for resolution{}?".format(resolution), len(left_points))
256 | for p in list(left_points):
257 | useful_traj.append([p])
258 | print("how many points in all trajectories for resolution{} after pending?".format(resolution), sum([len(i) for i in useful_traj]))
259 |
260 |
261 | longest_length = max([len(i) for i in useful_traj])
262 | sequence_length = (window_sizes[resolution]*2+1)**2 + longest_length - 1
263 |
264 | seqs = []
265 | masks = []
266 |
267 | # create a dictionary to facilitate checking the trajectories to which each point belongs.
268 | point_to_traj = {}
269 | for traj in useful_traj:
270 | for p in traj:
271 | point_to_traj[p] = traj
272 |
273 | for t in range(T):
274 | for x in range(H):
275 | for y in range(W):
276 | neighbours = neighbors_index((t,x,y), window_sizes[resolution], H, W)
277 | sequence = [(t,x,y)]+neighbours + [(0,0,0) for i in range((window_sizes[resolution]*2+1)**2-1-len(neighbours))]
278 | sequence_mask = torch.zeros(sequence_length, dtype=torch.bool)
279 | sequence_mask[:len(neighbours)+1] = True
280 |
281 | traj = point_to_traj[(t,x,y)].copy()
282 | traj.remove((t,x,y))
283 | sequence = sequence + traj + [(0,0,0) for k in range(longest_length-1-len(traj))]
284 | sequence_mask[(window_sizes[resolution]*2+1)**2: (window_sizes[resolution]*2+1)**2 + len(traj)] = True
285 |
286 | seqs.append(sequence)
287 | masks.append(sequence_mask)
288 |
289 | seqs = torch.tensor(seqs)
290 | masks = torch.stack(masks)
291 | res["traj{}".format(resolution)] = seqs
292 | res["mask{}".format(resolution)] = masks
293 | return res
294 |
295 |
--------------------------------------------------------------------------------
/truck.sh:
--------------------------------------------------------------------------------
1 | python inference.py \
2 | --prompt "Wooden trucks drive on a racetrack." \
3 | --neg_prompt " " \
4 | --guidance_scale 15 \
5 | --video_path "data/trucks-race.mp4" \
6 | --output_path "truck/" \
7 | --video_length 32 \
8 | --width 512 \
9 | --height 512 \
10 | --frame_rate 1 \
11 | --old_qk 1
12 |
--------------------------------------------------------------------------------