├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── lvdm ├── __init__.py ├── basics.py ├── common.py ├── distributions.py ├── ema.py ├── models │ ├── autoencoder.py │ ├── ddpm3d.py │ ├── samplers │ │ ├── ddim.py │ │ ├── ddim_multiplecond.py │ │ └── unipc │ │ │ ├── __init__.py │ │ │ ├── sampler.py │ │ │ └── uni_pc.py │ └── utils_diffusion.py └── modules │ ├── attention.py │ ├── encoders │ ├── condition.py │ └── resampler.py │ ├── networks │ ├── ae_modules.py │ └── openaimodel3d.py │ └── x_transformer.py ├── nodes.py ├── utils ├── model_utils.py └── utils.py └── workflows ├── dynamicrafter_512_basic.json └── dynamicrafter_512_interp_gen.json /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: If your issue is with the current release, please try this release first. https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter/tree/43ae6bebccb141c6d85b5075f4fed57a5ddea3c1. If you have issues with the main branch, but not with the old release, please proceed to post below. 4 | --- 5 | 6 | **Describe the bug** 7 | Describe the issue 8 | 9 | **To Reproduce** 10 | Steps to reproduce the behavior: 11 | 1. Go to '...' 12 | 2. Click on '....' 13 | 3. Scroll down to '....' 14 | 4. See error 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI - Native DynamiCrafter 2 | DynamiCrafter that works natively with ComfyUI's nodes, optimizations, ControlNet, and more. 3 | 4 | ![image](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter/assets/59846140/fd1008ed-7660-454a-8253-1e032c9d054f) 5 | 6 | | | | 7 | | ------------- | ------------- | 8 | | ![DynamiCrafter_00298](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter/assets/59846140/e66a2559-b973-4a63-bc97-1a0701ab7dd3) | ![DynamiCrafter_00327](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter/assets/59846140/81b2b681-ef44-4966-8cb3-fa04692710a8) | 9 | 10 | 11 | 12 | > [!NOTE] 13 | > While this is still considered WIP (or beta), everything should be fully functional and adaptable to various workflows. 14 | 15 | # Getting Started 16 | 17 | Go to your `custom_nodes` directory in ComfyUI, and install by: 18 | 19 | `git clone https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter.git` 20 | 21 | > [!IMPORTANT] 22 | > This is a rapid release project. If there are any issues installing from main, the last stable branch is [here](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter/tree/43ae6bebccb141c6d85b5075f4fed57a5ddea3c1). 23 | > If everything is working fine, you can ignore this, but will miss out on the latest features. 24 | 25 | # Installation 26 | 27 | The pruned UNet checkpoints have been uploaded to HuggingFace. Each variant is working and fully functional. 28 | 29 | https://huggingface.co/ExponentialML/DynamiCrafterUNet 30 | 31 | ## Instructions 32 | You will also need a VAE, The CLIP model used with Stable Diffusion 2.1, and the Open CLIP Vision Model. All of the necessary model downloads are at that link. 33 | 34 | If you aready have the base SD models, you do not need to download them (just use the CheckpointSimpleLoader without the model part). 35 | 36 | Place the **DynamiCrafter** models inside `ComfyUI_Path/models/dynamicrafter_models` 37 | 38 | If you are downloading the CLIP and VAE models separately, place them under their respective paths in the `ComfyUI_Path/models/` directory. 39 | 40 | # Usage 41 | 42 | - **model**: The loaded DynamiCrafter model. 43 | 44 | - **clip_vision**: The CLIP Vision Checkpoint. 45 | 46 | - **vae**: A Stable Diffusion VAE. If it works with < SD 2.1, it will work with this. 47 | 48 | - **image_proj_model**: The Image Projection Model that is in the DynamiCrafter model file. 49 | 50 | - **images**: The input images necessary for inference. If you are doing interpolation, you can simply batch two images together, check the toggle (see below), and everything will be handled automatically. 51 | 52 | - **use_interpolation**: Use the interpolation mode with the interpolation model variant. You can interpolate any two frames (images), or predict the rest using one input. 53 | 54 | - **fps**: Controls the speed of the video. If you're using a 256 based model, the highest plausible value is **4** 55 | 56 | - **frames**: The amount of frames to use. If you're doing interpolation, the max is **16**. This is strictly enforced as it doesn't work properly (blurry results) if set higher. 57 | 58 | - **model (output)**: The output into the a Sampler. 59 | 60 | - **empty_latent**: An empty latent with the same size and frames as the processed ones. 61 | 62 | - **latent_img**: If you're doing Img2Img based workflows, this is the necessary one to use. 63 | 64 | # ControlNet Support 65 | 66 | You can now use DynamiCrafter by applying ControlNet to the Spatial (image) portion to guide video generations in various ways. 67 | The ControlNets are based on 2.1, so you must download them at the link below (Thanks @thibaud !) . 68 | 69 | **ControlNet 2.1**: https://huggingface.co/thibaud/controlnet-sd21 70 | 71 | After you download them, you can use them as you would with any other workflow. 72 | 73 | # Tips 74 | 75 | > [!TIP] 76 | > You don't have to use the latent outputs. As long as you use the same frame length (as your batch size) and same height and with as your image inputs, you can use your own latents. 77 | > This means that you can experiment with inpainting and so on. 78 | 79 | > [!TIP] 80 | > You can choose which frame you use as init by using VAE Encode Inpaint or Set Latent Noise Mask. You set the beginning batch mask to full black, while the rest are at full white. This also means you can do interpolation with regular models. 81 | > As these workflows are more advanced, examples will arrive at a future date. 82 | 83 | # TODO 84 | - [x] Add various workflows. 85 | - [ ] Add advanced workflows. 86 | - [x] Add support for Spatial Transformer options. 87 | - [x] Add ControlNet support. 88 | - [x] Ensure attention optimizations are working properly. 89 | - [ ] Add autoregressive nodes (this may be a separate repository) 90 | - [x] Add examples. (For more, [check here](https://github.com/Doubiiu/DynamiCrafter?tab=readme-ov-file#11-showcases-576x1024)). 91 | 92 | # Credits 93 | 94 | Thanks to @Doubiiu for for open sourcing [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter)! Please support their work, and please follow any license terms they may uphold. 95 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /lvdm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExponentialML/ComfyUI_Native_DynamiCrafter/1a91d71103307c430da93d0942534fffa83d915b/lvdm/__init__.py -------------------------------------------------------------------------------- /lvdm/basics.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | import comfy.ops 12 | ops = comfy.ops.disable_weight_init 13 | 14 | from ..utils.utils import instantiate_from_config 15 | 16 | def disabled_train(self, mode=True): 17 | """Overwrite model.train with this function to make sure train/eval mode 18 | does not change anymore.""" 19 | return self 20 | 21 | def zero_module(module): 22 | """ 23 | Zero out the parameters of a module and return it. 24 | """ 25 | for p in module.parameters(): 26 | p.detach().zero_() 27 | return module 28 | 29 | def scale_module(module, scale): 30 | """ 31 | Scale the parameters of a module and return it. 32 | """ 33 | for p in module.parameters(): 34 | p.detach().mul_(scale) 35 | return module 36 | 37 | 38 | def conv_nd(dims, *args, **kwargs): 39 | """ 40 | Create a 1D, 2D, or 3D convolution module. 41 | """ 42 | if dims == 1: 43 | return nn.Conv1d(*args, **kwargs) 44 | elif dims == 2: 45 | return ops.Conv2d(*args, **kwargs) 46 | elif dims == 3: 47 | return ops.Conv3d(*args, **kwargs) 48 | raise ValueError(f"unsupported dimensions: {dims}") 49 | 50 | 51 | def linear(*args, **kwargs): 52 | """ 53 | Create a linear module. 54 | """ 55 | return ops.Linear(*args, **kwargs) 56 | 57 | 58 | def avg_pool_nd(dims, *args, **kwargs): 59 | """ 60 | Create a 1D, 2D, or 3D average pooling module. 61 | """ 62 | if dims == 1: 63 | return nn.AvgPool1d(*args, **kwargs) 64 | elif dims == 2: 65 | return nn.AvgPool2d(*args, **kwargs) 66 | elif dims == 3: 67 | return nn.AvgPool3d(*args, **kwargs) 68 | raise ValueError(f"unsupported dimensions: {dims}") 69 | 70 | 71 | def nonlinearity(type='silu'): 72 | if type == 'silu': 73 | return nn.SiLU() 74 | elif type == 'leaky_relu': 75 | return nn.LeakyReLU() 76 | 77 | 78 | class GroupNormSpecific(ops.GroupNorm): 79 | def forward(self, x): 80 | return super().forward(x.float()).type(x.dtype) 81 | 82 | 83 | def normalization(channels, num_groups=32, dtype=None, device=None): 84 | """ 85 | Make a standard normalization layer. 86 | :param channels: number of input channels. 87 | :return: an nn.Module for normalization. 88 | """ 89 | return GroupNormSpecific(num_groups, channels, dtype=dtype, device=device) 90 | 91 | 92 | class HybridConditioner(nn.Module): 93 | 94 | def __init__(self, c_concat_config, c_crossattn_config): 95 | super().__init__() 96 | self.concat_conditioner = instantiate_from_config(c_concat_config) 97 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 98 | 99 | def forward(self, c_concat, c_crossattn): 100 | c_concat = self.concat_conditioner(c_concat) 101 | c_crossattn = self.crossattn_conditioner(c_crossattn) 102 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} -------------------------------------------------------------------------------- /lvdm/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from inspect import isfunction 3 | import torch 4 | from torch import nn 5 | import torch.distributed as dist 6 | 7 | 8 | def gather_data(data, return_np=True): 9 | ''' gather data from multiple processes to one list ''' 10 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 11 | dist.all_gather(data_list, data) # gather not supported with NCCL 12 | if return_np: 13 | data_list = [data.cpu().numpy() for data in data_list] 14 | return data_list 15 | 16 | def autocast(f): 17 | def do_autocast(*args, **kwargs): 18 | with torch.cuda.amp.autocast(enabled=True, 19 | dtype=torch.get_autocast_gpu_dtype(), 20 | cache_enabled=torch.is_autocast_cache_enabled()): 21 | return f(*args, **kwargs) 22 | return do_autocast 23 | 24 | 25 | def extract_into_tensor(a, t, x_shape): 26 | b, *_ = t.shape 27 | out = a.gather(-1, t) 28 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 29 | 30 | 31 | def noise_like(shape, device, repeat=False): 32 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 33 | noise = lambda: torch.randn(shape, device=device) 34 | return repeat_noise() if repeat else noise() 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if isfunction(d) else d 41 | 42 | def exists(val): 43 | return val is not None 44 | 45 | def identity(*args, **kwargs): 46 | return nn.Identity() 47 | 48 | def uniq(arr): 49 | return{el: True for el in arr}.keys() 50 | 51 | def mean_flat(tensor): 52 | """ 53 | Take the mean over all non-batch dimensions. 54 | """ 55 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 56 | 57 | def ismap(x): 58 | if not isinstance(x, torch.Tensor): 59 | return False 60 | return (len(x.shape) == 4) and (x.shape[1] > 3) 61 | 62 | def isimage(x): 63 | if not isinstance(x,torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 66 | 67 | def max_neg_value(t): 68 | return -torch.finfo(t.dtype).max 69 | 70 | def shape_to_str(x): 71 | shape_str = "x".join([str(x) for x in x.shape]) 72 | return shape_str 73 | 74 | def init_(tensor): 75 | dim = tensor.shape[-1] 76 | std = 1 / math.sqrt(dim) 77 | tensor.uniform_(-std, std) 78 | return tensor 79 | 80 | ckpt = torch.utils.checkpoint.checkpoint 81 | def checkpoint(func, inputs, params, flag): 82 | """ 83 | Evaluate a function without caching intermediate activations, allowing for 84 | reduced memory at the expense of extra compute in the backward pass. 85 | :param func: the function to evaluate. 86 | :param inputs: the argument sequence to pass to `func`. 87 | :param params: a sequence of parameters `func` depends on but does not 88 | explicitly take as arguments. 89 | :param flag: if False, disable gradient checkpointing. 90 | """ 91 | if flag: 92 | return ckpt(func, *inputs, use_reentrant=False) 93 | else: 94 | return func(*inputs) -------------------------------------------------------------------------------- /lvdm/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self, noise=None): 36 | if noise is None: 37 | noise = torch.randn(self.mean.shape) 38 | 39 | x = self.mean + self.std * noise.to(device=self.parameters.device) 40 | return x 41 | 42 | def kl(self, other=None): 43 | if self.deterministic: 44 | return torch.Tensor([0.]) 45 | else: 46 | if other is None: 47 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 48 | + self.var - 1.0 - self.logvar, 49 | dim=[1, 2, 3]) 50 | else: 51 | return 0.5 * torch.sum( 52 | torch.pow(self.mean - other.mean, 2) / other.var 53 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 54 | dim=[1, 2, 3]) 55 | 56 | def nll(self, sample, dims=[1,2,3]): 57 | if self.deterministic: 58 | return torch.Tensor([0.]) 59 | logtwopi = np.log(2.0 * np.pi) 60 | return 0.5 * torch.sum( 61 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 62 | dim=dims) 63 | 64 | def mode(self): 65 | return self.mean 66 | 67 | 68 | def normal_kl(mean1, logvar1, mean2, logvar2): 69 | """ 70 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 71 | Compute the KL divergence between two gaussians. 72 | Shapes are automatically broadcasted, so batches can be compared to 73 | scalars, among other use cases. 74 | """ 75 | tensor = None 76 | for obj in (mean1, logvar1, mean2, logvar2): 77 | if isinstance(obj, torch.Tensor): 78 | tensor = obj 79 | break 80 | assert tensor is not None, "at least one argument must be a Tensor" 81 | 82 | # Force variances to be Tensors. Broadcasting helps convert scalars to 83 | # Tensors, but it does not work for torch.exp(). 84 | logvar1, logvar2 = [ 85 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 86 | for x in (logvar1, logvar2) 87 | ] 88 | 89 | return 0.5 * ( 90 | -1.0 91 | + logvar2 92 | - logvar1 93 | + torch.exp(logvar1 - logvar2) 94 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 95 | ) -------------------------------------------------------------------------------- /lvdm/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /lvdm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | import torch.nn.functional as F 7 | import pytorch_lightning as pl 8 | from ...modules.networks.ae_modules import Encoder, Decoder 9 | from ...distributions import DiagonalGaussianDistribution 10 | from utils.utils import instantiate_from_config 11 | 12 | 13 | class AutoencoderKL(pl.LightningModule): 14 | def __init__(self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | ckpt_path=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | test=False, 24 | logdir=None, 25 | input_dim=4, 26 | test_args=None, 27 | ): 28 | super().__init__() 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | self.loss = instantiate_from_config(lossconfig) 33 | assert ddconfig["double_z"] 34 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 35 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 36 | self.embed_dim = embed_dim 37 | self.input_dim = input_dim 38 | self.test = test 39 | self.test_args = test_args 40 | self.logdir = logdir 41 | if colorize_nlabels is not None: 42 | assert type(colorize_nlabels)==int 43 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 44 | if monitor is not None: 45 | self.monitor = monitor 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | if self.test: 49 | self.init_test() 50 | 51 | def init_test(self,): 52 | self.test = True 53 | save_dir = os.path.join(self.logdir, "test") 54 | if 'ckpt' in self.test_args: 55 | ckpt_name = os.path.basename(self.test_args.ckpt).split('.ckpt')[0] + f'_epoch{self._cur_epoch}' 56 | self.root = os.path.join(save_dir, ckpt_name) 57 | else: 58 | self.root = save_dir 59 | if 'test_subdir' in self.test_args: 60 | self.root = os.path.join(save_dir, self.test_args.test_subdir) 61 | 62 | self.root_zs = os.path.join(self.root, "zs") 63 | self.root_dec = os.path.join(self.root, "reconstructions") 64 | self.root_inputs = os.path.join(self.root, "inputs") 65 | os.makedirs(self.root, exist_ok=True) 66 | 67 | if self.test_args.save_z: 68 | os.makedirs(self.root_zs, exist_ok=True) 69 | if self.test_args.save_reconstruction: 70 | os.makedirs(self.root_dec, exist_ok=True) 71 | if self.test_args.save_input: 72 | os.makedirs(self.root_inputs, exist_ok=True) 73 | assert(self.test_args is not None) 74 | self.test_maximum = getattr(self.test_args, 'test_maximum', None) 75 | self.count = 0 76 | self.eval_metrics = {} 77 | self.decodes = [] 78 | self.save_decode_samples = 2048 79 | 80 | def init_from_ckpt(self, path, ignore_keys=list()): 81 | sd = torch.load(path, map_location="cpu") 82 | try: 83 | self._cur_epoch = sd['epoch'] 84 | sd = sd["state_dict"] 85 | except: 86 | self._cur_epoch = 'null' 87 | keys = list(sd.keys()) 88 | for k in keys: 89 | for ik in ignore_keys: 90 | if k.startswith(ik): 91 | print("Deleting key {} from state_dict.".format(k)) 92 | del sd[k] 93 | self.load_state_dict(sd, strict=False) 94 | # self.load_state_dict(sd, strict=True) 95 | print(f"Restored from {path}") 96 | 97 | def encode(self, x, **kwargs): 98 | 99 | h = self.encoder(x) 100 | moments = self.quant_conv(h) 101 | posterior = DiagonalGaussianDistribution(moments) 102 | return posterior 103 | 104 | def decode(self, z, **kwargs): 105 | z = self.post_quant_conv(z) 106 | dec = self.decoder(z) 107 | return dec 108 | 109 | def forward(self, input, sample_posterior=True): 110 | posterior = self.encode(input) 111 | if sample_posterior: 112 | z = posterior.sample() 113 | else: 114 | z = posterior.mode() 115 | dec = self.decode(z) 116 | return dec, posterior 117 | 118 | def get_input(self, batch, k): 119 | x = batch[k] 120 | if x.dim() == 5 and self.input_dim == 4: 121 | b,c,t,h,w = x.shape 122 | self.b = b 123 | self.t = t 124 | x = rearrange(x, 'b c t h w -> (b t) c h w') 125 | 126 | return x 127 | 128 | def training_step(self, batch, batch_idx, optimizer_idx): 129 | inputs = self.get_input(batch, self.image_key) 130 | reconstructions, posterior = self(inputs) 131 | 132 | if optimizer_idx == 0: 133 | # train encoder+decoder+logvar 134 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 135 | last_layer=self.get_last_layer(), split="train") 136 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 137 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 138 | return aeloss 139 | 140 | if optimizer_idx == 1: 141 | # train the discriminator 142 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 143 | last_layer=self.get_last_layer(), split="train") 144 | 145 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 146 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 147 | return discloss 148 | 149 | def validation_step(self, batch, batch_idx): 150 | inputs = self.get_input(batch, self.image_key) 151 | reconstructions, posterior = self(inputs) 152 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 153 | last_layer=self.get_last_layer(), split="val") 154 | 155 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 156 | last_layer=self.get_last_layer(), split="val") 157 | 158 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 159 | self.log_dict(log_dict_ae) 160 | self.log_dict(log_dict_disc) 161 | return self.log_dict 162 | 163 | def configure_optimizers(self): 164 | lr = self.learning_rate 165 | opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 166 | list(self.decoder.parameters())+ 167 | list(self.quant_conv.parameters())+ 168 | list(self.post_quant_conv.parameters()), 169 | lr=lr, betas=(0.5, 0.9)) 170 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 171 | lr=lr, betas=(0.5, 0.9)) 172 | return [opt_ae, opt_disc], [] 173 | 174 | def get_last_layer(self): 175 | return self.decoder.conv_out.weight 176 | 177 | @torch.no_grad() 178 | def log_images(self, batch, only_inputs=False, **kwargs): 179 | log = dict() 180 | x = self.get_input(batch, self.image_key) 181 | x = x.to(self.device) 182 | if not only_inputs: 183 | xrec, posterior = self(x) 184 | if x.shape[1] > 3: 185 | # colorize with random projection 186 | assert xrec.shape[1] > 3 187 | x = self.to_rgb(x) 188 | xrec = self.to_rgb(xrec) 189 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 190 | log["reconstructions"] = xrec 191 | log["inputs"] = x 192 | return log 193 | 194 | def to_rgb(self, x): 195 | assert self.image_key == "segmentation" 196 | if not hasattr(self, "colorize"): 197 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 198 | x = F.conv2d(x, weight=self.colorize) 199 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 200 | return x 201 | 202 | class IdentityFirstStage(torch.nn.Module): 203 | def __init__(self, *args, vq_interface=False, **kwargs): 204 | self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff 205 | super().__init__() 206 | 207 | def encode(self, x, *args, **kwargs): 208 | return x 209 | 210 | def decode(self, x, *args, **kwargs): 211 | return x 212 | 213 | def quantize(self, x, *args, **kwargs): 214 | if self.vq_interface: 215 | return x, None, [None, None, None] 216 | return x 217 | 218 | def forward(self, x, *args, **kwargs): 219 | return x -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from ..models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from ..common import noise_like 6 | from ..common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | **kwargs 87 | ): 88 | 89 | # check condition bs 90 | if conditioning is not None: 91 | if isinstance(conditioning, dict): 92 | try: 93 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 94 | except: 95 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 96 | 97 | if cbs != batch_size: 98 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 99 | else: 100 | if conditioning.shape[0] != batch_size: 101 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 102 | 103 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 104 | 105 | # make shape 106 | if len(shape) == 3: 107 | C, H, W = shape 108 | size = (batch_size, C, H, W) 109 | elif len(shape) == 4: 110 | C, T, H, W = shape 111 | size = (batch_size, C, T, H, W) 112 | 113 | samples, intermediates = self.ddim_sampling(conditioning, size, 114 | callback=callback, 115 | img_callback=img_callback, 116 | quantize_denoised=quantize_x0, 117 | mask=mask, x0=x0, 118 | ddim_use_original_steps=False, 119 | noise_dropout=noise_dropout, 120 | temperature=temperature, 121 | score_corrector=score_corrector, 122 | corrector_kwargs=corrector_kwargs, 123 | x_T=x_T, 124 | log_every_t=log_every_t, 125 | unconditional_guidance_scale=unconditional_guidance_scale, 126 | unconditional_conditioning=unconditional_conditioning, 127 | verbose=verbose, 128 | precision=precision, 129 | fs=fs, 130 | guidance_rescale=guidance_rescale, 131 | **kwargs) 132 | return samples, intermediates 133 | 134 | @torch.no_grad() 135 | def ddim_sampling(self, cond, shape, 136 | x_T=None, ddim_use_original_steps=False, 137 | callback=None, timesteps=None, quantize_denoised=False, 138 | mask=None, x0=None, img_callback=None, log_every_t=100, 139 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 140 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 141 | **kwargs): 142 | device = self.model.betas.device 143 | b = shape[0] 144 | if x_T is None: 145 | img = torch.randn(shape, device=device) 146 | else: 147 | img = x_T 148 | if precision is not None: 149 | if precision == 16: 150 | img = img.to(dtype=torch.float16) 151 | 152 | if timesteps is None: 153 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 154 | elif timesteps is not None and not ddim_use_original_steps: 155 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 156 | timesteps = self.ddim_timesteps[:subset_end] 157 | 158 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 159 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 160 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 161 | if verbose: 162 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 163 | else: 164 | iterator = time_range 165 | 166 | clean_cond = kwargs.pop("clean_cond", False) 167 | 168 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 169 | for i, step in enumerate(iterator): 170 | index = total_steps - i - 1 171 | ts = torch.full((b,), step, device=device, dtype=torch.long) 172 | 173 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 174 | if mask is not None: 175 | assert x0 is not None 176 | if clean_cond: 177 | img_orig = x0 178 | else: 179 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 180 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 181 | 182 | 183 | 184 | 185 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 186 | quantize_denoised=quantize_denoised, temperature=temperature, 187 | noise_dropout=noise_dropout, score_corrector=score_corrector, 188 | corrector_kwargs=corrector_kwargs, 189 | unconditional_guidance_scale=unconditional_guidance_scale, 190 | unconditional_conditioning=unconditional_conditioning, 191 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 192 | **kwargs) 193 | 194 | 195 | img, pred_x0 = outs 196 | if callback: callback(i) 197 | if img_callback: img_callback(pred_x0, i) 198 | 199 | if index % log_every_t == 0 or index == total_steps - 1: 200 | intermediates['x_inter'].append(img) 201 | intermediates['pred_x0'].append(pred_x0) 202 | 203 | return img, intermediates 204 | 205 | @torch.no_grad() 206 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 207 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 208 | unconditional_guidance_scale=1., unconditional_conditioning=None, 209 | uc_type=None, conditional_guidance_scale_temporal=None,mask=None,x0=None,guidance_rescale=0.0,**kwargs): 210 | b, *_, device = *x.shape, x.device 211 | if x.dim() == 5: 212 | is_video = True 213 | else: 214 | is_video = False 215 | 216 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 217 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 218 | else: 219 | ### do_classifier_free_guidance 220 | if isinstance(c, torch.Tensor) or isinstance(c, dict): 221 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 222 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 223 | else: 224 | raise NotImplementedError 225 | 226 | model_output = e_t_uncond + unconditional_guidance_scale * (e_t_cond - e_t_uncond) 227 | 228 | if guidance_rescale > 0.0: 229 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 230 | 231 | if self.model.parameterization == "v": 232 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 233 | else: 234 | e_t = model_output 235 | 236 | if score_corrector is not None: 237 | assert self.model.parameterization == "eps", 'not implemented' 238 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 239 | 240 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 241 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 242 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 243 | # sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 244 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 245 | # select parameters corresponding to the currently considered timestep 246 | 247 | if is_video: 248 | size = (b, 1, 1, 1, 1) 249 | else: 250 | size = (b, 1, 1, 1) 251 | a_t = torch.full(size, alphas[index], device=device) 252 | a_prev = torch.full(size, alphas_prev[index], device=device) 253 | sigma_t = torch.full(size, sigmas[index], device=device) 254 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 255 | 256 | # current prediction for x_0 257 | if self.model.parameterization != "v": 258 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 259 | else: 260 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 261 | 262 | if self.model.use_dynamic_rescale: 263 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 264 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 265 | rescale = (prev_scale_t / scale_t) 266 | pred_x0 *= rescale 267 | 268 | if quantize_denoised: 269 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 270 | # direction pointing to x_t 271 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 272 | 273 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 274 | if noise_dropout > 0.: 275 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 276 | 277 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 278 | 279 | return x_prev, pred_x0 280 | 281 | @torch.no_grad() 282 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 283 | use_original_steps=False, callback=None): 284 | 285 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 286 | timesteps = timesteps[:t_start] 287 | 288 | time_range = np.flip(timesteps) 289 | total_steps = timesteps.shape[0] 290 | print(f"Running DDIM Sampling with {total_steps} timesteps") 291 | 292 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 293 | x_dec = x_latent 294 | for i, step in enumerate(iterator): 295 | index = total_steps - i - 1 296 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 297 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 298 | unconditional_guidance_scale=unconditional_guidance_scale, 299 | unconditional_conditioning=unconditional_conditioning) 300 | if callback: callback(i) 301 | return x_dec 302 | 303 | @torch.no_grad() 304 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 305 | # fast, but does not allow for exact reconstruction 306 | # t serves as an index to gather the correct alphas 307 | if use_original_steps: 308 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 309 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 310 | else: 311 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 312 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 313 | 314 | if noise is None: 315 | noise = torch.randn_like(x0) 316 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 317 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 318 | -------------------------------------------------------------------------------- /lvdm/models/samplers/ddim_multiplecond.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import torch 4 | from ...models.utils_diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg 5 | from ..common import noise_like 6 | from ..common import extract_into_tensor 7 | import copy 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | self.counter = 0 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 26 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 27 | alphas_cumprod = self.model.alphas_cumprod 28 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 29 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 30 | 31 | if self.model.use_dynamic_rescale: 32 | self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps] 33 | self.ddim_scale_arr_prev = torch.cat([self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]]) 34 | 35 | self.register_buffer('betas', to_torch(self.model.betas)) 36 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 37 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 38 | 39 | # calculations for diffusion q(x_t | x_{t-1}) and others 40 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 44 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 45 | 46 | # ddim sampling parameters 47 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 48 | ddim_timesteps=self.ddim_timesteps, 49 | eta=ddim_eta,verbose=verbose) 50 | self.register_buffer('ddim_sigmas', ddim_sigmas) 51 | self.register_buffer('ddim_alphas', ddim_alphas) 52 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 53 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 54 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 55 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 56 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 57 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 58 | 59 | @torch.no_grad() 60 | def sample(self, 61 | S, 62 | batch_size, 63 | shape, 64 | conditioning=None, 65 | callback=None, 66 | normals_sequence=None, 67 | img_callback=None, 68 | quantize_x0=False, 69 | eta=0., 70 | mask=None, 71 | x0=None, 72 | temperature=1., 73 | noise_dropout=0., 74 | score_corrector=None, 75 | corrector_kwargs=None, 76 | verbose=True, 77 | schedule_verbose=False, 78 | x_T=None, 79 | log_every_t=100, 80 | unconditional_guidance_scale=1., 81 | unconditional_conditioning=None, 82 | precision=None, 83 | fs=None, 84 | timestep_spacing='uniform', #uniform_trailing for starting from last timestep 85 | guidance_rescale=0.0, 86 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 87 | **kwargs 88 | ): 89 | 90 | # check condition bs 91 | if conditioning is not None: 92 | if isinstance(conditioning, dict): 93 | try: 94 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 95 | except: 96 | cbs = conditioning[list(conditioning.keys())[0]][0].shape[0] 97 | 98 | if cbs != batch_size: 99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 100 | else: 101 | if conditioning.shape[0] != batch_size: 102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 103 | 104 | # print('==> timestep_spacing: ', timestep_spacing, guidance_rescale) 105 | self.make_schedule(ddim_num_steps=S, ddim_discretize=timestep_spacing, ddim_eta=eta, verbose=schedule_verbose) 106 | 107 | # make shape 108 | if len(shape) == 3: 109 | C, H, W = shape 110 | size = (batch_size, C, H, W) 111 | elif len(shape) == 4: 112 | C, T, H, W = shape 113 | size = (batch_size, C, T, H, W) 114 | # print(f'Data shape for DDIM sampling is {size}, eta {eta}') 115 | 116 | samples, intermediates = self.ddim_sampling(conditioning, size, 117 | callback=callback, 118 | img_callback=img_callback, 119 | quantize_denoised=quantize_x0, 120 | mask=mask, x0=x0, 121 | ddim_use_original_steps=False, 122 | noise_dropout=noise_dropout, 123 | temperature=temperature, 124 | score_corrector=score_corrector, 125 | corrector_kwargs=corrector_kwargs, 126 | x_T=x_T, 127 | log_every_t=log_every_t, 128 | unconditional_guidance_scale=unconditional_guidance_scale, 129 | unconditional_conditioning=unconditional_conditioning, 130 | verbose=verbose, 131 | precision=precision, 132 | fs=fs, 133 | guidance_rescale=guidance_rescale, 134 | **kwargs) 135 | return samples, intermediates 136 | 137 | @torch.no_grad() 138 | def ddim_sampling(self, cond, shape, 139 | x_T=None, ddim_use_original_steps=False, 140 | callback=None, timesteps=None, quantize_denoised=False, 141 | mask=None, x0=None, img_callback=None, log_every_t=100, 142 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 143 | unconditional_guidance_scale=1., unconditional_conditioning=None, verbose=True,precision=None,fs=None,guidance_rescale=0.0, 144 | **kwargs): 145 | device = self.model.betas.device 146 | b = shape[0] 147 | if x_T is None: 148 | img = torch.randn(shape, device=device) 149 | else: 150 | img = x_T 151 | if precision is not None: 152 | if precision == 16: 153 | img = img.to(dtype=torch.float16) 154 | 155 | 156 | if timesteps is None: 157 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 158 | elif timesteps is not None and not ddim_use_original_steps: 159 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 160 | timesteps = self.ddim_timesteps[:subset_end] 161 | 162 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 163 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 164 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 165 | if verbose: 166 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 167 | else: 168 | iterator = time_range 169 | 170 | clean_cond = kwargs.pop("clean_cond", False) 171 | 172 | # cond_copy, unconditional_conditioning_copy = copy.deepcopy(cond), copy.deepcopy(unconditional_conditioning) 173 | for i, step in enumerate(iterator): 174 | index = total_steps - i - 1 175 | ts = torch.full((b,), step, device=device, dtype=torch.long) 176 | 177 | ## use mask to blend noised original latent (img_orig) & new sampled latent (img) 178 | if mask is not None: 179 | assert x0 is not None 180 | if clean_cond: 181 | img_orig = x0 182 | else: 183 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 184 | img = img_orig * mask + (1. - mask) * img # keep original & modify use img 185 | 186 | 187 | 188 | 189 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 190 | quantize_denoised=quantize_denoised, temperature=temperature, 191 | noise_dropout=noise_dropout, score_corrector=score_corrector, 192 | corrector_kwargs=corrector_kwargs, 193 | unconditional_guidance_scale=unconditional_guidance_scale, 194 | unconditional_conditioning=unconditional_conditioning, 195 | mask=mask,x0=x0,fs=fs,guidance_rescale=guidance_rescale, 196 | **kwargs) 197 | 198 | 199 | 200 | img, pred_x0 = outs 201 | if callback: callback(i) 202 | if img_callback: img_callback(pred_x0, i) 203 | 204 | if index % log_every_t == 0 or index == total_steps - 1: 205 | intermediates['x_inter'].append(img) 206 | intermediates['pred_x0'].append(pred_x0) 207 | 208 | return img, intermediates 209 | 210 | @torch.no_grad() 211 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 212 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 213 | unconditional_guidance_scale=1., unconditional_conditioning=None, 214 | uc_type=None, cfg_img=None,mask=None,x0=None,guidance_rescale=0.0, **kwargs): 215 | b, *_, device = *x.shape, x.device 216 | if x.dim() == 5: 217 | is_video = True 218 | else: 219 | is_video = False 220 | if cfg_img is None: 221 | cfg_img = unconditional_guidance_scale 222 | 223 | unconditional_conditioning_img_nonetext = kwargs['unconditional_conditioning_img_nonetext'] 224 | 225 | 226 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 227 | model_output = self.model.apply_model(x, t, c, **kwargs) # unet denoiser 228 | else: 229 | ### with unconditional condition 230 | e_t_cond = self.model.apply_model(x, t, c, **kwargs) 231 | e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning, **kwargs) 232 | e_t_uncond_img = self.model.apply_model(x, t, unconditional_conditioning_img_nonetext, **kwargs) 233 | # text cfg 234 | model_output = e_t_uncond + cfg_img * (e_t_uncond_img - e_t_uncond) + unconditional_guidance_scale * (e_t_cond - e_t_uncond_img) 235 | if guidance_rescale > 0.0: 236 | model_output = rescale_noise_cfg(model_output, e_t_cond, guidance_rescale=guidance_rescale) 237 | 238 | if self.model.parameterization == "v": 239 | e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) 240 | else: 241 | e_t = model_output 242 | 243 | if score_corrector is not None: 244 | assert self.model.parameterization == "eps", 'not implemented' 245 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 246 | 247 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 248 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 249 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 250 | sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 251 | # select parameters corresponding to the currently considered timestep 252 | 253 | if is_video: 254 | size = (b, 1, 1, 1, 1) 255 | else: 256 | size = (b, 1, 1, 1) 257 | a_t = torch.full(size, alphas[index], device=device) 258 | a_prev = torch.full(size, alphas_prev[index], device=device) 259 | sigma_t = torch.full(size, sigmas[index], device=device) 260 | sqrt_one_minus_at = torch.full(size, sqrt_one_minus_alphas[index],device=device) 261 | 262 | # current prediction for x_0 263 | if self.model.parameterization != "v": 264 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 265 | else: 266 | pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) 267 | 268 | if self.model.use_dynamic_rescale: 269 | scale_t = torch.full(size, self.ddim_scale_arr[index], device=device) 270 | prev_scale_t = torch.full(size, self.ddim_scale_arr_prev[index], device=device) 271 | rescale = (prev_scale_t / scale_t) 272 | pred_x0 *= rescale 273 | 274 | if quantize_denoised: 275 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 276 | # direction pointing to x_t 277 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 278 | 279 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 280 | if noise_dropout > 0.: 281 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 282 | 283 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 284 | 285 | return x_prev, pred_x0 286 | 287 | @torch.no_grad() 288 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 289 | use_original_steps=False, callback=None): 290 | 291 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 292 | timesteps = timesteps[:t_start] 293 | 294 | time_range = np.flip(timesteps) 295 | total_steps = timesteps.shape[0] 296 | print(f"Running DDIM Sampling with {total_steps} timesteps") 297 | 298 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 299 | x_dec = x_latent 300 | for i, step in enumerate(iterator): 301 | index = total_steps - i - 1 302 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 303 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 304 | unconditional_guidance_scale=unconditional_guidance_scale, 305 | unconditional_conditioning=unconditional_conditioning) 306 | if callback: callback(i) 307 | return x_dec 308 | 309 | @torch.no_grad() 310 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 311 | # fast, but does not allow for exact reconstruction 312 | # t serves as an index to gather the correct alphas 313 | if use_original_steps: 314 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 315 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 316 | else: 317 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 318 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 319 | 320 | if noise is None: 321 | noise = torch.randn_like(x0) 322 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 323 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) -------------------------------------------------------------------------------- /lvdm/models/samplers/unipc/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import UniPCSampler -------------------------------------------------------------------------------- /lvdm/models/samplers/unipc/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC 6 | 7 | class UniPCSampler(object): 8 | def __init__(self, model, **kwargs): 9 | super().__init__() 10 | self.model = model 11 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 12 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 13 | 14 | def register_buffer(self, name, attr): 15 | if type(attr) == torch.Tensor: 16 | if attr.device != torch.device("cuda"): 17 | attr = attr.to(torch.device("cuda")) 18 | setattr(self, name, attr) 19 | 20 | @torch.no_grad() 21 | def sample(self, 22 | S, 23 | batch_size, 24 | shape, 25 | conditioning=None, 26 | callback=None, 27 | normals_sequence=None, 28 | img_callback=None, 29 | quantize_x0=False, 30 | eta=0., 31 | mask=None, 32 | x0=None, 33 | temperature=1., 34 | noise_dropout=0., 35 | score_corrector=None, 36 | corrector_kwargs=None, 37 | verbose=True, 38 | x_T=None, 39 | log_every_t=100, 40 | unconditional_guidance_scale=1., 41 | unconditional_conditioning=None, 42 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 43 | **kwargs 44 | ): 45 | if conditioning is not None: 46 | if isinstance(conditioning, dict): 47 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 48 | if cbs != batch_size: 49 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 50 | else: 51 | if conditioning.shape[0] != batch_size: 52 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 53 | 54 | # sampling 55 | C, F, H, W = shape 56 | size = (batch_size, C, H, W) 57 | 58 | device = self.model.betas.device 59 | if x_T is None: 60 | img = torch.randn(size, device=device) 61 | else: 62 | img = x_T 63 | 64 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 65 | 66 | model_fn = model_wrapper( 67 | lambda x, t, c: self.model.apply_model(x, t, c), 68 | ns, 69 | model_type="noise", 70 | guidance_type="classifier-free", 71 | condition=conditioning, 72 | unconditional_condition=unconditional_conditioning, 73 | guidance_scale=unconditional_guidance_scale, 74 | ) 75 | 76 | uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False) 77 | x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) 78 | 79 | return x.to(device), None -------------------------------------------------------------------------------- /lvdm/models/utils_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import repeat 6 | 7 | 8 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=None): 9 | """ 10 | Create sinusoidal timestep embeddings. 11 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 12 | These may be fractional. 13 | :param dim: the dimension of the output. 14 | :param max_period: controls the minimum frequency of the embeddings. 15 | :return: an [N x dim] Tensor of positional embeddings. 16 | """ 17 | if not repeat_only: 18 | half = dim // 2 19 | freqs = torch.exp( 20 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half 21 | ).to(device=timesteps.device) 22 | args = timesteps[:, None].float() * freqs[None] 23 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 24 | if dim % 2: 25 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 26 | else: 27 | embedding = repeat(timesteps, 'b -> b d', d=dim) 28 | return embedding.to(dtype) 29 | 30 | 31 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 32 | if schedule == "linear": 33 | betas = ( 34 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 35 | ) 36 | 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 40 | ) 41 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 42 | alphas = torch.cos(alphas).pow(2) 43 | alphas = alphas / alphas[0] 44 | betas = 1 - alphas[1:] / alphas[:-1] 45 | betas = np.clip(betas, a_min=0, a_max=0.999) 46 | 47 | elif schedule == "sqrt_linear": 48 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | elif schedule == "sqrt": 50 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 57 | if ddim_discr_method == 'uniform': 58 | c = num_ddpm_timesteps // num_ddim_timesteps 59 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 60 | steps_out = ddim_timesteps + 1 61 | elif ddim_discr_method == 'uniform_trailing': 62 | c = num_ddpm_timesteps / num_ddim_timesteps 63 | ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0, -c))).astype(np.int64) 64 | steps_out = ddim_timesteps - 1 65 | elif ddim_discr_method == 'quad': 66 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 67 | steps_out = ddim_timesteps + 1 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | # steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}') 82 | alphas = alphacums[ddim_timesteps] 83 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 84 | 85 | # according the the formula provided in https://arxiv.org/abs/2010.02502 86 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 87 | if verbose: 88 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 89 | print(f'For the chosen value of eta, which is {eta}, ' 90 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 91 | return sigmas, alphas, alphas_prev 92 | 93 | 94 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 95 | """ 96 | Create a beta schedule that discretizes the given alpha_t_bar function, 97 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 98 | :param num_diffusion_timesteps: the number of betas to produce. 99 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 100 | produces the cumulative product of (1-beta) up to that 101 | part of the diffusion process. 102 | :param max_beta: the maximum beta to use; use values lower than 1 to 103 | prevent singularities. 104 | """ 105 | betas = [] 106 | for i in range(num_diffusion_timesteps): 107 | t1 = i / num_diffusion_timesteps 108 | t2 = (i + 1) / num_diffusion_timesteps 109 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 110 | return np.array(betas) 111 | 112 | def rescale_zero_terminal_snr(betas): 113 | """ 114 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 115 | 116 | Args: 117 | betas (`numpy.ndarray`): 118 | the betas that the scheduler is being initialized with. 119 | 120 | Returns: 121 | `numpy.ndarray`: rescaled betas with zero terminal SNR 122 | """ 123 | # Convert betas to alphas_bar_sqrt 124 | alphas = 1.0 - betas 125 | alphas_cumprod = np.cumprod(alphas, axis=0) 126 | alphas_bar_sqrt = np.sqrt(alphas_cumprod) 127 | 128 | # Store old values. 129 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() 130 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() 131 | 132 | # Shift so the last timestep is zero. 133 | alphas_bar_sqrt -= alphas_bar_sqrt_T 134 | 135 | # Scale so the first timestep is back to the old value. 136 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 137 | 138 | # Convert alphas_bar_sqrt to betas 139 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 140 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 141 | alphas = np.concatenate([alphas_bar[0:1], alphas]) 142 | betas = 1 - alphas 143 | 144 | return betas 145 | 146 | 147 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 148 | """ 149 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 150 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 151 | """ 152 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 153 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 154 | # rescale the results from guidance (fixes overexposure) 155 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 156 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 157 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 158 | return noise_cfg -------------------------------------------------------------------------------- /lvdm/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from functools import partial 6 | from ..common import ( 7 | checkpoint, 8 | exists, 9 | default, 10 | ) 11 | from ..basics import zero_module 12 | import comfy.ops 13 | ops = comfy.ops.disable_weight_init 14 | from comfy import model_management 15 | from comfy.ldm.modules.attention import optimized_attention, optimized_attention_masked 16 | 17 | if model_management.xformers_enabled(): 18 | import xformers 19 | import xformers.ops 20 | XFORMERS_IS_AVAILBLE = True 21 | else: 22 | XFORMERS_IS_AVAILBLE = False 23 | 24 | class RelativePosition(nn.Module): 25 | """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """ 26 | 27 | def __init__(self, num_units, max_relative_position): 28 | super().__init__() 29 | self.num_units = num_units 30 | self.max_relative_position = max_relative_position 31 | self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) 32 | nn.init.xavier_uniform_(self.embeddings_table) 33 | 34 | def forward(self, length_q, length_k): 35 | device = self.embeddings_table.device 36 | range_vec_q = torch.arange(length_q, device=device) 37 | range_vec_k = torch.arange(length_k, device=device) 38 | distance_mat = range_vec_k[None, :] - range_vec_q[:, None] 39 | distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) 40 | final_mat = distance_mat_clipped + self.max_relative_position 41 | final_mat = final_mat.long() 42 | embeddings = self.embeddings_table[final_mat] 43 | return embeddings 44 | 45 | 46 | # TODO Add native Comfy optimized attention. 47 | class CrossAttention(nn.Module): 48 | 49 | def __init__( 50 | self, 51 | query_dim, 52 | context_dim=None, 53 | heads=8, 54 | dim_head=64, 55 | dropout=0., 56 | relative_position=False, 57 | temporal_length=None, 58 | video_length=None, 59 | image_cross_attention=False, 60 | image_cross_attention_scale=1.0, 61 | image_cross_attention_scale_learnable=False, 62 | text_context_len=77, 63 | device=None, 64 | dtype=None, 65 | operations=ops 66 | ): 67 | super().__init__() 68 | inner_dim = dim_head * heads 69 | context_dim = default(context_dim, query_dim) 70 | self.scale = dim_head**-0.5 71 | self.heads = heads 72 | self.dim_head = dim_head 73 | self.to_q = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) 74 | self.to_k = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) 75 | self.to_v = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) 76 | 77 | self.to_out = nn.Sequential( 78 | operations.Linear(inner_dim, query_dim, device=device, dtype=dtype), 79 | nn.Dropout(dropout) 80 | ) 81 | 82 | self.relative_position = relative_position 83 | if self.relative_position: 84 | assert(temporal_length is not None) 85 | self.relative_position_k = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 86 | self.relative_position_v = RelativePosition(num_units=dim_head, max_relative_position=temporal_length) 87 | else: 88 | ## only used for spatial attention, while NOT for temporal attention 89 | if XFORMERS_IS_AVAILBLE and temporal_length is None: 90 | self.forward = self.efficient_forward 91 | else: 92 | self.forward = self.comfy_efficient_forward 93 | 94 | self.video_length = video_length 95 | self.image_cross_attention = image_cross_attention 96 | self.image_cross_attention_scale = image_cross_attention_scale 97 | self.text_context_len = text_context_len 98 | self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable 99 | if self.image_cross_attention: 100 | self.to_k_ip = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) 101 | self.to_v_ip = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) 102 | if image_cross_attention_scale_learnable: 103 | self.register_parameter('alpha', nn.Parameter(torch.tensor(0.)) ) 104 | 105 | def comfy_efficient_forward(self, x, context=None, mask=None, *args, **kwargs): 106 | spatial_self_attn = (context is None) 107 | k_ip, v_ip, out_ip = None, None, None 108 | 109 | h = self.heads 110 | q = self.to_q(x) 111 | context = default(context, x) 112 | 113 | if self.image_cross_attention and not spatial_self_attn: 114 | context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 115 | k = self.to_k(context) 116 | v = self.to_v(context) 117 | k_ip = self.to_k_ip(context_image) 118 | v_ip = self.to_v_ip(context_image) 119 | else: 120 | if not spatial_self_attn: 121 | context = context[:,:self.text_context_len,:] 122 | k = self.to_k(context) 123 | v = self.to_v(context) 124 | 125 | out = optimized_attention(q, k, v, h) 126 | 127 | if exists(mask): 128 | ## feasible for causal attention mask only 129 | out = optimized_attention_masked(q, k, v, h) 130 | 131 | ## for image cross-attention 132 | if k_ip is not None: 133 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) 134 | k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) 135 | sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale 136 | del k_ip 137 | sim_ip = sim_ip.softmax(dim=-1) 138 | out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) 139 | out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) 140 | 141 | if out_ip is not None: 142 | if self.image_cross_attention_scale_learnable: 143 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) 144 | else: 145 | out = out + self.image_cross_attention_scale * out_ip 146 | 147 | return self.to_out(out) 148 | 149 | def forward(self, x, context=None, mask=None): 150 | spatial_self_attn = (context is None) 151 | k_ip, v_ip, out_ip = None, None, None 152 | 153 | h = self.heads 154 | q = self.to_q(x) 155 | context = default(context, x) 156 | 157 | if self.image_cross_attention and not spatial_self_attn: 158 | context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 159 | k = self.to_k(context) 160 | v = self.to_v(context) 161 | k_ip = self.to_k_ip(context_image) 162 | v_ip = self.to_v_ip(context_image) 163 | else: 164 | 165 | # Assumed Spatial Attention (b c h w) 166 | if not spatial_self_attn: 167 | context = context[:,:self.text_context_len,:] 168 | k = self.to_k(context) 169 | v = self.to_v(context) 170 | 171 | 172 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 173 | 174 | sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale 175 | if self.relative_position: 176 | len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1] 177 | k2 = self.relative_position_k(len_q, len_k) 178 | sim2 = einsum('b t d, t s d -> b t s', q, k2) * self.scale # TODO check 179 | sim += sim2 180 | del k 181 | 182 | if exists(mask): 183 | ## feasible for causal attention mask only 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b i j -> (b h) i j', h=h) 186 | sim.masked_fill_(~(mask>0.5), max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | sim = sim.softmax(dim=-1) 190 | 191 | out = torch.einsum('b i j, b j d -> b i d', sim, v) 192 | if self.relative_position: 193 | v2 = self.relative_position_v(len_q, len_v) 194 | out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check 195 | out += out2 196 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 197 | 198 | 199 | ## for image cross-attention 200 | if k_ip is not None: 201 | k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_ip, v_ip)) 202 | sim_ip = torch.einsum('b i d, b j d -> b i j', q, k_ip) * self.scale 203 | del k_ip 204 | sim_ip = sim_ip.softmax(dim=-1) 205 | out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) 206 | out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) 207 | 208 | 209 | if out_ip is not None: 210 | if self.image_cross_attention_scale_learnable: 211 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) 212 | else: 213 | out = out + self.image_cross_attention_scale * out_ip 214 | 215 | return self.to_out(out) 216 | 217 | def efficient_forward(self, x, context=None, mask=None): 218 | spatial_self_attn = (context is None) 219 | k_ip, v_ip, out_ip = None, None, None 220 | 221 | q = self.to_q(x) 222 | context = default(context, x) 223 | 224 | if self.image_cross_attention and not spatial_self_attn: 225 | context, context_image = context[:,:self.text_context_len,:], context[:,self.text_context_len:,:] 226 | k = self.to_k(context) 227 | v = self.to_v(context) 228 | k_ip = self.to_k_ip(context_image) 229 | v_ip = self.to_v_ip(context_image) 230 | else: 231 | if not spatial_self_attn: 232 | context = context[:,:self.text_context_len,:] 233 | k = self.to_k(context) 234 | v = self.to_v(context) 235 | 236 | b, _, _ = q.shape 237 | q, k, v = map( 238 | lambda t: t.unsqueeze(3) 239 | .reshape(b, t.shape[1], self.heads, self.dim_head) 240 | .permute(0, 2, 1, 3) 241 | .reshape(b * self.heads, t.shape[1], self.dim_head) 242 | .contiguous(), 243 | (q, k, v), 244 | ) 245 | # actually compute the attention, what we cannot get enough of 246 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) 247 | 248 | ## for image cross-attention 249 | if k_ip is not None: 250 | k_ip, v_ip = map( 251 | lambda t: t.unsqueeze(3) 252 | .reshape(b, t.shape[1], self.heads, self.dim_head) 253 | .permute(0, 2, 1, 3) 254 | .reshape(b * self.heads, t.shape[1], self.dim_head) 255 | .contiguous(), 256 | (k_ip, v_ip), 257 | ) 258 | out_ip = xformers.ops.memory_efficient_attention(q, k_ip, v_ip, attn_bias=None, op=None) 259 | out_ip = ( 260 | out_ip.unsqueeze(0) 261 | .reshape(b, self.heads, out.shape[1], self.dim_head) 262 | .permute(0, 2, 1, 3) 263 | .reshape(b, out.shape[1], self.heads * self.dim_head) 264 | ) 265 | 266 | if exists(mask): 267 | raise NotImplementedError 268 | out = ( 269 | out.unsqueeze(0) 270 | .reshape(b, self.heads, out.shape[1], self.dim_head) 271 | .permute(0, 2, 1, 3) 272 | .reshape(b, out.shape[1], self.heads * self.dim_head) 273 | ) 274 | if out_ip is not None: 275 | if self.image_cross_attention_scale_learnable: 276 | out = out + self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha)+1) 277 | else: 278 | out = out + self.image_cross_attention_scale * out_ip 279 | 280 | return self.to_out(out) 281 | 282 | 283 | class BasicTransformerBlock(nn.Module): 284 | 285 | def __init__( 286 | self, 287 | dim, 288 | n_heads, 289 | d_head, 290 | dropout=0., 291 | context_dim=None, 292 | gated_ff=True, 293 | checkpoint=True, 294 | disable_self_attn=False, 295 | attention_cls=None, 296 | video_length=None, 297 | inner_dim=None, 298 | image_cross_attention=False, 299 | image_cross_attention_scale=1.0, 300 | image_cross_attention_scale_learnable=False, 301 | switch_temporal_ca_to_sa=False, 302 | text_context_len=77, 303 | ff_in=None, 304 | device=None, 305 | dtype=None, 306 | operations=ops 307 | ): 308 | super().__init__() 309 | attn_cls = CrossAttention if attention_cls is None else attention_cls 310 | 311 | self.ff_in = ff_in or inner_dim is not None 312 | if self.ff_in: 313 | self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) 314 | self.ff_in = FeedForward( 315 | dim, 316 | dim_out=inner_dim, 317 | dropout=dropout, 318 | glu=gated_ff, 319 | dtype=dtype, 320 | device=device, 321 | operations=operations 322 | ) 323 | if inner_dim is None: 324 | inner_dim = dim 325 | 326 | self.is_res = inner_dim == dim 327 | self.disable_self_attn = disable_self_attn 328 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 329 | context_dim=None, device=device, dtype=dtype if self.disable_self_attn else None) 330 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, device=device, dtype=dtype) 331 | self.attn2 = attn_cls( 332 | query_dim=dim, 333 | context_dim=context_dim, 334 | heads=n_heads, 335 | dim_head=d_head, 336 | dropout=dropout, 337 | video_length=video_length, 338 | image_cross_attention=image_cross_attention, 339 | image_cross_attention_scale=image_cross_attention_scale, 340 | image_cross_attention_scale_learnable=image_cross_attention_scale_learnable, 341 | text_context_len=text_context_len, 342 | device=device, 343 | dtype=dtype 344 | ) 345 | self.image_cross_attention = image_cross_attention 346 | 347 | self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) 348 | self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) 349 | self.norm3 = operations.LayerNorm(dim, device=device, dtype=dtype) 350 | 351 | self.n_heads = n_heads 352 | self.d_head = d_head 353 | self.checkpoint = checkpoint 354 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 355 | 356 | def forward(self, x, context=None, mask=None, **kwargs): 357 | ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments 358 | input_tuple = (x,) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments 359 | if context is not None: 360 | input_tuple = (x, context) 361 | if mask is not None: 362 | forward_mask = partial(self._forward, mask=mask) 363 | return checkpoint(forward_mask, (x,), self.parameters(), self.checkpoint) 364 | return checkpoint(self._forward, input_tuple, self.parameters(), self.checkpoint) 365 | 366 | 367 | def _forward(self, x, context=None, mask=None, transformer_options={}): 368 | extra_options = {} 369 | block = transformer_options.get("block", None) 370 | block_index = transformer_options.get("block_index", 0) 371 | transformer_patches = {} 372 | transformer_patches_replace = {} 373 | 374 | for k in transformer_options: 375 | if k == "patches": 376 | transformer_patches = transformer_options[k] 377 | elif k == "patches_replace": 378 | transformer_patches_replace = transformer_options[k] 379 | else: 380 | extra_options[k] = transformer_options[k] 381 | 382 | extra_options["n_heads"] = self.n_heads 383 | extra_options["dim_head"] = self.d_head 384 | 385 | if self.ff_in: 386 | x_skip = x 387 | x = self.ff_in(self.norm_in(x)) 388 | if self.is_res: 389 | x += x_skip 390 | 391 | n = self.norm1(x) 392 | if self.disable_self_attn: 393 | context_attn1 = context 394 | else: 395 | context_attn1 = None 396 | value_attn1 = None 397 | 398 | if "attn1_patch" in transformer_patches: 399 | patch = transformer_patches["attn1_patch"] 400 | if context_attn1 is None: 401 | context_attn1 = n 402 | value_attn1 = context_attn1 403 | for p in patch: 404 | n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) 405 | 406 | if block is not None: 407 | transformer_block = (block[0], block[1], block_index) 408 | else: 409 | transformer_block = None 410 | attn1_replace_patch = transformer_patches_replace.get("attn1", {}) 411 | block_attn1 = transformer_block 412 | if block_attn1 not in attn1_replace_patch: 413 | block_attn1 = block 414 | 415 | if block_attn1 in attn1_replace_patch: 416 | if context_attn1 is None: 417 | context_attn1 = n 418 | value_attn1 = n 419 | n = self.attn1.to_q(n) 420 | context_attn1 = self.attn1.to_k(context_attn1) 421 | value_attn1 = self.attn1.to_v(value_attn1) 422 | n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) 423 | n = self.attn1.to_out(n) 424 | else: 425 | n = self.attn1(n, context=context_attn1, value=value_attn1) 426 | 427 | if "attn1_output_patch" in transformer_patches: 428 | patch = transformer_patches["attn1_output_patch"] 429 | for p in patch: 430 | n = p(n, extra_options) 431 | 432 | x += n 433 | if "middle_patch" in transformer_patches: 434 | patch = transformer_patches["middle_patch"] 435 | for p in patch: 436 | x = p(x, extra_options) 437 | 438 | if self.attn2 is not None: 439 | n = self.norm2(x) 440 | if self.switch_temporal_ca_to_sa: 441 | context_attn2 = n 442 | else: 443 | context_attn2 = context 444 | value_attn2 = None 445 | if "attn2_patch" in transformer_patches: 446 | patch = transformer_patches["attn2_patch"] 447 | value_attn2 = context_attn2 448 | for p in patch: 449 | n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) 450 | 451 | attn2_replace_patch = transformer_patches_replace.get("attn2", {}) 452 | block_attn2 = transformer_block 453 | if block_attn2 not in attn2_replace_patch: 454 | block_attn2 = block 455 | 456 | if block_attn2 in attn2_replace_patch: 457 | if value_attn2 is None: 458 | value_attn2 = context_attn2 459 | n = self.attn2.to_q(n) 460 | context_attn2 = self.attn2.to_k(context_attn2) 461 | value_attn2 = self.attn2.to_v(value_attn2) 462 | n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) 463 | n = self.attn2.to_out(n) 464 | else: 465 | n = self.attn2(n, context=context_attn2, value=value_attn2) 466 | 467 | if "attn2_output_patch" in transformer_patches: 468 | patch = transformer_patches["attn2_output_patch"] 469 | for p in patch: 470 | n = p(n, extra_options) 471 | 472 | x += n 473 | if self.is_res: 474 | x_skip = x 475 | x = self.ff(self.norm3(x)) 476 | if self.is_res: 477 | x += x_skip 478 | 479 | return x 480 | 481 | 482 | class SpatialTransformer(nn.Module): 483 | """ 484 | Transformer block for image-like data in spatial axis. 485 | First, project the input (aka embedding) 486 | and reshape to b, t, d. 487 | Then apply standard transformer action. 488 | Finally, reshape to image 489 | NEW: use_linear for more efficiency instead of the 1x1 convs 490 | """ 491 | 492 | def __init__( 493 | self, 494 | in_channels, 495 | n_heads, 496 | d_head, 497 | depth=1, 498 | dropout=0., 499 | context_dim=None, 500 | use_checkpoint=True, 501 | disable_self_attn=False, 502 | use_linear=False, 503 | video_length=None, 504 | image_cross_attention=False, 505 | image_cross_attention_scale_learnable=False, 506 | device=None, 507 | dtype=None, 508 | operations=ops 509 | ): 510 | super().__init__() 511 | self.in_channels = in_channels 512 | inner_dim = n_heads * d_head 513 | self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, device=device, dtype=dtype) 514 | if not use_linear: 515 | self.proj_in = opeations.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype) 516 | else: 517 | self.proj_in = operations.Linear(in_channels, inner_dim, device=device, dtype=dtype) 518 | 519 | attention_cls = None 520 | self.transformer_blocks = nn.ModuleList([ 521 | BasicTransformerBlock( 522 | inner_dim, 523 | n_heads, 524 | d_head, 525 | dropout=dropout, 526 | context_dim=context_dim, 527 | disable_self_attn=disable_self_attn, 528 | checkpoint=use_checkpoint, 529 | attention_cls=attention_cls, 530 | video_length=video_length, 531 | image_cross_attention=image_cross_attention, 532 | image_cross_attention_scale_learnable=image_cross_attention_scale_learnable, 533 | device=device, 534 | dtype=dtype 535 | ) for d in range(depth) 536 | ]) 537 | if not use_linear: 538 | self.proj_out = zero_module(operations.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)) 539 | else: 540 | self.proj_out = zero_module(operations.Linear(inner_dim, in_channels, device=device, dtype=dtype)) 541 | self.use_linear = use_linear 542 | 543 | def forward(self, x, context=None, transformer_options={}, **kwargs): 544 | b, c, h, w = x.shape 545 | x_in = x 546 | x = self.norm(x) 547 | if not self.use_linear: 548 | x = self.proj_in(x) 549 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 550 | if self.use_linear: 551 | x = self.proj_in(x) 552 | for i, block in enumerate(self.transformer_blocks): 553 | transformer_options['block_index'] = i 554 | x = block(x, context=context, **kwargs) 555 | if self.use_linear: 556 | x = self.proj_out(x) 557 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 558 | if not self.use_linear: 559 | x = self.proj_out(x) 560 | return x + x_in 561 | 562 | 563 | class TemporalTransformer(nn.Module): 564 | """ 565 | Transformer block for image-like data in temporal axis. 566 | First, reshape to b, t, d. 567 | Then apply standard transformer action. 568 | Finally, reshape to image 569 | """ 570 | def __init__( 571 | self, 572 | in_channels, 573 | n_heads, 574 | d_head, 575 | depth=1, 576 | dropout=0., 577 | context_dim=None, 578 | use_checkpoint=True, 579 | use_linear=False, 580 | only_self_att=True, 581 | causal_attention=False, 582 | causal_block_size=1, 583 | relative_position=False, 584 | temporal_length=None, 585 | device=None, 586 | dtype=None, 587 | operations=ops 588 | ): 589 | super().__init__() 590 | self.only_self_att = only_self_att 591 | self.relative_position = relative_position 592 | self.causal_attention = causal_attention 593 | self.causal_block_size = causal_block_size 594 | 595 | if only_self_att: 596 | context_dim = None 597 | 598 | self.in_channels = in_channels 599 | inner_dim = n_heads * d_head 600 | self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, device=device, dtype=dtype) 601 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0).to(device, dtype) 602 | if not use_linear: 603 | self.proj_in = nn.Conv1d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0).to(device, dtype) 604 | else: 605 | self.proj_in = operations.Linear(in_channels, inner_dim, device=device, dtype=dtype) 606 | 607 | if relative_position: 608 | assert(temporal_length is not None) 609 | attention_cls = partial(CrossAttention, relative_position=True, temporal_length=temporal_length, device=device, dtype=dtype) 610 | else: 611 | attention_cls = partial(CrossAttention, temporal_length=temporal_length, device=device, dtype=dtype) 612 | if self.causal_attention: 613 | assert(temporal_length is not None) 614 | self.mask = torch.tril(torch.ones([1, temporal_length, temporal_length])) 615 | 616 | if self.only_self_att: 617 | context_dim = None 618 | self.transformer_blocks = nn.ModuleList([ 619 | BasicTransformerBlock( 620 | inner_dim, 621 | n_heads, 622 | d_head, 623 | dropout=dropout, 624 | context_dim=context_dim, 625 | attention_cls=attention_cls, 626 | checkpoint=use_checkpoint, 627 | device=device, 628 | dtype=dtype 629 | ) for d in range(depth) 630 | ]) 631 | if not use_linear: 632 | self.proj_out = zero_module(nn.Conv1d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0).to(device, dtype)) 633 | else: 634 | self.proj_out = zero_module(operations.Linear(inner_dim, in_channels, device=device, dtype=dtype)) 635 | self.use_linear = use_linear 636 | 637 | def forward(self, x, context=None): 638 | b, c, t, h, w = x.shape 639 | x_in = x 640 | x = self.norm(x) 641 | x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() 642 | if not self.use_linear: 643 | x = self.proj_in(x) 644 | x = rearrange(x, 'bhw c t -> bhw t c').contiguous() 645 | if self.use_linear: 646 | x = self.proj_in(x) 647 | 648 | temp_mask = None 649 | if self.causal_attention: 650 | # slice the from mask map 651 | temp_mask = self.mask[:,:t,:t].to(x.device) 652 | 653 | if temp_mask is not None: 654 | mask = temp_mask.to(x.device) 655 | mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) 656 | else: 657 | mask = None 658 | 659 | if self.only_self_att: 660 | ## note: if no context is given, cross-attention defaults to self-attention 661 | for i, block in enumerate(self.transformer_blocks): 662 | x = block(x, mask=mask) 663 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 664 | else: 665 | x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() 666 | context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() 667 | for i, block in enumerate(self.transformer_blocks): 668 | # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) 669 | for j in range(b): 670 | context_j = repeat( 671 | context[j], 672 | 't l con -> (t r) l con', r=(h * w) // t, t=t).contiguous() 673 | ## note: causal mask will not applied in cross-attention case 674 | x[j] = block(x[j], context=context_j) 675 | 676 | if self.use_linear: 677 | x = self.proj_out(x) 678 | x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() 679 | if not self.use_linear: 680 | x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() 681 | x = self.proj_out(x) 682 | x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() 683 | 684 | return x + x_in 685 | 686 | 687 | class GEGLU(nn.Module): 688 | def __init__(self, dim_in, dim_out, device=None, dtype=None, operations=ops): 689 | super().__init__() 690 | self.proj = operations.Linear(dim_in, dim_out * 2, device=device, dtype=dtype) 691 | 692 | def forward(self, x): 693 | x, gate = self.proj(x).chunk(2, dim=-1) 694 | return x * F.gelu(gate) 695 | 696 | 697 | class FeedForward(nn.Module): 698 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., device=None, dtype=None, operations=ops): 699 | super().__init__() 700 | inner_dim = int(dim * mult) 701 | dim_out = default(dim_out, dim) 702 | project_in = nn.Sequential( 703 | operations.Linear(dim, inner_dim, device=device, dtype=dtype), 704 | nn.GELU() 705 | ) if not glu else GEGLU(dim, inner_dim) 706 | 707 | self.net = nn.Sequential( 708 | project_in, 709 | nn.Dropout(dropout), 710 | operations.Linear(inner_dim, dim_out, device=device, dtype=dtype) 711 | ) 712 | 713 | def forward(self, x): 714 | return self.net(x) 715 | 716 | 717 | class LinearAttention(nn.Module): 718 | def __init__(self, dim, heads=4, dim_head=32, device=None, dtype=None, operations=ops): 719 | super().__init__() 720 | self.heads = heads 721 | hidden_dim = dim_head * heads 722 | self.to_qkv = operations.Conv2d(dim, hidden_dim * 3, 1, bias = False, device=device, dtype=dtype) 723 | self.to_out = operations.Conv2d(hidden_dim, dim, 1, device=device, dtype=dtype) 724 | 725 | def forward(self, x): 726 | b, c, h, w = x.shape 727 | qkv = self.to_qkv(x) 728 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 729 | k = k.softmax(dim=-1) 730 | context = torch.einsum('bhdn,bhen->bhde', k, v) 731 | out = torch.einsum('bhde,bhdn->bhen', context, q) 732 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 733 | return self.to_out(out) 734 | 735 | 736 | class SpatialSelfAttention(nn.Module): 737 | def __init__(self, in_channels, device=None, dtype=None, operations=ops): 738 | super().__init__() 739 | self.in_channels = in_channels 740 | 741 | self.norm = operations.GroupNorm( 742 | num_groups=32, 743 | num_channels=in_channels, 744 | eps=1e-6, 745 | affine=True, 746 | device=device, 747 | dtype=dtype 748 | ) 749 | self.q = operations.Conv2d( 750 | in_channels, 751 | in_channels, 752 | kernel_size=1, 753 | stride=1, 754 | padding=0, 755 | device=device, 756 | dtype=dtype 757 | ) 758 | self.k = operations.Conv2d( 759 | in_channels, 760 | in_channels, 761 | kernel_size=1, 762 | stride=1, 763 | padding=0, 764 | device=device, 765 | dtype=dtype 766 | ) 767 | self.v = operations.Conv2d( 768 | in_channels, 769 | in_channels, 770 | kernel_size=1, 771 | stride=1, 772 | padding=0, 773 | device=device, 774 | dtype=dtype 775 | ) 776 | self.proj_out = operations.Conv2d( 777 | in_channels, 778 | in_channels, 779 | kernel_size=1, 780 | stride=1, 781 | padding=0, 782 | device=device, 783 | dtype=dtype 784 | ) 785 | 786 | def forward(self, x): 787 | h_ = x 788 | h_ = self.norm(h_) 789 | q = self.q(h_) 790 | k = self.k(h_) 791 | v = self.v(h_) 792 | 793 | # compute attention 794 | b,c,h,w = q.shape 795 | q = rearrange(q, 'b c h w -> b (h w) c') 796 | k = rearrange(k, 'b c h w -> b c (h w)') 797 | w_ = torch.einsum('bij,bjk->bik', q, k) 798 | 799 | w_ = w_ * (int(c)**(-0.5)) 800 | w_ = torch.nn.functional.softmax(w_, dim=2) 801 | 802 | # attend to values 803 | v = rearrange(v, 'b c h w -> b c (h w)') 804 | w_ = rearrange(w_, 'b i j -> b j i') 805 | h_ = torch.einsum('bij,bjk->bik', v, w_) 806 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 807 | h_ = self.proj_out(h_) 808 | 809 | return x+h_ 810 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/condition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | import open_clip 5 | from torch.utils.checkpoint import checkpoint 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | from ..common import autocast 8 | from utils.utils import count_params 9 | 10 | 11 | class AbstractEncoder(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def encode(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | 19 | class IdentityEncoder(AbstractEncoder): 20 | def encode(self, x): 21 | return x 22 | 23 | 24 | class ClassEmbedder(nn.Module): 25 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 26 | super().__init__() 27 | self.key = key 28 | self.embedding = nn.Embedding(n_classes, embed_dim) 29 | self.n_classes = n_classes 30 | self.ucg_rate = ucg_rate 31 | 32 | def forward(self, batch, key=None, disable_dropout=False): 33 | if key is None: 34 | key = self.key 35 | # this is for use in crossattn 36 | c = batch[key][:, None] 37 | if self.ucg_rate > 0. and not disable_dropout: 38 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 39 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 40 | c = c.long() 41 | c = self.embedding(c) 42 | return c 43 | 44 | def get_unconditional_conditioning(self, bs, device="cuda"): 45 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 46 | uc = torch.ones((bs,), device=device) * uc_class 47 | uc = {self.key: uc} 48 | return uc 49 | 50 | 51 | def disabled_train(self, mode=True): 52 | """Overwrite model.train with this function to make sure train/eval mode 53 | does not change anymore.""" 54 | return self 55 | 56 | 57 | class FrozenT5Embedder(AbstractEncoder): 58 | """Uses the T5 transformer encoder for text""" 59 | 60 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 61 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 62 | super().__init__() 63 | self.tokenizer = T5Tokenizer.from_pretrained(version) 64 | self.transformer = T5EncoderModel.from_pretrained(version) 65 | self.device = device 66 | self.max_length = max_length # TODO: typical value? 67 | if freeze: 68 | self.freeze() 69 | 70 | def freeze(self): 71 | self.transformer = self.transformer.eval() 72 | # self.train = disabled_train 73 | for param in self.parameters(): 74 | param.requires_grad = False 75 | 76 | def forward(self, text): 77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 79 | tokens = batch_encoding["input_ids"].to(self.device) 80 | outputs = self.transformer(input_ids=tokens) 81 | 82 | z = outputs.last_hidden_state 83 | return z 84 | 85 | def encode(self, text): 86 | return self(text) 87 | 88 | 89 | class FrozenCLIPEmbedder(AbstractEncoder): 90 | """Uses the CLIP transformer encoder for text (from huggingface)""" 91 | LAYERS = [ 92 | "last", 93 | "pooled", 94 | "hidden" 95 | ] 96 | 97 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 98 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 99 | super().__init__() 100 | assert layer in self.LAYERS 101 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 102 | self.transformer = CLIPTextModel.from_pretrained(version) 103 | self.device = device 104 | self.max_length = max_length 105 | if freeze: 106 | self.freeze() 107 | self.layer = layer 108 | self.layer_idx = layer_idx 109 | if layer == "hidden": 110 | assert layer_idx is not None 111 | assert 0 <= abs(layer_idx) <= 12 112 | 113 | def freeze(self): 114 | self.transformer = self.transformer.eval() 115 | # self.train = disabled_train 116 | for param in self.parameters(): 117 | param.requires_grad = False 118 | 119 | def forward(self, text): 120 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 121 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 122 | tokens = batch_encoding["input_ids"].to(self.device) 123 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 124 | if self.layer == "last": 125 | z = outputs.last_hidden_state 126 | elif self.layer == "pooled": 127 | z = outputs.pooler_output[:, None, :] 128 | else: 129 | z = outputs.hidden_states[self.layer_idx] 130 | return z 131 | 132 | def encode(self, text): 133 | return self(text) 134 | 135 | 136 | class ClipImageEmbedder(nn.Module): 137 | def __init__( 138 | self, 139 | model, 140 | jit=False, 141 | device='cuda' if torch.cuda.is_available() else 'cpu', 142 | antialias=True, 143 | ucg_rate=0. 144 | ): 145 | super().__init__() 146 | from clip import load as load_clip 147 | self.model, _ = load_clip(name=model, device=device, jit=jit) 148 | 149 | self.antialias = antialias 150 | 151 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 152 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 153 | self.ucg_rate = ucg_rate 154 | 155 | def preprocess(self, x): 156 | # normalize to [0,1] 157 | x = kornia.geometry.resize(x, (224, 224), 158 | interpolation='bicubic', align_corners=True, 159 | antialias=self.antialias) 160 | x = (x + 1.) / 2. 161 | # re-normalize according to clip 162 | x = kornia.enhance.normalize(x, self.mean, self.std) 163 | return x 164 | 165 | def forward(self, x, no_dropout=False): 166 | # x is assumed to be in range [-1,1] 167 | out = self.model.encode_image(self.preprocess(x)) 168 | out = out.to(x.dtype) 169 | if self.ucg_rate > 0. and not no_dropout: 170 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 171 | return out 172 | 173 | 174 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 175 | """ 176 | Uses the OpenCLIP transformer encoder for text 177 | """ 178 | LAYERS = [ 179 | # "pooled", 180 | "last", 181 | "penultimate" 182 | ] 183 | 184 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 185 | freeze=True, layer="last"): 186 | super().__init__() 187 | assert layer in self.LAYERS 188 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 189 | del model.visual 190 | self.model = model 191 | 192 | self.device = device 193 | self.max_length = max_length 194 | if freeze: 195 | self.freeze() 196 | self.layer = layer 197 | if self.layer == "last": 198 | self.layer_idx = 0 199 | elif self.layer == "penultimate": 200 | self.layer_idx = 1 201 | else: 202 | raise NotImplementedError() 203 | 204 | def freeze(self): 205 | self.model = self.model.eval() 206 | for param in self.parameters(): 207 | param.requires_grad = False 208 | 209 | def forward(self, text): 210 | tokens = open_clip.tokenize(text) ## all clip models use 77 as context length 211 | z = self.encode_with_transformer(tokens.to(self.device)) 212 | return z 213 | 214 | def encode_with_transformer(self, text): 215 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 216 | x = x + self.model.positional_embedding 217 | x = x.permute(1, 0, 2) # NLD -> LND 218 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 219 | x = x.permute(1, 0, 2) # LND -> NLD 220 | x = self.model.ln_final(x) 221 | return x 222 | 223 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 224 | for i, r in enumerate(self.model.transformer.resblocks): 225 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 226 | break 227 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 228 | x = checkpoint(r, x, attn_mask) 229 | else: 230 | x = r(x, attn_mask=attn_mask) 231 | return x 232 | 233 | def encode(self, text): 234 | return self(text) 235 | 236 | 237 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 238 | """ 239 | Uses the OpenCLIP vision transformer encoder for images 240 | """ 241 | 242 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 243 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 244 | super().__init__() 245 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 246 | pretrained=version, ) 247 | del model.transformer 248 | self.model = model 249 | # self.mapper = torch.nn.Linear(1280, 1024) 250 | self.device = device 251 | self.max_length = max_length 252 | if freeze: 253 | self.freeze() 254 | self.layer = layer 255 | if self.layer == "penultimate": 256 | raise NotImplementedError() 257 | self.layer_idx = 1 258 | 259 | self.antialias = antialias 260 | 261 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 262 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 263 | self.ucg_rate = ucg_rate 264 | 265 | def preprocess(self, x): 266 | # normalize to [0,1] 267 | x = kornia.geometry.resize(x, (224, 224), 268 | interpolation='bicubic', align_corners=True, 269 | antialias=self.antialias) 270 | x = (x + 1.) / 2. 271 | # renormalize according to clip 272 | x = kornia.enhance.normalize(x, self.mean, self.std) 273 | return x 274 | 275 | def freeze(self): 276 | self.model = self.model.eval() 277 | for param in self.model.parameters(): 278 | param.requires_grad = False 279 | 280 | @autocast 281 | def forward(self, image, no_dropout=False): 282 | z = self.encode_with_vision_transformer(image) 283 | if self.ucg_rate > 0. and not no_dropout: 284 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 285 | return z 286 | 287 | def encode_with_vision_transformer(self, img): 288 | img = self.preprocess(img) 289 | x = self.model.visual(img) 290 | return x 291 | 292 | def encode(self, text): 293 | return self(text) 294 | 295 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 296 | """ 297 | Uses the OpenCLIP vision transformer encoder for images 298 | """ 299 | 300 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 301 | freeze=True, layer="pooled", antialias=True): 302 | super().__init__() 303 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 304 | pretrained=version, ) 305 | del model.transformer 306 | self.model = model 307 | self.device = device 308 | 309 | if freeze: 310 | self.freeze() 311 | self.layer = layer 312 | if self.layer == "penultimate": 313 | raise NotImplementedError() 314 | self.layer_idx = 1 315 | 316 | self.antialias = antialias 317 | 318 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 319 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 320 | 321 | 322 | def preprocess(self, x): 323 | # normalize to [0,1] 324 | x = kornia.geometry.resize(x, (224, 224), 325 | interpolation='bicubic', align_corners=True, 326 | antialias=self.antialias) 327 | x = (x + 1.) / 2. 328 | # renormalize according to clip 329 | x = kornia.enhance.normalize(x, self.mean, self.std) 330 | return x 331 | 332 | def freeze(self): 333 | self.model = self.model.eval() 334 | for param in self.model.parameters(): 335 | param.requires_grad = False 336 | 337 | def forward(self, image, no_dropout=False): 338 | ## image: b c h w 339 | z = self.encode_with_vision_transformer(image) 340 | return z 341 | 342 | def encode_with_vision_transformer(self, x): 343 | x = self.preprocess(x) 344 | 345 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 346 | if self.model.visual.input_patchnorm: 347 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 348 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 349 | x = x.permute(0, 2, 4, 1, 3, 5) 350 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 351 | x = self.model.visual.patchnorm_pre_ln(x) 352 | x = self.model.visual.conv1(x) 353 | else: 354 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 355 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 356 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 357 | 358 | # class embeddings and positional embeddings 359 | x = torch.cat( 360 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 361 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 362 | x = x + self.model.visual.positional_embedding.to(x.dtype) 363 | 364 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 365 | x = self.model.visual.patch_dropout(x) 366 | x = self.model.visual.ln_pre(x) 367 | 368 | x = x.permute(1, 0, 2) # NLD -> LND 369 | x = self.model.visual.transformer(x) 370 | x = x.permute(1, 0, 2) # LND -> NLD 371 | 372 | return x 373 | 374 | class FrozenCLIPT5Encoder(AbstractEncoder): 375 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 376 | clip_max_length=77, t5_max_length=77): 377 | super().__init__() 378 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 379 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 380 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 381 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 382 | 383 | def encode(self, text): 384 | return self(text) 385 | 386 | def forward(self, text): 387 | clip_z = self.clip_encoder.encode(text) 388 | t5_z = self.t5_encoder.encode(text) 389 | return [clip_z, t5_z] 390 | -------------------------------------------------------------------------------- /lvdm/modules/encoders/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | # and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class ImageProjModel(nn.Module): 10 | """Projection Model""" 11 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 12 | super().__init__() 13 | self.cross_attention_dim = cross_attention_dim 14 | self.clip_extra_context_tokens = clip_extra_context_tokens 15 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 16 | self.norm = nn.LayerNorm(cross_attention_dim) 17 | 18 | def forward(self, image_embeds): 19 | #embeds = image_embeds 20 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 21 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 22 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 23 | return clip_extra_context_tokens 24 | 25 | 26 | # FFN 27 | def FeedForward(dim, mult=4): 28 | inner_dim = int(dim * mult) 29 | return nn.Sequential( 30 | nn.LayerNorm(dim), 31 | nn.Linear(dim, inner_dim, bias=False), 32 | nn.GELU(), 33 | nn.Linear(inner_dim, dim, bias=False), 34 | ) 35 | 36 | 37 | def reshape_tensor(x, heads): 38 | bs, length, width = x.shape 39 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 40 | x = x.view(bs, length, heads, -1) 41 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 42 | x = x.transpose(1, 2) 43 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 44 | x = x.reshape(bs, heads, length, -1) 45 | return x 46 | 47 | 48 | class PerceiverAttention(nn.Module): 49 | def __init__(self, *, dim, dim_head=64, heads=8): 50 | super().__init__() 51 | self.scale = dim_head**-0.5 52 | self.dim_head = dim_head 53 | self.heads = heads 54 | inner_dim = dim_head * heads 55 | 56 | self.norm1 = nn.LayerNorm(dim) 57 | self.norm2 = nn.LayerNorm(dim) 58 | 59 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 60 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 61 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 62 | 63 | 64 | def forward(self, x, latents): 65 | """ 66 | Args: 67 | x (torch.Tensor): image features 68 | shape (b, n1, D) 69 | latent (torch.Tensor): latent features 70 | shape (b, n2, D) 71 | """ 72 | x = self.norm1(x) 73 | latents = self.norm2(latents) 74 | 75 | b, l, _ = latents.shape 76 | 77 | q = self.to_q(latents) 78 | kv_input = torch.cat((x, latents), dim=-2) 79 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 80 | 81 | q = reshape_tensor(q, self.heads) 82 | k = reshape_tensor(k, self.heads) 83 | v = reshape_tensor(v, self.heads) 84 | 85 | # attention 86 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 87 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 88 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 89 | out = weight @ v 90 | 91 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 92 | 93 | return self.to_out(out) 94 | 95 | 96 | class Resampler(nn.Module): 97 | def __init__( 98 | self, 99 | dim=1024, 100 | depth=8, 101 | dim_head=64, 102 | heads=16, 103 | num_queries=8, 104 | embedding_dim=768, 105 | output_dim=1024, 106 | ff_mult=4, 107 | video_length=None, # using frame-wise version or not 108 | ): 109 | super().__init__() 110 | ## queries for a single frame / image 111 | self.num_queries = num_queries 112 | self.video_length = video_length 113 | 114 | ## queries for each frame 115 | if video_length is not None: 116 | num_queries = num_queries * video_length 117 | 118 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 119 | self.proj_in = nn.Linear(embedding_dim, dim) 120 | self.proj_out = nn.Linear(dim, output_dim) 121 | self.norm_out = nn.LayerNorm(output_dim) 122 | 123 | self.layers = nn.ModuleList([]) 124 | for _ in range(depth): 125 | self.layers.append( 126 | nn.ModuleList( 127 | [ 128 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 129 | FeedForward(dim=dim, mult=ff_mult), 130 | ] 131 | ) 132 | ) 133 | 134 | def forward(self, x): 135 | latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C 136 | x = self.proj_in(x) 137 | 138 | for attn, ff in self.layers: 139 | latents = attn(x, latents) + latents 140 | latents = ff(latents) + latents 141 | 142 | latents = self.proj_out(latents) 143 | latents = self.norm_out(latents) # B L C or B (T L) C 144 | 145 | return latents -------------------------------------------------------------------------------- /lvdm/modules/x_transformer.py: -------------------------------------------------------------------------------- 1 | """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" 2 | from functools import partial 3 | from inspect import isfunction 4 | from collections import namedtuple 5 | from einops import rearrange, repeat 6 | import torch 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | # constants 11 | DEFAULT_DIM_HEAD = 64 12 | 13 | Intermediates = namedtuple('Intermediates', [ 14 | 'pre_softmax_attn', 15 | 'post_softmax_attn' 16 | ]) 17 | 18 | LayerIntermediates = namedtuple('Intermediates', [ 19 | 'hiddens', 20 | 'attn_intermediates' 21 | ]) 22 | 23 | 24 | class AbsolutePositionalEmbedding(nn.Module): 25 | def __init__(self, dim, max_seq_len): 26 | super().__init__() 27 | self.emb = nn.Embedding(max_seq_len, dim) 28 | self.init_() 29 | 30 | def init_(self): 31 | nn.init.normal_(self.emb.weight, std=0.02) 32 | 33 | def forward(self, x): 34 | n = torch.arange(x.shape[1], device=x.device) 35 | return self.emb(n)[None, :, :] 36 | 37 | 38 | class FixedPositionalEmbedding(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 42 | self.register_buffer('inv_freq', inv_freq) 43 | 44 | def forward(self, x, seq_dim=1, offset=0): 45 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset 46 | sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) 47 | emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) 48 | return emb[None, :, :] 49 | 50 | 51 | # helpers 52 | 53 | def exists(val): 54 | return val is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def always(val): 64 | def inner(*args, **kwargs): 65 | return val 66 | return inner 67 | 68 | 69 | def not_equals(val): 70 | def inner(x): 71 | return x != val 72 | return inner 73 | 74 | 75 | def equals(val): 76 | def inner(x): 77 | return x == val 78 | return inner 79 | 80 | 81 | def max_neg_value(tensor): 82 | return -torch.finfo(tensor.dtype).max 83 | 84 | 85 | # keyword argument helpers 86 | 87 | def pick_and_pop(keys, d): 88 | values = list(map(lambda key: d.pop(key), keys)) 89 | return dict(zip(keys, values)) 90 | 91 | 92 | def group_dict_by_key(cond, d): 93 | return_val = [dict(), dict()] 94 | for key in d.keys(): 95 | match = bool(cond(key)) 96 | ind = int(not match) 97 | return_val[ind][key] = d[key] 98 | return (*return_val,) 99 | 100 | 101 | def string_begins_with(prefix, str): 102 | return str.startswith(prefix) 103 | 104 | 105 | def group_by_key_prefix(prefix, d): 106 | return group_dict_by_key(partial(string_begins_with, prefix), d) 107 | 108 | 109 | def groupby_prefix_and_trim(prefix, d): 110 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 111 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 112 | return kwargs_without_prefix, kwargs 113 | 114 | 115 | # classes 116 | class Scale(nn.Module): 117 | def __init__(self, value, fn): 118 | super().__init__() 119 | self.value = value 120 | self.fn = fn 121 | 122 | def forward(self, x, **kwargs): 123 | x, *rest = self.fn(x, **kwargs) 124 | return (x * self.value, *rest) 125 | 126 | 127 | class Rezero(nn.Module): 128 | def __init__(self, fn): 129 | super().__init__() 130 | self.fn = fn 131 | self.g = nn.Parameter(torch.zeros(1)) 132 | 133 | def forward(self, x, **kwargs): 134 | x, *rest = self.fn(x, **kwargs) 135 | return (x * self.g, *rest) 136 | 137 | 138 | class ScaleNorm(nn.Module): 139 | def __init__(self, dim, eps=1e-5): 140 | super().__init__() 141 | self.scale = dim ** -0.5 142 | self.eps = eps 143 | self.g = nn.Parameter(torch.ones(1)) 144 | 145 | def forward(self, x): 146 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 147 | return x / norm.clamp(min=self.eps) * self.g 148 | 149 | 150 | class RMSNorm(nn.Module): 151 | def __init__(self, dim, eps=1e-8): 152 | super().__init__() 153 | self.scale = dim ** -0.5 154 | self.eps = eps 155 | self.g = nn.Parameter(torch.ones(dim)) 156 | 157 | def forward(self, x): 158 | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale 159 | return x / norm.clamp(min=self.eps) * self.g 160 | 161 | 162 | class Residual(nn.Module): 163 | def forward(self, x, residual): 164 | return x + residual 165 | 166 | 167 | class GRUGating(nn.Module): 168 | def __init__(self, dim): 169 | super().__init__() 170 | self.gru = nn.GRUCell(dim, dim) 171 | 172 | def forward(self, x, residual): 173 | gated_output = self.gru( 174 | rearrange(x, 'b n d -> (b n) d'), 175 | rearrange(residual, 'b n d -> (b n) d') 176 | ) 177 | 178 | return gated_output.reshape_as(x) 179 | 180 | 181 | # feedforward 182 | 183 | class GEGLU(nn.Module): 184 | def __init__(self, dim_in, dim_out): 185 | super().__init__() 186 | self.proj = nn.Linear(dim_in, dim_out * 2) 187 | 188 | def forward(self, x): 189 | x, gate = self.proj(x).chunk(2, dim=-1) 190 | return x * F.gelu(gate) 191 | 192 | 193 | class FeedForward(nn.Module): 194 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 195 | super().__init__() 196 | inner_dim = int(dim * mult) 197 | dim_out = default(dim_out, dim) 198 | project_in = nn.Sequential( 199 | nn.Linear(dim, inner_dim), 200 | nn.GELU() 201 | ) if not glu else GEGLU(dim, inner_dim) 202 | 203 | self.net = nn.Sequential( 204 | project_in, 205 | nn.Dropout(dropout), 206 | nn.Linear(inner_dim, dim_out) 207 | ) 208 | 209 | def forward(self, x): 210 | return self.net(x) 211 | 212 | 213 | # attention. 214 | class Attention(nn.Module): 215 | def __init__( 216 | self, 217 | dim, 218 | dim_head=DEFAULT_DIM_HEAD, 219 | heads=8, 220 | causal=False, 221 | mask=None, 222 | talking_heads=False, 223 | sparse_topk=None, 224 | use_entmax15=False, 225 | num_mem_kv=0, 226 | dropout=0., 227 | on_attn=False 228 | ): 229 | super().__init__() 230 | if use_entmax15: 231 | raise NotImplementedError("Check out entmax activation instead of softmax activation!") 232 | self.scale = dim_head ** -0.5 233 | self.heads = heads 234 | self.causal = causal 235 | self.mask = mask 236 | 237 | inner_dim = dim_head * heads 238 | 239 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 240 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 241 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 242 | self.dropout = nn.Dropout(dropout) 243 | 244 | # talking heads 245 | self.talking_heads = talking_heads 246 | if talking_heads: 247 | self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 248 | self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) 249 | 250 | # explicit topk sparse attention 251 | self.sparse_topk = sparse_topk 252 | 253 | # entmax 254 | #self.attn_fn = entmax15 if use_entmax15 else F.softmax 255 | self.attn_fn = F.softmax 256 | 257 | # add memory key / values 258 | self.num_mem_kv = num_mem_kv 259 | if num_mem_kv > 0: 260 | self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 261 | self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) 262 | 263 | # attention on attention 264 | self.attn_on_attn = on_attn 265 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) 266 | 267 | def forward( 268 | self, 269 | x, 270 | context=None, 271 | mask=None, 272 | context_mask=None, 273 | rel_pos=None, 274 | sinusoidal_emb=None, 275 | prev_attn=None, 276 | mem=None 277 | ): 278 | b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device 279 | kv_input = default(context, x) 280 | 281 | q_input = x 282 | k_input = kv_input 283 | v_input = kv_input 284 | 285 | if exists(mem): 286 | k_input = torch.cat((mem, k_input), dim=-2) 287 | v_input = torch.cat((mem, v_input), dim=-2) 288 | 289 | if exists(sinusoidal_emb): 290 | # in shortformer, the query would start at a position offset depending on the past cached memory 291 | offset = k_input.shape[-2] - q_input.shape[-2] 292 | q_input = q_input + sinusoidal_emb(q_input, offset=offset) 293 | k_input = k_input + sinusoidal_emb(k_input) 294 | 295 | q = self.to_q(q_input) 296 | k = self.to_k(k_input) 297 | v = self.to_v(v_input) 298 | 299 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 300 | 301 | input_mask = None 302 | if any(map(exists, (mask, context_mask))): 303 | q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) 304 | k_mask = q_mask if not exists(context) else context_mask 305 | k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) 306 | q_mask = rearrange(q_mask, 'b i -> b () i ()') 307 | k_mask = rearrange(k_mask, 'b j -> b () () j') 308 | input_mask = q_mask * k_mask 309 | 310 | if self.num_mem_kv > 0: 311 | mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) 312 | k = torch.cat((mem_k, k), dim=-2) 313 | v = torch.cat((mem_v, v), dim=-2) 314 | if exists(input_mask): 315 | input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) 316 | 317 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 318 | mask_value = max_neg_value(dots) 319 | 320 | if exists(prev_attn): 321 | dots = dots + prev_attn 322 | 323 | pre_softmax_attn = dots 324 | 325 | if talking_heads: 326 | dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() 327 | 328 | if exists(rel_pos): 329 | dots = rel_pos(dots) 330 | 331 | if exists(input_mask): 332 | dots.masked_fill_(~input_mask, mask_value) 333 | del input_mask 334 | 335 | if self.causal: 336 | i, j = dots.shape[-2:] 337 | r = torch.arange(i, device=device) 338 | mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') 339 | mask = F.pad(mask, (j - i, 0), value=False) 340 | dots.masked_fill_(mask, mask_value) 341 | del mask 342 | 343 | if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: 344 | top, _ = dots.topk(self.sparse_topk, dim=-1) 345 | vk = top[..., -1].unsqueeze(-1).expand_as(dots) 346 | mask = dots < vk 347 | dots.masked_fill_(mask, mask_value) 348 | del mask 349 | 350 | attn = self.attn_fn(dots, dim=-1) 351 | post_softmax_attn = attn 352 | 353 | attn = self.dropout(attn) 354 | 355 | if talking_heads: 356 | attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() 357 | 358 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 359 | out = rearrange(out, 'b h n d -> b n (h d)') 360 | 361 | intermediates = Intermediates( 362 | pre_softmax_attn=pre_softmax_attn, 363 | post_softmax_attn=post_softmax_attn 364 | ) 365 | 366 | return self.to_out(out), intermediates 367 | 368 | 369 | class AttentionLayers(nn.Module): 370 | def __init__( 371 | self, 372 | dim, 373 | depth, 374 | heads=8, 375 | causal=False, 376 | cross_attend=False, 377 | only_cross=False, 378 | use_scalenorm=False, 379 | use_rmsnorm=False, 380 | use_rezero=False, 381 | rel_pos_num_buckets=32, 382 | rel_pos_max_distance=128, 383 | position_infused_attn=False, 384 | custom_layers=None, 385 | sandwich_coef=None, 386 | par_ratio=None, 387 | residual_attn=False, 388 | cross_residual_attn=False, 389 | macaron=False, 390 | pre_norm=True, 391 | gate_residual=False, 392 | **kwargs 393 | ): 394 | super().__init__() 395 | ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) 396 | attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) 397 | 398 | dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) 399 | 400 | self.dim = dim 401 | self.depth = depth 402 | self.layers = nn.ModuleList([]) 403 | 404 | self.has_pos_emb = position_infused_attn 405 | self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None 406 | self.rotary_pos_emb = always(None) 407 | 408 | assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' 409 | self.rel_pos = None 410 | 411 | self.pre_norm = pre_norm 412 | 413 | self.residual_attn = residual_attn 414 | self.cross_residual_attn = cross_residual_attn 415 | 416 | norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm 417 | norm_class = RMSNorm if use_rmsnorm else norm_class 418 | norm_fn = partial(norm_class, dim) 419 | 420 | norm_fn = nn.Identity if use_rezero else norm_fn 421 | branch_fn = Rezero if use_rezero else None 422 | 423 | if cross_attend and not only_cross: 424 | default_block = ('a', 'c', 'f') 425 | elif cross_attend and only_cross: 426 | default_block = ('c', 'f') 427 | else: 428 | default_block = ('a', 'f') 429 | 430 | if macaron: 431 | default_block = ('f',) + default_block 432 | 433 | if exists(custom_layers): 434 | layer_types = custom_layers 435 | elif exists(par_ratio): 436 | par_depth = depth * len(default_block) 437 | assert 1 < par_ratio <= par_depth, 'par ratio out of range' 438 | default_block = tuple(filter(not_equals('f'), default_block)) 439 | par_attn = par_depth // par_ratio 440 | depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper 441 | par_width = (depth_cut + depth_cut // par_attn) // par_attn 442 | assert len(default_block) <= par_width, 'default block is too large for par_ratio' 443 | par_block = default_block + ('f',) * (par_width - len(default_block)) 444 | par_head = par_block * par_attn 445 | layer_types = par_head + ('f',) * (par_depth - len(par_head)) 446 | elif exists(sandwich_coef): 447 | assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' 448 | layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef 449 | else: 450 | layer_types = default_block * depth 451 | 452 | self.layer_types = layer_types 453 | self.num_attn_layers = len(list(filter(equals('a'), layer_types))) 454 | 455 | for layer_type in self.layer_types: 456 | if layer_type == 'a': 457 | layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) 458 | elif layer_type == 'c': 459 | layer = Attention(dim, heads=heads, **attn_kwargs) 460 | elif layer_type == 'f': 461 | layer = FeedForward(dim, **ff_kwargs) 462 | layer = layer if not macaron else Scale(0.5, layer) 463 | else: 464 | raise Exception(f'invalid layer type {layer_type}') 465 | 466 | if isinstance(layer, Attention) and exists(branch_fn): 467 | layer = branch_fn(layer) 468 | 469 | if gate_residual: 470 | residual_fn = GRUGating(dim) 471 | else: 472 | residual_fn = Residual() 473 | 474 | self.layers.append(nn.ModuleList([ 475 | norm_fn(), 476 | layer, 477 | residual_fn 478 | ])) 479 | 480 | def forward( 481 | self, 482 | x, 483 | context=None, 484 | mask=None, 485 | context_mask=None, 486 | mems=None, 487 | return_hiddens=False 488 | ): 489 | hiddens = [] 490 | intermediates = [] 491 | prev_attn = None 492 | prev_cross_attn = None 493 | 494 | mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers 495 | 496 | for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): 497 | is_last = ind == (len(self.layers) - 1) 498 | 499 | if layer_type == 'a': 500 | hiddens.append(x) 501 | layer_mem = mems.pop(0) 502 | 503 | residual = x 504 | 505 | if self.pre_norm: 506 | x = norm(x) 507 | 508 | if layer_type == 'a': 509 | out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, 510 | prev_attn=prev_attn, mem=layer_mem) 511 | elif layer_type == 'c': 512 | out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) 513 | elif layer_type == 'f': 514 | out = block(x) 515 | 516 | x = residual_fn(out, residual) 517 | 518 | if layer_type in ('a', 'c'): 519 | intermediates.append(inter) 520 | 521 | if layer_type == 'a' and self.residual_attn: 522 | prev_attn = inter.pre_softmax_attn 523 | elif layer_type == 'c' and self.cross_residual_attn: 524 | prev_cross_attn = inter.pre_softmax_attn 525 | 526 | if not self.pre_norm and not is_last: 527 | x = norm(x) 528 | 529 | if return_hiddens: 530 | intermediates = LayerIntermediates( 531 | hiddens=hiddens, 532 | attn_intermediates=intermediates 533 | ) 534 | 535 | return x, intermediates 536 | 537 | return x 538 | 539 | 540 | class Encoder(AttentionLayers): 541 | def __init__(self, **kwargs): 542 | assert 'causal' not in kwargs, 'cannot set causality on encoder' 543 | super().__init__(causal=False, **kwargs) 544 | 545 | 546 | 547 | class TransformerWrapper(nn.Module): 548 | def __init__( 549 | self, 550 | *, 551 | num_tokens, 552 | max_seq_len, 553 | attn_layers, 554 | emb_dim=None, 555 | max_mem_len=0., 556 | emb_dropout=0., 557 | num_memory_tokens=None, 558 | tie_embedding=False, 559 | use_pos_emb=True 560 | ): 561 | super().__init__() 562 | assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' 563 | 564 | dim = attn_layers.dim 565 | emb_dim = default(emb_dim, dim) 566 | 567 | self.max_seq_len = max_seq_len 568 | self.max_mem_len = max_mem_len 569 | self.num_tokens = num_tokens 570 | 571 | self.token_emb = nn.Embedding(num_tokens, emb_dim) 572 | self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( 573 | use_pos_emb and not attn_layers.has_pos_emb) else always(0) 574 | self.emb_dropout = nn.Dropout(emb_dropout) 575 | 576 | self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() 577 | self.attn_layers = attn_layers 578 | self.norm = nn.LayerNorm(dim) 579 | 580 | self.init_() 581 | 582 | self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() 583 | 584 | # memory tokens (like [cls]) from Memory Transformers paper 585 | num_memory_tokens = default(num_memory_tokens, 0) 586 | self.num_memory_tokens = num_memory_tokens 587 | if num_memory_tokens > 0: 588 | self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) 589 | 590 | # let funnel encoder know number of memory tokens, if specified 591 | if hasattr(attn_layers, 'num_memory_tokens'): 592 | attn_layers.num_memory_tokens = num_memory_tokens 593 | 594 | def init_(self): 595 | nn.init.normal_(self.token_emb.weight, std=0.02) 596 | 597 | def forward( 598 | self, 599 | x, 600 | return_embeddings=False, 601 | mask=None, 602 | return_mems=False, 603 | return_attn=False, 604 | mems=None, 605 | **kwargs 606 | ): 607 | b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens 608 | x = self.token_emb(x) 609 | x += self.pos_emb(x) 610 | x = self.emb_dropout(x) 611 | 612 | x = self.project_emb(x) 613 | 614 | if num_mem > 0: 615 | mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) 616 | x = torch.cat((mem, x), dim=1) 617 | 618 | # auto-handle masking after appending memory tokens 619 | if exists(mask): 620 | mask = F.pad(mask, (num_mem, 0), value=True) 621 | 622 | x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) 623 | x = self.norm(x) 624 | 625 | mem, x = x[:, :num_mem], x[:, num_mem:] 626 | 627 | out = self.to_logits(x) if not return_embeddings else x 628 | 629 | if return_mems: 630 | hiddens = intermediates.hiddens 631 | new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens 632 | new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) 633 | return out, new_mems 634 | 635 | if return_attn: 636 | attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) 637 | return out, attn_maps 638 | 639 | return out -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import comfy 4 | import yaml 5 | import folder_paths 6 | 7 | from einops import rearrange 8 | from comfy import model_base, model_management, model_detection, latent_formats, model_sampling 9 | from .lvdm.modules.networks.openaimodel3d import UNetModel as DynamiCrafterUNetModel 10 | 11 | from .utils.model_utils import DynamiCrafterBase, DYNAMICRAFTER_CONFIG, \ 12 | load_image_proj_dict, load_dynamicrafter_dict, get_image_proj_model, load_vae_dict 13 | 14 | from .utils.utils import get_models_directory 15 | 16 | MODEL_DIR= "dynamicrafter_models" 17 | MODEL_DIR_PATH = os.path.join(folder_paths.models_dir, MODEL_DIR) 18 | 19 | class DynamiCrafterProcessor: 20 | @classmethod 21 | def INPUT_TYPES(s): 22 | return { 23 | "required": { 24 | "model": ("MODEL", ), 25 | "clip_vision": ("CLIP_VISION", ), 26 | "vae": ("VAE", ), 27 | "image_proj_model": ("IMAGE_PROJ_MODEL", ), 28 | "images": ("IMAGE", ), 29 | "use_interpolate": ("BOOLEAN", {"default": False}), 30 | "fps": ("INT", {"default": 15, "min": 1, "max": 30, "step": 1}, ), 31 | "frames": ("INT", {"default": 16}), 32 | "scale_latents": ("BOOLEAN", {"default": False}) 33 | }, 34 | } 35 | 36 | CATEGORY = "Native_DynamiCrafter/Processing" 37 | RETURN_TYPES = ("MODEL", "LATENT", "LATENT", ) 38 | RETURN_NAMES = ("model", "empty_latent", "latent_img", ) 39 | 40 | FUNCTION = "process_image_conditioning" 41 | 42 | def __init__(self): 43 | self.model_patcher = None 44 | 45 | # There is probably a better way to do this, but with the apply_model callback, this seems necessary. 46 | # The model gets wrapped around a CFG Denoiser class, and handles the conditioning parts there. 47 | # We cannot access it, so we must find the conditioning according to how ComfyUI handles it. 48 | def get_conditioning_pair(self, c_crossattn, use_cfg: bool): 49 | if not use_cfg: 50 | return c_crossattn 51 | 52 | conditioning_group = [] 53 | 54 | for i in range(c_crossattn.shape[0]): 55 | # Get the positive and negative conditioning. 56 | positive_idx = i + 1 57 | negative_idx = i 58 | 59 | if positive_idx >= c_crossattn.shape[0]: 60 | break 61 | 62 | if not torch.equal(c_crossattn[[positive_idx]], c_crossattn[[negative_idx]]): 63 | conditioning_group = [ 64 | c_crossattn[[positive_idx]], 65 | c_crossattn[[negative_idx]] 66 | ] 67 | break 68 | 69 | if len(conditioning_group) == 0: 70 | raise ValueError("Could not get the appropriate conditioning group.") 71 | 72 | return torch.cat(conditioning_group) 73 | 74 | # apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond} 75 | def _forward(self, *args): 76 | transformer_options = self.model_patcher.model_options['transformer_options'] 77 | conditioning = transformer_options['conditioning'] 78 | 79 | apply_model = args[0] 80 | 81 | # forward_dict 82 | fd = args[1] 83 | 84 | x, t, model_in_kwargs, _ = fd['input'], fd['timestep'], fd['c'], fd['cond_or_uncond'] 85 | 86 | c_crossattn = model_in_kwargs.pop("c_crossattn") 87 | c_concat = conditioning['c_concat'] 88 | num_video_frames = conditioning['num_video_frames'] 89 | fs = conditioning['fs'] 90 | 91 | 92 | original_num_frames = num_video_frames 93 | 94 | # Better way to determine if we're using CFG 95 | # The cond batch will always be num_frames >= 2 since we're doing video, 96 | # so we need get this condition differently here. 97 | if x.shape[0] > num_video_frames: 98 | num_video_frames *= 2 99 | batch_size = 2 100 | use_cfg = True 101 | else: 102 | use_cfg = False 103 | batch_size = 1 104 | 105 | if use_cfg: 106 | c_concat = torch.cat([c_concat] * 2) 107 | 108 | self.validate_forwardable_latent(x, c_concat, num_video_frames, use_cfg) 109 | 110 | x_in, c_concat = map(lambda xc: rearrange(xc, '(b t) c h w -> b c t h w', b=batch_size), (x, c_concat)) 111 | 112 | # We always assume video, so there will always be batched conditionings. 113 | c_crossattn = self.get_conditioning_pair(c_crossattn, use_cfg) 114 | c_crossattn = c_crossattn[:2] if use_cfg else c_crossattn[:1] 115 | context_in = c_crossattn 116 | 117 | img_embs = conditioning['image_emb'] 118 | 119 | if use_cfg: 120 | img_emb_uncond = conditioning['image_emb_uncond'] 121 | img_embs = torch.cat([img_embs, img_emb_uncond]) 122 | 123 | fs = torch.cat([fs] * x_in.shape[0]) 124 | 125 | outs = [] 126 | for i in range(batch_size): 127 | model_in_kwargs['transformer_options']['cond_idx'] = i 128 | x_out = apply_model( 129 | x_in[[i]], 130 | t=torch.cat([t[:1]]), 131 | context_in=context_in[[i]], 132 | c_crossattn=c_crossattn, 133 | cc_concat=c_concat[[i]], # "cc" is to handle naming conflict with apply_model wrapper. 134 | # We want to handle this in the UNet forward. 135 | num_video_frames=num_video_frames // 2 if batch_size > 1 else num_video_frames, 136 | img_emb=img_embs[[i]], 137 | fs=fs[[i]], 138 | **model_in_kwargs 139 | ) 140 | outs.append(x_out) 141 | 142 | x_out = torch.cat(list(reversed(outs))) 143 | x_out = rearrange(x_out, 'b c t h w -> (b t) c h w') 144 | 145 | return x_out 146 | 147 | def assign_forward_args( 148 | self, 149 | model, 150 | c_concat, 151 | image_emb, 152 | image_emb_uncond, 153 | fs, 154 | frames, 155 | ): 156 | model.model_options['transformer_options']['conditioning'] = { 157 | "c_concat": c_concat, 158 | "image_emb": image_emb, 159 | 'image_emb_uncond': image_emb_uncond, 160 | "fs": fs, 161 | "num_video_frames": frames, 162 | } 163 | 164 | def validate_forwardable_latent(self, latent, c_concat, num_video_frames, use_cfg): 165 | check_no_cfg = latent.shape[0] != num_video_frames 166 | check_with_cfg = latent.shape[0] != (num_video_frames * 2) 167 | 168 | latent_batch_size = latent.shape[0] if not use_cfg else latent.shape[0] // 2 169 | num_frames = num_video_frames if not use_cfg else num_video_frames // 2 170 | 171 | if all([check_no_cfg, check_with_cfg]): 172 | raise ValueError( 173 | "Please make sure your latent inputs match the number of frames in the DynamiCrafter Processor." 174 | f"Got a latent batch size of ({latent_batch_size}) with number of frames being ({num_frames})." 175 | ) 176 | 177 | latent_h, latent_w = latent.shape[-2:] 178 | c_concat_h, c_concat_w = c_concat.shape[-2:] 179 | 180 | if not all([latent_h == c_concat_h, latent_w == c_concat_w]): 181 | raise ValueError( 182 | "Please make sure that your input latent and image frames are the same height and width.", 183 | f"Image Size: {c_concat_w * 8}, {c_concat_h * 8}, Latent Size: {latent_h * 8}, {latent_w * 8}" 184 | ) 185 | 186 | def process_image_conditioning( 187 | self, 188 | model, 189 | clip_vision, 190 | vae, 191 | image_proj_model, 192 | images, 193 | use_interpolate, 194 | fps: int, 195 | frames: int, 196 | scale_latents: bool 197 | ): 198 | self.model_patcher = model 199 | encoded_latent = vae.encode(images[:, :, :, :3]) 200 | 201 | encoded_image = clip_vision.encode_image(images[:1])['last_hidden_state'] 202 | image_emb = image_proj_model(encoded_image) 203 | 204 | encoded_image_uncond = clip_vision.encode_image(torch.zeros_like(images)[:1])['last_hidden_state'] 205 | image_emb_uncond = image_proj_model(encoded_image_uncond) 206 | 207 | c_concat = encoded_latent 208 | 209 | if scale_latents: 210 | vae_process_input = vae.process_input 211 | vae.process_input = lambda image: (image - .5) * 2 212 | c_concat = vae.encode(images[:, :, :, :3]) 213 | vae.process_input = vae_process_input 214 | c_concat = model.model.process_latent_in(c_concat) * 1.3 215 | else: 216 | c_concat = model.model.process_latent_in(c_concat) 217 | 218 | fs = torch.tensor([fps], dtype=torch.long, device=model_management.intermediate_device()) 219 | 220 | model.set_model_unet_function_wrapper(self._forward) 221 | 222 | used_interpolate_processing = False 223 | 224 | if use_interpolate and frames > 16: 225 | raise ValueError( 226 | "When using interpolation mode, the maximum amount of frames are 16." 227 | "If you're doing long video generation, consider using the last frame\ 228 | from the first generation for the next one (autoregressive)." 229 | ) 230 | if encoded_latent.shape[0] == 1: 231 | c_concat = torch.cat([c_concat] * frames, dim=0)[:frames] 232 | 233 | if use_interpolate: 234 | mask = torch.zeros_like(c_concat) 235 | mask[:1] = c_concat[:1] 236 | c_concat = mask 237 | 238 | used_interpolate_processing = True 239 | else: 240 | if use_interpolate and c_concat.shape[0] in [2, 3]: 241 | input_frame_count = c_concat.shape[0] 242 | 243 | # We're just padding to the same type an size of the concat 244 | masked_frames = torch.zeros_like(torch.cat([c_concat[:1]] * frames))[:frames] 245 | 246 | # Start frame 247 | masked_frames[:1] = c_concat[:1] 248 | 249 | end_frame_idx = -1 250 | 251 | # TODO 252 | speed = 1.0 253 | if speed < 1.0: 254 | possible_speeds = list(torch.linspace(0, 1.0, c_concat.shape[0])) 255 | speed_from_frames = enumerate(possible_speeds) 256 | speed_idx = min(speed_from_frames, key=lambda n: n[1] - speed)[0] 257 | end_frame_idx = speed_idx 258 | 259 | # End frame 260 | masked_frames[-1:] = c_concat[[end_frame_idx]] 261 | 262 | # Possible middle frame, but not working at the moment. 263 | if input_frame_count == 3: 264 | middle_idx = masked_frames.shape[0] // 2 265 | middle_idx_frame = c_concat.shape[0] // 2 266 | masked_frames[[middle_idx]] = c_concat[[middle_idx_frame]] 267 | 268 | c_concat = masked_frames 269 | used_interpolate_processing = True 270 | 271 | print(f"Using interpolation mode with {input_frame_count} frames.") 272 | 273 | if c_concat.shape[0] < frames and not used_interpolate_processing: 274 | print( 275 | "Multiple images found, but interpolation mode is unset. Using the first frame as condition.", 276 | ) 277 | c_concat = torch.cat([c_concat[:1]] * frames) 278 | 279 | c_concat = c_concat[:frames] 280 | 281 | if encoded_latent.shape[0] == 1: 282 | encoded_latent = torch.cat([encoded_latent] * frames)[:frames] 283 | 284 | if encoded_latent.shape[0] < frames and encoded_latent.shape[0] != 1: 285 | encoded_latent = torch.cat( 286 | [encoded_latent] + [encoded_latent[-1:]] * abs(encoded_latent.shape[0] - frames) 287 | )[:frames] 288 | 289 | # We could store this as a state in this Node Class Instance, but to prevent any weird edge cases, 290 | # this should always be passed through the 'stateless' way, and let ComfyUI handle the transformer_options state. 291 | self.assign_forward_args(model, c_concat, image_emb, image_emb_uncond, fs, frames) 292 | 293 | return (model, {"samples": torch.zeros_like(c_concat)}, {"samples": encoded_latent},) 294 | 295 | class DynamiCrafterLoader: 296 | @classmethod 297 | def INPUT_TYPES(s): 298 | return { 299 | "required": { 300 | "model_path": (get_models_directory(os.listdir(MODEL_DIR_PATH)), ), 301 | }, 302 | } 303 | 304 | CATEGORY = "Native_DynamiCrafter/Loaders" 305 | RETURN_TYPES = ("MODEL", "IMAGE_PROJ_MODEL", ) 306 | RETURN_NAMES = ("model", "image_proj_model", ) 307 | FUNCTION = "load_dynamicrafter" 308 | 309 | def load_model_sicts(self, model_path: str): 310 | model_state_dict = comfy.utils.load_torch_file(model_path) 311 | dynamicrafter_dict = load_dynamicrafter_dict(model_state_dict) 312 | image_proj_dict = load_image_proj_dict(model_state_dict) 313 | 314 | return dynamicrafter_dict, image_proj_dict 315 | 316 | def get_prediction_type(self, is_eps: bool, model_config): 317 | if not is_eps and "image_cross_attention_scale_learnable" in model_config.unet_config.keys(): 318 | model_config.unet_config["image_cross_attention_scale_learnable"] = False 319 | 320 | return model_base.ModelType.EPS if is_eps else model_base.ModelType.V_PREDICTION 321 | 322 | def handle_model_management(self, dynamicrafter_dict: dict, model_config): 323 | parameters = comfy.utils.calculate_parameters(dynamicrafter_dict, "model.diffusion_model.") 324 | load_device = model_management.get_torch_device() 325 | unet_dtype = model_management.unet_dtype( 326 | model_params=parameters, 327 | supported_dtypes=model_config.supported_inference_dtypes 328 | ) 329 | manual_cast_dtype = model_management.unet_manual_cast( 330 | unet_dtype, 331 | load_device, 332 | model_config.supported_inference_dtypes 333 | ) 334 | model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) 335 | inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) 336 | offload_device = model_management.unet_offload_device() 337 | 338 | return load_device, inital_load_device 339 | 340 | def check_leftover_keys(self, state_dict: dict): 341 | left_over = state_dict.keys() 342 | if len(left_over) > 0: 343 | print("left over keys:", left_over) 344 | 345 | def load_dynamicrafter(self, model_path): 346 | model_path = os.path.join(MODEL_DIR_PATH, model_path) 347 | 348 | if os.path.exists(model_path): 349 | dynamicrafter_dict, image_proj_dict = self.load_model_sicts(model_path) 350 | model_config = DynamiCrafterBase(DYNAMICRAFTER_CONFIG) 351 | 352 | dynamicrafter_dict, is_eps = model_config.process_dict_version(state_dict=dynamicrafter_dict) 353 | 354 | MODEL_TYPE = self.get_prediction_type(is_eps, model_config) 355 | load_device, inital_load_device = self.handle_model_management(dynamicrafter_dict, model_config) 356 | 357 | model = model_base.BaseModel( 358 | model_config, 359 | model_type=MODEL_TYPE, 360 | device=inital_load_device, 361 | unet_model=DynamiCrafterUNetModel 362 | ) 363 | 364 | image_proj_model = get_image_proj_model(image_proj_dict) 365 | model.load_model_weights(dynamicrafter_dict, "model.diffusion_model.") 366 | self.check_leftover_keys(dynamicrafter_dict) 367 | 368 | model_patcher = comfy.model_patcher.ModelPatcher( 369 | model, 370 | load_device=load_device, 371 | offload_device=model_management.unet_offload_device(), 372 | current_device=inital_load_device 373 | ) 374 | 375 | return (model_patcher, image_proj_model, ) 376 | 377 | NODE_CLASS_MAPPINGS = { 378 | "DynamiCrafterLoader": DynamiCrafterLoader, 379 | "DynamiCrafterProcessor": DynamiCrafterProcessor, 380 | } 381 | 382 | NODE_DISPLAY_NAME_MAPPINGS = { 383 | "DynamiCrafterLoader": "Load a DynamiCrafter Checkpoint", 384 | "DynamiCrafterProcessor": "Apply DynamiCrafter", 385 | } 386 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from collections import OrderedDict 5 | 6 | from comfy import model_base 7 | from comfy import utils 8 | 9 | from comfy import sd2_clip 10 | 11 | from comfy import supported_models_base 12 | from comfy import latent_formats 13 | 14 | from ..lvdm.modules.encoders.resampler import Resampler 15 | 16 | DYNAMICRAFTER_CONFIG = { 17 | 'in_channels': 8, 18 | 'out_channels': 4, 19 | 'model_channels': 320, 20 | 'attention_resolutions': [4, 2, 1], 21 | 'num_res_blocks': 2, 22 | 'channel_mult': [1, 2, 4, 4], 23 | 'num_head_channels': 64, 24 | 'transformer_depth': 1, 25 | 'context_dim': 1024, 26 | 'use_linear': True, 27 | 'use_checkpoint': False, 28 | 'temporal_conv': True, 29 | 'temporal_attention': True, 30 | 'temporal_selfatt_only': True, 31 | 'use_relative_position': False, 32 | 'use_causal_attention': False, 33 | 'temporal_length': 16, 34 | 'addition_attention': True, 35 | 'image_cross_attention': True, 36 | 'image_cross_attention_scale_learnable': True, 37 | 'default_fs': 3, 38 | 'fs_condition': True 39 | } 40 | 41 | IMAGE_PROJ_CONFIG = { 42 | "dim": 1024, 43 | "depth": 4, 44 | "dim_head": 64, 45 | "heads": 12, 46 | "num_queries": 16, 47 | "embedding_dim": 1280, 48 | "output_dim": 1024, 49 | "ff_mult": 4, 50 | "video_length": 16 51 | } 52 | 53 | def process_list_or_str(target_key_or_keys, k): 54 | if isinstance(target_key_or_keys, list): 55 | return any([list_k in k for list_k in target_key_or_keys]) 56 | else: 57 | return target_key_or_keys in k 58 | 59 | def simple_state_dict_loader(state_dict: dict, target_key: str, target_dict: dict = None): 60 | out_dict = {} 61 | 62 | if target_dict is None: 63 | for k, v in state_dict.items(): 64 | if process_list_or_str(target_key, k): 65 | out_dict[k] = v 66 | else: 67 | for k, v in target_dict.items(): 68 | out_dict[k] = state_dict[k] 69 | 70 | return out_dict 71 | 72 | def load_image_proj_dict(state_dict: dict): 73 | return simple_state_dict_loader(state_dict, 'image_proj') 74 | 75 | def load_dynamicrafter_dict(state_dict: dict): 76 | return simple_state_dict_loader(state_dict, 'model.diffusion_model') 77 | 78 | def load_vae_dict(state_dict: dict): 79 | return simple_state_dict_loader(state_dict, 'first_stage_model') 80 | 81 | def get_base_model(state_dict: dict, version_checker=False): 82 | 83 | is_256_model = False 84 | 85 | for k in state_dict.keys(): 86 | if "framestride_embed" in k: 87 | is_256_model = True 88 | break 89 | 90 | def get_image_proj_model(state_dict: dict): 91 | 92 | state_dict = {k.replace('image_proj_model.', ''): v for k, v in state_dict.items()} 93 | #target_dict = Resampler().state_dict() 94 | 95 | ImageProjModel = Resampler(**IMAGE_PROJ_CONFIG) 96 | ImageProjModel.load_state_dict(state_dict) 97 | 98 | print("Image Projection Model loaded successfully") 99 | #del target_dict 100 | return ImageProjModel 101 | 102 | class DynamiCrafterBase(supported_models_base.BASE): 103 | unet_config = {} 104 | unet_extra_config = {} 105 | 106 | latent_format = latent_formats.SD15 107 | 108 | def process_clip_state_dict(self, state_dict): 109 | replace_prefix = {} 110 | replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format 111 | replace_prefix["cond_stage_model.model."] = "clip_h." 112 | state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True) 113 | state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.") 114 | return state_dict 115 | 116 | def process_clip_state_dict_for_saving(self, state_dict): 117 | replace_prefix = {} 118 | replace_prefix["clip_h"] = "cond_stage_model.model" 119 | state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) 120 | state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) 121 | return state_dict 122 | 123 | def clip_target(self): 124 | return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel) 125 | 126 | def process_dict_version(self, state_dict: dict): 127 | processed_dict = OrderedDict() 128 | is_eps = False 129 | 130 | for k in list(state_dict.keys()): 131 | if "framestride_embed" in k: 132 | new_key = k.replace("framestride_embed", "fps_embedding") 133 | processed_dict[new_key] = state_dict[k] 134 | is_eps = True 135 | continue 136 | 137 | processed_dict[k] = state_dict[k] 138 | 139 | return processed_dict, is_eps 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.distributed as dist 6 | 7 | MODEL_EXTS = ['ckpt', 'safetensors', 'bin'] 8 | 9 | def get_models_directory(directory: list): 10 | files_list = list(filter(lambda f: f.split(".")[-1] in MODEL_EXTS, directory)) 11 | return files_list 12 | 13 | def count_params(model, verbose=False): 14 | total_params = sum(p.numel() for p in model.parameters()) 15 | if verbose: 16 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 17 | return total_params 18 | 19 | 20 | def check_istarget(name, para_list): 21 | """ 22 | name: full name of source para 23 | para_list: partial name of target para 24 | """ 25 | istarget=False 26 | for para in para_list: 27 | if para in name: 28 | return True 29 | return istarget 30 | 31 | 32 | def instantiate_from_config(config): 33 | if not "target" in config: 34 | if config == '__is_first_stage__': 35 | return None 36 | elif config == "__is_unconditional__": 37 | return None 38 | raise KeyError("Expected key `target` to instantiate.") 39 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 40 | 41 | 42 | def get_obj_from_str(string, reload=False): 43 | module, cls = string.rsplit(".", 1) 44 | if reload: 45 | module_imp = importlib.import_module(module) 46 | importlib.reload(module_imp) 47 | return getattr(importlib.import_module(module, package=None), cls) 48 | 49 | 50 | def load_npz_from_dir(data_dir): 51 | data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] 52 | data = np.concatenate(data, axis=0) 53 | return data 54 | 55 | 56 | def load_npz_from_paths(data_paths): 57 | data = [np.load(data_path)['arr_0'] for data_path in data_paths] 58 | data = np.concatenate(data, axis=0) 59 | return data 60 | 61 | 62 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 63 | h, w = image.shape[:2] 64 | if resize_short_edge is not None: 65 | k = resize_short_edge / min(h, w) 66 | else: 67 | k = max_resolution / (h * w) 68 | k = k**0.5 69 | h = int(np.round(h * k / 64)) * 64 70 | w = int(np.round(w * k / 64)) * 64 71 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 72 | return image 73 | 74 | 75 | def setup_dist(args): 76 | if dist.is_initialized(): 77 | return 78 | torch.cuda.set_device(args.local_rank) 79 | torch.distributed.init_process_group( 80 | 'nccl', 81 | init_method='env://' 82 | ) -------------------------------------------------------------------------------- /workflows/dynamicrafter_512_basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 52, 3 | "last_link_id": 102, 4 | "nodes": [ 5 | { 6 | "id": 34, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | 780, 10 | 800 11 | ], 12 | "size": { 13 | "0": 400, 14 | "1": 200 15 | }, 16 | "flags": {}, 17 | "order": 6, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 78 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 74 32 | ], 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "CLIPTextEncode" 38 | }, 39 | "widgets_values": [ 40 | "bad quality, blurry" 41 | ] 42 | }, 43 | { 44 | "id": 35, 45 | "type": "EmptyLatentImage", 46 | "pos": [ 47 | 1335, 48 | 1037 49 | ], 50 | "size": { 51 | "0": 315, 52 | "1": 106 53 | }, 54 | "flags": {}, 55 | "order": 0, 56 | "mode": 0, 57 | "outputs": [ 58 | { 59 | "name": "LATENT", 60 | "type": "LATENT", 61 | "links": [ 62 | 75 63 | ], 64 | "shape": 3 65 | } 66 | ], 67 | "properties": { 68 | "Node name for S&R": "EmptyLatentImage" 69 | }, 70 | "widgets_values": [ 71 | 1024, 72 | 640, 73 | 1 74 | ] 75 | }, 76 | { 77 | "id": 37, 78 | "type": "VAEDecode", 79 | "pos": [ 80 | 1721, 81 | 519 82 | ], 83 | "size": { 84 | "0": 210, 85 | "1": 46 86 | }, 87 | "flags": {}, 88 | "order": 13, 89 | "mode": 0, 90 | "inputs": [ 91 | { 92 | "name": "samples", 93 | "type": "LATENT", 94 | "link": 79 95 | }, 96 | { 97 | "name": "vae", 98 | "type": "VAE", 99 | "link": 80 100 | } 101 | ], 102 | "outputs": [ 103 | { 104 | "name": "IMAGE", 105 | "type": "IMAGE", 106 | "links": [ 107 | 81, 108 | 84 109 | ], 110 | "shape": 3, 111 | "slot_index": 0 112 | } 113 | ], 114 | "properties": { 115 | "Node name for S&R": "VAEDecode" 116 | } 117 | }, 118 | { 119 | "id": 48, 120 | "type": "VAEDecode", 121 | "pos": [ 122 | 3840, 123 | 490 124 | ], 125 | "size": { 126 | "0": 210, 127 | "1": 46 128 | }, 129 | "flags": {}, 130 | "order": 18, 131 | "mode": 0, 132 | "inputs": [ 133 | { 134 | "name": "samples", 135 | "type": "LATENT", 136 | "link": 94 137 | }, 138 | { 139 | "name": "vae", 140 | "type": "VAE", 141 | "link": 95 142 | } 143 | ], 144 | "outputs": [ 145 | { 146 | "name": "IMAGE", 147 | "type": "IMAGE", 148 | "links": [ 149 | 96 150 | ], 151 | "shape": 3, 152 | "slot_index": 0 153 | } 154 | ], 155 | "properties": { 156 | "Node name for S&R": "VAEDecode" 157 | } 158 | }, 159 | { 160 | "id": 51, 161 | "type": "CLIPTextEncode", 162 | "pos": [ 163 | 3030, 164 | 920 165 | ], 166 | "size": { 167 | "0": 400, 168 | "1": 200 169 | }, 170 | "flags": {}, 171 | "order": 12, 172 | "mode": 0, 173 | "inputs": [ 174 | { 175 | "name": "clip", 176 | "type": "CLIP", 177 | "link": 98 178 | } 179 | ], 180 | "outputs": [ 181 | { 182 | "name": "CONDITIONING", 183 | "type": "CONDITIONING", 184 | "links": [ 185 | 100 186 | ], 187 | "shape": 3, 188 | "slot_index": 0 189 | } 190 | ], 191 | "properties": { 192 | "Node name for S&R": "CLIPTextEncode" 193 | }, 194 | "widgets_values": [ 195 | "bad quality" 196 | ] 197 | }, 198 | { 199 | "id": 41, 200 | "type": "ImageScale", 201 | "pos": [ 202 | 1954, 203 | 517 204 | ], 205 | "size": { 206 | "0": 315, 207 | "1": 130 208 | }, 209 | "flags": {}, 210 | "order": 15, 211 | "mode": 0, 212 | "inputs": [ 213 | { 214 | "name": "image", 215 | "type": "IMAGE", 216 | "link": 84 217 | } 218 | ], 219 | "outputs": [ 220 | { 221 | "name": "IMAGE", 222 | "type": "IMAGE", 223 | "links": [ 224 | 85 225 | ], 226 | "shape": 3, 227 | "slot_index": 0 228 | } 229 | ], 230 | "properties": { 231 | "Node name for S&R": "ImageScale" 232 | }, 233 | "widgets_values": [ 234 | "nearest-exact", 235 | 512, 236 | 320, 237 | "disabled" 238 | ] 239 | }, 240 | { 241 | "id": 38, 242 | "type": "PreviewImage", 243 | "pos": [ 244 | 1696, 245 | 716 246 | ], 247 | "size": [ 248 | 587.2334716796877, 249 | 436.29796447753915 250 | ], 251 | "flags": {}, 252 | "order": 14, 253 | "mode": 0, 254 | "inputs": [ 255 | { 256 | "name": "images", 257 | "type": "IMAGE", 258 | "link": 81 259 | } 260 | ], 261 | "properties": { 262 | "Node name for S&R": "PreviewImage" 263 | } 264 | }, 265 | { 266 | "id": 50, 267 | "type": "CLIPTextEncode", 268 | "pos": [ 269 | 3030, 270 | 670 271 | ], 272 | "size": { 273 | "0": 400, 274 | "1": 200 275 | }, 276 | "flags": {}, 277 | "order": 11, 278 | "mode": 0, 279 | "inputs": [ 280 | { 281 | "name": "clip", 282 | "type": "CLIP", 283 | "link": 97 284 | } 285 | ], 286 | "outputs": [ 287 | { 288 | "name": "CONDITIONING", 289 | "type": "CONDITIONING", 290 | "links": [ 291 | 99 292 | ], 293 | "shape": 3, 294 | "slot_index": 0 295 | } 296 | ], 297 | "properties": { 298 | "Node name for S&R": "CLIPTextEncode" 299 | }, 300 | "widgets_values": [ 301 | "a car is drifting on the road, tire smoke, beautiful scenery, slow motion" 302 | ] 303 | }, 304 | { 305 | "id": 32, 306 | "type": "KSampler", 307 | "pos": [ 308 | 1340, 309 | 520 310 | ], 311 | "size": { 312 | "0": 315, 313 | "1": 474 314 | }, 315 | "flags": {}, 316 | "order": 10, 317 | "mode": 0, 318 | "inputs": [ 319 | { 320 | "name": "model", 321 | "type": "MODEL", 322 | "link": 76, 323 | "slot_index": 0 324 | }, 325 | { 326 | "name": "positive", 327 | "type": "CONDITIONING", 328 | "link": 73, 329 | "slot_index": 1 330 | }, 331 | { 332 | "name": "negative", 333 | "type": "CONDITIONING", 334 | "link": 74, 335 | "slot_index": 2 336 | }, 337 | { 338 | "name": "latent_image", 339 | "type": "LATENT", 340 | "link": 75, 341 | "slot_index": 3 342 | } 343 | ], 344 | "outputs": [ 345 | { 346 | "name": "LATENT", 347 | "type": "LATENT", 348 | "links": [ 349 | 79 350 | ], 351 | "shape": 3, 352 | "slot_index": 0 353 | } 354 | ], 355 | "properties": { 356 | "Node name for S&R": "KSampler" 357 | }, 358 | "widgets_values": [ 359 | 665295996397955, 360 | "randomize", 361 | 20, 362 | 8, 363 | "dpmpp_2m", 364 | "karras", 365 | 1 366 | ] 367 | }, 368 | { 369 | "id": 33, 370 | "type": "CLIPTextEncode", 371 | "pos": [ 372 | 780, 373 | 540 374 | ], 375 | "size": { 376 | "0": 400, 377 | "1": 200 378 | }, 379 | "flags": {}, 380 | "order": 5, 381 | "mode": 0, 382 | "inputs": [ 383 | { 384 | "name": "clip", 385 | "type": "CLIP", 386 | "link": 77 387 | } 388 | ], 389 | "outputs": [ 390 | { 391 | "name": "CONDITIONING", 392 | "type": "CONDITIONING", 393 | "links": [ 394 | 73 395 | ], 396 | "shape": 3 397 | } 398 | ], 399 | "properties": { 400 | "Node name for S&R": "CLIPTextEncode" 401 | }, 402 | "widgets_values": [ 403 | "a high quality photo of a car drifting at sunset, depth of field, best quality, deep blacks, rendered in unreal engine 5" 404 | ] 405 | }, 406 | { 407 | "id": 36, 408 | "type": "CheckpointLoaderSimple", 409 | "pos": [ 410 | 303, 411 | 656 412 | ], 413 | "size": { 414 | "0": 315, 415 | "1": 98 416 | }, 417 | "flags": {}, 418 | "order": 1, 419 | "mode": 0, 420 | "outputs": [ 421 | { 422 | "name": "MODEL", 423 | "type": "MODEL", 424 | "links": [ 425 | 76 426 | ], 427 | "shape": 3 428 | }, 429 | { 430 | "name": "CLIP", 431 | "type": "CLIP", 432 | "links": [ 433 | 77, 434 | 78 435 | ], 436 | "shape": 3, 437 | "slot_index": 1 438 | }, 439 | { 440 | "name": "VAE", 441 | "type": "VAE", 442 | "links": [ 443 | 80, 444 | 86 445 | ], 446 | "shape": 3, 447 | "slot_index": 2 448 | } 449 | ], 450 | "properties": { 451 | "Node name for S&R": "CheckpointLoaderSimple" 452 | }, 453 | "widgets_values": [ 454 | "realisticVisionV60B1_v51VAE.safetensors" 455 | ] 456 | }, 457 | { 458 | "id": 42, 459 | "type": "Reroute", 460 | "pos": [ 461 | 2688, 462 | 625 463 | ], 464 | "size": [ 465 | 75, 466 | 26 467 | ], 468 | "flags": {}, 469 | "order": 7, 470 | "mode": 0, 471 | "inputs": [ 472 | { 473 | "name": "", 474 | "type": "*", 475 | "link": 86 476 | } 477 | ], 478 | "outputs": [ 479 | { 480 | "name": "", 481 | "type": "VAE", 482 | "links": [ 483 | 87, 484 | 95 485 | ], 486 | "slot_index": 0 487 | } 488 | ], 489 | "properties": { 490 | "showOutputText": false, 491 | "horizontal": false 492 | } 493 | }, 494 | { 495 | "id": 39, 496 | "type": "DynamiCrafterLoader", 497 | "pos": [ 498 | 2327, 499 | 488 500 | ], 501 | "size": { 502 | "0": 315, 503 | "1": 78 504 | }, 505 | "flags": {}, 506 | "order": 2, 507 | "mode": 0, 508 | "outputs": [ 509 | { 510 | "name": "model", 511 | "type": "MODEL", 512 | "links": [ 513 | 101 514 | ], 515 | "shape": 3, 516 | "slot_index": 0 517 | }, 518 | { 519 | "name": "image_proj_model", 520 | "type": "IMAGE_PROJ_MODEL", 521 | "links": [ 522 | 89 523 | ], 524 | "shape": 3, 525 | "slot_index": 1 526 | } 527 | ], 528 | "properties": { 529 | "Node name for S&R": "DynamiCrafterLoader" 530 | }, 531 | "widgets_values": [ 532 | "dynamicrafter_512.safetensors" 533 | ] 534 | }, 535 | { 536 | "id": 52, 537 | "type": "RescaleCFG", 538 | "pos": [ 539 | 2682, 540 | 495 541 | ], 542 | "size": { 543 | "0": 315, 544 | "1": 58 545 | }, 546 | "flags": {}, 547 | "order": 8, 548 | "mode": 0, 549 | "inputs": [ 550 | { 551 | "name": "model", 552 | "type": "MODEL", 553 | "link": 101 554 | } 555 | ], 556 | "outputs": [ 557 | { 558 | "name": "MODEL", 559 | "type": "MODEL", 560 | "links": [ 561 | 102 562 | ], 563 | "shape": 3, 564 | "slot_index": 0 565 | } 566 | ], 567 | "properties": { 568 | "Node name for S&R": "RescaleCFG" 569 | }, 570 | "widgets_values": [ 571 | 0.7 572 | ] 573 | }, 574 | { 575 | "id": 43, 576 | "type": "CLIPVisionLoader", 577 | "pos": [ 578 | 2684, 579 | 391 580 | ], 581 | "size": { 582 | "0": 315, 583 | "1": 58 584 | }, 585 | "flags": {}, 586 | "order": 3, 587 | "mode": 0, 588 | "outputs": [ 589 | { 590 | "name": "CLIP_VISION", 591 | "type": "CLIP_VISION", 592 | "links": [ 593 | 88 594 | ], 595 | "shape": 3, 596 | "slot_index": 0 597 | } 598 | ], 599 | "properties": { 600 | "Node name for S&R": "CLIPVisionLoader" 601 | }, 602 | "widgets_values": [ 603 | "open_clip_pytorch_model.bin" 604 | ] 605 | }, 606 | { 607 | "id": 44, 608 | "type": "CLIPSetLastLayer", 609 | "pos": [ 610 | 2672, 611 | 863 612 | ], 613 | "size": { 614 | "0": 315, 615 | "1": 58 616 | }, 617 | "flags": {}, 618 | "order": 9, 619 | "mode": 0, 620 | "inputs": [ 621 | { 622 | "name": "clip", 623 | "type": "CLIP", 624 | "link": 90, 625 | "slot_index": 0 626 | } 627 | ], 628 | "outputs": [ 629 | { 630 | "name": "CLIP", 631 | "type": "CLIP", 632 | "links": [ 633 | 97, 634 | 98 635 | ], 636 | "shape": 3, 637 | "slot_index": 0 638 | } 639 | ], 640 | "properties": { 641 | "Node name for S&R": "CLIPSetLastLayer" 642 | }, 643 | "widgets_values": [ 644 | -2 645 | ] 646 | }, 647 | { 648 | "id": 45, 649 | "type": "CLIPLoader", 650 | "pos": [ 651 | 2670, 652 | 733 653 | ], 654 | "size": { 655 | "0": 315, 656 | "1": 82 657 | }, 658 | "flags": {}, 659 | "order": 4, 660 | "mode": 0, 661 | "outputs": [ 662 | { 663 | "name": "CLIP", 664 | "type": "CLIP", 665 | "links": [ 666 | 90 667 | ], 668 | "shape": 3 669 | } 670 | ], 671 | "properties": { 672 | "Node name for S&R": "CLIPLoader" 673 | }, 674 | "widgets_values": [ 675 | "model.safetensors", 676 | "stable_diffusion" 677 | ] 678 | }, 679 | { 680 | "id": 49, 681 | "type": "VHS_VideoCombine", 682 | "pos": [ 683 | 4160, 684 | 500 685 | ], 686 | "size": [ 687 | 561.6313574218739, 688 | 586.2695983886712 689 | ], 690 | "flags": {}, 691 | "order": 19, 692 | "mode": 0, 693 | "inputs": [ 694 | { 695 | "name": "images", 696 | "type": "IMAGE", 697 | "link": 96 698 | } 699 | ], 700 | "outputs": [], 701 | "properties": { 702 | "Node name for S&R": "VHS_VideoCombine" 703 | }, 704 | "widgets_values": { 705 | "frame_rate": 8, 706 | "loop_count": 0, 707 | "filename_prefix": "DynamiCrafter", 708 | "format": "video/h264-mp4", 709 | "pingpong": false, 710 | "save_image": true, 711 | "crf": 20, 712 | "save_metadata": true, 713 | "audio_file": "", 714 | "videopreview": { 715 | "hidden": false, 716 | "paused": false, 717 | "params": { 718 | "filename": "DynamiCrafter_00013.mp4", 719 | "subfolder": "", 720 | "type": "output", 721 | "format": "video/h264-mp4" 722 | } 723 | } 724 | } 725 | }, 726 | { 727 | "id": 47, 728 | "type": "KSampler", 729 | "pos": [ 730 | 3470, 731 | 490 732 | ], 733 | "size": [ 734 | 320, 735 | 474 736 | ], 737 | "flags": {}, 738 | "order": 17, 739 | "mode": 0, 740 | "inputs": [ 741 | { 742 | "name": "model", 743 | "type": "MODEL", 744 | "link": 92 745 | }, 746 | { 747 | "name": "positive", 748 | "type": "CONDITIONING", 749 | "link": 99 750 | }, 751 | { 752 | "name": "negative", 753 | "type": "CONDITIONING", 754 | "link": 100 755 | }, 756 | { 757 | "name": "latent_image", 758 | "type": "LATENT", 759 | "link": 93 760 | } 761 | ], 762 | "outputs": [ 763 | { 764 | "name": "LATENT", 765 | "type": "LATENT", 766 | "links": [ 767 | 94 768 | ], 769 | "shape": 3, 770 | "slot_index": 0 771 | } 772 | ], 773 | "properties": { 774 | "Node name for S&R": "KSampler" 775 | }, 776 | "widgets_values": [ 777 | 513654128154819, 778 | "randomize", 779 | 15, 780 | 7.5, 781 | "euler_ancestral", 782 | "karras", 783 | 1 784 | ] 785 | }, 786 | { 787 | "id": 40, 788 | "type": "DynamiCrafterProcessor", 789 | "pos": [ 790 | 3040, 791 | 420 792 | ], 793 | "size": { 794 | "0": 367.79998779296875, 795 | "1": 186 796 | }, 797 | "flags": {}, 798 | "order": 16, 799 | "mode": 0, 800 | "inputs": [ 801 | { 802 | "name": "model", 803 | "type": "MODEL", 804 | "link": 102 805 | }, 806 | { 807 | "name": "clip_vision", 808 | "type": "CLIP_VISION", 809 | "link": 88 810 | }, 811 | { 812 | "name": "vae", 813 | "type": "VAE", 814 | "link": 87 815 | }, 816 | { 817 | "name": "image_proj_model", 818 | "type": "IMAGE_PROJ_MODEL", 819 | "link": 89 820 | }, 821 | { 822 | "name": "images", 823 | "type": "IMAGE", 824 | "link": 85 825 | } 826 | ], 827 | "outputs": [ 828 | { 829 | "name": "model", 830 | "type": "MODEL", 831 | "links": [ 832 | 92 833 | ], 834 | "shape": 3, 835 | "slot_index": 0 836 | }, 837 | { 838 | "name": "empty_latent", 839 | "type": "LATENT", 840 | "links": [ 841 | 93 842 | ], 843 | "shape": 3, 844 | "slot_index": 1 845 | }, 846 | { 847 | "name": "latent_img", 848 | "type": "LATENT", 849 | "links": null, 850 | "shape": 3 851 | } 852 | ], 853 | "properties": { 854 | "Node name for S&R": "DynamiCrafterProcessor" 855 | }, 856 | "widgets_values": [ 857 | false, 858 | 12, 859 | 16 860 | ] 861 | } 862 | ], 863 | "links": [ 864 | [ 865 | 73, 866 | 33, 867 | 0, 868 | 32, 869 | 1, 870 | "CONDITIONING" 871 | ], 872 | [ 873 | 74, 874 | 34, 875 | 0, 876 | 32, 877 | 2, 878 | "CONDITIONING" 879 | ], 880 | [ 881 | 75, 882 | 35, 883 | 0, 884 | 32, 885 | 3, 886 | "LATENT" 887 | ], 888 | [ 889 | 76, 890 | 36, 891 | 0, 892 | 32, 893 | 0, 894 | "MODEL" 895 | ], 896 | [ 897 | 77, 898 | 36, 899 | 1, 900 | 33, 901 | 0, 902 | "CLIP" 903 | ], 904 | [ 905 | 78, 906 | 36, 907 | 1, 908 | 34, 909 | 0, 910 | "CLIP" 911 | ], 912 | [ 913 | 79, 914 | 32, 915 | 0, 916 | 37, 917 | 0, 918 | "LATENT" 919 | ], 920 | [ 921 | 80, 922 | 36, 923 | 2, 924 | 37, 925 | 1, 926 | "VAE" 927 | ], 928 | [ 929 | 81, 930 | 37, 931 | 0, 932 | 38, 933 | 0, 934 | "IMAGE" 935 | ], 936 | [ 937 | 84, 938 | 37, 939 | 0, 940 | 41, 941 | 0, 942 | "IMAGE" 943 | ], 944 | [ 945 | 85, 946 | 41, 947 | 0, 948 | 40, 949 | 4, 950 | "IMAGE" 951 | ], 952 | [ 953 | 86, 954 | 36, 955 | 2, 956 | 42, 957 | 0, 958 | "*" 959 | ], 960 | [ 961 | 87, 962 | 42, 963 | 0, 964 | 40, 965 | 2, 966 | "VAE" 967 | ], 968 | [ 969 | 88, 970 | 43, 971 | 0, 972 | 40, 973 | 1, 974 | "CLIP_VISION" 975 | ], 976 | [ 977 | 89, 978 | 39, 979 | 1, 980 | 40, 981 | 3, 982 | "IMAGE_PROJ_MODEL" 983 | ], 984 | [ 985 | 90, 986 | 45, 987 | 0, 988 | 44, 989 | 0, 990 | "CLIP" 991 | ], 992 | [ 993 | 92, 994 | 40, 995 | 0, 996 | 47, 997 | 0, 998 | "MODEL" 999 | ], 1000 | [ 1001 | 93, 1002 | 40, 1003 | 1, 1004 | 47, 1005 | 3, 1006 | "LATENT" 1007 | ], 1008 | [ 1009 | 94, 1010 | 47, 1011 | 0, 1012 | 48, 1013 | 0, 1014 | "LATENT" 1015 | ], 1016 | [ 1017 | 95, 1018 | 42, 1019 | 0, 1020 | 48, 1021 | 1, 1022 | "VAE" 1023 | ], 1024 | [ 1025 | 96, 1026 | 48, 1027 | 0, 1028 | 49, 1029 | 0, 1030 | "IMAGE" 1031 | ], 1032 | [ 1033 | 97, 1034 | 44, 1035 | 0, 1036 | 50, 1037 | 0, 1038 | "CLIP" 1039 | ], 1040 | [ 1041 | 98, 1042 | 44, 1043 | 0, 1044 | 51, 1045 | 0, 1046 | "CLIP" 1047 | ], 1048 | [ 1049 | 99, 1050 | 50, 1051 | 0, 1052 | 47, 1053 | 1, 1054 | "CONDITIONING" 1055 | ], 1056 | [ 1057 | 100, 1058 | 51, 1059 | 0, 1060 | 47, 1061 | 2, 1062 | "CONDITIONING" 1063 | ], 1064 | [ 1065 | 101, 1066 | 39, 1067 | 0, 1068 | 52, 1069 | 0, 1070 | "MODEL" 1071 | ], 1072 | [ 1073 | 102, 1074 | 52, 1075 | 0, 1076 | 40, 1077 | 0, 1078 | "MODEL" 1079 | ] 1080 | ], 1081 | "groups": [], 1082 | "config": {}, 1083 | "extra": {}, 1084 | "version": 0.4 1085 | } --------------------------------------------------------------------------------