├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── animatediff └── models │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── sparse_controlnet.py │ ├── unet.py │ └── unet_blocks.py ├── configs ├── ad_unet_config.yaml ├── text_encoder_config.json ├── tokenizer │ ├── config.json │ ├── merges.txt │ ├── preprocessor_config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── tokenizer_config.json └── v1-inference.yaml ├── examples ├── magictime_example.json └── magictime_example.mp4 ├── nodes.py ├── requirements.txt └── utils ├── dataset.py ├── pipeline_magictime.py ├── unet.py ├── unet_blocks.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoints/ 3 | *.py[cod] 4 | *$py.class 5 | *.egg-info 6 | .pytest_cache -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI wrapper node for testing MagicTime 2 | 3 | # UPDATE 4 | 5 | While making this I figured out that I could just extract the lora and apply it to the v3 motion model to use it as it is with any Animatediff-Evolved workflow, the merged v3 checkpoint along with the spatial lora converted to .safetensors, are available here: 6 | 7 | https://huggingface.co/Kijai/MagicTime-merged-fp16 8 | 9 | **This does NOT need this repo, I will not be updating this further.** 10 | 11 | ___ 12 | 13 | ## Only use this repo and the following instructions for legacy/testing purposes: 14 | 15 | https://github.com/kijai/ComfyUI-MagicTimeWrapper/assets/40791699/c71d271d-8219-456c-891d-da9bdbd44d54 16 | 17 | # Installing 18 | Either use the Manager and it's install from git -feature, or clone this repo to custom_nodes and run: 19 | 20 | `pip install -r requirements.txt` 21 | 22 | or if you use portable (run this in ComfyUI_windows_portable -folder): 23 | 24 | `python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-MagicTimeWrapper\requirements.txt` 25 | 26 | You can use any 1.5 model, and the v3 AnimateDiff motion model 27 | placed in `ComfyUI/models/animatediff_models`: 28 | 29 | https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt 30 | 31 | rest (131.0 MB) is **auto downloaded**, from https://huggingface.co/BestWishYsh/MagicTime/tree/main/Magic_Weights 32 | to `ComfyUI/modes/magictime` 33 | ___ 34 | # Original repo: 35 | https://github.com/PKU-YuanGroup/MagicTime 36 | 37 | 38 | ## 🐳 ChronoMagic Dataset 39 | ChronoMagic with 2265 metamorphic time-lapse videos, each accompanied by a detailed caption. We released the subset of ChronoMagic used to train MagicTime. The dataset can be downloaded at [Google Drive](https://drive.google.com/drive/folders/1WsomdkmSp3ql3ImcNsmzFuSQ9Qukuyr8?usp=sharing). Some samples can be found on our Project Page. 40 | 41 | 42 | ## 👍 Acknowledgement 43 | * [Animatediff](https://github.com/guoyww/AnimateDiff/tree/main) The codebase we built upon and it is a strong U-Net-based text-to-video generation model. 44 | 45 | * [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) The codebase we built upon and it is a simple and scalable DiT-based text-to-video generation repo, to reproduce [Sora](https://openai.com/sora). 46 | 47 | ## 🔒 License 48 | * The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/PKU-YuanGroup/MagicTime/blob/main/LICENSE) file. 49 | * The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violations. 50 | 51 | 52 | 53 | ## ✏️ Citation 54 | If you find our paper and code useful in your research, please consider giving a star :star: and citation :pencil:. 55 | 56 | ```BibTeX 57 | @misc{yuan2024magictime, 58 | title={MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators}, 59 | author={Shenghai Yuan and Jinfa Huang and Yujun Shi and Yongqi Xu and Ruijie Zhu and Bin Lin and Xinhua Cheng and Li Yuan and Jiebo Luo}, 60 | year={2024}, 61 | eprint={2404.05014}, 62 | archivePrefix={arXiv}, 63 | primaryClass={cs.CV} 64 | } 65 | ``` 66 | 67 | ## 🤝 Contributors 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /animatediff/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | import pdb 18 | 19 | @dataclass 20 | class Transformer3DModelOutput(BaseOutput): 21 | sample: torch.FloatTensor 22 | 23 | 24 | if is_xformers_available(): 25 | import xformers 26 | import xformers.ops 27 | else: 28 | xformers = None 29 | 30 | 31 | class Transformer3DModel(ModelMixin, ConfigMixin): 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | activation_fn: str = "geglu", 44 | num_embeds_ada_norm: Optional[int] = None, 45 | use_linear_projection: bool = False, 46 | only_cross_attention: bool = False, 47 | upcast_attention: bool = False, 48 | 49 | unet_use_cross_frame_attention=None, 50 | unet_use_temporal_attention=None, 51 | ): 52 | super().__init__() 53 | self.use_linear_projection = use_linear_projection 54 | self.num_attention_heads = num_attention_heads 55 | self.attention_head_dim = attention_head_dim 56 | inner_dim = num_attention_heads * attention_head_dim 57 | 58 | # Define input layers 59 | self.in_channels = in_channels 60 | 61 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 62 | if use_linear_projection: 63 | self.proj_in = nn.Linear(in_channels, inner_dim) 64 | else: 65 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 66 | 67 | # Define transformers blocks 68 | self.transformer_blocks = nn.ModuleList( 69 | [ 70 | BasicTransformerBlock( 71 | inner_dim, 72 | num_attention_heads, 73 | attention_head_dim, 74 | dropout=dropout, 75 | cross_attention_dim=cross_attention_dim, 76 | activation_fn=activation_fn, 77 | num_embeds_ada_norm=num_embeds_ada_norm, 78 | attention_bias=attention_bias, 79 | only_cross_attention=only_cross_attention, 80 | upcast_attention=upcast_attention, 81 | 82 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 83 | unet_use_temporal_attention=unet_use_temporal_attention, 84 | ) 85 | for d in range(num_layers) 86 | ] 87 | ) 88 | 89 | # 4. Define output layers 90 | if use_linear_projection: 91 | self.proj_out = nn.Linear(in_channels, inner_dim) 92 | else: 93 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 94 | 95 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 96 | # Input 97 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 98 | video_length = hidden_states.shape[2] 99 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 100 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 101 | 102 | batch, channel, height, weight = hidden_states.shape 103 | residual = hidden_states 104 | 105 | hidden_states = self.norm(hidden_states) 106 | if not self.use_linear_projection: 107 | hidden_states = self.proj_in(hidden_states) 108 | inner_dim = hidden_states.shape[1] 109 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 110 | else: 111 | inner_dim = hidden_states.shape[1] 112 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 113 | hidden_states = self.proj_in(hidden_states) 114 | 115 | # Blocks 116 | for block in self.transformer_blocks: 117 | hidden_states = block( 118 | hidden_states, 119 | encoder_hidden_states=encoder_hidden_states, 120 | timestep=timestep, 121 | video_length=video_length 122 | ) 123 | 124 | # Output 125 | if not self.use_linear_projection: 126 | hidden_states = ( 127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | ) 129 | hidden_states = self.proj_out(hidden_states) 130 | else: 131 | hidden_states = self.proj_out(hidden_states) 132 | hidden_states = ( 133 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 134 | ) 135 | 136 | output = hidden_states + residual 137 | 138 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 139 | if not return_dict: 140 | return (output,) 141 | 142 | return Transformer3DModelOutput(sample=output) 143 | 144 | 145 | class BasicTransformerBlock(nn.Module): 146 | def __init__( 147 | self, 148 | dim: int, 149 | num_attention_heads: int, 150 | attention_head_dim: int, 151 | dropout=0.0, 152 | cross_attention_dim: Optional[int] = None, 153 | activation_fn: str = "geglu", 154 | num_embeds_ada_norm: Optional[int] = None, 155 | attention_bias: bool = False, 156 | only_cross_attention: bool = False, 157 | upcast_attention: bool = False, 158 | 159 | unet_use_cross_frame_attention = None, 160 | unet_use_temporal_attention = None, 161 | ): 162 | super().__init__() 163 | self.only_cross_attention = only_cross_attention 164 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 165 | self.unet_use_cross_frame_attention = unet_use_cross_frame_attention 166 | self.unet_use_temporal_attention = unet_use_temporal_attention 167 | 168 | # SC-Attn 169 | assert unet_use_cross_frame_attention is not None 170 | if unet_use_cross_frame_attention: 171 | self.attn1 = SparseCausalAttention2D( 172 | query_dim=dim, 173 | heads=num_attention_heads, 174 | dim_head=attention_head_dim, 175 | dropout=dropout, 176 | bias=attention_bias, 177 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 178 | upcast_attention=upcast_attention, 179 | ) 180 | else: 181 | self.attn1 = Attention( 182 | query_dim=dim, 183 | heads=num_attention_heads, 184 | dim_head=attention_head_dim, 185 | dropout=dropout, 186 | bias=attention_bias, 187 | upcast_attention=upcast_attention, 188 | ) 189 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 190 | 191 | # Cross-Attn 192 | if cross_attention_dim is not None: 193 | self.attn2 = Attention( 194 | query_dim=dim, 195 | cross_attention_dim=cross_attention_dim, 196 | heads=num_attention_heads, 197 | dim_head=attention_head_dim, 198 | dropout=dropout, 199 | bias=attention_bias, 200 | upcast_attention=upcast_attention, 201 | ) 202 | else: 203 | self.attn2 = None 204 | 205 | if cross_attention_dim is not None: 206 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 207 | else: 208 | self.norm2 = None 209 | 210 | # Feed-forward 211 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 212 | self.norm3 = nn.LayerNorm(dim) 213 | 214 | # Temp-Attn 215 | assert unet_use_temporal_attention is not None 216 | if unet_use_temporal_attention: 217 | self.attn_temp = Attention( 218 | query_dim=dim, 219 | heads=num_attention_heads, 220 | dim_head=attention_head_dim, 221 | dropout=dropout, 222 | bias=attention_bias, 223 | upcast_attention=upcast_attention, 224 | ) 225 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 226 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 227 | 228 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs): 229 | if not is_xformers_available(): 230 | print("Here is how to install it") 231 | raise ModuleNotFoundError( 232 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 233 | " xformers", 234 | name="xformers", 235 | ) 236 | elif not torch.cuda.is_available(): 237 | raise ValueError( 238 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 239 | " available for GPU " 240 | ) 241 | else: 242 | try: 243 | # Make sure we can run the memory efficient attention 244 | _ = xformers.ops.memory_efficient_attention( 245 | torch.randn((1, 2, 40), device="cuda"), 246 | torch.randn((1, 2, 40), device="cuda"), 247 | torch.randn((1, 2, 40), device="cuda"), 248 | ) 249 | except Exception as e: 250 | raise e 251 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 252 | if self.attn2 is not None: 253 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 254 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 255 | 256 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 257 | # SparseCausal-Attention 258 | norm_hidden_states = ( 259 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 260 | ) 261 | 262 | # if self.only_cross_attention: 263 | # hidden_states = ( 264 | # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 265 | # ) 266 | # else: 267 | # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 268 | 269 | # pdb.set_trace() 270 | if self.unet_use_cross_frame_attention: 271 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states 272 | else: 273 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states 274 | 275 | if self.attn2 is not None: 276 | # Cross-Attention 277 | norm_hidden_states = ( 278 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 279 | ) 280 | hidden_states = ( 281 | self.attn2( 282 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 283 | ) 284 | + hidden_states 285 | ) 286 | 287 | # Feed-forward 288 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 289 | 290 | # Temporal-Attention 291 | if self.unet_use_temporal_attention: 292 | d = hidden_states.shape[1] 293 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 294 | norm_hidden_states = ( 295 | self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 296 | ) 297 | hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 298 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 299 | 300 | return hidden_states 301 | -------------------------------------------------------------------------------- /animatediff/models/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import torchvision 9 | import diffusers 10 | from packaging import version 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers import ModelMixin 14 | from diffusers.utils import BaseOutput 15 | from diffusers.utils.import_utils import is_xformers_available 16 | from diffusers.models.attention import Attention, FeedForward 17 | 18 | from einops import rearrange, repeat 19 | import math 20 | 21 | 22 | def zero_module(module): 23 | # Zero out the parameters of a module and return it. 24 | for p in module.parameters(): 25 | p.detach().zero_() 26 | return module 27 | 28 | 29 | @dataclass 30 | class TemporalTransformer3DModelOutput(BaseOutput): 31 | sample: torch.FloatTensor 32 | 33 | 34 | if is_xformers_available(): 35 | import xformers 36 | import xformers.ops 37 | else: 38 | xformers = None 39 | 40 | 41 | def get_motion_module( 42 | in_channels, 43 | motion_module_type: str, 44 | motion_module_kwargs: dict 45 | ): 46 | if motion_module_type == "Vanilla": 47 | return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) 48 | else: 49 | raise ValueError 50 | 51 | 52 | class VanillaTemporalModule(nn.Module): 53 | def __init__( 54 | self, 55 | in_channels, 56 | num_attention_heads = 8, 57 | num_transformer_block = 2, 58 | attention_block_types =( "Temporal_Self", "Temporal_Self" ), 59 | cross_frame_attention_mode = None, 60 | temporal_position_encoding = False, 61 | temporal_position_encoding_max_len = 24, 62 | temporal_attention_dim_div = 1, 63 | zero_initialize = True, 64 | ): 65 | super().__init__() 66 | 67 | self.temporal_transformer = TemporalTransformer3DModel( 68 | in_channels=in_channels, 69 | num_attention_heads=num_attention_heads, 70 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, 71 | num_layers=num_transformer_block, 72 | attention_block_types=attention_block_types, 73 | cross_frame_attention_mode=cross_frame_attention_mode, 74 | temporal_position_encoding=temporal_position_encoding, 75 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 76 | ) 77 | 78 | if zero_initialize: 79 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) 80 | 81 | def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): 82 | video_length = input_tensor.shape[2] 83 | 84 | if video_length > 1: 85 | hidden_states = input_tensor 86 | hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) 87 | output = hidden_states 88 | else: 89 | output = input_tensor 90 | 91 | return output 92 | 93 | 94 | class TemporalTransformer3DModel(nn.Module): 95 | def __init__( 96 | self, 97 | in_channels, 98 | num_attention_heads, 99 | attention_head_dim, 100 | 101 | num_layers, 102 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 103 | dropout = 0.0, 104 | norm_num_groups = 32, 105 | cross_attention_dim = 768, 106 | activation_fn = "geglu", 107 | attention_bias = False, 108 | upcast_attention = False, 109 | 110 | cross_frame_attention_mode = None, 111 | temporal_position_encoding = False, 112 | temporal_position_encoding_max_len = 24, 113 | ): 114 | super().__init__() 115 | 116 | inner_dim = num_attention_heads * attention_head_dim 117 | 118 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 119 | self.proj_in = nn.Linear(in_channels, inner_dim) 120 | 121 | self.transformer_blocks = nn.ModuleList( 122 | [ 123 | TemporalTransformerBlock( 124 | dim=inner_dim, 125 | num_attention_heads=num_attention_heads, 126 | attention_head_dim=attention_head_dim, 127 | attention_block_types=attention_block_types, 128 | dropout=dropout, 129 | norm_num_groups=norm_num_groups, 130 | cross_attention_dim=cross_attention_dim, 131 | activation_fn=activation_fn, 132 | attention_bias=attention_bias, 133 | upcast_attention=upcast_attention, 134 | cross_frame_attention_mode=cross_frame_attention_mode, 135 | temporal_position_encoding=temporal_position_encoding, 136 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 137 | ) 138 | for d in range(num_layers) 139 | ] 140 | ) 141 | self.proj_out = nn.Linear(inner_dim, in_channels) 142 | 143 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): 144 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 145 | video_length = hidden_states.shape[2] 146 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 147 | 148 | batch, channel, height, weight = hidden_states.shape 149 | residual = hidden_states 150 | 151 | hidden_states = self.norm(hidden_states) 152 | inner_dim = hidden_states.shape[1] 153 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 154 | hidden_states = self.proj_in(hidden_states) 155 | 156 | # Transformer Blocks 157 | for block in self.transformer_blocks: 158 | hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) 159 | 160 | # output 161 | hidden_states = self.proj_out(hidden_states) 162 | hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 163 | 164 | output = hidden_states + residual 165 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 166 | 167 | return output 168 | 169 | 170 | class TemporalTransformerBlock(nn.Module): 171 | def __init__( 172 | self, 173 | dim, 174 | num_attention_heads, 175 | attention_head_dim, 176 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ), 177 | dropout = 0.0, 178 | norm_num_groups = 32, 179 | cross_attention_dim = 768, 180 | activation_fn = "geglu", 181 | attention_bias = False, 182 | upcast_attention = False, 183 | cross_frame_attention_mode = None, 184 | temporal_position_encoding = False, 185 | temporal_position_encoding_max_len = 24, 186 | ): 187 | super().__init__() 188 | 189 | attention_blocks = [] 190 | norms = [] 191 | 192 | for block_name in attention_block_types: 193 | attention_blocks.append( 194 | VersatileAttention( 195 | attention_mode=block_name.split("_")[0], 196 | cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, 197 | 198 | query_dim=dim, 199 | heads=num_attention_heads, 200 | dim_head=attention_head_dim, 201 | dropout=dropout, 202 | bias=attention_bias, 203 | upcast_attention=upcast_attention, 204 | 205 | cross_frame_attention_mode=cross_frame_attention_mode, 206 | temporal_position_encoding=temporal_position_encoding, 207 | temporal_position_encoding_max_len=temporal_position_encoding_max_len, 208 | ) 209 | ) 210 | norms.append(nn.LayerNorm(dim)) 211 | 212 | self.attention_blocks = nn.ModuleList(attention_blocks) 213 | self.norms = nn.ModuleList(norms) 214 | 215 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 216 | self.ff_norm = nn.LayerNorm(dim) 217 | 218 | 219 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 220 | for attention_block, norm in zip(self.attention_blocks, self.norms): 221 | norm_hidden_states = norm(hidden_states) 222 | hidden_states = attention_block( 223 | norm_hidden_states, 224 | encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, 225 | video_length=video_length, 226 | ) + hidden_states 227 | 228 | hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states 229 | 230 | output = hidden_states 231 | return output 232 | 233 | 234 | class PositionalEncoding(nn.Module): 235 | def __init__( 236 | self, 237 | d_model, 238 | dropout = 0., 239 | max_len = 24 240 | ): 241 | super().__init__() 242 | self.dropout = nn.Dropout(p=dropout) 243 | position = torch.arange(max_len).unsqueeze(1) 244 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 245 | pe = torch.zeros(1, max_len, d_model) 246 | pe[0, :, 0::2] = torch.sin(position * div_term) 247 | pe[0, :, 1::2] = torch.cos(position * div_term) 248 | self.register_buffer('pe', pe) 249 | 250 | def forward(self, x): 251 | x = x + self.pe[:, :x.size(1)] 252 | return self.dropout(x) 253 | 254 | 255 | class VersatileAttention(Attention): 256 | def __init__( 257 | self, 258 | attention_mode = None, 259 | cross_frame_attention_mode = None, 260 | temporal_position_encoding = False, 261 | temporal_position_encoding_max_len = 24, 262 | *args, **kwargs 263 | ): 264 | super().__init__(*args, **kwargs) 265 | assert attention_mode == "Temporal" 266 | 267 | self.attention_mode = attention_mode 268 | self.is_cross_attention = kwargs["cross_attention_dim"] is not None 269 | 270 | self.pos_encoder = PositionalEncoding( 271 | kwargs["query_dim"], 272 | dropout=0., 273 | max_len=temporal_position_encoding_max_len 274 | ) if (temporal_position_encoding and attention_mode == "Temporal") else None 275 | 276 | def extra_repr(self): 277 | return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" 278 | 279 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 280 | batch_size, sequence_length, _ = hidden_states.shape 281 | 282 | if self.attention_mode == "Temporal": 283 | d = hidden_states.shape[1] 284 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 285 | 286 | if self.pos_encoder is not None: 287 | hidden_states = self.pos_encoder(hidden_states) 288 | 289 | encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states 290 | else: 291 | raise NotImplementedError 292 | 293 | encoder_hidden_states = encoder_hidden_states 294 | 295 | if version.parse(diffusers.__version__) > version.parse("0.11.1"): 296 | hidden_states = self.processor(self, hidden_states, encoder_hidden_states) 297 | else: 298 | if self.group_norm is not None: 299 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 300 | 301 | query = self.to_q(hidden_states) 302 | dim = query.shape[-1] 303 | query = self.head_to_batch_dim(query) 304 | 305 | if self.added_kv_proj_dim is not None: 306 | raise NotImplementedError 307 | 308 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 309 | key = self.to_k(encoder_hidden_states) 310 | value = self.to_v(encoder_hidden_states) 311 | 312 | key = self.head_to_batch_dim(key) 313 | value = self.head_to_batch_dim(value) 314 | 315 | if attention_mask is not None: 316 | if attention_mask.shape[-1] != query.shape[1]: 317 | target_length = query.shape[1] 318 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 319 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 320 | 321 | # attention, what we cannot get enough of 322 | 323 | if self._use_memory_efficient_attention_xformers: 324 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 325 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 326 | hidden_states = hidden_states.to(query.dtype) 327 | else: 328 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 329 | hidden_states = self._attention(query, key, value, attention_mask) 330 | else: 331 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 332 | else: 333 | #if "xformers" in self.processor.__class__.__name__.lower(): 334 | # hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attention_mask) 335 | # # Some versions of xformers return output in fp32, cast it back to the dtype of the input 336 | # hidden_states = hidden_states.to(query.dtype) 337 | #else: 338 | hidden_states = F.scaled_dot_product_attention( 339 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 340 | ) 341 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 342 | hidden_states = hidden_states.to(query.dtype) 343 | 344 | # linear proj 345 | hidden_states = self.to_out[0](hidden_states) 346 | 347 | # dropout 348 | hidden_states = self.to_out[1](hidden_states) 349 | 350 | if self.attention_mode == "Temporal": 351 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 352 | 353 | return hidden_states 354 | -------------------------------------------------------------------------------- /animatediff/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 34 | super().__init__() 35 | self.channels = channels 36 | self.out_channels = out_channels or channels 37 | self.use_conv = use_conv 38 | self.use_conv_transpose = use_conv_transpose 39 | self.name = name 40 | 41 | conv = None 42 | if use_conv_transpose: 43 | raise NotImplementedError 44 | elif use_conv: 45 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 66 | else: 67 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 68 | 69 | # If the input is bfloat16, we cast back to bfloat16 70 | if dtype == torch.bfloat16: 71 | hidden_states = hidden_states.to(dtype) 72 | 73 | # if self.use_conv: 74 | # if self.name == "conv": 75 | # hidden_states = self.conv(hidden_states) 76 | # else: 77 | # hidden_states = self.Conv2d_0(hidden_states) 78 | hidden_states = self.conv(hidden_states) 79 | 80 | return hidden_states 81 | 82 | 83 | class Downsample3D(nn.Module): 84 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 85 | super().__init__() 86 | self.channels = channels 87 | self.out_channels = out_channels or channels 88 | self.use_conv = use_conv 89 | self.padding = padding 90 | stride = 2 91 | self.name = name 92 | 93 | if use_conv: 94 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 95 | else: 96 | raise NotImplementedError 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | use_inflated_groupnorm=False, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | assert use_inflated_groupnorm != None 142 | if use_inflated_groupnorm: 143 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | else: 145 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 146 | 147 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 148 | 149 | if temb_channels is not None: 150 | if self.time_embedding_norm == "default": 151 | time_emb_proj_out_channels = out_channels 152 | elif self.time_embedding_norm == "scale_shift": 153 | time_emb_proj_out_channels = out_channels * 2 154 | else: 155 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 156 | 157 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 158 | else: 159 | self.time_emb_proj = None 160 | 161 | if use_inflated_groupnorm: 162 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | else: 164 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 165 | 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | def forward(self, input_tensor, temb): 183 | hidden_states = input_tensor 184 | 185 | hidden_states = self.norm1(hidden_states) 186 | hidden_states = self.nonlinearity(hidden_states) 187 | 188 | hidden_states = self.conv1(hidden_states) 189 | 190 | if temb is not None: 191 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 192 | 193 | if temb is not None and self.time_embedding_norm == "default": 194 | hidden_states = hidden_states + temb 195 | 196 | hidden_states = self.norm2(hidden_states) 197 | 198 | if temb is not None and self.time_embedding_norm == "scale_shift": 199 | scale, shift = torch.chunk(temb, 2, dim=1) 200 | hidden_states = hidden_states * (1 + scale) + shift 201 | 202 | hidden_states = self.nonlinearity(hidden_states) 203 | 204 | hidden_states = self.dropout(hidden_states) 205 | hidden_states = self.conv2(hidden_states) 206 | 207 | if self.conv_shortcut is not None: 208 | input_tensor = self.conv_shortcut(input_tensor) 209 | 210 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 211 | 212 | return output_tensor 213 | 214 | 215 | class Mish(torch.nn.Module): 216 | def forward(self, hidden_states): 217 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /animatediff/models/sparse_controlnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Changes were made to this source code by Yuwei Guo. 16 | from dataclasses import dataclass 17 | from typing import Any, Dict, List, Optional, Tuple, Union 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import functional as F 22 | 23 | from diffusers.configuration_utils import ConfigMixin, register_to_config 24 | from diffusers.utils import BaseOutput, logging 25 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 26 | from diffusers import ModelMixin 27 | 28 | 29 | from .unet_blocks import ( 30 | CrossAttnDownBlock3D, 31 | DownBlock3D, 32 | UNetMidBlock3DCrossAttn, 33 | get_down_block, 34 | ) 35 | from einops import repeat, rearrange 36 | from .resnet import InflatedConv3d 37 | 38 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @dataclass 44 | class SparseControlNetOutput(BaseOutput): 45 | down_block_res_samples: Tuple[torch.Tensor] 46 | mid_block_res_sample: torch.Tensor 47 | 48 | 49 | class SparseControlNetConditioningEmbedding(nn.Module): 50 | def __init__( 51 | self, 52 | conditioning_embedding_channels: int, 53 | conditioning_channels: int = 3, 54 | block_out_channels: Tuple[int] = (16, 32, 96, 256), 55 | ): 56 | super().__init__() 57 | 58 | self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 59 | 60 | self.blocks = nn.ModuleList([]) 61 | 62 | for i in range(len(block_out_channels) - 1): 63 | channel_in = block_out_channels[i] 64 | channel_out = block_out_channels[i + 1] 65 | self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)) 66 | self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 67 | 68 | self.conv_out = zero_module( 69 | InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 70 | ) 71 | 72 | def forward(self, conditioning): 73 | embedding = self.conv_in(conditioning) 74 | embedding = F.silu(embedding) 75 | 76 | for block in self.blocks: 77 | embedding = block(embedding) 78 | embedding = F.silu(embedding) 79 | 80 | embedding = self.conv_out(embedding) 81 | 82 | return embedding 83 | 84 | 85 | class SparseControlNetModel(ModelMixin, ConfigMixin): 86 | _supports_gradient_checkpointing = True 87 | 88 | @register_to_config 89 | def __init__( 90 | self, 91 | in_channels: int = 4, 92 | conditioning_channels: int = 3, 93 | flip_sin_to_cos: bool = True, 94 | freq_shift: int = 0, 95 | down_block_types: Tuple[str] = ( 96 | "CrossAttnDownBlock2D", 97 | "CrossAttnDownBlock2D", 98 | "CrossAttnDownBlock2D", 99 | "DownBlock2D", 100 | ), 101 | only_cross_attention: Union[bool, Tuple[bool]] = False, 102 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 103 | layers_per_block: int = 2, 104 | downsample_padding: int = 1, 105 | mid_block_scale_factor: float = 1, 106 | act_fn: str = "silu", 107 | norm_num_groups: Optional[int] = 32, 108 | norm_eps: float = 1e-5, 109 | cross_attention_dim: int = 1280, 110 | attention_head_dim: Union[int, Tuple[int]] = 8, 111 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 112 | use_linear_projection: bool = False, 113 | class_embed_type: Optional[str] = None, 114 | num_class_embeds: Optional[int] = None, 115 | upcast_attention: bool = False, 116 | resnet_time_scale_shift: str = "default", 117 | projection_class_embeddings_input_dim: Optional[int] = None, 118 | controlnet_conditioning_channel_order: str = "rgb", 119 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 120 | global_pool_conditions: bool = False, 121 | 122 | use_motion_module = True, 123 | motion_module_resolutions = ( 1,2,4,8 ), 124 | motion_module_mid_block = False, 125 | motion_module_type = "Vanilla", 126 | motion_module_kwargs = { 127 | "num_attention_heads": 8, 128 | "num_transformer_block": 1, 129 | "attention_block_types": ["Temporal_Self"], 130 | "temporal_position_encoding": True, 131 | "temporal_position_encoding_max_len": 32, 132 | "temporal_attention_dim_div": 1, 133 | "causal_temporal_attention": False, 134 | }, 135 | 136 | concate_conditioning_mask: bool = True, 137 | use_simplified_condition_embedding: bool = False, 138 | 139 | set_noisy_sample_input_to_zero: bool = False, 140 | ): 141 | super().__init__() 142 | 143 | # If `num_attention_heads` is not defined (which is the case for most models) 144 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 145 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 146 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 147 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 148 | # which is why we correct for the naming here. 149 | num_attention_heads = num_attention_heads or attention_head_dim 150 | 151 | # Check inputs 152 | if len(block_out_channels) != len(down_block_types): 153 | raise ValueError( 154 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 155 | ) 156 | 157 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 158 | raise ValueError( 159 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 160 | ) 161 | 162 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 163 | raise ValueError( 164 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 165 | ) 166 | 167 | # input 168 | self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero 169 | 170 | conv_in_kernel = 3 171 | conv_in_padding = (conv_in_kernel - 1) // 2 172 | self.conv_in = InflatedConv3d( 173 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 174 | ) 175 | 176 | if concate_conditioning_mask: 177 | conditioning_channels = conditioning_channels + 1 178 | self.concate_conditioning_mask = concate_conditioning_mask 179 | 180 | # control net conditioning embedding 181 | if use_simplified_condition_embedding: 182 | self.controlnet_cond_embedding = zero_module( 183 | InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) 184 | ) 185 | else: 186 | self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( 187 | conditioning_embedding_channels=block_out_channels[0], 188 | block_out_channels=conditioning_embedding_out_channels, 189 | conditioning_channels=conditioning_channels, 190 | ) 191 | self.use_simplified_condition_embedding = use_simplified_condition_embedding 192 | 193 | # time 194 | time_embed_dim = block_out_channels[0] * 4 195 | 196 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 197 | timestep_input_dim = block_out_channels[0] 198 | 199 | self.time_embedding = TimestepEmbedding( 200 | timestep_input_dim, 201 | time_embed_dim, 202 | act_fn=act_fn, 203 | ) 204 | 205 | # class embedding 206 | if class_embed_type is None and num_class_embeds is not None: 207 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 208 | elif class_embed_type == "timestep": 209 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 210 | elif class_embed_type == "identity": 211 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 212 | elif class_embed_type == "projection": 213 | if projection_class_embeddings_input_dim is None: 214 | raise ValueError( 215 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 216 | ) 217 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 218 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 219 | # 2. it projects from an arbitrary input dimension. 220 | # 221 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 222 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 223 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 224 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 225 | else: 226 | self.class_embedding = None 227 | 228 | 229 | self.down_blocks = nn.ModuleList([]) 230 | self.controlnet_down_blocks = nn.ModuleList([]) 231 | 232 | if isinstance(only_cross_attention, bool): 233 | only_cross_attention = [only_cross_attention] * len(down_block_types) 234 | 235 | if isinstance(attention_head_dim, int): 236 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 237 | 238 | if isinstance(num_attention_heads, int): 239 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 240 | 241 | # down 242 | output_channel = block_out_channels[0] 243 | 244 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 245 | controlnet_block = zero_module(controlnet_block) 246 | self.controlnet_down_blocks.append(controlnet_block) 247 | 248 | for i, down_block_type in enumerate(down_block_types): 249 | res = 2 ** i 250 | input_channel = output_channel 251 | output_channel = block_out_channels[i] 252 | is_final_block = i == len(block_out_channels) - 1 253 | 254 | down_block = get_down_block( 255 | down_block_type, 256 | num_layers=layers_per_block, 257 | in_channels=input_channel, 258 | out_channels=output_channel, 259 | temb_channels=time_embed_dim, 260 | add_downsample=not is_final_block, 261 | resnet_eps=norm_eps, 262 | resnet_act_fn=act_fn, 263 | resnet_groups=norm_num_groups, 264 | cross_attention_dim=cross_attention_dim, 265 | attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 266 | downsample_padding=downsample_padding, 267 | use_linear_projection=use_linear_projection, 268 | only_cross_attention=only_cross_attention[i], 269 | upcast_attention=upcast_attention, 270 | resnet_time_scale_shift=resnet_time_scale_shift, 271 | 272 | use_inflated_groupnorm=True, 273 | 274 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 275 | motion_module_type=motion_module_type, 276 | motion_module_kwargs=motion_module_kwargs, 277 | ) 278 | self.down_blocks.append(down_block) 279 | 280 | for _ in range(layers_per_block): 281 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 282 | controlnet_block = zero_module(controlnet_block) 283 | self.controlnet_down_blocks.append(controlnet_block) 284 | 285 | if not is_final_block: 286 | controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) 287 | controlnet_block = zero_module(controlnet_block) 288 | self.controlnet_down_blocks.append(controlnet_block) 289 | 290 | # mid 291 | mid_block_channel = block_out_channels[-1] 292 | 293 | controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1) 294 | controlnet_block = zero_module(controlnet_block) 295 | self.controlnet_mid_block = controlnet_block 296 | 297 | self.mid_block = UNetMidBlock3DCrossAttn( 298 | in_channels=mid_block_channel, 299 | temb_channels=time_embed_dim, 300 | resnet_eps=norm_eps, 301 | resnet_act_fn=act_fn, 302 | output_scale_factor=mid_block_scale_factor, 303 | resnet_time_scale_shift=resnet_time_scale_shift, 304 | cross_attention_dim=cross_attention_dim, 305 | attn_num_head_channels=num_attention_heads[-1], 306 | resnet_groups=norm_num_groups, 307 | use_linear_projection=use_linear_projection, 308 | upcast_attention=upcast_attention, 309 | 310 | use_inflated_groupnorm=True, 311 | use_motion_module=use_motion_module and motion_module_mid_block, 312 | motion_module_type=motion_module_type, 313 | motion_module_kwargs=motion_module_kwargs, 314 | ) 315 | 316 | @classmethod 317 | def from_unet( 318 | cls, 319 | unet: UNet2DConditionModel, 320 | controlnet_conditioning_channel_order: str = "rgb", 321 | conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), 322 | load_weights_from_unet: bool = True, 323 | 324 | controlnet_additional_kwargs: dict = {}, 325 | ): 326 | controlnet = cls( 327 | in_channels=unet.config.in_channels, 328 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 329 | freq_shift=unet.config.freq_shift, 330 | down_block_types=unet.config.down_block_types, 331 | only_cross_attention=unet.config.only_cross_attention, 332 | block_out_channels=unet.config.block_out_channels, 333 | layers_per_block=unet.config.layers_per_block, 334 | downsample_padding=unet.config.downsample_padding, 335 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 336 | act_fn=unet.config.act_fn, 337 | norm_num_groups=unet.config.norm_num_groups, 338 | norm_eps=unet.config.norm_eps, 339 | cross_attention_dim=unet.config.cross_attention_dim, 340 | attention_head_dim=unet.config.attention_head_dim, 341 | num_attention_heads=unet.config.num_attention_heads, 342 | use_linear_projection=unet.config.use_linear_projection, 343 | class_embed_type=unet.config.class_embed_type, 344 | num_class_embeds=unet.config.num_class_embeds, 345 | upcast_attention=unet.config.upcast_attention, 346 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 347 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 348 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 349 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 350 | 351 | **controlnet_additional_kwargs, 352 | ) 353 | 354 | if load_weights_from_unet: 355 | m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False) 356 | assert len(u) == 0 357 | m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False) 358 | assert len(u) == 0 359 | m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False) 360 | assert len(u) == 0 361 | 362 | if controlnet.class_embedding: 363 | m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False) 364 | assert len(u) == 0 365 | m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False) 366 | assert len(u) == 0 367 | m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False) 368 | assert len(u) == 0 369 | 370 | return controlnet 371 | 372 | @staticmethod 373 | def image_layer_filter(state_dict): 374 | new_state_dict = {} 375 | for name, param in state_dict.items(): 376 | if "motion_modules." in name or "lora" in name: continue 377 | new_state_dict[name] = param 378 | return new_state_dict 379 | 380 | # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice 381 | def set_attention_slice(self, slice_size): 382 | r""" 383 | Enable sliced attention computation. 384 | 385 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 386 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 387 | 388 | Args: 389 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 390 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 391 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 392 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 393 | must be a multiple of `slice_size`. 394 | """ 395 | sliceable_head_dims = [] 396 | 397 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 398 | if hasattr(module, "set_attention_slice"): 399 | sliceable_head_dims.append(module.sliceable_head_dim) 400 | 401 | for child in module.children(): 402 | fn_recursive_retrieve_sliceable_dims(child) 403 | 404 | # retrieve number of attention layers 405 | for module in self.children(): 406 | fn_recursive_retrieve_sliceable_dims(module) 407 | 408 | num_sliceable_layers = len(sliceable_head_dims) 409 | 410 | if slice_size == "auto": 411 | # half the attention head size is usually a good trade-off between 412 | # speed and memory 413 | slice_size = [dim // 2 for dim in sliceable_head_dims] 414 | elif slice_size == "max": 415 | # make smallest slice possible 416 | slice_size = num_sliceable_layers * [1] 417 | 418 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 419 | 420 | if len(slice_size) != len(sliceable_head_dims): 421 | raise ValueError( 422 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 423 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 424 | ) 425 | 426 | for i in range(len(slice_size)): 427 | size = slice_size[i] 428 | dim = sliceable_head_dims[i] 429 | if size is not None and size > dim: 430 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 431 | 432 | # Recursively walk through all the children. 433 | # Any children which exposes the set_attention_slice method 434 | # gets the message 435 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 436 | if hasattr(module, "set_attention_slice"): 437 | module.set_attention_slice(slice_size.pop()) 438 | 439 | for child in module.children(): 440 | fn_recursive_set_attention_slice(child, slice_size) 441 | 442 | reversed_slice_size = list(reversed(slice_size)) 443 | for module in self.children(): 444 | fn_recursive_set_attention_slice(module, reversed_slice_size) 445 | 446 | def _set_gradient_checkpointing(self, module, value=False): 447 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 448 | module.gradient_checkpointing = value 449 | 450 | def forward( 451 | self, 452 | sample: torch.FloatTensor, 453 | timestep: Union[torch.Tensor, float, int], 454 | encoder_hidden_states: torch.Tensor, 455 | 456 | controlnet_cond: torch.FloatTensor, 457 | conditioning_mask: Optional[torch.FloatTensor] = None, 458 | 459 | conditioning_scale: float = 1.0, 460 | class_labels: Optional[torch.Tensor] = None, 461 | attention_mask: Optional[torch.Tensor] = None, 462 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 463 | guess_mode: bool = False, 464 | return_dict: bool = True, 465 | ) -> Union[SparseControlNetOutput, Tuple]: 466 | 467 | # set input noise to zero 468 | if self.set_noisy_sample_input_to_zero: 469 | sample = torch.zeros_like(sample).to(sample.device) 470 | 471 | # prepare attention_mask 472 | if attention_mask is not None: 473 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 474 | attention_mask = attention_mask.unsqueeze(1) 475 | 476 | # 1. time 477 | timesteps = timestep 478 | if not torch.is_tensor(timesteps): 479 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 480 | # This would be a good case for the `match` statement (Python 3.10+) 481 | is_mps = sample.device.type == "mps" 482 | if isinstance(timestep, float): 483 | dtype = torch.float32 if is_mps else torch.float64 484 | else: 485 | dtype = torch.int32 if is_mps else torch.int64 486 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 487 | elif len(timesteps.shape) == 0: 488 | timesteps = timesteps[None].to(sample.device) 489 | 490 | timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) 491 | encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1) 492 | 493 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 494 | timesteps = timesteps.expand(sample.shape[0]) 495 | 496 | t_emb = self.time_proj(timesteps) 497 | 498 | # timesteps does not contain any weights and will always return f32 tensors 499 | # but time_embedding might actually be running in fp16. so we need to cast here. 500 | # there might be better ways to encapsulate this. 501 | t_emb = t_emb.to(dtype=self.dtype) 502 | emb = self.time_embedding(t_emb) 503 | 504 | if self.class_embedding is not None: 505 | if class_labels is None: 506 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 507 | 508 | if self.config.class_embed_type == "timestep": 509 | class_labels = self.time_proj(class_labels) 510 | 511 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 512 | emb = emb + class_emb 513 | 514 | # 2. pre-process 515 | sample = self.conv_in(sample) 516 | 517 | if self.concate_conditioning_mask: 518 | controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) 519 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 520 | 521 | sample = sample + controlnet_cond 522 | 523 | # 3. down 524 | down_block_res_samples = (sample,) 525 | for downsample_block in self.down_blocks: 526 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 527 | sample, res_samples = downsample_block( 528 | hidden_states=sample, 529 | temb=emb, 530 | encoder_hidden_states=encoder_hidden_states, 531 | attention_mask=attention_mask, 532 | # cross_attention_kwargs=cross_attention_kwargs, 533 | ) 534 | else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 535 | 536 | down_block_res_samples += res_samples 537 | 538 | # 4. mid 539 | if self.mid_block is not None: 540 | sample = self.mid_block( 541 | sample, 542 | emb, 543 | encoder_hidden_states=encoder_hidden_states, 544 | attention_mask=attention_mask, 545 | # cross_attention_kwargs=cross_attention_kwargs, 546 | ) 547 | 548 | # 5. controlnet blocks 549 | controlnet_down_block_res_samples = () 550 | 551 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 552 | down_block_res_sample = controlnet_block(down_block_res_sample) 553 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 554 | 555 | down_block_res_samples = controlnet_down_block_res_samples 556 | 557 | mid_block_res_sample = self.controlnet_mid_block(sample) 558 | 559 | # 6. scaling 560 | if guess_mode and not self.config.global_pool_conditions: 561 | scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 562 | 563 | scales = scales * conditioning_scale 564 | down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] 565 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 566 | else: 567 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 568 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 569 | 570 | if self.config.global_pool_conditions: 571 | down_block_res_samples = [ 572 | torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples 573 | ] 574 | mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) 575 | 576 | if not return_dict: 577 | return (down_block_res_samples, mid_block_res_sample) 578 | 579 | return SparseControlNetOutput( 580 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 581 | ) 582 | 583 | 584 | def zero_module(module): 585 | for p in module.parameters(): 586 | nn.init.zeros_(p) 587 | return module 588 | -------------------------------------------------------------------------------- /animatediff/models/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import os 7 | import json 8 | import pdb 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.checkpoint 13 | 14 | from diffusers.configuration_utils import ConfigMixin, register_to_config 15 | from diffusers import ModelMixin 16 | from diffusers.utils import BaseOutput, logging 17 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 18 | from .unet_blocks import ( 19 | CrossAttnDownBlock3D, 20 | CrossAttnUpBlock3D, 21 | DownBlock3D, 22 | UNetMidBlock3DCrossAttn, 23 | UpBlock3D, 24 | get_down_block, 25 | get_up_block, 26 | ) 27 | from .resnet import InflatedConv3d, InflatedGroupNorm 28 | 29 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 31 | 32 | 33 | @dataclass 34 | class UNet3DConditionOutput(BaseOutput): 35 | sample: torch.FloatTensor 36 | 37 | 38 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 39 | _supports_gradient_checkpointing = True 40 | 41 | @register_to_config 42 | def __init__( 43 | self, 44 | sample_size: Optional[int] = None, 45 | in_channels: int = 4, 46 | out_channels: int = 4, 47 | center_input_sample: bool = False, 48 | flip_sin_to_cos: bool = True, 49 | freq_shift: int = 0, 50 | down_block_types: Tuple[str] = ( 51 | "CrossAttnDownBlock3D", 52 | "CrossAttnDownBlock3D", 53 | "CrossAttnDownBlock3D", 54 | "DownBlock3D", 55 | ), 56 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 57 | up_block_types: Tuple[str] = ( 58 | "UpBlock3D", 59 | "CrossAttnUpBlock3D", 60 | "CrossAttnUpBlock3D", 61 | "CrossAttnUpBlock3D" 62 | ), 63 | only_cross_attention: Union[bool, Tuple[bool]] = False, 64 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 65 | layers_per_block: int = 2, 66 | downsample_padding: int = 1, 67 | mid_block_scale_factor: float = 1, 68 | act_fn: str = "silu", 69 | norm_num_groups: int = 32, 70 | norm_eps: float = 1e-5, 71 | cross_attention_dim: int = 1280, 72 | attention_head_dim: Union[int, Tuple[int]] = 8, 73 | dual_cross_attention: bool = False, 74 | use_linear_projection: bool = False, 75 | class_embed_type: Optional[str] = None, 76 | num_class_embeds: Optional[int] = None, 77 | upcast_attention: bool = False, 78 | resnet_time_scale_shift: str = "default", 79 | 80 | use_inflated_groupnorm=False, 81 | 82 | # Additional 83 | use_motion_module = False, 84 | motion_module_resolutions = ( 1,2,4,8 ), 85 | motion_module_mid_block = False, 86 | motion_module_decoder_only = False, 87 | motion_module_type = None, 88 | motion_module_kwargs = {}, 89 | unet_use_cross_frame_attention = False, 90 | unet_use_temporal_attention = False, 91 | ): 92 | super().__init__() 93 | 94 | self.sample_size = sample_size 95 | time_embed_dim = block_out_channels[0] * 4 96 | 97 | # input 98 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 99 | 100 | # time 101 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 102 | timestep_input_dim = block_out_channels[0] 103 | 104 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 105 | 106 | # class embedding 107 | if class_embed_type is None and num_class_embeds is not None: 108 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 109 | elif class_embed_type == "timestep": 110 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 111 | elif class_embed_type == "identity": 112 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 113 | else: 114 | self.class_embedding = None 115 | 116 | self.down_blocks = nn.ModuleList([]) 117 | self.mid_block = None 118 | self.up_blocks = nn.ModuleList([]) 119 | 120 | if isinstance(only_cross_attention, bool): 121 | only_cross_attention = [only_cross_attention] * len(down_block_types) 122 | 123 | if isinstance(attention_head_dim, int): 124 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 125 | 126 | # down 127 | output_channel = block_out_channels[0] 128 | for i, down_block_type in enumerate(down_block_types): 129 | res = 2 ** i 130 | input_channel = output_channel 131 | output_channel = block_out_channels[i] 132 | is_final_block = i == len(block_out_channels) - 1 133 | 134 | down_block = get_down_block( 135 | down_block_type, 136 | num_layers=layers_per_block, 137 | in_channels=input_channel, 138 | out_channels=output_channel, 139 | temb_channels=time_embed_dim, 140 | add_downsample=not is_final_block, 141 | resnet_eps=norm_eps, 142 | resnet_act_fn=act_fn, 143 | resnet_groups=norm_num_groups, 144 | cross_attention_dim=cross_attention_dim, 145 | attn_num_head_channels=attention_head_dim[i], 146 | downsample_padding=downsample_padding, 147 | dual_cross_attention=dual_cross_attention, 148 | use_linear_projection=use_linear_projection, 149 | only_cross_attention=only_cross_attention[i], 150 | upcast_attention=upcast_attention, 151 | resnet_time_scale_shift=resnet_time_scale_shift, 152 | 153 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 154 | unet_use_temporal_attention=unet_use_temporal_attention, 155 | use_inflated_groupnorm=use_inflated_groupnorm, 156 | 157 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 158 | motion_module_type=motion_module_type, 159 | motion_module_kwargs=motion_module_kwargs, 160 | ) 161 | self.down_blocks.append(down_block) 162 | 163 | # mid 164 | if mid_block_type == "UNetMidBlock3DCrossAttn": 165 | self.mid_block = UNetMidBlock3DCrossAttn( 166 | in_channels=block_out_channels[-1], 167 | temb_channels=time_embed_dim, 168 | resnet_eps=norm_eps, 169 | resnet_act_fn=act_fn, 170 | output_scale_factor=mid_block_scale_factor, 171 | resnet_time_scale_shift=resnet_time_scale_shift, 172 | cross_attention_dim=cross_attention_dim, 173 | attn_num_head_channels=attention_head_dim[-1], 174 | resnet_groups=norm_num_groups, 175 | dual_cross_attention=dual_cross_attention, 176 | use_linear_projection=use_linear_projection, 177 | upcast_attention=upcast_attention, 178 | 179 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 180 | unet_use_temporal_attention=unet_use_temporal_attention, 181 | use_inflated_groupnorm=use_inflated_groupnorm, 182 | 183 | use_motion_module=use_motion_module and motion_module_mid_block, 184 | motion_module_type=motion_module_type, 185 | motion_module_kwargs=motion_module_kwargs, 186 | ) 187 | else: 188 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 189 | 190 | # count how many layers upsample the videos 191 | self.num_upsamplers = 0 192 | 193 | # up 194 | reversed_block_out_channels = list(reversed(block_out_channels)) 195 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 196 | only_cross_attention = list(reversed(only_cross_attention)) 197 | output_channel = reversed_block_out_channels[0] 198 | for i, up_block_type in enumerate(up_block_types): 199 | res = 2 ** (3 - i) 200 | is_final_block = i == len(block_out_channels) - 1 201 | 202 | prev_output_channel = output_channel 203 | output_channel = reversed_block_out_channels[i] 204 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 205 | 206 | # add upsample block for all BUT final layer 207 | if not is_final_block: 208 | add_upsample = True 209 | self.num_upsamplers += 1 210 | else: 211 | add_upsample = False 212 | 213 | up_block = get_up_block( 214 | up_block_type, 215 | num_layers=layers_per_block + 1, 216 | in_channels=input_channel, 217 | out_channels=output_channel, 218 | prev_output_channel=prev_output_channel, 219 | temb_channels=time_embed_dim, 220 | add_upsample=add_upsample, 221 | resnet_eps=norm_eps, 222 | resnet_act_fn=act_fn, 223 | resnet_groups=norm_num_groups, 224 | cross_attention_dim=cross_attention_dim, 225 | attn_num_head_channels=reversed_attention_head_dim[i], 226 | dual_cross_attention=dual_cross_attention, 227 | use_linear_projection=use_linear_projection, 228 | only_cross_attention=only_cross_attention[i], 229 | upcast_attention=upcast_attention, 230 | resnet_time_scale_shift=resnet_time_scale_shift, 231 | 232 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 233 | unet_use_temporal_attention=unet_use_temporal_attention, 234 | use_inflated_groupnorm=use_inflated_groupnorm, 235 | 236 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 237 | motion_module_type=motion_module_type, 238 | motion_module_kwargs=motion_module_kwargs, 239 | ) 240 | self.up_blocks.append(up_block) 241 | prev_output_channel = output_channel 242 | 243 | # out 244 | if use_inflated_groupnorm: 245 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 246 | else: 247 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 248 | self.conv_act = nn.SiLU() 249 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 250 | 251 | def set_attention_slice(self, slice_size): 252 | r""" 253 | Enable sliced attention computation. 254 | 255 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 256 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 257 | 258 | Args: 259 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 260 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 261 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 262 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 263 | must be a multiple of `slice_size`. 264 | """ 265 | sliceable_head_dims = [] 266 | 267 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 268 | if hasattr(module, "set_attention_slice"): 269 | sliceable_head_dims.append(module.sliceable_head_dim) 270 | 271 | for child in module.children(): 272 | fn_recursive_retrieve_slicable_dims(child) 273 | 274 | # retrieve number of attention layers 275 | for module in self.children(): 276 | fn_recursive_retrieve_slicable_dims(module) 277 | 278 | num_slicable_layers = len(sliceable_head_dims) 279 | 280 | if slice_size == "auto": 281 | # half the attention head size is usually a good trade-off between 282 | # speed and memory 283 | slice_size = [dim // 2 for dim in sliceable_head_dims] 284 | elif slice_size == "max": 285 | # make smallest slice possible 286 | slice_size = num_slicable_layers * [1] 287 | 288 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 289 | 290 | if len(slice_size) != len(sliceable_head_dims): 291 | raise ValueError( 292 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 293 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 294 | ) 295 | 296 | for i in range(len(slice_size)): 297 | size = slice_size[i] 298 | dim = sliceable_head_dims[i] 299 | if size is not None and size > dim: 300 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 301 | 302 | # Recursively walk through all the children. 303 | # Any children which exposes the set_attention_slice method 304 | # gets the message 305 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 306 | if hasattr(module, "set_attention_slice"): 307 | module.set_attention_slice(slice_size.pop()) 308 | 309 | for child in module.children(): 310 | fn_recursive_set_attention_slice(child, slice_size) 311 | 312 | reversed_slice_size = list(reversed(slice_size)) 313 | for module in self.children(): 314 | fn_recursive_set_attention_slice(module, reversed_slice_size) 315 | 316 | def _set_gradient_checkpointing(self, module, value=False): 317 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 318 | module.gradient_checkpointing = value 319 | 320 | def forward( 321 | self, 322 | sample: torch.FloatTensor, 323 | timestep: Union[torch.Tensor, float, int], 324 | encoder_hidden_states: torch.Tensor, 325 | class_labels: Optional[torch.Tensor] = None, 326 | attention_mask: Optional[torch.Tensor] = None, 327 | 328 | # support controlnet 329 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 330 | mid_block_additional_residual: Optional[torch.Tensor] = None, 331 | 332 | return_dict: bool = True, 333 | ) -> Union[UNet3DConditionOutput, Tuple]: 334 | r""" 335 | Args: 336 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 337 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 338 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 339 | return_dict (`bool`, *optional*, defaults to `True`): 340 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 341 | 342 | Returns: 343 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 344 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 345 | returning a tuple, the first element is the sample tensor. 346 | """ 347 | # By default samples have to be AT least a multiple of the overall upsampling factor. 348 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 349 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 350 | # on the fly if necessary. 351 | default_overall_up_factor = 2**self.num_upsamplers 352 | 353 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 354 | forward_upsample_size = False 355 | upsample_size = None 356 | 357 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 358 | logger.info("Forward upsample size to force interpolation output size.") 359 | forward_upsample_size = True 360 | 361 | # prepare attention_mask 362 | if attention_mask is not None: 363 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 364 | attention_mask = attention_mask.unsqueeze(1) 365 | 366 | # center input if necessary 367 | if self.config.center_input_sample: 368 | sample = 2 * sample - 1.0 369 | 370 | # time 371 | timesteps = timestep 372 | if not torch.is_tensor(timesteps): 373 | # This would be a good case for the `match` statement (Python 3.10+) 374 | is_mps = sample.device.type == "mps" 375 | if isinstance(timestep, float): 376 | dtype = torch.float32 if is_mps else torch.float64 377 | else: 378 | dtype = torch.int32 if is_mps else torch.int64 379 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 380 | elif len(timesteps.shape) == 0: 381 | timesteps = timesteps[None].to(sample.device) 382 | 383 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 384 | timesteps = timesteps.expand(sample.shape[0]) 385 | 386 | t_emb = self.time_proj(timesteps) 387 | 388 | # timesteps does not contain any weights and will always return f32 tensors 389 | # but time_embedding might actually be running in fp16. so we need to cast here. 390 | # there might be better ways to encapsulate this. 391 | t_emb = t_emb.to(dtype=self.dtype) 392 | emb = self.time_embedding(t_emb) 393 | 394 | if self.class_embedding is not None: 395 | if class_labels is None: 396 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 397 | 398 | if self.config.class_embed_type == "timestep": 399 | class_labels = self.time_proj(class_labels) 400 | 401 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 402 | emb = emb + class_emb 403 | 404 | # pre-process 405 | sample = self.conv_in(sample) 406 | 407 | # down 408 | down_block_res_samples = (sample,) 409 | for downsample_block in self.down_blocks: 410 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 411 | sample, res_samples = downsample_block( 412 | hidden_states=sample, 413 | temb=emb, 414 | encoder_hidden_states=encoder_hidden_states, 415 | attention_mask=attention_mask, 416 | ) 417 | else: 418 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 419 | 420 | down_block_res_samples += res_samples 421 | 422 | # support controlnet 423 | down_block_res_samples = list(down_block_res_samples) 424 | if down_block_additional_residuals is not None: 425 | for i, down_block_additional_residual in enumerate(down_block_additional_residuals): 426 | if down_block_additional_residual.dim() == 4: # boardcast 427 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2) 428 | down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual 429 | 430 | # mid 431 | sample = self.mid_block( 432 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 433 | ) 434 | 435 | # support controlnet 436 | if mid_block_additional_residual is not None: 437 | if mid_block_additional_residual.dim() == 4: # boardcast 438 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) 439 | sample = sample + mid_block_additional_residual 440 | 441 | # up 442 | for i, upsample_block in enumerate(self.up_blocks): 443 | is_final_block = i == len(self.up_blocks) - 1 444 | 445 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 446 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 447 | 448 | # if we have not reached the final block and need to forward the 449 | # upsample size, we do it here 450 | if not is_final_block and forward_upsample_size: 451 | upsample_size = down_block_res_samples[-1].shape[2:] 452 | 453 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 454 | sample = upsample_block( 455 | hidden_states=sample, 456 | temb=emb, 457 | res_hidden_states_tuple=res_samples, 458 | encoder_hidden_states=encoder_hidden_states, 459 | upsample_size=upsample_size, 460 | attention_mask=attention_mask, 461 | ) 462 | else: 463 | sample = upsample_block( 464 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, 465 | ) 466 | 467 | # post-process 468 | sample = self.conv_norm_out(sample) 469 | sample = self.conv_act(sample) 470 | sample = self.conv_out(sample) 471 | 472 | if not return_dict: 473 | return (sample,) 474 | 475 | return UNet3DConditionOutput(sample=sample) 476 | 477 | @classmethod 478 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 479 | if subfolder is not None: 480 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 481 | print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") 482 | 483 | config_file = os.path.join(pretrained_model_path, 'config.json') 484 | if not os.path.isfile(config_file): 485 | raise RuntimeError(f"{config_file} does not exist") 486 | with open(config_file, "r") as f: 487 | config = json.load(f) 488 | config["_class_name"] = cls.__name__ 489 | config["down_block_types"] = [ 490 | "CrossAttnDownBlock3D", 491 | "CrossAttnDownBlock3D", 492 | "CrossAttnDownBlock3D", 493 | "DownBlock3D" 494 | ] 495 | config["up_block_types"] = [ 496 | "UpBlock3D", 497 | "CrossAttnUpBlock3D", 498 | "CrossAttnUpBlock3D", 499 | "CrossAttnUpBlock3D" 500 | ] 501 | 502 | from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME 503 | model = cls.from_config(config, **unet_additional_kwargs) 504 | 505 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 506 | model_file_safe = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) 507 | 508 | if os.path.isfile(model_file_safe): 509 | model_file = model_file_safe 510 | 511 | if not os.path.isfile(model_file): 512 | raise RuntimeError(f"{model_file} does not exist") 513 | 514 | if SAFETENSORS_WEIGHTS_NAME in model_file: 515 | from safetensors.torch import load_file 516 | state_dict = load_file(model_file) 517 | else: 518 | state_dict = torch.load(model_file, map_location="cpu") 519 | 520 | m, u = model.load_state_dict(state_dict, strict=False) 521 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 522 | 523 | params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] 524 | print(f"### Motion Module Parameters: {sum(params) / 1e6} M") 525 | 526 | return model 527 | -------------------------------------------------------------------------------- /animatediff/models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | from .motion_module import get_motion_module 9 | 10 | import pdb 11 | 12 | def checkpoint_no_reentrant(*args, **kwargs): 13 | kwargs['use_reentrant'] = False 14 | return torch.utils.checkpoint.checkpoint(*args, **kwargs) 15 | 16 | def get_down_block( 17 | down_block_type, 18 | num_layers, 19 | in_channels, 20 | out_channels, 21 | temb_channels, 22 | add_downsample, 23 | resnet_eps, 24 | resnet_act_fn, 25 | attn_num_head_channels, 26 | resnet_groups=None, 27 | cross_attention_dim=None, 28 | downsample_padding=None, 29 | dual_cross_attention=False, 30 | use_linear_projection=False, 31 | only_cross_attention=False, 32 | upcast_attention=False, 33 | resnet_time_scale_shift="default", 34 | 35 | unet_use_cross_frame_attention=False, 36 | unet_use_temporal_attention=False, 37 | use_inflated_groupnorm=False, 38 | 39 | use_motion_module=None, 40 | 41 | motion_module_type=None, 42 | motion_module_kwargs=None, 43 | ): 44 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 45 | if down_block_type == "DownBlock3D": 46 | return DownBlock3D( 47 | num_layers=num_layers, 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | temb_channels=temb_channels, 51 | add_downsample=add_downsample, 52 | resnet_eps=resnet_eps, 53 | resnet_act_fn=resnet_act_fn, 54 | resnet_groups=resnet_groups, 55 | downsample_padding=downsample_padding, 56 | resnet_time_scale_shift=resnet_time_scale_shift, 57 | 58 | use_inflated_groupnorm=use_inflated_groupnorm, 59 | 60 | use_motion_module=use_motion_module, 61 | motion_module_type=motion_module_type, 62 | motion_module_kwargs=motion_module_kwargs, 63 | ) 64 | elif down_block_type == "CrossAttnDownBlock3D": 65 | if cross_attention_dim is None: 66 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 67 | return CrossAttnDownBlock3D( 68 | num_layers=num_layers, 69 | in_channels=in_channels, 70 | out_channels=out_channels, 71 | temb_channels=temb_channels, 72 | add_downsample=add_downsample, 73 | resnet_eps=resnet_eps, 74 | resnet_act_fn=resnet_act_fn, 75 | resnet_groups=resnet_groups, 76 | downsample_padding=downsample_padding, 77 | cross_attention_dim=cross_attention_dim, 78 | attn_num_head_channels=attn_num_head_channels, 79 | dual_cross_attention=dual_cross_attention, 80 | use_linear_projection=use_linear_projection, 81 | only_cross_attention=only_cross_attention, 82 | upcast_attention=upcast_attention, 83 | resnet_time_scale_shift=resnet_time_scale_shift, 84 | 85 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 86 | unet_use_temporal_attention=unet_use_temporal_attention, 87 | use_inflated_groupnorm=use_inflated_groupnorm, 88 | 89 | use_motion_module=use_motion_module, 90 | motion_module_type=motion_module_type, 91 | motion_module_kwargs=motion_module_kwargs, 92 | ) 93 | raise ValueError(f"{down_block_type} does not exist.") 94 | 95 | 96 | def get_up_block( 97 | up_block_type, 98 | num_layers, 99 | in_channels, 100 | out_channels, 101 | prev_output_channel, 102 | temb_channels, 103 | add_upsample, 104 | resnet_eps, 105 | resnet_act_fn, 106 | attn_num_head_channels, 107 | resnet_groups=None, 108 | cross_attention_dim=None, 109 | dual_cross_attention=False, 110 | use_linear_projection=False, 111 | only_cross_attention=False, 112 | upcast_attention=False, 113 | resnet_time_scale_shift="default", 114 | 115 | unet_use_cross_frame_attention=False, 116 | unet_use_temporal_attention=False, 117 | use_inflated_groupnorm=False, 118 | 119 | use_motion_module=None, 120 | motion_module_type=None, 121 | motion_module_kwargs=None, 122 | ): 123 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 124 | if up_block_type == "UpBlock3D": 125 | return UpBlock3D( 126 | num_layers=num_layers, 127 | in_channels=in_channels, 128 | out_channels=out_channels, 129 | prev_output_channel=prev_output_channel, 130 | temb_channels=temb_channels, 131 | add_upsample=add_upsample, 132 | resnet_eps=resnet_eps, 133 | resnet_act_fn=resnet_act_fn, 134 | resnet_groups=resnet_groups, 135 | resnet_time_scale_shift=resnet_time_scale_shift, 136 | 137 | use_inflated_groupnorm=use_inflated_groupnorm, 138 | 139 | use_motion_module=use_motion_module, 140 | motion_module_type=motion_module_type, 141 | motion_module_kwargs=motion_module_kwargs, 142 | ) 143 | elif up_block_type == "CrossAttnUpBlock3D": 144 | if cross_attention_dim is None: 145 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 146 | return CrossAttnUpBlock3D( 147 | num_layers=num_layers, 148 | in_channels=in_channels, 149 | out_channels=out_channels, 150 | prev_output_channel=prev_output_channel, 151 | temb_channels=temb_channels, 152 | add_upsample=add_upsample, 153 | resnet_eps=resnet_eps, 154 | resnet_act_fn=resnet_act_fn, 155 | resnet_groups=resnet_groups, 156 | cross_attention_dim=cross_attention_dim, 157 | attn_num_head_channels=attn_num_head_channels, 158 | dual_cross_attention=dual_cross_attention, 159 | use_linear_projection=use_linear_projection, 160 | only_cross_attention=only_cross_attention, 161 | upcast_attention=upcast_attention, 162 | resnet_time_scale_shift=resnet_time_scale_shift, 163 | 164 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 165 | unet_use_temporal_attention=unet_use_temporal_attention, 166 | use_inflated_groupnorm=use_inflated_groupnorm, 167 | 168 | use_motion_module=use_motion_module, 169 | motion_module_type=motion_module_type, 170 | motion_module_kwargs=motion_module_kwargs, 171 | ) 172 | raise ValueError(f"{up_block_type} does not exist.") 173 | 174 | 175 | class UNetMidBlock3DCrossAttn(nn.Module): 176 | def __init__( 177 | self, 178 | in_channels: int, 179 | temb_channels: int, 180 | dropout: float = 0.0, 181 | num_layers: int = 1, 182 | resnet_eps: float = 1e-6, 183 | resnet_time_scale_shift: str = "default", 184 | resnet_act_fn: str = "swish", 185 | resnet_groups: int = 32, 186 | resnet_pre_norm: bool = True, 187 | attn_num_head_channels=1, 188 | output_scale_factor=1.0, 189 | cross_attention_dim=1280, 190 | dual_cross_attention=False, 191 | use_linear_projection=False, 192 | upcast_attention=False, 193 | 194 | unet_use_cross_frame_attention=False, 195 | unet_use_temporal_attention=False, 196 | use_inflated_groupnorm=False, 197 | 198 | use_motion_module=None, 199 | 200 | motion_module_type=None, 201 | motion_module_kwargs=None, 202 | ): 203 | super().__init__() 204 | 205 | self.has_cross_attention = True 206 | self.attn_num_head_channels = attn_num_head_channels 207 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 208 | 209 | # there is always at least one resnet 210 | resnets = [ 211 | ResnetBlock3D( 212 | in_channels=in_channels, 213 | out_channels=in_channels, 214 | temb_channels=temb_channels, 215 | eps=resnet_eps, 216 | groups=resnet_groups, 217 | dropout=dropout, 218 | time_embedding_norm=resnet_time_scale_shift, 219 | non_linearity=resnet_act_fn, 220 | output_scale_factor=output_scale_factor, 221 | pre_norm=resnet_pre_norm, 222 | 223 | use_inflated_groupnorm=use_inflated_groupnorm, 224 | ) 225 | ] 226 | attentions = [] 227 | motion_modules = [] 228 | 229 | for _ in range(num_layers): 230 | if dual_cross_attention: 231 | raise NotImplementedError 232 | attentions.append( 233 | Transformer3DModel( 234 | attn_num_head_channels, 235 | in_channels // attn_num_head_channels, 236 | in_channels=in_channels, 237 | num_layers=1, 238 | cross_attention_dim=cross_attention_dim, 239 | norm_num_groups=resnet_groups, 240 | use_linear_projection=use_linear_projection, 241 | upcast_attention=upcast_attention, 242 | 243 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 244 | unet_use_temporal_attention=unet_use_temporal_attention, 245 | ) 246 | ) 247 | motion_modules.append( 248 | get_motion_module( 249 | in_channels=in_channels, 250 | motion_module_type=motion_module_type, 251 | motion_module_kwargs=motion_module_kwargs, 252 | ) if use_motion_module else None 253 | ) 254 | resnets.append( 255 | ResnetBlock3D( 256 | in_channels=in_channels, 257 | out_channels=in_channels, 258 | temb_channels=temb_channels, 259 | eps=resnet_eps, 260 | groups=resnet_groups, 261 | dropout=dropout, 262 | time_embedding_norm=resnet_time_scale_shift, 263 | non_linearity=resnet_act_fn, 264 | output_scale_factor=output_scale_factor, 265 | pre_norm=resnet_pre_norm, 266 | 267 | use_inflated_groupnorm=use_inflated_groupnorm, 268 | ) 269 | ) 270 | 271 | self.attentions = nn.ModuleList(attentions) 272 | self.resnets = nn.ModuleList(resnets) 273 | self.motion_modules = nn.ModuleList(motion_modules) 274 | 275 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 276 | hidden_states = self.resnets[0](hidden_states, temb) 277 | for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): 278 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 279 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 280 | hidden_states = resnet(hidden_states, temb) 281 | 282 | return hidden_states 283 | 284 | 285 | class CrossAttnDownBlock3D(nn.Module): 286 | def __init__( 287 | self, 288 | in_channels: int, 289 | out_channels: int, 290 | temb_channels: int, 291 | dropout: float = 0.0, 292 | num_layers: int = 1, 293 | resnet_eps: float = 1e-6, 294 | resnet_time_scale_shift: str = "default", 295 | resnet_act_fn: str = "swish", 296 | resnet_groups: int = 32, 297 | resnet_pre_norm: bool = True, 298 | attn_num_head_channels=1, 299 | cross_attention_dim=1280, 300 | output_scale_factor=1.0, 301 | downsample_padding=1, 302 | add_downsample=True, 303 | dual_cross_attention=False, 304 | use_linear_projection=False, 305 | only_cross_attention=False, 306 | upcast_attention=False, 307 | 308 | unet_use_cross_frame_attention=False, 309 | unet_use_temporal_attention=False, 310 | use_inflated_groupnorm=False, 311 | 312 | use_motion_module=None, 313 | 314 | motion_module_type=None, 315 | motion_module_kwargs=None, 316 | ): 317 | super().__init__() 318 | resnets = [] 319 | attentions = [] 320 | motion_modules = [] 321 | 322 | self.has_cross_attention = True 323 | self.attn_num_head_channels = attn_num_head_channels 324 | 325 | for i in range(num_layers): 326 | in_channels = in_channels if i == 0 else out_channels 327 | resnets.append( 328 | ResnetBlock3D( 329 | in_channels=in_channels, 330 | out_channels=out_channels, 331 | temb_channels=temb_channels, 332 | eps=resnet_eps, 333 | groups=resnet_groups, 334 | dropout=dropout, 335 | time_embedding_norm=resnet_time_scale_shift, 336 | non_linearity=resnet_act_fn, 337 | output_scale_factor=output_scale_factor, 338 | pre_norm=resnet_pre_norm, 339 | 340 | use_inflated_groupnorm=use_inflated_groupnorm, 341 | ) 342 | ) 343 | if dual_cross_attention: 344 | raise NotImplementedError 345 | attentions.append( 346 | Transformer3DModel( 347 | attn_num_head_channels, 348 | out_channels // attn_num_head_channels, 349 | in_channels=out_channels, 350 | num_layers=1, 351 | cross_attention_dim=cross_attention_dim, 352 | norm_num_groups=resnet_groups, 353 | use_linear_projection=use_linear_projection, 354 | only_cross_attention=only_cross_attention, 355 | upcast_attention=upcast_attention, 356 | 357 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 358 | unet_use_temporal_attention=unet_use_temporal_attention, 359 | ) 360 | ) 361 | motion_modules.append( 362 | get_motion_module( 363 | in_channels=out_channels, 364 | motion_module_type=motion_module_type, 365 | motion_module_kwargs=motion_module_kwargs, 366 | ) if use_motion_module else None 367 | ) 368 | 369 | self.attentions = nn.ModuleList(attentions) 370 | self.resnets = nn.ModuleList(resnets) 371 | self.motion_modules = nn.ModuleList(motion_modules) 372 | 373 | if add_downsample: 374 | self.downsamplers = nn.ModuleList( 375 | [ 376 | Downsample3D( 377 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 378 | ) 379 | ] 380 | ) 381 | else: 382 | self.downsamplers = None 383 | 384 | self.gradient_checkpointing = False 385 | 386 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 387 | output_states = () 388 | 389 | for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): 390 | if self.training and self.gradient_checkpointing: 391 | 392 | def create_custom_forward(module, return_dict=None): 393 | def custom_forward(*inputs): 394 | if return_dict is not None: 395 | return module(*inputs, return_dict=return_dict) 396 | else: 397 | return module(*inputs) 398 | 399 | return custom_forward 400 | 401 | hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb) 402 | hidden_states = checkpoint_no_reentrant( 403 | create_custom_forward(attn, return_dict=False), 404 | hidden_states, 405 | encoder_hidden_states, 406 | )[0] 407 | if motion_module is not None: 408 | hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 409 | 410 | else: 411 | hidden_states = resnet(hidden_states, temb) 412 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 413 | 414 | # add motion module 415 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 416 | 417 | output_states += (hidden_states,) 418 | 419 | if self.downsamplers is not None: 420 | for downsampler in self.downsamplers: 421 | hidden_states = downsampler(hidden_states) 422 | 423 | output_states += (hidden_states,) 424 | 425 | return hidden_states, output_states 426 | 427 | 428 | class DownBlock3D(nn.Module): 429 | def __init__( 430 | self, 431 | in_channels: int, 432 | out_channels: int, 433 | temb_channels: int, 434 | dropout: float = 0.0, 435 | num_layers: int = 1, 436 | resnet_eps: float = 1e-6, 437 | resnet_time_scale_shift: str = "default", 438 | resnet_act_fn: str = "swish", 439 | resnet_groups: int = 32, 440 | resnet_pre_norm: bool = True, 441 | output_scale_factor=1.0, 442 | add_downsample=True, 443 | downsample_padding=1, 444 | 445 | use_inflated_groupnorm=False, 446 | 447 | use_motion_module=None, 448 | motion_module_type=None, 449 | motion_module_kwargs=None, 450 | ): 451 | super().__init__() 452 | resnets = [] 453 | motion_modules = [] 454 | 455 | for i in range(num_layers): 456 | in_channels = in_channels if i == 0 else out_channels 457 | resnets.append( 458 | ResnetBlock3D( 459 | in_channels=in_channels, 460 | out_channels=out_channels, 461 | temb_channels=temb_channels, 462 | eps=resnet_eps, 463 | groups=resnet_groups, 464 | dropout=dropout, 465 | time_embedding_norm=resnet_time_scale_shift, 466 | non_linearity=resnet_act_fn, 467 | output_scale_factor=output_scale_factor, 468 | pre_norm=resnet_pre_norm, 469 | 470 | use_inflated_groupnorm=use_inflated_groupnorm, 471 | ) 472 | ) 473 | motion_modules.append( 474 | get_motion_module( 475 | in_channels=out_channels, 476 | motion_module_type=motion_module_type, 477 | motion_module_kwargs=motion_module_kwargs, 478 | ) if use_motion_module else None 479 | ) 480 | 481 | self.resnets = nn.ModuleList(resnets) 482 | self.motion_modules = nn.ModuleList(motion_modules) 483 | 484 | if add_downsample: 485 | self.downsamplers = nn.ModuleList( 486 | [ 487 | Downsample3D( 488 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 489 | ) 490 | ] 491 | ) 492 | else: 493 | self.downsamplers = None 494 | 495 | self.gradient_checkpointing = False 496 | 497 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None): 498 | output_states = () 499 | 500 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 501 | if self.training and self.gradient_checkpointing: 502 | def create_custom_forward(module): 503 | def custom_forward(*inputs): 504 | return module(*inputs) 505 | 506 | return custom_forward 507 | 508 | hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb) 509 | if motion_module is not None: 510 | hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 511 | else: 512 | hidden_states = resnet(hidden_states, temb) 513 | 514 | # add motion module 515 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 516 | 517 | output_states += (hidden_states,) 518 | 519 | if self.downsamplers is not None: 520 | for downsampler in self.downsamplers: 521 | hidden_states = downsampler(hidden_states) 522 | 523 | output_states += (hidden_states,) 524 | 525 | return hidden_states, output_states 526 | 527 | 528 | class CrossAttnUpBlock3D(nn.Module): 529 | def __init__( 530 | self, 531 | in_channels: int, 532 | out_channels: int, 533 | prev_output_channel: int, 534 | temb_channels: int, 535 | dropout: float = 0.0, 536 | num_layers: int = 1, 537 | resnet_eps: float = 1e-6, 538 | resnet_time_scale_shift: str = "default", 539 | resnet_act_fn: str = "swish", 540 | resnet_groups: int = 32, 541 | resnet_pre_norm: bool = True, 542 | attn_num_head_channels=1, 543 | cross_attention_dim=1280, 544 | output_scale_factor=1.0, 545 | add_upsample=True, 546 | dual_cross_attention=False, 547 | use_linear_projection=False, 548 | only_cross_attention=False, 549 | upcast_attention=False, 550 | 551 | unet_use_cross_frame_attention=False, 552 | unet_use_temporal_attention=False, 553 | use_inflated_groupnorm=False, 554 | 555 | use_motion_module=None, 556 | 557 | motion_module_type=None, 558 | motion_module_kwargs=None, 559 | ): 560 | super().__init__() 561 | resnets = [] 562 | attentions = [] 563 | motion_modules = [] 564 | 565 | self.has_cross_attention = True 566 | self.attn_num_head_channels = attn_num_head_channels 567 | 568 | for i in range(num_layers): 569 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 570 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 571 | 572 | resnets.append( 573 | ResnetBlock3D( 574 | in_channels=resnet_in_channels + res_skip_channels, 575 | out_channels=out_channels, 576 | temb_channels=temb_channels, 577 | eps=resnet_eps, 578 | groups=resnet_groups, 579 | dropout=dropout, 580 | time_embedding_norm=resnet_time_scale_shift, 581 | non_linearity=resnet_act_fn, 582 | output_scale_factor=output_scale_factor, 583 | pre_norm=resnet_pre_norm, 584 | 585 | use_inflated_groupnorm=use_inflated_groupnorm, 586 | ) 587 | ) 588 | if dual_cross_attention: 589 | raise NotImplementedError 590 | attentions.append( 591 | Transformer3DModel( 592 | attn_num_head_channels, 593 | out_channels // attn_num_head_channels, 594 | in_channels=out_channels, 595 | num_layers=1, 596 | cross_attention_dim=cross_attention_dim, 597 | norm_num_groups=resnet_groups, 598 | use_linear_projection=use_linear_projection, 599 | only_cross_attention=only_cross_attention, 600 | upcast_attention=upcast_attention, 601 | 602 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 603 | unet_use_temporal_attention=unet_use_temporal_attention, 604 | ) 605 | ) 606 | motion_modules.append( 607 | get_motion_module( 608 | in_channels=out_channels, 609 | motion_module_type=motion_module_type, 610 | motion_module_kwargs=motion_module_kwargs, 611 | ) if use_motion_module else None 612 | ) 613 | 614 | self.attentions = nn.ModuleList(attentions) 615 | self.resnets = nn.ModuleList(resnets) 616 | self.motion_modules = nn.ModuleList(motion_modules) 617 | 618 | if add_upsample: 619 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 620 | else: 621 | self.upsamplers = None 622 | 623 | self.gradient_checkpointing = False 624 | 625 | def forward( 626 | self, 627 | hidden_states, 628 | res_hidden_states_tuple, 629 | temb=None, 630 | encoder_hidden_states=None, 631 | upsample_size=None, 632 | attention_mask=None, 633 | ): 634 | for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): 635 | # pop res hidden states 636 | res_hidden_states = res_hidden_states_tuple[-1] 637 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 638 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 639 | 640 | if self.training and self.gradient_checkpointing: 641 | 642 | def create_custom_forward(module, return_dict=None): 643 | def custom_forward(*inputs): 644 | if return_dict is not None: 645 | return module(*inputs, return_dict=return_dict) 646 | else: 647 | return module(*inputs) 648 | 649 | return custom_forward 650 | 651 | hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb) 652 | hidden_states = checkpoint_no_reentrant( 653 | create_custom_forward(attn, return_dict=False), 654 | hidden_states, 655 | encoder_hidden_states, 656 | )[0] 657 | if motion_module is not None: 658 | hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 659 | 660 | else: 661 | hidden_states = resnet(hidden_states, temb) 662 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 663 | 664 | # add motion module 665 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 666 | 667 | if self.upsamplers is not None: 668 | for upsampler in self.upsamplers: 669 | hidden_states = upsampler(hidden_states, upsample_size) 670 | 671 | return hidden_states 672 | 673 | 674 | class UpBlock3D(nn.Module): 675 | def __init__( 676 | self, 677 | in_channels: int, 678 | prev_output_channel: int, 679 | out_channels: int, 680 | temb_channels: int, 681 | dropout: float = 0.0, 682 | num_layers: int = 1, 683 | resnet_eps: float = 1e-6, 684 | resnet_time_scale_shift: str = "default", 685 | resnet_act_fn: str = "swish", 686 | resnet_groups: int = 32, 687 | resnet_pre_norm: bool = True, 688 | output_scale_factor=1.0, 689 | add_upsample=True, 690 | 691 | use_inflated_groupnorm=False, 692 | 693 | use_motion_module=None, 694 | motion_module_type=None, 695 | motion_module_kwargs=None, 696 | ): 697 | super().__init__() 698 | resnets = [] 699 | motion_modules = [] 700 | 701 | for i in range(num_layers): 702 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 703 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 704 | 705 | resnets.append( 706 | ResnetBlock3D( 707 | in_channels=resnet_in_channels + res_skip_channels, 708 | out_channels=out_channels, 709 | temb_channels=temb_channels, 710 | eps=resnet_eps, 711 | groups=resnet_groups, 712 | dropout=dropout, 713 | time_embedding_norm=resnet_time_scale_shift, 714 | non_linearity=resnet_act_fn, 715 | output_scale_factor=output_scale_factor, 716 | pre_norm=resnet_pre_norm, 717 | 718 | use_inflated_groupnorm=use_inflated_groupnorm, 719 | ) 720 | ) 721 | motion_modules.append( 722 | get_motion_module( 723 | in_channels=out_channels, 724 | motion_module_type=motion_module_type, 725 | motion_module_kwargs=motion_module_kwargs, 726 | ) if use_motion_module else None 727 | ) 728 | 729 | self.resnets = nn.ModuleList(resnets) 730 | self.motion_modules = nn.ModuleList(motion_modules) 731 | 732 | if add_upsample: 733 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 734 | else: 735 | self.upsamplers = None 736 | 737 | self.gradient_checkpointing = False 738 | 739 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): 740 | for resnet, motion_module in zip(self.resnets, self.motion_modules): 741 | # pop res hidden states 742 | res_hidden_states = res_hidden_states_tuple[-1] 743 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 744 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 745 | 746 | if self.training and self.gradient_checkpointing: 747 | def create_custom_forward(module): 748 | def custom_forward(*inputs): 749 | return module(*inputs) 750 | 751 | return custom_forward 752 | 753 | hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb) 754 | if motion_module is not None: 755 | hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) 756 | else: 757 | hidden_states = resnet(hidden_states, temb) 758 | hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states 759 | 760 | if self.upsamplers is not None: 761 | for upsampler in self.upsamplers: 762 | hidden_states = upsampler(hidden_states, upsample_size) 763 | 764 | return hidden_states 765 | -------------------------------------------------------------------------------- /configs/ad_unet_config.yaml: -------------------------------------------------------------------------------- 1 | sample_size: 64 2 | in_channels: 4 3 | out_channels: 4 4 | center_input_sample: false 5 | flip_sin_to_cos: true 6 | freq_shift: 0 7 | down_block_types: 8 | - CrossAttnDownBlock3D 9 | - CrossAttnDownBlock3D 10 | - CrossAttnDownBlock3D 11 | - DownBlock3D 12 | mid_block_type: UNetMidBlock3DCrossAttn 13 | up_block_types: 14 | - UpBlock3D 15 | - CrossAttnUpBlock3D 16 | - CrossAttnUpBlock3D 17 | - CrossAttnUpBlock3D 18 | only_cross_attention: false 19 | block_out_channels: 20 | - 320 21 | - 640 22 | - 1280 23 | - 1280 24 | layers_per_block: 2 25 | downsample_padding: 1 26 | mid_block_scale_factor: 1 27 | act_fn: silu 28 | norm_num_groups: 32 29 | norm_eps: 1e-05 30 | cross_attention_dim: 768 31 | attention_head_dim: 8 32 | dual_cross_attention: false 33 | use_linear_projection: false 34 | class_embed_type: null 35 | num_class_embeds: null 36 | upcast_attention: false 37 | resnet_time_scale_shift: default 38 | use_inflated_groupnorm: true 39 | use_motion_module: true 40 | motion_module_resolutions: 41 | - 1 42 | - 2 43 | - 4 44 | - 8 45 | motion_module_mid_block: false 46 | motion_module_decoder_only: false 47 | motion_module_type: Vanilla 48 | motion_module_kwargs: 49 | num_attention_heads: 8 50 | num_transformer_block: 1 51 | attention_block_types: 52 | - Temporal_Self 53 | - Temporal_Self 54 | temporal_position_encoding: true 55 | temporal_position_encoding_max_len: 32 56 | temporal_attention_dim_div: 1 57 | zero_initialize: true 58 | unet_use_cross_frame_attention: false 59 | unet_use_temporal_attention: false 60 | _use_default_values: 61 | - resnet_time_scale_shift 62 | - only_cross_attention 63 | - mid_block_type 64 | - unet_use_cross_frame_attention 65 | - class_embed_type 66 | - unet_use_temporal_attention 67 | - dual_cross_attention 68 | - num_class_embeds 69 | - upcast_attention 70 | - use_linear_projection 71 | - motion_module_decoder_only 72 | _class_name: UNet3DConditionModel 73 | _diffusers_version: '0.6.0' -------------------------------------------------------------------------------- /configs/text_encoder_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.22.0.dev0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /configs/tokenizer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "clip-vit-large-patch14/", 3 | "architectures": [ 4 | "CLIPModel" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 768, 10 | "text_config": { 11 | "_name_or_path": "", 12 | "add_cross_attention": false, 13 | "architectures": null, 14 | "attention_dropout": 0.0, 15 | "bad_words_ids": null, 16 | "bos_token_id": 0, 17 | "chunk_size_feed_forward": 0, 18 | "cross_attention_hidden_size": null, 19 | "decoder_start_token_id": null, 20 | "diversity_penalty": 0.0, 21 | "do_sample": false, 22 | "dropout": 0.0, 23 | "early_stopping": false, 24 | "encoder_no_repeat_ngram_size": 0, 25 | "eos_token_id": 2, 26 | "finetuning_task": null, 27 | "forced_bos_token_id": null, 28 | "forced_eos_token_id": null, 29 | "hidden_act": "quick_gelu", 30 | "hidden_size": 768, 31 | "id2label": { 32 | "0": "LABEL_0", 33 | "1": "LABEL_1" 34 | }, 35 | "initializer_factor": 1.0, 36 | "initializer_range": 0.02, 37 | "intermediate_size": 3072, 38 | "is_decoder": false, 39 | "is_encoder_decoder": false, 40 | "label2id": { 41 | "LABEL_0": 0, 42 | "LABEL_1": 1 43 | }, 44 | "layer_norm_eps": 1e-05, 45 | "length_penalty": 1.0, 46 | "max_length": 20, 47 | "max_position_embeddings": 77, 48 | "min_length": 0, 49 | "model_type": "clip_text_model", 50 | "no_repeat_ngram_size": 0, 51 | "num_attention_heads": 12, 52 | "num_beam_groups": 1, 53 | "num_beams": 1, 54 | "num_hidden_layers": 12, 55 | "num_return_sequences": 1, 56 | "output_attentions": false, 57 | "output_hidden_states": false, 58 | "output_scores": false, 59 | "pad_token_id": 1, 60 | "prefix": null, 61 | "problem_type": null, 62 | "projection_dim" : 768, 63 | "pruned_heads": {}, 64 | "remove_invalid_values": false, 65 | "repetition_penalty": 1.0, 66 | "return_dict": true, 67 | "return_dict_in_generate": false, 68 | "sep_token_id": null, 69 | "task_specific_params": null, 70 | "temperature": 1.0, 71 | "tie_encoder_decoder": false, 72 | "tie_word_embeddings": true, 73 | "tokenizer_class": null, 74 | "top_k": 50, 75 | "top_p": 1.0, 76 | "torch_dtype": null, 77 | "torchscript": false, 78 | "transformers_version": "4.16.0.dev0", 79 | "use_bfloat16": false, 80 | "vocab_size": 49408 81 | }, 82 | "text_config_dict": { 83 | "hidden_size": 768, 84 | "intermediate_size": 3072, 85 | "num_attention_heads": 12, 86 | "num_hidden_layers": 12, 87 | "projection_dim": 768 88 | }, 89 | "torch_dtype": "float32", 90 | "transformers_version": null, 91 | "vision_config": { 92 | "_name_or_path": "", 93 | "add_cross_attention": false, 94 | "architectures": null, 95 | "attention_dropout": 0.0, 96 | "bad_words_ids": null, 97 | "bos_token_id": null, 98 | "chunk_size_feed_forward": 0, 99 | "cross_attention_hidden_size": null, 100 | "decoder_start_token_id": null, 101 | "diversity_penalty": 0.0, 102 | "do_sample": false, 103 | "dropout": 0.0, 104 | "early_stopping": false, 105 | "encoder_no_repeat_ngram_size": 0, 106 | "eos_token_id": null, 107 | "finetuning_task": null, 108 | "forced_bos_token_id": null, 109 | "forced_eos_token_id": null, 110 | "hidden_act": "quick_gelu", 111 | "hidden_size": 1024, 112 | "id2label": { 113 | "0": "LABEL_0", 114 | "1": "LABEL_1" 115 | }, 116 | "image_size": 224, 117 | "initializer_factor": 1.0, 118 | "initializer_range": 0.02, 119 | "intermediate_size": 4096, 120 | "is_decoder": false, 121 | "is_encoder_decoder": false, 122 | "label2id": { 123 | "LABEL_0": 0, 124 | "LABEL_1": 1 125 | }, 126 | "layer_norm_eps": 1e-05, 127 | "length_penalty": 1.0, 128 | "max_length": 20, 129 | "min_length": 0, 130 | "model_type": "clip_vision_model", 131 | "no_repeat_ngram_size": 0, 132 | "num_attention_heads": 16, 133 | "num_beam_groups": 1, 134 | "num_beams": 1, 135 | "num_hidden_layers": 24, 136 | "num_return_sequences": 1, 137 | "output_attentions": false, 138 | "output_hidden_states": false, 139 | "output_scores": false, 140 | "pad_token_id": null, 141 | "patch_size": 14, 142 | "prefix": null, 143 | "problem_type": null, 144 | "projection_dim" : 768, 145 | "pruned_heads": {}, 146 | "remove_invalid_values": false, 147 | "repetition_penalty": 1.0, 148 | "return_dict": true, 149 | "return_dict_in_generate": false, 150 | "sep_token_id": null, 151 | "task_specific_params": null, 152 | "temperature": 1.0, 153 | "tie_encoder_decoder": false, 154 | "tie_word_embeddings": true, 155 | "tokenizer_class": null, 156 | "top_k": 50, 157 | "top_p": 1.0, 158 | "torch_dtype": null, 159 | "torchscript": false, 160 | "transformers_version": "4.16.0.dev0", 161 | "use_bfloat16": false 162 | }, 163 | "vision_config_dict": { 164 | "hidden_size": 1024, 165 | "intermediate_size": 4096, 166 | "num_attention_heads": 16, 167 | "num_hidden_layers": 24, 168 | "patch_size": 14, 169 | "projection_dim": 768 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /configs/tokenizer/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_normalize": true, 5 | "do_resize": true, 6 | "feature_extractor_type": "CLIPFeatureExtractor", 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_std": [ 13 | 0.26862954, 14 | 0.26130258, 15 | 0.27577711 16 | ], 17 | "resample": 3, 18 | "size": 224 19 | } 20 | -------------------------------------------------------------------------------- /configs/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /configs/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": { 3 | "content": "<|endoftext|>", 4 | "single_word": false, 5 | "lstrip": false, 6 | "rstrip": false, 7 | "normalized": true, 8 | "__type": "AddedToken" 9 | }, 10 | "bos_token": { 11 | "content": "<|startoftext|>", 12 | "single_word": false, 13 | "lstrip": false, 14 | "rstrip": false, 15 | "normalized": true, 16 | "__type": "AddedToken" 17 | }, 18 | "eos_token": { 19 | "content": "<|endoftext|>", 20 | "single_word": false, 21 | "lstrip": false, 22 | "rstrip": false, 23 | "normalized": true, 24 | "__type": "AddedToken" 25 | }, 26 | "pad_token": "<|endoftext|>", 27 | "add_prefix_space": false, 28 | "errors": "replace", 29 | "do_lower_case": true, 30 | "name_or_path": "openai/clip-vit-base-patch32", 31 | "model_max_length": 77, 32 | "special_tokens_map_file": "./special_tokens_map.json", 33 | "tokenizer_class": "CLIPTokenizer" 34 | } 35 | -------------------------------------------------------------------------------- /configs/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /examples/magictime_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 11, 3 | "last_link_id": 14, 4 | "nodes": [ 5 | { 6 | "id": 5, 7 | "type": "ADE_LoadAnimateDiffModel", 8 | "pos": [ 9 | 475, 10 | 535 11 | ], 12 | "size": { 13 | "0": 347.2771911621094, 14 | "1": 58 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "ad_settings", 22 | "type": "AD_SETTINGS", 23 | "link": null 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "MOTION_MODEL", 29 | "type": "MOTION_MODEL_ADE", 30 | "links": [ 31 | 6 32 | ], 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "ADE_LoadAnimateDiffModel" 38 | }, 39 | "widgets_values": [ 40 | "v3_sd15_mm.ckpt" 41 | ] 42 | }, 43 | { 44 | "id": 3, 45 | "type": "CheckpointLoaderSimple", 46 | "pos": [ 47 | 429, 48 | 378 49 | ], 50 | "size": [ 51 | 412.6001037597656, 52 | 98 53 | ], 54 | "flags": {}, 55 | "order": 1, 56 | "mode": 0, 57 | "outputs": [ 58 | { 59 | "name": "MODEL", 60 | "type": "MODEL", 61 | "links": [ 62 | 10 63 | ], 64 | "shape": 3, 65 | "slot_index": 0 66 | }, 67 | { 68 | "name": "CLIP", 69 | "type": "CLIP", 70 | "links": [ 71 | 3 72 | ], 73 | "shape": 3, 74 | "slot_index": 1 75 | }, 76 | { 77 | "name": "VAE", 78 | "type": "VAE", 79 | "links": [ 80 | 4 81 | ], 82 | "shape": 3 83 | } 84 | ], 85 | "properties": { 86 | "Node name for S&R": "CheckpointLoaderSimple" 87 | }, 88 | "widgets_values": [ 89 | "1_5/photon_v1.safetensors" 90 | ] 91 | }, 92 | { 93 | "id": 1, 94 | "type": "magictime_sampler", 95 | "pos": [ 96 | 1203, 97 | 378 98 | ], 99 | "size": { 100 | "0": 392.1629943847656, 101 | "1": 418.54119873046875 102 | }, 103 | "flags": {}, 104 | "order": 3, 105 | "mode": 0, 106 | "inputs": [ 107 | { 108 | "name": "magictime_model", 109 | "type": "MAGICTIME", 110 | "link": 1, 111 | "slot_index": 0 112 | } 113 | ], 114 | "outputs": [ 115 | { 116 | "name": "images", 117 | "type": "IMAGE", 118 | "links": [ 119 | 7, 120 | 11 121 | ], 122 | "shape": 3, 123 | "slot_index": 0 124 | } 125 | ], 126 | "properties": { 127 | "Node name for S&R": "magictime_sampler" 128 | }, 129 | "widgets_values": [ 130 | "Dough starts smooth, swells and browns in the oven, finishing as fully expanded, baked bread.", 131 | "bad quality, worse quality, blurry, nsfw", 132 | 16, 133 | 512, 134 | 512, 135 | 20, 136 | 7, 137 | 763609894569169, 138 | "fixed", 139 | "DPMSolverMultistepScheduler" 140 | ] 141 | }, 142 | { 143 | "id": 2, 144 | "type": "magictime_model_loader", 145 | "pos": [ 146 | 912, 147 | 378 148 | ], 149 | "size": { 150 | "0": 236.8000030517578, 151 | "1": 86 152 | }, 153 | "flags": {}, 154 | "order": 2, 155 | "mode": 0, 156 | "inputs": [ 157 | { 158 | "name": "model", 159 | "type": "MODEL", 160 | "link": 10, 161 | "slot_index": 0 162 | }, 163 | { 164 | "name": "clip", 165 | "type": "CLIP", 166 | "link": 3 167 | }, 168 | { 169 | "name": "vae", 170 | "type": "VAE", 171 | "link": 4, 172 | "slot_index": 2 173 | }, 174 | { 175 | "name": "motion_model", 176 | "type": "MOTION_MODEL_ADE", 177 | "link": 6, 178 | "slot_index": 3 179 | } 180 | ], 181 | "outputs": [ 182 | { 183 | "name": "magictime_model", 184 | "type": "MAGICTIME", 185 | "links": [ 186 | 1 187 | ], 188 | "shape": 3, 189 | "slot_index": 0 190 | } 191 | ], 192 | "properties": { 193 | "Node name for S&R": "magictime_model_loader" 194 | } 195 | }, 196 | { 197 | "id": 6, 198 | "type": "VHS_VideoCombine", 199 | "pos": [ 200 | 1682, 201 | 74 202 | ], 203 | "size": [ 204 | 459.2300720214844, 205 | 743.2300720214844 206 | ], 207 | "flags": {}, 208 | "order": 4, 209 | "mode": 0, 210 | "inputs": [ 211 | { 212 | "name": "images", 213 | "type": "IMAGE", 214 | "link": 7 215 | }, 216 | { 217 | "name": "audio", 218 | "type": "VHS_AUDIO", 219 | "link": null 220 | }, 221 | { 222 | "name": "batch_manager", 223 | "type": "VHS_BatchManager", 224 | "link": null 225 | } 226 | ], 227 | "outputs": [ 228 | { 229 | "name": "Filenames", 230 | "type": "VHS_FILENAMES", 231 | "links": null, 232 | "shape": 3 233 | } 234 | ], 235 | "properties": { 236 | "Node name for S&R": "VHS_VideoCombine" 237 | }, 238 | "widgets_values": { 239 | "frame_rate": 8, 240 | "loop_count": 0, 241 | "filename_prefix": "MagicTime", 242 | "format": "video/h264-mp4", 243 | "pix_fmt": "yuv420p", 244 | "crf": 19, 245 | "save_metadata": true, 246 | "pingpong": false, 247 | "save_output": false, 248 | "videopreview": { 249 | "hidden": false, 250 | "paused": false, 251 | "params": { 252 | "filename": "MagicTime_00002.mp4", 253 | "subfolder": "", 254 | "type": "temp", 255 | "format": "video/h264-mp4" 256 | } 257 | } 258 | } 259 | }, 260 | { 261 | "id": 8, 262 | "type": "RIFE VFI", 263 | "pos": [ 264 | 1691, 265 | 888 266 | ], 267 | "size": { 268 | "0": 443.4000244140625, 269 | "1": 198 270 | }, 271 | "flags": {}, 272 | "order": 5, 273 | "mode": 0, 274 | "inputs": [ 275 | { 276 | "name": "frames", 277 | "type": "IMAGE", 278 | "link": 11 279 | }, 280 | { 281 | "name": "optional_interpolation_states", 282 | "type": "INTERPOLATION_STATES", 283 | "link": null 284 | } 285 | ], 286 | "outputs": [ 287 | { 288 | "name": "IMAGE", 289 | "type": "IMAGE", 290 | "links": [ 291 | 12 292 | ], 293 | "shape": 3, 294 | "slot_index": 0 295 | } 296 | ], 297 | "properties": { 298 | "Node name for S&R": "RIFE VFI" 299 | }, 300 | "widgets_values": [ 301 | "rife49.pth", 302 | 10, 303 | 3, 304 | true, 305 | true, 306 | 1 307 | ] 308 | }, 309 | { 310 | "id": 9, 311 | "type": "VHS_VideoCombine", 312 | "pos": [ 313 | 2200, 314 | 75 315 | ], 316 | "size": [ 317 | 459.2300720214844, 318 | 743.2300720214844 319 | ], 320 | "flags": {}, 321 | "order": 6, 322 | "mode": 0, 323 | "inputs": [ 324 | { 325 | "name": "images", 326 | "type": "IMAGE", 327 | "link": 12 328 | }, 329 | { 330 | "name": "audio", 331 | "type": "VHS_AUDIO", 332 | "link": null 333 | }, 334 | { 335 | "name": "batch_manager", 336 | "type": "VHS_BatchManager", 337 | "link": null 338 | } 339 | ], 340 | "outputs": [ 341 | { 342 | "name": "Filenames", 343 | "type": "VHS_FILENAMES", 344 | "links": null, 345 | "shape": 3 346 | } 347 | ], 348 | "properties": { 349 | "Node name for S&R": "VHS_VideoCombine" 350 | }, 351 | "widgets_values": { 352 | "frame_rate": 24, 353 | "loop_count": 0, 354 | "filename_prefix": "MagitTimeInterpolated", 355 | "format": "video/h264-mp4", 356 | "pix_fmt": "yuv420p", 357 | "crf": 19, 358 | "save_metadata": true, 359 | "pingpong": false, 360 | "save_output": false, 361 | "videopreview": { 362 | "hidden": false, 363 | "paused": false, 364 | "params": { 365 | "filename": "MagitTimeInterpolated_00001.mp4", 366 | "subfolder": "", 367 | "type": "temp", 368 | "format": "video/h264-mp4" 369 | } 370 | } 371 | } 372 | } 373 | ], 374 | "links": [ 375 | [ 376 | 1, 377 | 2, 378 | 0, 379 | 1, 380 | 0, 381 | "MAGICTIME" 382 | ], 383 | [ 384 | 3, 385 | 3, 386 | 1, 387 | 2, 388 | 1, 389 | "CLIP" 390 | ], 391 | [ 392 | 4, 393 | 3, 394 | 2, 395 | 2, 396 | 2, 397 | "VAE" 398 | ], 399 | [ 400 | 6, 401 | 5, 402 | 0, 403 | 2, 404 | 3, 405 | "MOTION_MODEL_ADE" 406 | ], 407 | [ 408 | 7, 409 | 1, 410 | 0, 411 | 6, 412 | 0, 413 | "IMAGE" 414 | ], 415 | [ 416 | 10, 417 | 3, 418 | 0, 419 | 2, 420 | 0, 421 | "MODEL" 422 | ], 423 | [ 424 | 11, 425 | 1, 426 | 0, 427 | 8, 428 | 0, 429 | "IMAGE" 430 | ], 431 | [ 432 | 12, 433 | 8, 434 | 0, 435 | 9, 436 | 0, 437 | "IMAGE" 438 | ] 439 | ], 440 | "groups": [], 441 | "config": {}, 442 | "extra": {}, 443 | "version": 0.4 444 | } -------------------------------------------------------------------------------- /examples/magictime_example.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-MagicTimeWrapper/010eef4f4ba1235e2e65dbc6cd95b1f06d760493/examples/magictime_example.mp4 -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import nullcontext 3 | import torch 4 | 5 | try: 6 | from diffusers import ( 7 | DDIMScheduler, 8 | DPMSolverMultistepScheduler, 9 | EulerDiscreteScheduler, 10 | EulerAncestralDiscreteScheduler, 11 | AutoencoderKL, 12 | LCMScheduler, 13 | DDPMScheduler, 14 | DEISMultistepScheduler, 15 | PNDMScheduler 16 | ) 17 | from diffusers.loaders.single_file_utils import ( 18 | convert_ldm_vae_checkpoint, 19 | convert_ldm_unet_checkpoint, 20 | create_text_encoder_from_ldm_clip_checkpoint, 21 | create_vae_diffusers_config, 22 | create_unet_diffusers_config 23 | ) 24 | except: 25 | raise ImportError("Diffusers version too old. Please update to 0.26.0 minimum.") 26 | 27 | from omegaconf import OmegaConf 28 | 29 | from transformers import CLIPTokenizer 30 | from .animatediff.models.unet import UNet3DConditionModel 31 | from .utils.pipeline_magictime import MagicTimePipeline 32 | from .utils.util import load_diffusers_lora_unet 33 | 34 | import comfy.model_management as mm 35 | import comfy.utils 36 | import folder_paths 37 | 38 | script_directory = os.path.dirname(os.path.abspath(__file__)) 39 | 40 | class magictime_model_loader: 41 | @classmethod 42 | def INPUT_TYPES(s): 43 | return {"required": { 44 | "model": ("MODEL",), 45 | "clip": ("CLIP",), 46 | "vae": ("VAE",), 47 | "motion_model":("MOTION_MODEL_ADE",), 48 | }, 49 | } 50 | 51 | RETURN_TYPES = ("MAGICTIME",) 52 | RETURN_NAMES = ("magictime_model",) 53 | FUNCTION = "loadmodel" 54 | CATEGORY = "MagicTimeWrapper" 55 | 56 | def loadmodel(self, model, clip, vae, motion_model): 57 | mm.soft_empty_cache() 58 | custom_config = { 59 | 'model': model, 60 | 'vae': vae, 61 | 'clip': clip, 62 | 'motion_model': motion_model 63 | } 64 | if not hasattr(self, 'model') or self.model == None or custom_config != self.current_config: 65 | pbar = comfy.utils.ProgressBar(7) 66 | self.current_config = custom_config 67 | # config paths 68 | original_config = OmegaConf.load(os.path.join(script_directory, f"configs/v1-inference.yaml")) 69 | ad_unet_config = OmegaConf.load(os.path.join(script_directory, f"configs/ad_unet_config.yaml")) 70 | 71 | # load models 72 | 73 | checkpoint_path = os.path.join(folder_paths.models_dir,'magictime') 74 | magic_adapter_s_path = os.path.join(checkpoint_path, 'Magic_Weights', 'magic_adapter_s', 'magic_adapter_s.ckpt') 75 | magic_adapter_t_path = os.path.join(checkpoint_path, 'Magic_Weights', 'magic_adapter_t') 76 | magic_text_encoder_path = os.path.join(checkpoint_path, 'Magic_Weights', 'magic_text_encoder') 77 | 78 | if not os.path.exists(checkpoint_path): 79 | print(f"Downloading magictime from https://huggingface.co/BestWishYsh/MagicTime to {checkpoint_path}") 80 | from huggingface_hub import snapshot_download 81 | snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir=checkpoint_path, local_dir_use_symlinks=False) 82 | 83 | pbar.update(1) 84 | 85 | # get state dict from comfy models 86 | clip_sd = None 87 | load_models = [model] 88 | load_models.append(clip.load_model()) 89 | clip_sd = clip.get_sd() 90 | 91 | comfy.model_management.load_models_gpu(load_models) 92 | sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), None) 93 | 94 | pbar.update(1) 95 | 96 | # 1. vae 97 | converted_vae_config = create_vae_diffusers_config(original_config, image_size=512) 98 | converted_vae = convert_ldm_vae_checkpoint(sd, converted_vae_config) 99 | vae = AutoencoderKL(**converted_vae_config) 100 | vae.load_state_dict(converted_vae, strict=False) 101 | pbar.update(1) 102 | 103 | # 2. unet 104 | converted_unet_config = create_unet_diffusers_config(original_config, image_size=512) 105 | converted_unet = convert_ldm_unet_checkpoint(sd, converted_unet_config) 106 | pbar.update(1) 107 | 108 | # motion module 109 | motion_module_state_dict = motion_model.model.state_dict() 110 | if motion_model.model.mm_info.mm_format == "AnimateLCM": 111 | motion_module_state_dict = {k: v for k, v in motion_module_state_dict.items() if "pos_encoder" not in k} 112 | converted_unet.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 113 | converted_unet.pop("animatediff_config", "") 114 | pbar.update(1) 115 | 116 | unet = UNet3DConditionModel(**ad_unet_config) 117 | unet.load_state_dict(converted_unet, strict=False) 118 | 119 | pbar.update(1) 120 | # 3. text_encoder 121 | text_encoder = create_text_encoder_from_ldm_clip_checkpoint("openai/clip-vit-large-patch14",sd) 122 | 123 | # 4. tokenizer 124 | tokenizer_path = os.path.join(script_directory, "configs/tokenizer") 125 | tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) 126 | 127 | # 5. scheduler 128 | scheduler_config = { 129 | 'num_train_timesteps': 1000, 130 | 'beta_start': 0.00085, 131 | 'beta_end': 0.012, 132 | 'beta_schedule': "linear", 133 | 'steps_offset': 1 134 | } 135 | scheduler=DPMSolverMultistepScheduler(**scheduler_config) 136 | 137 | #6. magictime 138 | from swift import Swift 139 | magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") 140 | unet = load_diffusers_lora_unet(unet, magic_adapter_s_state_dict, alpha=1.0) 141 | unet = Swift.from_pretrained(unet, magic_adapter_t_path) 142 | text_encoder = Swift.from_pretrained(text_encoder, magic_text_encoder_path) 143 | del sd 144 | 145 | pbar.update(1) 146 | 147 | self.pipe = MagicTimePipeline( 148 | vae=vae, 149 | text_encoder=text_encoder, 150 | tokenizer=tokenizer, 151 | unet=unet, 152 | scheduler=scheduler 153 | ) 154 | 155 | magictime_model = { 156 | 'pipe': self.pipe, 157 | } 158 | 159 | return (magictime_model,) 160 | 161 | class magictime_sampler: 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return {"required": { 165 | "magictime_model": ("MAGICTIME",), 166 | "prompt": ("STRING", {"multiline": True, "default": "positive",}), 167 | "n_prompt": ("STRING", {"multiline": True, "default": "negative",}), 168 | "frames": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}), 169 | "width": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 170 | "height": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 64}), 171 | "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), 172 | "guidance_scale": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 20.0, "step": 0.01}), 173 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 174 | "scheduler": ( 175 | [ 176 | 'DPMSolverMultistepScheduler', 177 | 'DPMSolverMultistepScheduler_SDE_karras', 178 | 'DDPMScheduler', 179 | 'DDIMScheduler', 180 | 'LCMScheduler', 181 | 'PNDMScheduler', 182 | 'DEISMultistepScheduler', 183 | 'EulerDiscreteScheduler', 184 | 'EulerAncestralDiscreteScheduler' 185 | ], { 186 | "default": 'DPMSolverMultistepScheduler' 187 | }), 188 | }, 189 | } 190 | 191 | RETURN_TYPES = ("IMAGE",) 192 | RETURN_NAMES = ("images",) 193 | FUNCTION = "process" 194 | CATEGORY = "MagicTimeWrapper" 195 | 196 | def process(self, magictime_model, prompt, n_prompt, frames, width, height, steps, guidance_scale, seed, scheduler): 197 | device = mm.get_torch_device() 198 | mm.unload_all_models() 199 | mm.soft_empty_cache() 200 | dtype = mm.unet_dtype() 201 | vae_dtype = mm.vae_dtype() 202 | device = mm.get_torch_device() 203 | offload_device = mm.unet_offload_device() 204 | 205 | pipe=magictime_model['pipe'] 206 | pipe.to(device, dtype=dtype) 207 | 208 | scheduler_config = { 209 | 'num_train_timesteps': 1000, 210 | 'beta_start': 0.00085, 211 | 'beta_end': 0.012, 212 | 'beta_schedule': "linear", 213 | 'steps_offset': 1, 214 | } 215 | if scheduler == 'DPMSolverMultistepScheduler': 216 | noise_scheduler = DPMSolverMultistepScheduler(**scheduler_config) 217 | elif scheduler == 'DDIMScheduler': 218 | noise_scheduler = DDIMScheduler(**scheduler_config) 219 | elif scheduler == 'DPMSolverMultistepScheduler_SDE_karras': 220 | scheduler_config.update({"algorithm_type": "sde-dpmsolver++"}) 221 | scheduler_config.update({"use_karras_sigmas": True}) 222 | noise_scheduler = DPMSolverMultistepScheduler(**scheduler_config) 223 | elif scheduler == 'DDPMScheduler': 224 | noise_scheduler = DDPMScheduler(**scheduler_config) 225 | elif scheduler == 'LCMScheduler': 226 | noise_scheduler = LCMScheduler(**scheduler_config) 227 | elif scheduler == 'PNDMScheduler': 228 | scheduler_config.update({"set_alpha_to_one": False}) 229 | scheduler_config.update({"trained_betas": None}) 230 | noise_scheduler = PNDMScheduler(**scheduler_config) 231 | elif scheduler == 'DEISMultistepScheduler': 232 | noise_scheduler = DEISMultistepScheduler(**scheduler_config) 233 | elif scheduler == 'EulerDiscreteScheduler': 234 | noise_scheduler = EulerDiscreteScheduler(**scheduler_config) 235 | elif scheduler == 'EulerAncestralDiscreteScheduler': 236 | noise_scheduler = EulerAncestralDiscreteScheduler(**scheduler_config) 237 | pipe.scheduler = noise_scheduler 238 | 239 | autocast_condition = (dtype != torch.float32) and not mm.is_device_mps(device) 240 | with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): 241 | 242 | generator = torch.Generator(device=device) 243 | generator.manual_seed(seed) 244 | 245 | sample = pipe( 246 | prompt, 247 | negative_prompt = n_prompt, 248 | num_inference_steps = steps, 249 | guidance_scale = guidance_scale, 250 | width = width, 251 | height = height, 252 | video_length = frames, 253 | generator = generator, 254 | ).videos 255 | pipe.to(offload_device) 256 | image_out = sample.squeeze(0).permute(1, 2, 3, 0).cpu().float() 257 | return (image_out,) 258 | 259 | 260 | NODE_CLASS_MAPPINGS = { 261 | "magictime_model_loader": magictime_model_loader, 262 | "magictime_sampler": magictime_sampler, 263 | } 264 | NODE_DISPLAY_NAME_MAPPINGS = { 265 | "magic_time_model_loader": "MagicTime Model Loader", 266 | "magictime_sampler": "MagicTime Sampler", 267 | } 268 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | omegaconf 3 | ms-swift 4 | accelerate>=0.28.0 5 | diffusers>=0.26.0 6 | transformers>=4.38.2 7 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os, csv, random 2 | import numpy as np 3 | from decord import VideoReader 4 | import torch 5 | import torchvision.transforms as transforms 6 | from torch.utils.data.dataset import Dataset 7 | 8 | 9 | class ChronoMagic(Dataset): 10 | def __init__( 11 | self, 12 | csv_path, video_folder, 13 | sample_size=512, sample_stride=4, sample_n_frames=16, 14 | is_image=False, 15 | is_uniform=True, 16 | ): 17 | with open(csv_path, 'r') as csvfile: 18 | self.dataset = list(csv.DictReader(csvfile)) 19 | self.length = len(self.dataset) 20 | 21 | self.video_folder = video_folder 22 | self.sample_stride = sample_stride 23 | self.sample_n_frames = sample_n_frames 24 | self.is_image = is_image 25 | self.is_uniform = is_uniform 26 | 27 | sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) 28 | self.pixel_transforms = transforms.Compose([ 29 | transforms.RandomHorizontalFlip(), 30 | transforms.Resize(sample_size[0], interpolation=transforms.InterpolationMode.BICUBIC), 31 | transforms.CenterCrop(sample_size), 32 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 33 | ]) 34 | 35 | def _get_frame_indices_adjusted(self, video_length, n_frames): 36 | indices = list(range(video_length)) 37 | additional_frames_needed = n_frames - video_length 38 | 39 | repeat_indices = [] 40 | for i in range(additional_frames_needed): 41 | index_to_repeat = i % video_length 42 | repeat_indices.append(indices[index_to_repeat]) 43 | 44 | all_indices = indices + repeat_indices 45 | all_indices.sort() 46 | 47 | return all_indices 48 | 49 | def _generate_frame_indices(self, video_length, n_frames, sample_stride, is_transmit): 50 | prob_execute_original = 1 if int(is_transmit) == 0 else 0 51 | 52 | # Generate a random number to decide which block of code to execute 53 | if random.random() < prob_execute_original: 54 | if video_length <= n_frames: 55 | return self._get_frame_indices_adjusted(video_length, n_frames) 56 | else: 57 | interval = (video_length - 1) / (n_frames - 1) 58 | indices = [int(round(i * interval)) for i in range(n_frames)] 59 | indices[-1] = video_length - 1 60 | return indices 61 | else: 62 | if video_length <= n_frames: 63 | return self._get_frame_indices_adjusted(video_length, n_frames) 64 | else: 65 | clip_length = min(video_length, (n_frames - 1) * sample_stride + 1) 66 | start_idx = random.randint(0, video_length - clip_length) 67 | return np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist() 68 | 69 | def get_batch(self, idx): 70 | video_dict = self.dataset[idx] 71 | videoid, name, is_transmit = video_dict['videoid'], video_dict['name'], video_dict['is_transmit'] 72 | 73 | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") 74 | video_reader = VideoReader(video_dir, num_threads=0) 75 | video_length = len(video_reader) 76 | 77 | batch_index = self._generate_frame_indices(video_length, self.sample_n_frames, self.sample_stride, is_transmit) if not self.is_image else [random.randint(0, video_length - 1)] 78 | 79 | pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2) / 255. 80 | del video_reader 81 | 82 | if self.is_image: 83 | pixel_values = pixel_values[0] 84 | 85 | return pixel_values, name, videoid 86 | 87 | def __len__(self): 88 | return self.length 89 | 90 | def __getitem__(self, idx): 91 | while True: 92 | try: 93 | pixel_values, name, videoid = self.get_batch(idx) 94 | break 95 | 96 | except Exception as e: 97 | idx = random.randint(0, self.length-1) 98 | 99 | pixel_values = self.pixel_transforms(pixel_values) 100 | sample = dict(pixel_values=pixel_values, text=name, id=videoid) 101 | return sample -------------------------------------------------------------------------------- /utils/pipeline_magictime.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/pipelines/pipeline_animation.py 2 | 3 | import torch 4 | import inspect 5 | import numpy as np 6 | from tqdm import tqdm 7 | from einops import rearrange 8 | from packaging import version 9 | from dataclasses import dataclass 10 | from typing import Callable, List, Optional, Union 11 | from transformers import CLIPTextModel, CLIPTokenizer 12 | 13 | from diffusers.utils import is_accelerate_available, deprecate, logging, BaseOutput 14 | from diffusers.configuration_utils import FrozenDict 15 | from diffusers.models import AutoencoderKL 16 | from diffusers import DiffusionPipeline 17 | from diffusers.schedulers import ( 18 | DDIMScheduler, 19 | DPMSolverMultistepScheduler, 20 | EulerAncestralDiscreteScheduler, 21 | EulerDiscreteScheduler, 22 | LMSDiscreteScheduler, 23 | PNDMScheduler, 24 | ) 25 | 26 | from .unet import UNet3DConditionModel 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | @dataclass 31 | class MagicTimePipelineOutput(BaseOutput): 32 | videos: Union[torch.Tensor, np.ndarray] 33 | 34 | class MagicTimePipeline(DiffusionPipeline): 35 | _optional_components = [] 36 | 37 | def __init__( 38 | self, 39 | vae: AutoencoderKL, 40 | text_encoder: CLIPTextModel, 41 | tokenizer: CLIPTokenizer, 42 | unet: UNet3DConditionModel, 43 | scheduler: Union[ 44 | DDIMScheduler, 45 | PNDMScheduler, 46 | LMSDiscreteScheduler, 47 | EulerDiscreteScheduler, 48 | EulerAncestralDiscreteScheduler, 49 | DPMSolverMultistepScheduler, 50 | ], 51 | ): 52 | super().__init__() 53 | 54 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 55 | deprecation_message = ( 56 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 57 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 58 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 59 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 60 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 61 | " file" 62 | ) 63 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 64 | new_config = dict(scheduler.config) 65 | new_config["steps_offset"] = 1 66 | scheduler._internal_dict = FrozenDict(new_config) 67 | 68 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 69 | deprecation_message = ( 70 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 71 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 72 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 73 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 74 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 75 | ) 76 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 77 | new_config = dict(scheduler.config) 78 | new_config["clip_sample"] = False 79 | scheduler._internal_dict = FrozenDict(new_config) 80 | 81 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 82 | version.parse(unet.config._diffusers_version).base_version 83 | ) < version.parse("0.9.0.dev0") 84 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 85 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 86 | deprecation_message = ( 87 | "The configuration file of the unet has set the default `sample_size` to smaller than" 88 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 89 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 90 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 91 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 92 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 93 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 94 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 95 | " the `unet/config.json` file" 96 | ) 97 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 98 | new_config = dict(unet.config) 99 | new_config["sample_size"] = 64 100 | unet._internal_dict = FrozenDict(new_config) 101 | 102 | self.register_modules( 103 | vae=vae, 104 | text_encoder=text_encoder, 105 | tokenizer=tokenizer, 106 | unet=unet, 107 | scheduler=scheduler, 108 | ) 109 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 110 | 111 | def enable_vae_slicing(self): 112 | self.vae.enable_slicing() 113 | 114 | def disable_vae_slicing(self): 115 | self.vae.disable_slicing() 116 | 117 | def enable_sequential_cpu_offload(self, gpu_id=0): 118 | if is_accelerate_available(): 119 | from accelerate import cpu_offload 120 | else: 121 | raise ImportError("Please install accelerate via `pip install accelerate`") 122 | 123 | device = torch.device(f"cuda:{gpu_id}") 124 | 125 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 126 | if cpu_offloaded_model is not None: 127 | cpu_offload(cpu_offloaded_model, device) 128 | 129 | 130 | @property 131 | def _execution_device(self): 132 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 133 | return self.device 134 | for module in self.unet.modules(): 135 | if ( 136 | hasattr(module, "_hf_hook") 137 | and hasattr(module._hf_hook, "execution_device") 138 | and module._hf_hook.execution_device is not None 139 | ): 140 | return torch.device(module._hf_hook.execution_device) 141 | return self.device 142 | 143 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 144 | batch_size = len(prompt) if isinstance(prompt, list) else 1 145 | 146 | text_inputs = self.tokenizer( 147 | prompt, 148 | padding="max_length", 149 | max_length=self.tokenizer.model_max_length, 150 | truncation=True, 151 | return_tensors="pt", 152 | ) 153 | text_input_ids = text_inputs.input_ids 154 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 155 | 156 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 157 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) 158 | logger.warning( 159 | "The following part of your input was truncated because CLIP can only handle sequences up to" 160 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 161 | ) 162 | 163 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 164 | attention_mask = text_inputs.attention_mask.to(device) 165 | else: 166 | attention_mask = None 167 | 168 | text_embeddings = self.text_encoder( 169 | text_input_ids.to(device), 170 | attention_mask=attention_mask, 171 | ) 172 | text_embeddings = text_embeddings[0] 173 | 174 | # duplicate text embeddings for each generation per prompt, using mps friendly method 175 | bs_embed, seq_len, _ = text_embeddings.shape 176 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 177 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 178 | 179 | # get unconditional embeddings for classifier free guidance 180 | if do_classifier_free_guidance: 181 | uncond_tokens: List[str] 182 | if negative_prompt is None: 183 | uncond_tokens = [""] * batch_size 184 | elif type(prompt) is not type(negative_prompt): 185 | raise TypeError( 186 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 187 | f" {type(prompt)}." 188 | ) 189 | elif isinstance(negative_prompt, str): 190 | uncond_tokens = [negative_prompt] 191 | elif batch_size != len(negative_prompt): 192 | raise ValueError( 193 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 194 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 195 | " the batch size of `prompt`." 196 | ) 197 | else: 198 | uncond_tokens = negative_prompt 199 | 200 | max_length = text_input_ids.shape[-1] 201 | uncond_input = self.tokenizer( 202 | uncond_tokens, 203 | padding="max_length", 204 | max_length=max_length, 205 | truncation=True, 206 | return_tensors="pt", 207 | ) 208 | 209 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 210 | attention_mask = uncond_input.attention_mask.to(device) 211 | else: 212 | attention_mask = None 213 | 214 | uncond_embeddings = self.text_encoder( 215 | uncond_input.input_ids.to(device), 216 | attention_mask=attention_mask, 217 | ) 218 | uncond_embeddings = uncond_embeddings[0] 219 | 220 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 221 | seq_len = uncond_embeddings.shape[1] 222 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 223 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 224 | 225 | # For classifier free guidance, we need to do two forward passes. 226 | # Here we concatenate the unconditional and text embeddings into a single batch 227 | # to avoid doing two forward passes 228 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 229 | 230 | return text_embeddings 231 | 232 | def decode_latents(self, latents): 233 | video_length = latents.shape[2] 234 | latents = 1 / 0.18215 * latents 235 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 236 | # video = self.vae.decode(latents).sample 237 | video = [] 238 | for frame_idx in tqdm(range(latents.shape[0])): 239 | video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) 240 | video = torch.cat(video) 241 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 242 | video = (video / 2 + 0.5).clamp(0, 1) 243 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 244 | video = video.cpu().float().numpy() 245 | return video 246 | 247 | def prepare_extra_step_kwargs(self, generator, eta): 248 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 249 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 250 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 251 | # and should be between [0, 1] 252 | 253 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 254 | extra_step_kwargs = {} 255 | if accepts_eta: 256 | extra_step_kwargs["eta"] = eta 257 | 258 | # check if the scheduler accepts generator 259 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 260 | if accepts_generator: 261 | extra_step_kwargs["generator"] = generator 262 | return extra_step_kwargs 263 | 264 | def check_inputs(self, prompt, height, width, callback_steps): 265 | if not isinstance(prompt, str) and not isinstance(prompt, list): 266 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 267 | 268 | if height % 8 != 0 or width % 8 != 0: 269 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 270 | 271 | if (callback_steps is None) or ( 272 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 273 | ): 274 | raise ValueError( 275 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 276 | f" {type(callback_steps)}." 277 | ) 278 | 279 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): 280 | shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) 281 | if isinstance(generator, list) and len(generator) != batch_size: 282 | raise ValueError( 283 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 284 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 285 | ) 286 | if latents is None: 287 | rand_device = "cpu" if device.type == "mps" else device 288 | 289 | if isinstance(generator, list): 290 | shape = shape 291 | # shape = (1,) + shape[1:] 292 | latents = [ 293 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 294 | for i in range(batch_size) 295 | ] 296 | latents = torch.cat(latents, dim=0).to(device) 297 | else: 298 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 299 | else: 300 | if latents.shape != shape: 301 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 302 | latents = latents.to(device) 303 | 304 | # scale the initial noise by the standard deviation required by the scheduler 305 | latents = latents * self.scheduler.init_noise_sigma 306 | return latents 307 | 308 | @torch.no_grad() 309 | def __call__( 310 | self, 311 | prompt: Union[str, List[str]], 312 | video_length: Optional[int], 313 | height: Optional[int] = None, 314 | width: Optional[int] = None, 315 | num_inference_steps: int = 50, 316 | guidance_scale: float = 7.5, 317 | negative_prompt: Optional[Union[str, List[str]]] = None, 318 | num_videos_per_prompt: Optional[int] = 1, 319 | eta: float = 0.0, 320 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 321 | latents: Optional[torch.FloatTensor] = None, 322 | output_type: Optional[str] = "tensor", 323 | return_dict: bool = True, 324 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 325 | callback_steps: Optional[int] = 1, 326 | **kwargs, 327 | ): 328 | # Default height and width to unet 329 | height = height or self.unet.config.sample_size * self.vae_scale_factor 330 | width = width or self.unet.config.sample_size * self.vae_scale_factor 331 | 332 | # Check inputs. Raise error if not correct 333 | self.check_inputs(prompt, height, width, callback_steps) 334 | 335 | # Define call parameters 336 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 337 | batch_size = 1 338 | if latents is not None: 339 | batch_size = latents.shape[0] 340 | if isinstance(prompt, list): 341 | batch_size = len(prompt) 342 | 343 | device = self._execution_device 344 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 345 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 346 | # corresponds to doing no classifier free guidance. 347 | do_classifier_free_guidance = guidance_scale > 1.0 348 | 349 | # Encode input prompt 350 | prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size 351 | if negative_prompt is not None: 352 | negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size 353 | text_embeddings = self._encode_prompt( 354 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 355 | ) 356 | 357 | # Prepare timesteps 358 | self.scheduler.set_timesteps(num_inference_steps, device=device) 359 | timesteps = self.scheduler.timesteps 360 | 361 | # Prepare latent variables 362 | num_channels_latents = self.unet.in_channels 363 | latents = self.prepare_latents( 364 | batch_size * num_videos_per_prompt, 365 | num_channels_latents, 366 | video_length, 367 | height, 368 | width, 369 | text_embeddings.dtype, 370 | device, 371 | generator, 372 | latents, 373 | ) 374 | latents_dtype = latents.dtype 375 | 376 | # Prepare extra step kwargs. 377 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 378 | 379 | # Denoising loop 380 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 381 | from comfy.utils import ProgressBar 382 | comfy_pbar = ProgressBar(num_inference_steps) 383 | with self.progress_bar(total=num_inference_steps) as progress_bar: 384 | for i, t in enumerate(timesteps): 385 | # expand the latents if we are doing classifier free guidance 386 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 387 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 388 | 389 | down_block_additional_residuals = mid_block_additional_residual = None 390 | 391 | # predict the noise residual 392 | noise_pred = self.unet( 393 | latent_model_input, t, 394 | encoder_hidden_states=text_embeddings, 395 | down_block_additional_residuals = down_block_additional_residuals, 396 | mid_block_additional_residual = mid_block_additional_residual, 397 | ).sample.to(dtype=latents_dtype) 398 | 399 | # perform guidance 400 | if do_classifier_free_guidance: 401 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 402 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 403 | 404 | # compute the previous noisy sample x_t -> x_t-1 405 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 406 | 407 | # call the callback, if provided 408 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 409 | progress_bar.update() 410 | comfy_pbar.update(1) 411 | if callback is not None and i % callback_steps == 0: 412 | callback(i, t, latents) 413 | 414 | # Post-processing 415 | video = self.decode_latents(latents) 416 | 417 | # Convert to tensor 418 | if output_type == "tensor": 419 | video = torch.from_numpy(video) 420 | 421 | if not return_dict: 422 | return video 423 | 424 | return MagicTimePipelineOutput(videos=video) 425 | -------------------------------------------------------------------------------- /utils/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/guoyww/AnimateDiff/animatediff/models/unet.py 2 | import os 3 | import json 4 | import pdb 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint 11 | 12 | from diffusers.configuration_utils import ConfigMixin, register_to_config 13 | from diffusers import ModelMixin 14 | from diffusers.utils import BaseOutput, logging 15 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 16 | from .unet_blocks import ( 17 | CrossAttnDownBlock3D, 18 | CrossAttnUpBlock3D, 19 | DownBlock3D, 20 | UNetMidBlock3DCrossAttn, 21 | UpBlock3D, 22 | get_down_block, 23 | get_up_block, 24 | InflatedConv3d, 25 | InflatedGroupNorm, 26 | ) 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | class UNet3DConditionOutput(BaseOutput): 33 | sample: torch.FloatTensor 34 | 35 | 36 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 37 | _supports_gradient_checkpointing = True 38 | 39 | @register_to_config 40 | def __init__( 41 | self, 42 | sample_size: Optional[int] = None, 43 | in_channels: int = 4, 44 | out_channels: int = 4, 45 | center_input_sample: bool = False, 46 | flip_sin_to_cos: bool = True, 47 | freq_shift: int = 0, 48 | down_block_types: Tuple[str] = ( 49 | "CrossAttnDownBlock3D", 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "DownBlock3D", 53 | ), 54 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 55 | up_block_types: Tuple[str] = ( 56 | "UpBlock3D", 57 | "CrossAttnUpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D" 60 | ), 61 | only_cross_attention: Union[bool, Tuple[bool]] = False, 62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 63 | layers_per_block: int = 2, 64 | downsample_padding: int = 1, 65 | mid_block_scale_factor: float = 1, 66 | act_fn: str = "silu", 67 | norm_num_groups: int = 32, 68 | norm_eps: float = 1e-5, 69 | cross_attention_dim: int = 1280, 70 | attention_head_dim: Union[int, Tuple[int]] = 8, 71 | dual_cross_attention: bool = False, 72 | use_linear_projection: bool = False, 73 | class_embed_type: Optional[str] = None, 74 | num_class_embeds: Optional[int] = None, 75 | upcast_attention: bool = False, 76 | resnet_time_scale_shift: str = "default", 77 | 78 | use_inflated_groupnorm=False, 79 | 80 | # Additional 81 | use_motion_module = False, 82 | motion_module_resolutions = ( 1,2,4,8 ), 83 | motion_module_mid_block = False, 84 | motion_module_decoder_only = False, 85 | motion_module_type = None, 86 | motion_module_kwargs = {}, 87 | unet_use_cross_frame_attention = False, 88 | unet_use_temporal_attention = False, 89 | ): 90 | super().__init__() 91 | 92 | self.sample_size = sample_size 93 | time_embed_dim = block_out_channels[0] * 4 94 | 95 | # input 96 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 97 | 98 | # time 99 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 100 | timestep_input_dim = block_out_channels[0] 101 | 102 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 103 | 104 | # class embedding 105 | if class_embed_type is None and num_class_embeds is not None: 106 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 107 | elif class_embed_type == "timestep": 108 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 109 | elif class_embed_type == "identity": 110 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 111 | else: 112 | self.class_embedding = None 113 | 114 | self.down_blocks = nn.ModuleList([]) 115 | self.mid_block = None 116 | self.up_blocks = nn.ModuleList([]) 117 | 118 | if isinstance(only_cross_attention, bool): 119 | only_cross_attention = [only_cross_attention] * len(down_block_types) 120 | 121 | if isinstance(attention_head_dim, int): 122 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 123 | 124 | # down 125 | output_channel = block_out_channels[0] 126 | for i, down_block_type in enumerate(down_block_types): 127 | res = 2 ** i 128 | input_channel = output_channel 129 | output_channel = block_out_channels[i] 130 | is_final_block = i == len(block_out_channels) - 1 131 | 132 | down_block = get_down_block( 133 | down_block_type, 134 | num_layers=layers_per_block, 135 | in_channels=input_channel, 136 | out_channels=output_channel, 137 | temb_channels=time_embed_dim, 138 | add_downsample=not is_final_block, 139 | resnet_eps=norm_eps, 140 | resnet_act_fn=act_fn, 141 | resnet_groups=norm_num_groups, 142 | cross_attention_dim=cross_attention_dim, 143 | attn_num_head_channels=attention_head_dim[i], 144 | downsample_padding=downsample_padding, 145 | dual_cross_attention=dual_cross_attention, 146 | use_linear_projection=use_linear_projection, 147 | only_cross_attention=only_cross_attention[i], 148 | upcast_attention=upcast_attention, 149 | resnet_time_scale_shift=resnet_time_scale_shift, 150 | 151 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 152 | unet_use_temporal_attention=unet_use_temporal_attention, 153 | use_inflated_groupnorm=use_inflated_groupnorm, 154 | 155 | use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), 156 | motion_module_type=motion_module_type, 157 | motion_module_kwargs=motion_module_kwargs, 158 | ) 159 | self.down_blocks.append(down_block) 160 | 161 | # mid 162 | if mid_block_type == "UNetMidBlock3DCrossAttn": 163 | self.mid_block = UNetMidBlock3DCrossAttn( 164 | in_channels=block_out_channels[-1], 165 | temb_channels=time_embed_dim, 166 | resnet_eps=norm_eps, 167 | resnet_act_fn=act_fn, 168 | output_scale_factor=mid_block_scale_factor, 169 | resnet_time_scale_shift=resnet_time_scale_shift, 170 | cross_attention_dim=cross_attention_dim, 171 | attn_num_head_channels=attention_head_dim[-1], 172 | resnet_groups=norm_num_groups, 173 | dual_cross_attention=dual_cross_attention, 174 | use_linear_projection=use_linear_projection, 175 | upcast_attention=upcast_attention, 176 | 177 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 178 | unet_use_temporal_attention=unet_use_temporal_attention, 179 | use_inflated_groupnorm=use_inflated_groupnorm, 180 | 181 | use_motion_module=use_motion_module and motion_module_mid_block, 182 | motion_module_type=motion_module_type, 183 | motion_module_kwargs=motion_module_kwargs, 184 | ) 185 | else: 186 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 187 | 188 | # count how many layers upsample the videos 189 | self.num_upsamplers = 0 190 | 191 | # up 192 | reversed_block_out_channels = list(reversed(block_out_channels)) 193 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 194 | only_cross_attention = list(reversed(only_cross_attention)) 195 | output_channel = reversed_block_out_channels[0] 196 | for i, up_block_type in enumerate(up_block_types): 197 | res = 2 ** (3 - i) 198 | is_final_block = i == len(block_out_channels) - 1 199 | 200 | prev_output_channel = output_channel 201 | output_channel = reversed_block_out_channels[i] 202 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 203 | 204 | # add upsample block for all BUT final layer 205 | if not is_final_block: 206 | add_upsample = True 207 | self.num_upsamplers += 1 208 | else: 209 | add_upsample = False 210 | 211 | up_block = get_up_block( 212 | up_block_type, 213 | num_layers=layers_per_block + 1, 214 | in_channels=input_channel, 215 | out_channels=output_channel, 216 | prev_output_channel=prev_output_channel, 217 | temb_channels=time_embed_dim, 218 | add_upsample=add_upsample, 219 | resnet_eps=norm_eps, 220 | resnet_act_fn=act_fn, 221 | resnet_groups=norm_num_groups, 222 | cross_attention_dim=cross_attention_dim, 223 | attn_num_head_channels=reversed_attention_head_dim[i], 224 | dual_cross_attention=dual_cross_attention, 225 | use_linear_projection=use_linear_projection, 226 | only_cross_attention=only_cross_attention[i], 227 | upcast_attention=upcast_attention, 228 | resnet_time_scale_shift=resnet_time_scale_shift, 229 | 230 | unet_use_cross_frame_attention=unet_use_cross_frame_attention, 231 | unet_use_temporal_attention=unet_use_temporal_attention, 232 | use_inflated_groupnorm=use_inflated_groupnorm, 233 | 234 | use_motion_module=use_motion_module and (res in motion_module_resolutions), 235 | motion_module_type=motion_module_type, 236 | motion_module_kwargs=motion_module_kwargs, 237 | ) 238 | self.up_blocks.append(up_block) 239 | prev_output_channel = output_channel 240 | 241 | # out 242 | if use_inflated_groupnorm: 243 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 244 | else: 245 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 246 | self.conv_act = nn.SiLU() 247 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 248 | 249 | def set_attention_slice(self, slice_size): 250 | r""" 251 | Enable sliced attention computation. 252 | 253 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 254 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 255 | 256 | Args: 257 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 258 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 259 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 260 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 261 | must be a multiple of `slice_size`. 262 | """ 263 | sliceable_head_dims = [] 264 | 265 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 266 | if hasattr(module, "set_attention_slice"): 267 | sliceable_head_dims.append(module.sliceable_head_dim) 268 | 269 | for child in module.children(): 270 | fn_recursive_retrieve_slicable_dims(child) 271 | 272 | # retrieve number of attention layers 273 | for module in self.children(): 274 | fn_recursive_retrieve_slicable_dims(module) 275 | 276 | num_slicable_layers = len(sliceable_head_dims) 277 | 278 | if slice_size == "auto": 279 | # half the attention head size is usually a good trade-off between 280 | # speed and memory 281 | slice_size = [dim // 2 for dim in sliceable_head_dims] 282 | elif slice_size == "max": 283 | # make smallest slice possible 284 | slice_size = num_slicable_layers * [1] 285 | 286 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 287 | 288 | if len(slice_size) != len(sliceable_head_dims): 289 | raise ValueError( 290 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 291 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 292 | ) 293 | 294 | for i in range(len(slice_size)): 295 | size = slice_size[i] 296 | dim = sliceable_head_dims[i] 297 | if size is not None and size > dim: 298 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 299 | 300 | # Recursively walk through all the children. 301 | # Any children which exposes the set_attention_slice method 302 | # gets the message 303 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 304 | if hasattr(module, "set_attention_slice"): 305 | module.set_attention_slice(slice_size.pop()) 306 | 307 | for child in module.children(): 308 | fn_recursive_set_attention_slice(child, slice_size) 309 | 310 | reversed_slice_size = list(reversed(slice_size)) 311 | for module in self.children(): 312 | fn_recursive_set_attention_slice(module, reversed_slice_size) 313 | 314 | def _set_gradient_checkpointing(self, module, value=False): 315 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 316 | module.gradient_checkpointing = value 317 | 318 | def forward( 319 | self, 320 | sample: torch.FloatTensor, 321 | timestep: Union[torch.Tensor, float, int], 322 | encoder_hidden_states: torch.Tensor, 323 | class_labels: Optional[torch.Tensor] = None, 324 | attention_mask: Optional[torch.Tensor] = None, 325 | 326 | # support controlnet 327 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 328 | mid_block_additional_residual: Optional[torch.Tensor] = None, 329 | 330 | return_dict: bool = True, 331 | ) -> Union[UNet3DConditionOutput, Tuple]: 332 | r""" 333 | Args: 334 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 335 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 336 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 337 | return_dict (`bool`, *optional*, defaults to `True`): 338 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 339 | 340 | Returns: 341 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 342 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 343 | returning a tuple, the first element is the sample tensor. 344 | """ 345 | # By default samples have to be AT least a multiple of the overall upsampling factor. 346 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 347 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 348 | # on the fly if necessary. 349 | default_overall_up_factor = 2**self.num_upsamplers 350 | 351 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 352 | forward_upsample_size = False 353 | upsample_size = None 354 | 355 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 356 | logger.info("Forward upsample size to force interpolation output size.") 357 | forward_upsample_size = True 358 | 359 | # prepare attention_mask 360 | if attention_mask is not None: 361 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 362 | attention_mask = attention_mask.unsqueeze(1) 363 | 364 | # center input if necessary 365 | if self.config.center_input_sample: 366 | sample = 2 * sample - 1.0 367 | 368 | # time 369 | timesteps = timestep 370 | if not torch.is_tensor(timesteps): 371 | # This would be a good case for the `match` statement (Python 3.10+) 372 | is_mps = sample.device.type == "mps" 373 | if isinstance(timestep, float): 374 | dtype = torch.float32 if is_mps else torch.float64 375 | else: 376 | dtype = torch.int32 if is_mps else torch.int64 377 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 378 | elif len(timesteps.shape) == 0: 379 | timesteps = timesteps[None].to(sample.device) 380 | 381 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 382 | timesteps = timesteps.expand(sample.shape[0]) 383 | 384 | t_emb = self.time_proj(timesteps) 385 | 386 | # timesteps does not contain any weights and will always return f32 tensors 387 | # but time_embedding might actually be running in fp16. so we need to cast here. 388 | # there might be better ways to encapsulate this. 389 | t_emb = t_emb.to(dtype=self.dtype) 390 | emb = self.time_embedding(t_emb) 391 | 392 | if self.class_embedding is not None: 393 | if class_labels is None: 394 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 395 | 396 | if self.config.class_embed_type == "timestep": 397 | class_labels = self.time_proj(class_labels) 398 | 399 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 400 | emb = emb + class_emb 401 | 402 | # pre-process 403 | sample = self.conv_in(sample) 404 | 405 | # down 406 | down_block_res_samples = (sample,) 407 | for downsample_block in self.down_blocks: 408 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 409 | sample, res_samples = downsample_block( 410 | hidden_states=sample, 411 | temb=emb, 412 | encoder_hidden_states=encoder_hidden_states, 413 | attention_mask=attention_mask, 414 | ) 415 | else: 416 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) 417 | 418 | down_block_res_samples += res_samples 419 | 420 | # support controlnet 421 | down_block_res_samples = list(down_block_res_samples) 422 | if down_block_additional_residuals is not None: 423 | for i, down_block_additional_residual in enumerate(down_block_additional_residuals): 424 | if down_block_additional_residual.dim() == 4: # boardcast 425 | down_block_additional_residual = down_block_additional_residual.unsqueeze(2) 426 | down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual 427 | 428 | # mid 429 | sample = self.mid_block( 430 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 431 | ) 432 | 433 | # support controlnet 434 | if mid_block_additional_residual is not None: 435 | if mid_block_additional_residual.dim() == 4: # boardcast 436 | mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2) 437 | sample = sample + mid_block_additional_residual 438 | 439 | # up 440 | for i, upsample_block in enumerate(self.up_blocks): 441 | is_final_block = i == len(self.up_blocks) - 1 442 | 443 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 444 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 445 | 446 | # if we have not reached the final block and need to forward the 447 | # upsample size, we do it here 448 | if not is_final_block and forward_upsample_size: 449 | upsample_size = down_block_res_samples[-1].shape[2:] 450 | 451 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 452 | sample = upsample_block( 453 | hidden_states=sample, 454 | temb=emb, 455 | res_hidden_states_tuple=res_samples, 456 | encoder_hidden_states=encoder_hidden_states, 457 | upsample_size=upsample_size, 458 | attention_mask=attention_mask, 459 | ) 460 | else: 461 | sample = upsample_block( 462 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, 463 | ) 464 | 465 | # post-process 466 | sample = self.conv_norm_out(sample) 467 | sample = self.conv_act(sample) 468 | sample = self.conv_out(sample) 469 | 470 | if not return_dict: 471 | return (sample,) 472 | 473 | return UNet3DConditionOutput(sample=sample) 474 | 475 | @classmethod 476 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): 477 | if subfolder is not None: 478 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 479 | print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") 480 | 481 | config_file = os.path.join(pretrained_model_path, 'config.json') 482 | if not os.path.isfile(config_file): 483 | raise RuntimeError(f"{config_file} does not exist") 484 | with open(config_file, "r") as f: 485 | config = json.load(f) 486 | config["_class_name"] = cls.__name__ 487 | config["down_block_types"] = [ 488 | "CrossAttnDownBlock3D", 489 | "CrossAttnDownBlock3D", 490 | "CrossAttnDownBlock3D", 491 | "DownBlock3D" 492 | ] 493 | config["up_block_types"] = [ 494 | "UpBlock3D", 495 | "CrossAttnUpBlock3D", 496 | "CrossAttnUpBlock3D", 497 | "CrossAttnUpBlock3D" 498 | ] 499 | 500 | from diffusers.utils import WEIGHTS_NAME 501 | model = cls.from_config(config, **unet_additional_kwargs) 502 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 503 | if not os.path.isfile(model_file): 504 | raise RuntimeError(f"{model_file} does not exist") 505 | state_dict = torch.load(model_file, map_location="cpu") 506 | 507 | m, u = model.load_state_dict(state_dict, strict=False) 508 | print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") 509 | 510 | params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()] 511 | print(f"### Motion Module Parameters: {sum(params) / 1e6} M") 512 | 513 | return model 514 | --------------------------------------------------------------------------------