├── LICENSE ├── README.md ├── __init__.py ├── configs ├── T5_tokenizer │ ├── special_tokens_map.json │ ├── spiece.model │ ├── tokenizer.json │ └── tokenizer_config.json └── transformer_config_i2v.json ├── context.py ├── diffsynth └── vram_management │ ├── LICENSE │ ├── __init__.py │ ├── layers.py │ └── utils.py ├── enhance_a_video ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── enhance.cpython-310.pyc │ └── globals.cpython-310.pyc ├── enhance.py └── globals.py ├── example_workflows └── Wan2.1_sei_workflow.json ├── fp8_optimization.py ├── nodes.py ├── pyproject.toml ├── requirements.txt ├── utils.py └── wanvideo ├── LICENSE ├── __pycache__ └── wan_video_vae.cpython-310.pyc ├── configs ├── __init__.py ├── shared_config.py ├── wan_i2v_14B.py ├── wan_t2v_14B.py └── wan_t2v_1_3B.py ├── model.py ├── modules ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── attention.cpython-310.pyc │ ├── clip.cpython-310.pyc │ ├── model.cpython-310.pyc │ ├── t5.cpython-310.pyc │ └── tokenizers.cpython-310.pyc ├── attention.py ├── clip.py ├── model.py ├── t5.py ├── tokenizers.py └── vae.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── fm_solvers.cpython-310.pyc │ └── fm_solvers_unipc.cpython-310.pyc ├── fm_solvers.py └── fm_solvers_unipc.py └── wan_video_vae_SE.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-WanVideoStartEndFrames 2 | ComfyUI nodes that support video generation by start and end frames 3 | 4 | # Start 5 | This project is a node-based implementation for video generation using the Wan2.1 model, with a focus on start and end frame guidance. The source code is a modification of Kijai's nodes code, so for model download and installation instructions, please refer to [ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper). This project specifically adds the functionality of start and end frame guided video generation. 6 | 7 | The nodes support Wan2.1 models in both 720P and 480P versions. It is recommended to generate videos with a frame count of 25 or higher, as a lower frame count may affect the consistency of character identity. 8 | 9 | Currently, the start and end frame video generation approach is in its early stages. It primarily implements the start and end frame video generation functionality at the code level and does not yet involve model or LoRA fine-tuning, which is planned for future work. Additionally, incorporating end frame guidance in Image-to-Video (I2V) seems to degrade video generation quality, which is another area for future improvement. 10 | 11 | I welcome discussions in the issues section and extend our gratitude to Kijai for the open-source nodes. 12 | 13 | Note: Video generation should ideally be accompanied by positive prompts. Currently, the absence of positive prompts can result in severe video distortion. 14 | 15 | 16 | # Changelog 17 | - 2025.3.20: Added start and end frame weight controls for video transitions 18 | - 2025.3.22: Compatible with SLG functionality in KJ's nodes 19 | 20 | 21 | # Examples 22 | Start Frame: 23 | ![start_frame_](https://github.com/user-attachments/assets/6c301578-56ae-45c7-8d1c-9ac5f727bf53) 24 | End Frame: 25 | ![end_frame](https://github.com/user-attachments/assets/97de3844-e974-4be9-9157-0785c564574d) 26 | prompt: 27 | 两个角色搀扶着往前走,并看向彼此 28 | 29 | Frame count:81+1=82 30 | 31 | Video output: 32 | 33 | - 720P (w/o SLG): 34 | 35 | https://github.com/user-attachments/assets/948b70c7-172b-4754-8453-cd6f78b0338a 36 | 37 | 38 | - 480P (w/o SLG): 39 | 40 | https://github.com/user-attachments/assets/09c224e3-ac17-4621-bfcd-a8d449b8720e 41 | 42 | - 480P (w SLG): 43 | 44 | https://github.com/user-attachments/assets/3db1b9ed-589d-4ece-805a-c2dd8b651ff6 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /configs/T5_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "", 24 | "", 25 | "", 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "", 33 | "", 34 | "", 35 | "", 36 | "", 37 | "", 38 | "", 39 | "", 40 | "", 41 | "", 42 | "", 43 | "", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | "", 50 | "", 51 | "", 52 | "", 53 | "", 54 | "", 55 | "", 56 | "", 57 | "", 58 | "", 59 | "", 60 | "", 61 | "", 62 | "", 63 | "", 64 | "", 65 | "", 66 | "", 67 | "", 68 | "", 69 | "", 70 | "", 71 | "", 72 | "", 73 | "", 74 | "", 75 | "", 76 | "", 77 | "", 78 | "", 79 | "", 80 | "", 81 | "", 82 | "", 83 | "", 84 | "", 85 | "", 86 | "", 87 | "", 88 | "", 89 | "", 90 | "", 91 | "", 92 | "", 93 | "", 94 | "", 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "", 103 | "", 104 | "", 105 | "", 106 | "", 107 | "", 108 | "", 109 | "", 110 | "", 111 | "", 112 | "", 113 | "", 114 | "", 115 | "", 116 | "", 117 | "", 118 | "", 119 | "", 120 | "", 121 | "", 122 | "", 123 | "", 124 | "", 125 | "", 126 | "", 127 | "", 128 | "", 129 | "", 130 | "", 131 | "", 132 | "", 133 | "", 134 | "", 135 | "", 136 | "", 137 | "", 138 | "", 139 | "", 140 | "", 141 | "", 142 | "", 143 | "", 144 | "", 145 | "", 146 | "", 147 | "", 148 | "", 149 | "", 150 | "", 151 | "", 152 | "", 153 | "", 154 | "", 155 | "", 156 | "", 157 | "", 158 | "", 159 | "", 160 | "", 161 | "", 162 | "", 163 | "", 164 | "", 165 | "", 166 | "", 167 | "", 168 | "", 169 | "", 170 | "", 171 | "", 172 | "", 173 | "", 174 | "", 175 | "", 176 | "", 177 | "", 178 | "", 179 | "", 180 | "", 181 | "", 182 | "", 183 | "", 184 | "", 185 | "", 186 | "", 187 | "", 188 | "", 189 | "", 190 | "", 191 | "", 192 | "", 193 | "", 194 | "", 195 | "", 196 | "", 197 | "", 198 | "", 199 | "", 200 | "", 201 | "", 202 | "", 203 | "", 204 | "", 205 | "", 206 | "", 207 | "", 208 | "", 209 | "", 210 | "", 211 | "", 212 | "", 213 | "", 214 | "", 215 | "", 216 | "", 217 | "", 218 | "", 219 | "", 220 | "", 221 | "", 222 | "", 223 | "", 224 | "", 225 | "", 226 | "", 227 | "", 228 | "", 229 | "", 230 | "", 231 | "", 232 | "", 233 | "", 234 | "", 235 | "", 236 | "", 237 | "", 238 | "", 239 | "", 240 | "", 241 | "", 242 | "", 243 | "", 244 | "", 245 | "", 246 | "", 247 | "", 248 | "", 249 | "", 250 | "", 251 | "", 252 | "", 253 | "", 254 | "", 255 | "", 256 | "", 257 | "", 258 | "", 259 | "", 260 | "", 261 | "", 262 | "", 263 | "", 264 | "", 265 | "", 266 | "", 267 | "", 268 | "", 269 | "", 270 | "", 271 | "", 272 | "", 273 | "", 274 | "", 275 | "", 276 | "", 277 | "", 278 | "", 279 | "", 280 | "", 281 | "", 282 | "", 283 | "", 284 | "", 285 | "", 286 | "", 287 | "", 288 | "", 289 | "", 290 | "", 291 | "", 292 | "", 293 | "", 294 | "", 295 | "", 296 | "", 297 | "", 298 | "", 299 | "", 300 | "", 301 | "", 302 | "" 303 | ], 304 | "bos_token": "", 305 | "eos_token": "", 306 | "pad_token": "", 307 | "unk_token": "" 308 | } 309 | -------------------------------------------------------------------------------- /configs/T5_tokenizer/spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/configs/T5_tokenizer/spiece.model -------------------------------------------------------------------------------- /configs/transformer_config_i2v.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "WanModel", 3 | "_diffusers_version": "0.30.0", 4 | "dim": 5120, 5 | "eps": 1e-06, 6 | "ffn_dim": 13824, 7 | "freq_dim": 256, 8 | "in_dim": 36, 9 | "model_type": "i2v", 10 | "num_heads": 40, 11 | "num_layers": 40, 12 | "out_dim": 16, 13 | "text_len": 512 14 | } -------------------------------------------------------------------------------- /context.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Callable, Optional, List 3 | 4 | 5 | def ordered_halving(val): 6 | bin_str = f"{val:064b}" 7 | bin_flip = bin_str[::-1] 8 | as_int = int(bin_flip, 2) 9 | 10 | return as_int / (1 << 64) 11 | 12 | def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: 13 | prev_val = -1 14 | for i, val in enumerate(window): 15 | val = val % num_frames 16 | if val < prev_val: 17 | return True, i 18 | prev_val = val 19 | return False, -1 20 | 21 | def shift_window_to_start(window: list[int], num_frames: int): 22 | start_val = window[0] 23 | for i in range(len(window)): 24 | # 1) subtract each element by start_val to move vals relative to the start of all frames 25 | # 2) add num_frames and take modulus to get adjusted vals 26 | window[i] = ((window[i] - start_val) + num_frames) % num_frames 27 | 28 | def shift_window_to_end(window: list[int], num_frames: int): 29 | # 1) shift window to start 30 | shift_window_to_start(window, num_frames) 31 | end_val = window[-1] 32 | end_delta = num_frames - end_val - 1 33 | for i in range(len(window)): 34 | # 2) add end_delta to each val to slide windows to end 35 | window[i] = window[i] + end_delta 36 | 37 | def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: 38 | all_indexes = list(range(num_frames)) 39 | for w in windows: 40 | for val in w: 41 | try: 42 | all_indexes.remove(val) 43 | except ValueError: 44 | pass 45 | return all_indexes 46 | 47 | def uniform_looped( 48 | step: int = ..., 49 | num_steps: Optional[int] = None, 50 | num_frames: int = ..., 51 | context_size: Optional[int] = None, 52 | context_stride: int = 3, 53 | context_overlap: int = 4, 54 | closed_loop: bool = True, 55 | ): 56 | if num_frames <= context_size: 57 | yield list(range(num_frames)) 58 | return 59 | 60 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 61 | 62 | for context_step in 1 << np.arange(context_stride): 63 | pad = int(round(num_frames * ordered_halving(step))) 64 | for j in range( 65 | int(ordered_halving(step) * context_step) + pad, 66 | num_frames + pad + (0 if closed_loop else -context_overlap), 67 | (context_size * context_step - context_overlap), 68 | ): 69 | yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] 70 | 71 | #from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) 72 | def uniform_standard( 73 | step: int = ..., 74 | num_steps: Optional[int] = None, 75 | num_frames: int = ..., 76 | context_size: Optional[int] = None, 77 | context_stride: int = 3, 78 | context_overlap: int = 4, 79 | closed_loop: bool = True, 80 | ): 81 | windows = [] 82 | if num_frames <= context_size: 83 | windows.append(list(range(num_frames))) 84 | return windows 85 | 86 | context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) 87 | 88 | for context_step in 1 << np.arange(context_stride): 89 | pad = int(round(num_frames * ordered_halving(step))) 90 | for j in range( 91 | int(ordered_halving(step) * context_step) + pad, 92 | num_frames + pad + (0 if closed_loop else -context_overlap), 93 | (context_size * context_step - context_overlap), 94 | ): 95 | windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) 96 | 97 | # now that windows are created, shift any windows that loop, and delete duplicate windows 98 | delete_idxs = [] 99 | win_i = 0 100 | while win_i < len(windows): 101 | # if window is rolls over itself, need to shift it 102 | is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) 103 | if is_roll: 104 | roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides 105 | shift_window_to_end(windows[win_i], num_frames=num_frames) 106 | # check if next window (cyclical) is missing roll_val 107 | if roll_val not in windows[(win_i+1) % len(windows)]: 108 | # need to insert new window here - just insert window starting at roll_val 109 | windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) 110 | # delete window if it's not unique 111 | for pre_i in range(0, win_i): 112 | if windows[win_i] == windows[pre_i]: 113 | delete_idxs.append(win_i) 114 | break 115 | win_i += 1 116 | 117 | # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation 118 | delete_idxs.reverse() 119 | for i in delete_idxs: 120 | windows.pop(i) 121 | return windows 122 | 123 | def static_standard( 124 | step: int = ..., 125 | num_steps: Optional[int] = None, 126 | num_frames: int = ..., 127 | context_size: Optional[int] = None, 128 | context_stride: int = 3, 129 | context_overlap: int = 4, 130 | closed_loop: bool = True, 131 | ): 132 | windows = [] 133 | if num_frames <= context_size: 134 | windows.append(list(range(num_frames))) 135 | return windows 136 | # always return the same set of windows 137 | delta = context_size - context_overlap 138 | for start_idx in range(0, num_frames, delta): 139 | # if past the end of frames, move start_idx back to allow same context_length 140 | ending = start_idx + context_size 141 | if ending >= num_frames: 142 | final_delta = ending - num_frames 143 | final_start_idx = start_idx - final_delta 144 | windows.append(list(range(final_start_idx, final_start_idx + context_size))) 145 | break 146 | windows.append(list(range(start_idx, start_idx + context_size))) 147 | return windows 148 | 149 | def get_context_scheduler(name: str) -> Callable: 150 | if name == "uniform_looped": 151 | return uniform_looped 152 | elif name == "uniform_standard": 153 | return uniform_standard 154 | elif name == "static_standard": 155 | return static_standard 156 | else: 157 | raise ValueError(f"Unknown context_overlap policy {name}") 158 | 159 | 160 | def get_total_steps( 161 | scheduler, 162 | timesteps: List[int], 163 | num_steps: Optional[int] = None, 164 | num_frames: int = ..., 165 | context_size: Optional[int] = None, 166 | context_stride: int = 3, 167 | context_overlap: int = 4, 168 | closed_loop: bool = True, 169 | ): 170 | return sum( 171 | len( 172 | list( 173 | scheduler( 174 | i, 175 | num_steps, 176 | num_frames, 177 | context_size, 178 | context_stride, 179 | context_overlap, 180 | ) 181 | ) 182 | ) 183 | for i in range(len(timesteps)) 184 | ) 185 | -------------------------------------------------------------------------------- /diffsynth/vram_management/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2023] [Zhongjie Duan] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /diffsynth/vram_management/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * -------------------------------------------------------------------------------- /diffsynth/vram_management/layers.py: -------------------------------------------------------------------------------- 1 | import torch, copy 2 | from .utils import init_weights_on_device 3 | 4 | 5 | def cast_to(weight, dtype, device): 6 | r = torch.empty_like(weight, dtype=dtype, device=device) 7 | r.copy_(weight) 8 | return r 9 | 10 | 11 | class AutoWrappedModule(torch.nn.Module): 12 | def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): 13 | super().__init__() 14 | self.module = module.to(dtype=offload_dtype, device=offload_device) 15 | self.offload_dtype = offload_dtype 16 | self.offload_device = offload_device 17 | self.onload_dtype = onload_dtype 18 | self.onload_device = onload_device 19 | self.computation_dtype = computation_dtype 20 | self.computation_device = computation_device 21 | self.state = 0 22 | 23 | def offload(self): 24 | if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): 25 | self.module.to(dtype=self.offload_dtype, device=self.offload_device) 26 | self.state = 0 27 | 28 | def onload(self): 29 | if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): 30 | self.module.to(dtype=self.onload_dtype, device=self.onload_device) 31 | self.state = 1 32 | 33 | def forward(self, *args, **kwargs): 34 | if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: 35 | module = self.module 36 | else: 37 | module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) 38 | return module(*args, **kwargs) 39 | 40 | 41 | class AutoWrappedLinear(torch.nn.Linear): 42 | def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device): 43 | with init_weights_on_device(device=torch.device("meta")): 44 | super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) 45 | self.weight = module.weight 46 | self.bias = module.bias 47 | self.offload_dtype = offload_dtype 48 | self.offload_device = offload_device 49 | self.onload_dtype = onload_dtype 50 | self.onload_device = onload_device 51 | self.computation_dtype = computation_dtype 52 | self.computation_device = computation_device 53 | self.state = 0 54 | 55 | def offload(self): 56 | if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): 57 | self.to(dtype=self.offload_dtype, device=self.offload_device) 58 | self.state = 0 59 | 60 | def onload(self): 61 | if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device): 62 | self.to(dtype=self.onload_dtype, device=self.onload_device) 63 | self.state = 1 64 | 65 | def forward(self, x, *args, **kwargs): 66 | if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device: 67 | weight, bias = self.weight, self.bias 68 | else: 69 | weight = cast_to(self.weight, self.computation_dtype, self.computation_device) 70 | bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) 71 | return torch.nn.functional.linear(x, weight, bias) 72 | 73 | 74 | def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0): 75 | for name, module in model.named_children(): 76 | for source_module, target_module in module_map.items(): 77 | if isinstance(module, source_module): 78 | num_param = sum(p.numel() for p in module.parameters()) 79 | if max_num_param is not None and total_num_param + num_param > max_num_param: 80 | module_config_ = overflow_module_config 81 | else: 82 | module_config_ = module_config 83 | module_ = target_module(module, **module_config_) 84 | setattr(model, name, module_) 85 | total_num_param += num_param 86 | break 87 | else: 88 | total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param) 89 | return total_num_param 90 | 91 | 92 | def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None): 93 | enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0) 94 | model.vram_management_enabled = True -------------------------------------------------------------------------------- /diffsynth/vram_management/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import contextmanager 3 | 4 | @contextmanager 5 | def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): 6 | 7 | old_register_parameter = torch.nn.Module.register_parameter 8 | if include_buffers: 9 | old_register_buffer = torch.nn.Module.register_buffer 10 | 11 | def register_empty_parameter(module, name, param): 12 | old_register_parameter(module, name, param) 13 | if param is not None: 14 | param_cls = type(module._parameters[name]) 15 | kwargs = module._parameters[name].__dict__ 16 | kwargs["requires_grad"] = param.requires_grad 17 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 18 | 19 | def register_empty_buffer(module, name, buffer, persistent=True): 20 | old_register_buffer(module, name, buffer, persistent=persistent) 21 | if buffer is not None: 22 | module._buffers[name] = module._buffers[name].to(device) 23 | 24 | def patch_tensor_constructor(fn): 25 | def wrapper(*args, **kwargs): 26 | kwargs["device"] = device 27 | return fn(*args, **kwargs) 28 | 29 | return wrapper 30 | 31 | if include_buffers: 32 | tensor_constructors_to_patch = { 33 | torch_function_name: getattr(torch, torch_function_name) 34 | for torch_function_name in ["empty", "zeros", "ones", "full"] 35 | } 36 | else: 37 | tensor_constructors_to_patch = {} 38 | 39 | try: 40 | torch.nn.Module.register_parameter = register_empty_parameter 41 | if include_buffers: 42 | torch.nn.Module.register_buffer = register_empty_buffer 43 | for torch_function_name in tensor_constructors_to_patch.keys(): 44 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 45 | yield 46 | finally: 47 | torch.nn.Module.register_parameter = old_register_parameter 48 | if include_buffers: 49 | torch.nn.Module.register_buffer = old_register_buffer 50 | for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): 51 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /enhance_a_video/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/enhance_a_video/__init__.py -------------------------------------------------------------------------------- /enhance_a_video/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/enhance_a_video/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /enhance_a_video/__pycache__/enhance.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/enhance_a_video/__pycache__/enhance.cpython-310.pyc -------------------------------------------------------------------------------- /enhance_a_video/__pycache__/globals.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/enhance_a_video/__pycache__/globals.cpython-310.pyc -------------------------------------------------------------------------------- /enhance_a_video/enhance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from .globals import get_enhance_weight, get_num_frames 4 | 5 | @torch.compiler.disable() 6 | def get_feta_scores(query, key): 7 | img_q, img_k = query, key 8 | 9 | num_frames = get_num_frames() 10 | 11 | B, S, N, C = img_q.shape 12 | 13 | # Calculate spatial dimension 14 | spatial_dim = S // num_frames 15 | 16 | # Add time dimension between spatial and head dims 17 | query_image = img_q.reshape(B, spatial_dim, num_frames, N, C) 18 | key_image = img_k.reshape(B, spatial_dim, num_frames, N, C) 19 | 20 | # Expand time dimension 21 | query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] 22 | key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] 23 | 24 | # Reshape to match feta_score input format: [(B S) N T C] 25 | query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128]) 26 | key_image = rearrange(key_image, "b s t n c -> (b s) n t c") 27 | 28 | return feta_score(query_image, key_image, C, num_frames) 29 | 30 | @torch.compiler.disable() 31 | def feta_score(query_image, key_image, head_dim, num_frames): 32 | scale = head_dim**-0.5 33 | query_image = query_image * scale 34 | attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 35 | attn_temp = attn_temp.to(torch.float32) 36 | attn_temp = attn_temp.softmax(dim=-1) 37 | 38 | # Reshape to [batch_size * num_tokens, num_frames, num_frames] 39 | attn_temp = attn_temp.reshape(-1, num_frames, num_frames) 40 | 41 | # Create a mask for diagonal elements 42 | diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() 43 | diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) 44 | 45 | # Zero out diagonal elements 46 | attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) 47 | 48 | # Calculate mean for each token's attention matrix 49 | # Number of off-diagonal elements per matrix is n*n - n 50 | num_off_diag = num_frames * num_frames - num_frames 51 | mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag 52 | 53 | enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight()) 54 | enhance_scores = enhance_scores.clamp(min=1) 55 | return enhance_scores 56 | -------------------------------------------------------------------------------- /enhance_a_video/globals.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | NUM_FRAMES = None 4 | FETA_WEIGHT = None 5 | ENABLE_FETA= False 6 | 7 | @torch.compiler.disable() 8 | def set_num_frames(num_frames: int): 9 | global NUM_FRAMES 10 | NUM_FRAMES = num_frames 11 | 12 | @torch.compiler.disable() 13 | def get_num_frames() -> int: 14 | return NUM_FRAMES 15 | 16 | 17 | def enable_enhance(): 18 | global ENABLE_FETA 19 | ENABLE_FETA = True 20 | 21 | def disable_enhance(): 22 | global ENABLE_FETA 23 | ENABLE_FETA = False 24 | 25 | @torch.compiler.disable() 26 | def is_enhance_enabled() -> bool: 27 | return ENABLE_FETA 28 | 29 | @torch.compiler.disable() 30 | def set_enhance_weight(feta_weight: float): 31 | global FETA_WEIGHT 32 | FETA_WEIGHT = feta_weight 33 | 34 | @torch.compiler.disable() 35 | def get_enhance_weight() -> float: 36 | return FETA_WEIGHT 37 | -------------------------------------------------------------------------------- /fp8_optimization.py: -------------------------------------------------------------------------------- 1 | #based on ComfyUI's and MinusZoneAI's fp8_linear optimization 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def fp8_linear_forward(cls, original_dtype, input): 7 | weight_dtype = cls.weight.dtype 8 | if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: 9 | if len(input.shape) == 3: 10 | target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn 11 | inn = input.reshape(-1, input.shape[2]).to(target_dtype) 12 | w = cls.weight.t() 13 | 14 | scale = torch.ones((1), device=input.device, dtype=torch.float32) 15 | bias = cls.bias.to(original_dtype) if cls.bias is not None else None 16 | 17 | if bias is not None: 18 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale) 19 | else: 20 | o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale) 21 | 22 | if isinstance(o, tuple): 23 | o = o[0] 24 | 25 | return o.reshape((-1, input.shape[1], cls.weight.shape[0])) 26 | else: 27 | return cls.original_forward(input.to(original_dtype)) 28 | else: 29 | return cls.original_forward(input) 30 | 31 | def convert_fp8_linear(module, original_dtype, params_to_keep={}): 32 | setattr(module, "fp8_matmul_enabled", True) 33 | 34 | for name, module in module.named_modules(): 35 | if not any(keyword in name for keyword in params_to_keep): 36 | if isinstance(module, nn.Linear): 37 | original_forward = module.forward 38 | setattr(module, "original_forward", original_forward) 39 | setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input)) 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-WanVideoStartEndFrames" 3 | description = "ComfyUI nodes that support video generation by start and end frames" 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["accelerate >= 1.2.1", "diffusers >= 0.32.0", "ftfy"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/raindrop313/ComfyUI-WanVideoStartEndFrames" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "raindrop-313" 14 | DisplayName = "ComfyUI-WanVideoStartEndFrames" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | accelerate>=1.2.1 3 | einops 4 | diffusers>=0.32.0 5 | sentencepiece>=0.2.0 6 | protobuf -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import torch 3 | import logging 4 | from contextlib import contextmanager 5 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 6 | log = logging.getLogger(__name__) 7 | 8 | def check_diffusers_version(): 9 | try: 10 | version = importlib.metadata.version('diffusers') 11 | required_version = '0.31.0' 12 | if version < required_version: 13 | raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") 14 | except importlib.metadata.PackageNotFoundError: 15 | raise AssertionError("diffusers is not installed.") 16 | 17 | def print_memory(device): 18 | memory = torch.cuda.memory_allocated(device) / 1024**3 19 | max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 20 | max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 21 | log.info(f"Allocated memory: {memory=:.3f} GB") 22 | log.info(f"Max allocated memory: {max_memory=:.3f} GB") 23 | log.info(f"Max reserved memory: {max_reserved=:.3f} GB") 24 | #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) 25 | #log.info(f"Memory Summary:\n{memory_summary}") 26 | 27 | def get_module_memory_mb(module): 28 | memory = 0 29 | for param in module.parameters(): 30 | if param.data is not None: 31 | memory += param.nelement() * param.element_size() 32 | return memory / (1024 * 1024) # Convert to MB -------------------------------------------------------------------------------- /wanvideo/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /wanvideo/__pycache__/wan_video_vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/__pycache__/wan_video_vae.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_i2v_14B import i2v_14B 8 | from .wan_t2v_1_3B import t2v_1_3B 9 | from .wan_t2v_14B import t2v_14B 10 | 11 | # the config of t2i_14B is the same as t2v_14B 12 | t2i_14B = copy.deepcopy(t2v_14B) 13 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 14 | 15 | WAN_CONFIGS = { 16 | 't2v-14B': t2v_14B, 17 | 't2v-1.3B': t2v_1_3B, 18 | 'i2v-14B': i2v_14B, 19 | 't2i-14B': t2i_14B, 20 | } 21 | 22 | SIZE_CONFIGS = { 23 | '720*1280': (720, 1280), 24 | '1280*720': (1280, 720), 25 | '480*832': (480, 832), 26 | '832*480': (832, 480), 27 | '1024*1024': (1024, 1024), 28 | } 29 | 30 | MAX_AREA_CONFIGS = { 31 | '720*1280': 720 * 1280, 32 | '1280*720': 1280 * 720, 33 | '480*832': 480 * 832, 34 | '832*480': 832 * 480, 35 | } 36 | 37 | SUPPORTED_SIZES = { 38 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 39 | 't2v-1.3B': ('480*832', '832*480'), 40 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 41 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 42 | } 43 | -------------------------------------------------------------------------------- /wanvideo/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 20 | -------------------------------------------------------------------------------- /wanvideo/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | 12 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # clip 16 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 17 | i2v_14B.clip_dtype = torch.float16 18 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 19 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 20 | 21 | # vae 22 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 23 | i2v_14B.vae_stride = (4, 8, 8) 24 | 25 | # transformer 26 | i2v_14B.patch_size = (1, 2, 2) 27 | i2v_14B.dim = 5120 28 | i2v_14B.ffn_dim = 13824 29 | i2v_14B.freq_dim = 256 30 | i2v_14B.num_heads = 40 31 | i2v_14B.num_layers = 40 32 | i2v_14B.window_size = (-1, -1) 33 | i2v_14B.qk_norm = True 34 | i2v_14B.cross_attn_norm = True 35 | i2v_14B.eps = 1e-6 36 | -------------------------------------------------------------------------------- /wanvideo/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wanvideo/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wanvideo/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.modeling_utils import ModelMixin 8 | 9 | from ...enhance_a_video.enhance import get_feta_scores 10 | from ...enhance_a_video.globals import is_enhance_enabled 11 | 12 | from .attention import attention 13 | import numpy as np 14 | __all__ = ['WanModel'] 15 | 16 | from tqdm import tqdm 17 | import gc 18 | import comfy.model_management as mm 19 | from ...utils import log, get_module_memory_mb 20 | 21 | def poly1d(coefficients, x): 22 | result = torch.zeros_like(x) 23 | for i, coeff in enumerate(coefficients): 24 | result += coeff * (x ** (len(coefficients) - 1 - i)) 25 | return result.abs() 26 | 27 | def sinusoidal_embedding_1d(dim, position): 28 | # preprocess 29 | assert dim % 2 == 0 30 | half = dim // 2 31 | position = position.type(torch.float64) 32 | 33 | # calculation 34 | sinusoid = torch.outer( 35 | position, torch.pow(10000, -torch.arange(half).to(position).div(half))) 36 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) 37 | return x 38 | 39 | 40 | def rope_params(max_seq_len, dim, theta=10000, L_test=25, k=0): 41 | assert dim % 2 == 0 42 | exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim) 43 | inv_theta_pow = 1.0 / torch.pow(theta, exponents) 44 | 45 | if k > 0: 46 | print(f"RifleX: Using {k}th freq") 47 | inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test 48 | 49 | freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow) 50 | freqs = torch.polar(torch.ones_like(freqs), freqs) 51 | return freqs 52 | 53 | from comfy.model_management import get_torch_device, get_autocast_device 54 | @torch.autocast(device_type=get_autocast_device(get_torch_device()), enabled=False) 55 | @torch.compiler.disable() 56 | def rope_apply(x, grid_sizes, freqs): 57 | n, c = x.size(2), x.size(3) // 2 58 | 59 | # split freqs 60 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 61 | 62 | # loop over samples 63 | output = [] 64 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 65 | seq_len = f * h * w 66 | 67 | # precompute multipliers 68 | x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( 69 | seq_len, n, -1, 2)) 70 | freqs_i = torch.cat([ 71 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 72 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 73 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 74 | ], 75 | dim=-1).reshape(seq_len, 1, -1) 76 | 77 | # apply rotary embedding 78 | x_i = torch.view_as_real(x_i * freqs_i).flatten(2) 79 | x_i = torch.cat([x_i, x[i, seq_len:]]) 80 | 81 | # append to collection 82 | output.append(x_i) 83 | return torch.stack(output).float() 84 | 85 | 86 | class WanRMSNorm(nn.Module): 87 | 88 | def __init__(self, dim, eps=1e-5): 89 | super().__init__() 90 | self.dim = dim 91 | self.eps = eps 92 | self.weight = nn.Parameter(torch.ones(dim)) 93 | 94 | def forward(self, x): 95 | r""" 96 | Args: 97 | x(Tensor): Shape [B, L, C] 98 | """ 99 | return self._norm(x.float()).type_as(x) * self.weight 100 | 101 | def _norm(self, x): 102 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) 103 | 104 | 105 | class WanLayerNorm(nn.LayerNorm): 106 | 107 | def __init__(self, dim, eps=1e-6, elementwise_affine=False): 108 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) 109 | 110 | def forward(self, x): 111 | r""" 112 | Args: 113 | x(Tensor): Shape [B, L, C] 114 | """ 115 | return super().forward(x.float()).type_as(x) 116 | 117 | 118 | class WanSelfAttention(nn.Module): 119 | 120 | def __init__(self, 121 | dim, 122 | num_heads, 123 | window_size=(-1, -1), 124 | qk_norm=True, 125 | eps=1e-6, 126 | attention_mode='sdpa'): 127 | assert dim % num_heads == 0 128 | super().__init__() 129 | self.dim = dim 130 | self.num_heads = num_heads 131 | self.head_dim = dim // num_heads 132 | self.window_size = window_size 133 | self.qk_norm = qk_norm 134 | self.eps = eps 135 | self.attention_mode = attention_mode 136 | 137 | # layers 138 | self.q = nn.Linear(dim, dim) 139 | self.k = nn.Linear(dim, dim) 140 | self.v = nn.Linear(dim, dim) 141 | self.o = nn.Linear(dim, dim) 142 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 143 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 144 | 145 | def forward(self, x, seq_lens, grid_sizes, freqs): 146 | r""" 147 | Args: 148 | x(Tensor): Shape [B, L, num_heads, C / num_heads] 149 | seq_lens(Tensor): Shape [B] 150 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 151 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 152 | """ 153 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 154 | 155 | # query, key, value function 156 | def qkv_fn(x): 157 | q = self.norm_q(self.q(x)).view(b, s, n, d) 158 | k = self.norm_k(self.k(x)).view(b, s, n, d) 159 | v = self.v(x).view(b, s, n, d) 160 | return q, k, v 161 | 162 | q, k, v = qkv_fn(x) 163 | 164 | if self.attention_mode == 'spargeattn_tune' or self.attention_mode == 'spargeattn': 165 | tune_mode = False 166 | if self.attention_mode == 'spargeattn_tune': 167 | tune_mode = True 168 | 169 | if hasattr(self, 'inner_attention'): 170 | #print("has inner attention") 171 | q=rope_apply(q, grid_sizes, freqs) 172 | k=rope_apply(k, grid_sizes, freqs) 173 | q = q.permute(0, 2, 1, 3) 174 | k = k.permute(0, 2, 1, 3) 175 | v = v.permute(0, 2, 1, 3) 176 | x = self.inner_attention( 177 | q=q, 178 | k=k, 179 | v=v, 180 | is_causal=False, 181 | tune_mode=tune_mode 182 | ).permute(0, 2, 1, 3) 183 | #print("inner attention", x.shape) #inner attention torch.Size([1, 12, 32760, 128]) 184 | else: 185 | q=rope_apply(q, grid_sizes, freqs) 186 | k=rope_apply(k, grid_sizes, freqs) 187 | if is_enhance_enabled(): 188 | feta_scores = get_feta_scores(q, k) 189 | 190 | x = attention( 191 | q=q, 192 | k=k, 193 | v=v, 194 | k_lens=seq_lens, 195 | window_size=self.window_size, 196 | attention_mode=self.attention_mode) 197 | 198 | # output 199 | x = x.flatten(2) 200 | x = self.o(x) 201 | 202 | if is_enhance_enabled(): 203 | x *= feta_scores 204 | 205 | return x 206 | 207 | 208 | class WanT2VCrossAttention(WanSelfAttention): 209 | 210 | def forward(self, x, context, context_lens): 211 | r""" 212 | Args: 213 | x(Tensor): Shape [B, L1, C] 214 | context(Tensor): Shape [B, L2, C] 215 | context_lens(Tensor): Shape [B] 216 | """ 217 | b, n, d = x.size(0), self.num_heads, self.head_dim 218 | 219 | # compute query, key, value 220 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 221 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 222 | v = self.v(context).view(b, -1, n, d) 223 | 224 | # compute attention 225 | x = attention(q, k, v, k_lens=context_lens, attention_mode=self.attention_mode) 226 | 227 | # output 228 | x = x.flatten(2) 229 | x = self.o(x) 230 | return x 231 | 232 | 233 | class WanI2VCrossAttention(WanSelfAttention): 234 | 235 | def __init__(self, 236 | dim, 237 | num_heads, 238 | window_size=(-1, -1), 239 | qk_norm=True, 240 | eps=1e-6, 241 | attention_mode='sdpa'): 242 | super().__init__(dim, num_heads, window_size, qk_norm, eps) 243 | 244 | self.k_img = nn.Linear(dim, dim) 245 | self.v_img = nn.Linear(dim, dim) 246 | # self.alpha = nn.Parameter(torch.zeros((1, ))) 247 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 248 | self.attention_mode = attention_mode 249 | 250 | def forward(self, x, context, context_lens): 251 | r""" 252 | Args: 253 | x(Tensor): Shape [B, L1, C] 254 | context(Tensor): Shape [B, L2, C] 255 | context_lens(Tensor): Shape [B] 256 | """ 257 | context_img = context[:, :257] 258 | context = context[:, 257:] 259 | b, n, d = x.size(0), self.num_heads, self.head_dim 260 | 261 | # compute query, key, value 262 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 263 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 264 | v = self.v(context).view(b, -1, n, d) 265 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) 266 | v_img = self.v_img(context_img).view(b, -1, n, d) 267 | img_x = attention(q, k_img, v_img, k_lens=None, attention_mode=self.attention_mode) 268 | # compute attention 269 | x = attention(q, k, v, k_lens=context_lens, attention_mode=self.attention_mode) 270 | 271 | # output 272 | x = x.flatten(2) 273 | img_x = img_x.flatten(2) 274 | x = x + img_x 275 | x = self.o(x) 276 | return x 277 | 278 | 279 | WAN_CROSSATTENTION_CLASSES = { 280 | 't2v_cross_attn': WanT2VCrossAttention, 281 | 'i2v_cross_attn': WanI2VCrossAttention, 282 | } 283 | 284 | 285 | class WanAttentionBlock(nn.Module): 286 | 287 | def __init__(self, 288 | cross_attn_type, 289 | dim, 290 | ffn_dim, 291 | num_heads, 292 | window_size=(-1, -1), 293 | qk_norm=True, 294 | cross_attn_norm=False, 295 | eps=1e-6, 296 | attention_mode='sdpa'): 297 | super().__init__() 298 | self.dim = dim 299 | self.ffn_dim = ffn_dim 300 | self.num_heads = num_heads 301 | self.window_size = window_size 302 | self.qk_norm = qk_norm 303 | self.cross_attn_norm = cross_attn_norm 304 | self.eps = eps 305 | self.attention_mode = attention_mode 306 | 307 | # layers 308 | self.norm1 = WanLayerNorm(dim, eps) 309 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, 310 | eps, self.attention_mode) 311 | self.norm3 = WanLayerNorm( 312 | dim, eps, 313 | elementwise_affine=True) if cross_attn_norm else nn.Identity() 314 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, 315 | num_heads, 316 | (-1, -1), 317 | qk_norm, 318 | eps,#attention_mode=attention_mode sageattn doesn't seem faster here 319 | ) 320 | self.norm2 = WanLayerNorm(dim, eps) 321 | self.ffn = nn.Sequential( 322 | nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), 323 | nn.Linear(ffn_dim, dim)) 324 | 325 | # modulation 326 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 327 | 328 | def forward( 329 | self, 330 | x, 331 | e, 332 | seq_lens, 333 | grid_sizes, 334 | freqs, 335 | context, 336 | context_lens, 337 | ): 338 | r""" 339 | Args: 340 | x(Tensor): Shape [B, L, C] 341 | e(Tensor): Shape [B, 6, C] 342 | seq_lens(Tensor): Shape [B], length of each sequence in batch 343 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 344 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 345 | """ 346 | assert e.dtype == torch.float32 347 | e = (self.modulation.to(torch.float32).to(e.device) + e.to(torch.float32)).chunk(6, dim=1) 348 | assert e[0].dtype == torch.float32 349 | 350 | # self-attention 351 | y = self.self_attn( 352 | self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, 353 | freqs) 354 | x = x.to(torch.float32) + (y.to(torch.float32) * e[2].to(torch.float32)) 355 | 356 | # cross-attention & ffn function 357 | def cross_attn_ffn(x, context, context_lens, e): 358 | x = x + self.cross_attn(self.norm3(x), context, context_lens) 359 | y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) 360 | x = x.to(torch.float32) + (y.to(torch.float32) * e[5].to(torch.float32)) 361 | return x 362 | 363 | x = cross_attn_ffn(x, context, context_lens, e) 364 | return x 365 | 366 | 367 | class Head(nn.Module): 368 | 369 | def __init__(self, dim, out_dim, patch_size, eps=1e-6): 370 | super().__init__() 371 | self.dim = dim 372 | self.out_dim = out_dim 373 | self.patch_size = patch_size 374 | self.eps = eps 375 | 376 | # layers 377 | out_dim = math.prod(patch_size) * out_dim 378 | self.norm = WanLayerNorm(dim, eps) 379 | self.head = nn.Linear(dim, out_dim) 380 | 381 | # modulation 382 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) 383 | 384 | def forward(self, x, e): 385 | r""" 386 | Args: 387 | x(Tensor): Shape [B, L1, C] 388 | e(Tensor): Shape [B, C] 389 | """ 390 | assert e.dtype == torch.float32 391 | e_unsqueezed = e.unsqueeze(1).to(torch.float32) 392 | e = (self.modulation.to(torch.float32).to(e.device) + e_unsqueezed).chunk(2, dim=1) 393 | normed = self.norm(x).to(torch.float32) 394 | x = self.head(normed * (1 + e[1].to(torch.float32)) + e[0].to(torch.float32)) 395 | return x 396 | 397 | 398 | class MLPProj(torch.nn.Module): 399 | 400 | def __init__(self, in_dim, out_dim): 401 | super().__init__() 402 | 403 | self.proj = torch.nn.Sequential( 404 | torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), 405 | torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), 406 | torch.nn.LayerNorm(out_dim)) 407 | 408 | def forward(self, image_embeds): 409 | clip_extra_context_tokens = self.proj(image_embeds) 410 | return clip_extra_context_tokens 411 | 412 | 413 | class WanModel(ModelMixin, ConfigMixin): 414 | r""" 415 | Wan diffusion backbone supporting both text-to-video and image-to-video. 416 | """ 417 | 418 | ignore_for_config = [ 419 | 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' 420 | ] 421 | _no_split_modules = ['WanAttentionBlock'] 422 | 423 | @register_to_config 424 | def __init__(self, 425 | model_type='t2v', 426 | patch_size=(1, 2, 2), 427 | text_len=512, 428 | in_dim=16, 429 | dim=2048, 430 | ffn_dim=8192, 431 | freq_dim=256, 432 | text_dim=4096, 433 | out_dim=16, 434 | num_heads=16, 435 | num_layers=32, 436 | window_size=(-1, -1), 437 | qk_norm=True, 438 | cross_attn_norm=True, 439 | eps=1e-6, 440 | attention_mode='sdpa', 441 | main_device=torch.device('cuda'), 442 | offload_device=torch.device('cpu'), 443 | teacache_coefficients=[],): 444 | r""" 445 | Initialize the diffusion model backbone. 446 | 447 | Args: 448 | model_type (`str`, *optional*, defaults to 't2v'): 449 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) 450 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 451 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) 452 | text_len (`int`, *optional*, defaults to 512): 453 | Fixed length for text embeddings 454 | in_dim (`int`, *optional*, defaults to 16): 455 | Input video channels (C_in) 456 | dim (`int`, *optional*, defaults to 2048): 457 | Hidden dimension of the transformer 458 | ffn_dim (`int`, *optional*, defaults to 8192): 459 | Intermediate dimension in feed-forward network 460 | freq_dim (`int`, *optional*, defaults to 256): 461 | Dimension for sinusoidal time embeddings 462 | text_dim (`int`, *optional*, defaults to 4096): 463 | Input dimension for text embeddings 464 | out_dim (`int`, *optional*, defaults to 16): 465 | Output video channels (C_out) 466 | num_heads (`int`, *optional*, defaults to 16): 467 | Number of attention heads 468 | num_layers (`int`, *optional*, defaults to 32): 469 | Number of transformer blocks 470 | window_size (`tuple`, *optional*, defaults to (-1, -1)): 471 | Window size for local attention (-1 indicates global attention) 472 | qk_norm (`bool`, *optional*, defaults to True): 473 | Enable query/key normalization 474 | cross_attn_norm (`bool`, *optional*, defaults to False): 475 | Enable cross-attention normalization 476 | eps (`float`, *optional*, defaults to 1e-6): 477 | Epsilon value for normalization layers 478 | """ 479 | 480 | super().__init__() 481 | 482 | assert model_type in ['t2v', 'i2v'] 483 | self.model_type = model_type 484 | 485 | self.patch_size = patch_size 486 | self.text_len = text_len 487 | self.in_dim = in_dim 488 | self.dim = dim 489 | self.ffn_dim = ffn_dim 490 | self.freq_dim = freq_dim 491 | self.text_dim = text_dim 492 | self.out_dim = out_dim 493 | self.num_heads = num_heads 494 | self.num_layers = num_layers 495 | self.window_size = window_size 496 | self.qk_norm = qk_norm 497 | self.cross_attn_norm = cross_attn_norm 498 | self.eps = eps 499 | self.attention_mode = attention_mode 500 | self.main_device = main_device 501 | self.offload_device = offload_device 502 | 503 | self.blocks_to_swap = -1 504 | self.offload_txt_emb = False 505 | self.offload_img_emb = False 506 | 507 | #init TeaCache variables 508 | self.enable_teacache = False 509 | self.rel_l1_thresh = 0.15 510 | self.teacache_start_step= 0 511 | self.teacache_end_step = -1 512 | self.teacache_cache_device = main_device 513 | self.teacache_state = TeaCacheState() 514 | self.teacache_coefficients = teacache_coefficients 515 | self.teacache_use_coefficients = False 516 | # self.l1_history_x = [] 517 | # self.l1_history_temb = [] 518 | # self.l1_history_rescaled = [] 519 | 520 | # embeddings 521 | self.patch_embedding = nn.Conv3d( 522 | in_dim, dim, kernel_size=patch_size, stride=patch_size) 523 | self.text_embedding = nn.Sequential( 524 | nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), 525 | nn.Linear(dim, dim)) 526 | 527 | self.time_embedding = nn.Sequential( 528 | nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) 529 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) 530 | 531 | # blocks 532 | cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' 533 | self.blocks = nn.ModuleList([ 534 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, 535 | window_size, qk_norm, cross_attn_norm, eps, 536 | attention_mode=self.attention_mode) 537 | for _ in range(num_layers) 538 | ]) 539 | 540 | # head 541 | self.head = Head(dim, out_dim, patch_size, eps) 542 | 543 | # buffers (don't use register_buffer otherwise dtype will be changed in to()) 544 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 545 | 546 | 547 | if model_type == 'i2v': 548 | self.img_emb = MLPProj(1280, dim) 549 | 550 | # initialize weights 551 | #self.init_weights() 552 | 553 | def block_swap(self, blocks_to_swap, offload_txt_emb=False, offload_img_emb=False): 554 | print(f"Swapping {blocks_to_swap + 1} transformer blocks") 555 | self.blocks_to_swap = blocks_to_swap 556 | self.offload_img_emb = offload_img_emb 557 | self.offload_txt_emb = offload_txt_emb 558 | 559 | total_offload_memory = 0 560 | total_main_memory = 0 561 | 562 | for b, block in tqdm(enumerate(self.blocks), total=len(self.blocks), desc="Initializing block swap"): 563 | block_memory = get_module_memory_mb(block) 564 | 565 | if b > self.blocks_to_swap: 566 | block.to(self.main_device) 567 | total_main_memory += block_memory 568 | else: 569 | block.to(self.offload_device) 570 | total_offload_memory += block_memory 571 | 572 | mm.soft_empty_cache() 573 | gc.collect() 574 | 575 | #print(f"Block {b}: {block_memory:.2f}MB on {block.parameters().__next__().device}") 576 | log.info("----------------------") 577 | log.info(f"Block swap memory summary:") 578 | log.info(f"Transformer blocks on {self.offload_device}: {total_offload_memory:.2f}MB") 579 | log.info(f"Transformer blocks on {self.main_device}: {total_main_memory:.2f}MB") 580 | log.info(f"Total memory used by transformer blocks: {(total_offload_memory + total_main_memory):.2f}MB") 581 | log.info("----------------------") 582 | 583 | def forward( 584 | self, 585 | x, 586 | t, 587 | context, 588 | seq_len, 589 | clip_fea=None, 590 | y=None, 591 | device=torch.device('cuda'), 592 | freqs=None, 593 | current_step=0, 594 | pred_id=None 595 | ): 596 | r""" 597 | Forward pass through the diffusion model 598 | 599 | Args: 600 | x (List[Tensor]): 601 | List of input video tensors, each with shape [C_in, F, H, W] 602 | t (Tensor): 603 | Diffusion timesteps tensor of shape [B] 604 | context (List[Tensor]): 605 | List of text embeddings each with shape [L, C] 606 | seq_len (`int`): 607 | Maximum sequence length for positional encoding 608 | clip_fea (Tensor, *optional*): 609 | CLIP image features for image-to-video mode 610 | y (List[Tensor], *optional*): 611 | Conditional video inputs for image-to-video mode, same shape as x 612 | 613 | Returns: 614 | List[Tensor]: 615 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 616 | """ 617 | if self.model_type == 'i2v': 618 | assert clip_fea is not None and y is not None 619 | # params 620 | #device = self.patch_embedding.weight.device 621 | if freqs.device != device: 622 | freqs = freqs.to(device) 623 | 624 | if y is not None: 625 | #torch.Size([20, 17, 58, 104]) torch.Size([16, 17, 58, 104]) 626 | #c ,t,h,w 627 | x = torch.cat([x, y], dim=0) 628 | 629 | # embeddings 630 | x = [self.patch_embedding(x.unsqueeze(0))] 631 | print(x.shape,self.patch_embedding) 632 | grid_sizes = torch.stack( 633 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 634 | x = [u.flatten(2).transpose(1, 2) for u in x] 635 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 636 | assert seq_lens.max() <= seq_len 637 | x = torch.cat([ 638 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 639 | dim=1) for u in x 640 | ]) 641 | 642 | # time embeddings 643 | with torch.autocast(device_type='cuda', dtype=torch.float32): 644 | e = self.time_embedding( 645 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 646 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 647 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 648 | 649 | # context 650 | context_lens = None 651 | if self.offload_txt_emb: 652 | self.text_embedding.to(self.main_device) 653 | context = self.text_embedding( 654 | torch.stack([ 655 | torch.cat( 656 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 657 | for u in context 658 | ])) 659 | if self.offload_txt_emb: 660 | self.text_embedding.to(self.offload_device, non_blocking=True) 661 | 662 | if clip_fea is not None: 663 | if self.offload_img_emb: 664 | self.img_emb.to(self.main_device) 665 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 666 | context = torch.concat([context_clip, context], dim=1) 667 | if self.offload_img_emb: 668 | self.img_emb.to(self.offload_device, non_blocking=True) 669 | 670 | should_calc = True 671 | accumulated_rel_l1_distance = torch.tensor(0.0, dtype=torch.float32, device=device) 672 | if self.enable_teacache and self.teacache_start_step <= current_step <= self.teacache_end_step: 673 | if pred_id is None: 674 | pred_id = self.teacache_state.new_prediction() 675 | #log.info(current_step) 676 | #log.info(f"TeaCache: Initializing TeaCache variables for model pred: {pred_id}") 677 | should_calc = True 678 | else: 679 | previous_modulated_input = self.teacache_state.get(pred_id)['previous_modulated_input'] 680 | previous_modulated_input = previous_modulated_input.to(device) 681 | previous_residual = self.teacache_state.get(pred_id)['previous_residual'] 682 | accumulated_rel_l1_distance = self.teacache_state.get(pred_id)['accumulated_rel_l1_distance'] 683 | 684 | if self.teacache_use_coefficients: 685 | rescale_func = np.poly1d(self.teacache_coefficients) 686 | accumulated_rel_l1_distance += rescale_func(((e-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item()) 687 | else: 688 | temb_relative_l1 = relative_l1_distance(previous_modulated_input, e0) 689 | accumulated_rel_l1_distance = accumulated_rel_l1_distance.to(e0.device) + temb_relative_l1 690 | 691 | #print("accumulated_rel_l1_distance", accumulated_rel_l1_distance) 692 | 693 | if accumulated_rel_l1_distance < self.rel_l1_thresh: 694 | should_calc = False 695 | else: 696 | should_calc = True 697 | accumulated_rel_l1_distance = torch.tensor(0.0, dtype=torch.float32, device=device) 698 | 699 | previous_modulated_input = e.clone() if self.teacache_use_coefficients else e0.clone() 700 | if not should_calc: 701 | x += previous_residual.to(x.device) 702 | #log.info(f"TeaCache: Skipping uncond step {current_step+1}") 703 | self.teacache_state.update( 704 | pred_id, 705 | accumulated_rel_l1_distance=accumulated_rel_l1_distance, 706 | skipped_steps=self.teacache_state.get(pred_id)['skipped_steps'] + 1, 707 | ) 708 | 709 | if not self.enable_teacache or (self.enable_teacache and should_calc): 710 | if self.enable_teacache: 711 | original_x = x.clone() 712 | # arguments 713 | kwargs = dict( 714 | e=e0, 715 | seq_lens=seq_lens, 716 | grid_sizes=grid_sizes, 717 | freqs=freqs, 718 | context=context, 719 | context_lens=context_lens) 720 | 721 | for b, block in enumerate(self.blocks): 722 | if b <= self.blocks_to_swap and self.blocks_to_swap >= 0: 723 | block.to(self.main_device) 724 | x = block(x, **kwargs) 725 | if b <= self.blocks_to_swap and self.blocks_to_swap >= 0: 726 | block.to(self.offload_device, non_blocking=True) 727 | 728 | if self.enable_teacache and pred_id is not None: 729 | self.teacache_state.update( 730 | pred_id, 731 | previous_residual=(x - original_x), 732 | accumulated_rel_l1_distance=accumulated_rel_l1_distance, 733 | previous_modulated_input=previous_modulated_input 734 | ) 735 | #self.teacache_state.report() 736 | 737 | # head 738 | x = self.head(x, e) 739 | # unpatchify 740 | x = self.unpatchify(x, grid_sizes) 741 | return x, pred_id 742 | 743 | def unpatchify(self, x, grid_sizes): 744 | r""" 745 | Reconstruct video tensors from patch embeddings. 746 | 747 | Args: 748 | x (List[Tensor]): 749 | List of patchified features, each with shape [L, C_out * prod(patch_size)] 750 | grid_sizes (Tensor): 751 | Original spatial-temporal grid dimensions before patching, 752 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) 753 | 754 | Returns: 755 | List[Tensor]: 756 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] 757 | """ 758 | 759 | c = self.out_dim 760 | for v in grid_sizes.tolist(): 761 | x = x[:math.prod(v)].view(*v, *self.patch_size, c) 762 | x = torch.einsum('fhwpqrc->cfphqwr', x) 763 | x = x.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) 764 | return x 765 | 766 | class TeaCacheState: 767 | def __init__(self, cache_device='cpu'): 768 | self.cache_device = cache_device 769 | self.states = {} 770 | self._next_pred_id = 0 771 | 772 | def new_prediction(self): 773 | """Create new prediction state and return its ID""" 774 | pred_id = self._next_pred_id 775 | self._next_pred_id += 1 776 | self.states[pred_id] = { 777 | 'previous_residual': None, 778 | 'accumulated_rel_l1_distance': 0, 779 | 'previous_modulated_input': None, 780 | 'skipped_steps': 0 781 | } 782 | return pred_id 783 | 784 | def update(self, pred_id, **kwargs): 785 | """Update state for specific prediction""" 786 | if pred_id not in self.states: 787 | return None 788 | for key, value in kwargs.items(): 789 | if isinstance(value, torch.Tensor): 790 | value = value.to(self.cache_device) 791 | self.states[pred_id][key] = value 792 | 793 | def get(self, pred_id): 794 | return self.states.get(pred_id, {}) 795 | 796 | def report(self): 797 | for pred_id in self.states: 798 | log.info(f"Prediction {pred_id}: {self.states[pred_id]}") 799 | 800 | def clear_prediction(self, pred_id): 801 | if pred_id in self.states: 802 | del self.states[pred_id] 803 | 804 | def clear_all(self): 805 | self.states.clear() 806 | self._next_pred_id = 0 807 | 808 | def relative_l1_distance(last_tensor, current_tensor): 809 | l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean() 810 | norm = torch.abs(last_tensor).mean() 811 | relative_l1_distance = l1_distance / norm 812 | return relative_l1_distance.to(torch.float32).to(current_tensor.device) 813 | 814 | def normalize_values(values): 815 | min_val = min(values) 816 | max_val = max(values) 817 | if max_val == min_val: 818 | return [0.0] * len(values) 819 | return [(x - min_val) / (max_val - min_val) for x in values] 820 | 821 | def rescale_differences(input_diffs, output_diffs): 822 | """Polynomial fitting between input and output differences""" 823 | poly_degree = 4 824 | if len(input_diffs) < 2: 825 | return input_diffs 826 | 827 | x = np.array([x.item() for x in input_diffs]) 828 | y = np.array([y.item() for y in output_diffs]) 829 | print("x ", x) 830 | print("y ", y) 831 | 832 | # Fit polynomial 833 | coeffs = np.polyfit(x, y, poly_degree) 834 | 835 | # Apply polynomial transformation 836 | return np.polyval(coeffs, x) -------------------------------------------------------------------------------- /wanvideo/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import WanModel 2 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 3 | from .tokenizers import HuggingfaceTokenizer 4 | 5 | __all__ = [ 6 | 'WanModel', 7 | 'T5Model', 8 | 'T5Encoder', 9 | 'T5Decoder', 10 | 'T5EncoderModel', 11 | 'HuggingfaceTokenizer', 12 | ] 13 | -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/t5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/t5.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/__pycache__/tokenizers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/modules/__pycache__/tokenizers.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | 4 | try: 5 | import flash_attn_interface 6 | FLASH_ATTN_3_AVAILABLE = True 7 | except ModuleNotFoundError: 8 | FLASH_ATTN_3_AVAILABLE = False 9 | 10 | try: 11 | import flash_attn 12 | FLASH_ATTN_2_AVAILABLE = True 13 | except ModuleNotFoundError: 14 | FLASH_ATTN_2_AVAILABLE = False 15 | 16 | try: 17 | from sageattention import sageattn 18 | @torch.compiler.disable() 19 | def sageattn_func(q, k, v, attn_mask=None, dropout_p=0, is_causal=False): 20 | return sageattn(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) 21 | except Exception as e: 22 | print(f"Warning: Could not load sageattention: {str(e)}") 23 | if isinstance(e, ModuleNotFoundError): 24 | print("sageattention package is not installed") 25 | elif isinstance(e, ImportError) and "DLL" in str(e): 26 | print("sageattention DLL loading error") 27 | sageattn_func = None 28 | import warnings 29 | 30 | __all__ = [ 31 | 'flash_attention', 32 | 'attention', 33 | ] 34 | 35 | 36 | def flash_attention( 37 | q, 38 | k, 39 | v, 40 | q_lens=None, 41 | k_lens=None, 42 | dropout_p=0., 43 | softmax_scale=None, 44 | q_scale=None, 45 | causal=False, 46 | window_size=(-1, -1), 47 | deterministic=False, 48 | dtype=torch.bfloat16, 49 | version=None, 50 | ): 51 | """ 52 | q: [B, Lq, Nq, C1]. 53 | k: [B, Lk, Nk, C1]. 54 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 55 | q_lens: [B]. 56 | k_lens: [B]. 57 | dropout_p: float. Dropout probability. 58 | softmax_scale: float. The scaling of QK^T before applying softmax. 59 | causal: bool. Whether to apply causal attention mask. 60 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 61 | deterministic: bool. If True, slightly slower and uses more memory. 62 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 63 | """ 64 | half_dtypes = (torch.float16, torch.bfloat16) 65 | #assert dtype in half_dtypes 66 | #assert q.device.type == 'cuda' and q.size(-1) <= 256 67 | 68 | # params 69 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 70 | 71 | def half(x): 72 | return x if x.dtype in half_dtypes else x.to(dtype) 73 | 74 | # preprocess query 75 | if q_lens is None: 76 | q = half(q.flatten(0, 1)) 77 | q_lens = torch.tensor( 78 | [lq] * b, dtype=torch.int32).to( 79 | device=q.device, non_blocking=True) 80 | else: 81 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) 82 | 83 | # preprocess key, value 84 | if k_lens is None: 85 | k = half(k.flatten(0, 1)) 86 | v = half(v.flatten(0, 1)) 87 | k_lens = torch.tensor( 88 | [lk] * b, dtype=torch.int32).to( 89 | device=k.device, non_blocking=True) 90 | else: 91 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) 92 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) 93 | 94 | q = q.to(v.dtype) 95 | k = k.to(v.dtype) 96 | 97 | if q_scale is not None: 98 | q = q * q_scale 99 | 100 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 101 | warnings.warn( 102 | 'Flash attention 3 is not available, use flash attention 2 instead.' 103 | ) 104 | 105 | # apply attention 106 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 107 | # Note: dropout_p, window_size are not supported in FA3 now. 108 | x = flash_attn_interface.flash_attn_varlen_func( 109 | q=q, 110 | k=k, 111 | v=v, 112 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 113 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 114 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 115 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 116 | seqused_q=None, 117 | seqused_k=None, 118 | max_seqlen_q=lq, 119 | max_seqlen_k=lk, 120 | softmax_scale=softmax_scale, 121 | causal=causal, 122 | deterministic=deterministic)[0].unflatten(0, (b, lq)) 123 | else: 124 | assert FLASH_ATTN_2_AVAILABLE 125 | x = flash_attn.flash_attn_varlen_func( 126 | q=q, 127 | k=k, 128 | v=v, 129 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 130 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 131 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 132 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 133 | max_seqlen_q=lq, 134 | max_seqlen_k=lk, 135 | dropout_p=dropout_p, 136 | softmax_scale=softmax_scale, 137 | causal=causal, 138 | window_size=window_size, 139 | deterministic=deterministic).unflatten(0, (b, lq)) 140 | 141 | # output 142 | return x.type(out_dtype) 143 | 144 | 145 | def attention( 146 | q, 147 | k, 148 | v, 149 | q_lens=None, 150 | k_lens=None, 151 | dropout_p=0., 152 | softmax_scale=None, 153 | q_scale=None, 154 | causal=False, 155 | window_size=(-1, -1), 156 | deterministic=False, 157 | dtype=torch.bfloat16, 158 | attention_mode='sdpa', 159 | ): 160 | if "flash" in attention_mode: 161 | if attention_mode == 'flash_attn_2': 162 | fa_version = 2 163 | elif attention_mode == 'flash_attn_3': 164 | fa_version = 3 165 | return flash_attention( 166 | q=q, 167 | k=k, 168 | v=v, 169 | q_lens=q_lens, 170 | k_lens=k_lens, 171 | dropout_p=dropout_p, 172 | softmax_scale=softmax_scale, 173 | q_scale=q_scale, 174 | causal=causal, 175 | window_size=window_size, 176 | deterministic=deterministic, 177 | dtype=dtype, 178 | version=fa_version, 179 | ) 180 | elif attention_mode == 'sdpa': 181 | # if q_lens is not None or k_lens is not None: 182 | # warnings.warn( 183 | # 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' 184 | # ) 185 | attn_mask = None 186 | 187 | q = q.transpose(1, 2).to(dtype) 188 | k = k.transpose(1, 2).to(dtype) 189 | v = v.transpose(1, 2).to(dtype) 190 | 191 | out = torch.nn.functional.scaled_dot_product_attention( 192 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 193 | 194 | out = out.transpose(1, 2).contiguous() 195 | return out 196 | elif attention_mode == 'sageattn': 197 | attn_mask = None 198 | 199 | q = q.transpose(1, 2).to(dtype) 200 | k = k.transpose(1, 2).to(dtype) 201 | v = v.transpose(1, 2).to(dtype) 202 | 203 | out = sageattn_func( 204 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 205 | 206 | out = out.transpose(1, 2).contiguous() 207 | return out 208 | -------------------------------------------------------------------------------- /wanvideo/modules/clip.py: -------------------------------------------------------------------------------- 1 | # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms as T 10 | 11 | from .attention import attention 12 | 13 | __all__ = [ 14 | 'XLMRobertaCLIP', 15 | 'clip_xlm_roberta_vit_h_14', 16 | 'CLIPModel', 17 | ] 18 | from accelerate import init_empty_weights 19 | from accelerate.utils import set_module_tensor_to_device 20 | 21 | import comfy.model_management as mm 22 | 23 | def pos_interpolate(pos, seq_len): 24 | if pos.size(1) == seq_len: 25 | return pos 26 | else: 27 | src_grid = int(math.sqrt(pos.size(1))) 28 | tar_grid = int(math.sqrt(seq_len)) 29 | n = pos.size(1) - src_grid * src_grid 30 | return torch.cat([ 31 | pos[:, :n], 32 | F.interpolate( 33 | pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( 34 | 0, 3, 1, 2), 35 | size=(tar_grid, tar_grid), 36 | mode='bicubic', 37 | align_corners=False).flatten(2).transpose(1, 2) 38 | ], 39 | dim=1) 40 | 41 | 42 | class QuickGELU(nn.Module): 43 | 44 | def forward(self, x): 45 | return x * torch.sigmoid(1.702 * x) 46 | 47 | 48 | class LayerNorm(nn.LayerNorm): 49 | 50 | def forward(self, x): 51 | return super().forward(x.float()).type_as(x) 52 | 53 | 54 | class SelfAttention(nn.Module): 55 | 56 | def __init__(self, 57 | dim, 58 | num_heads, 59 | causal=False, 60 | attn_dropout=0.0, 61 | proj_dropout=0.0): 62 | assert dim % num_heads == 0 63 | super().__init__() 64 | self.dim = dim 65 | self.num_heads = num_heads 66 | self.head_dim = dim // num_heads 67 | self.causal = causal 68 | self.attn_dropout = attn_dropout 69 | self.proj_dropout = proj_dropout 70 | 71 | # layers 72 | self.to_qkv = nn.Linear(dim, dim * 3) 73 | self.proj = nn.Linear(dim, dim) 74 | 75 | def forward(self, x): 76 | """ 77 | x: [B, L, C]. 78 | """ 79 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 80 | 81 | # compute query, key, value 82 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) 83 | 84 | # compute attention 85 | p = self.attn_dropout if self.training else 0.0 86 | x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_mode="sdpa") 87 | x = x.reshape(b, s, c) 88 | 89 | # output 90 | x = self.proj(x) 91 | x = F.dropout(x, self.proj_dropout, self.training) 92 | return x 93 | 94 | 95 | class SwiGLU(nn.Module): 96 | 97 | def __init__(self, dim, mid_dim): 98 | super().__init__() 99 | self.dim = dim 100 | self.mid_dim = mid_dim 101 | 102 | # layers 103 | self.fc1 = nn.Linear(dim, mid_dim) 104 | self.fc2 = nn.Linear(dim, mid_dim) 105 | self.fc3 = nn.Linear(mid_dim, dim) 106 | 107 | def forward(self, x): 108 | x = F.silu(self.fc1(x)) * self.fc2(x) 109 | x = self.fc3(x) 110 | return x 111 | 112 | 113 | class AttentionBlock(nn.Module): 114 | 115 | def __init__(self, 116 | dim, 117 | mlp_ratio, 118 | num_heads, 119 | post_norm=False, 120 | causal=False, 121 | activation='quick_gelu', 122 | attn_dropout=0.0, 123 | proj_dropout=0.0, 124 | norm_eps=1e-5): 125 | assert activation in ['quick_gelu', 'gelu', 'swi_glu'] 126 | super().__init__() 127 | self.dim = dim 128 | self.mlp_ratio = mlp_ratio 129 | self.num_heads = num_heads 130 | self.post_norm = post_norm 131 | self.causal = causal 132 | self.norm_eps = norm_eps 133 | 134 | # layers 135 | self.norm1 = LayerNorm(dim, eps=norm_eps) 136 | self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, 137 | proj_dropout) 138 | self.norm2 = LayerNorm(dim, eps=norm_eps) 139 | if activation == 'swi_glu': 140 | self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) 141 | else: 142 | self.mlp = nn.Sequential( 143 | nn.Linear(dim, int(dim * mlp_ratio)), 144 | QuickGELU() if activation == 'quick_gelu' else nn.GELU(), 145 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) 146 | 147 | def forward(self, x): 148 | if self.post_norm: 149 | x = x + self.norm1(self.attn(x)) 150 | x = x + self.norm2(self.mlp(x)) 151 | else: 152 | x = x + self.attn(self.norm1(x)) 153 | x = x + self.mlp(self.norm2(x)) 154 | return x 155 | 156 | class VisionTransformer(nn.Module): 157 | 158 | def __init__(self, 159 | image_size=224, 160 | patch_size=16, 161 | dim=768, 162 | mlp_ratio=4, 163 | out_dim=512, 164 | num_heads=12, 165 | num_layers=12, 166 | pool_type='token', 167 | pre_norm=True, 168 | post_norm=False, 169 | activation='quick_gelu', 170 | attn_dropout=0.0, 171 | proj_dropout=0.0, 172 | embedding_dropout=0.0, 173 | norm_eps=1e-5): 174 | if image_size % patch_size != 0: 175 | print( 176 | '[WARNING] image_size is not divisible by patch_size', 177 | flush=True) 178 | assert pool_type in ('token', 'token_fc', 'attn_pool') 179 | out_dim = out_dim or dim 180 | super().__init__() 181 | self.image_size = image_size 182 | self.patch_size = patch_size 183 | self.num_patches = (image_size // patch_size)**2 184 | self.dim = dim 185 | self.mlp_ratio = mlp_ratio 186 | self.out_dim = out_dim 187 | self.num_heads = num_heads 188 | self.num_layers = num_layers 189 | self.pool_type = pool_type 190 | self.post_norm = post_norm 191 | self.norm_eps = norm_eps 192 | 193 | # embeddings 194 | gain = 1.0 / math.sqrt(dim) 195 | self.patch_embedding = nn.Conv2d( 196 | 3, 197 | dim, 198 | kernel_size=patch_size, 199 | stride=patch_size, 200 | bias=not pre_norm) 201 | if pool_type in ('token', 'token_fc'): 202 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 203 | self.pos_embedding = nn.Parameter(gain * torch.randn( 204 | 1, self.num_patches + 205 | (1 if pool_type in ('token', 'token_fc') else 0), dim)) 206 | self.dropout = nn.Dropout(embedding_dropout) 207 | 208 | # transformer 209 | self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None 210 | self.transformer = nn.Sequential(*[ 211 | AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, 212 | activation, attn_dropout, proj_dropout, norm_eps) 213 | for _ in range(num_layers) 214 | ]) 215 | self.post_norm = LayerNorm(dim, eps=norm_eps) 216 | 217 | # head 218 | if pool_type == 'token': 219 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) 220 | elif pool_type == 'token_fc': 221 | self.head = nn.Linear(dim, out_dim) 222 | 223 | def forward(self, x, interpolation=False, use_31_block=False): 224 | b = x.size(0) 225 | 226 | # embeddings 227 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) 228 | if self.pool_type in ('token', 'token_fc'): 229 | x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) 230 | if interpolation: 231 | e = pos_interpolate(self.pos_embedding, x.size(1)) 232 | else: 233 | e = self.pos_embedding 234 | x = self.dropout(x + e) 235 | if self.pre_norm is not None: 236 | x = self.pre_norm(x) 237 | 238 | # transformer 239 | if use_31_block: 240 | x = self.transformer[:-1](x) 241 | return x 242 | else: 243 | x = self.transformer(x) 244 | return x 245 | 246 | 247 | class XLMRobertaCLIP(nn.Module): 248 | 249 | def __init__(self, 250 | embed_dim=1024, 251 | image_size=224, 252 | patch_size=14, 253 | vision_dim=1280, 254 | vision_mlp_ratio=4, 255 | vision_heads=16, 256 | vision_layers=32, 257 | vision_pool='token', 258 | vision_pre_norm=True, 259 | vision_post_norm=False, 260 | activation='gelu', 261 | vocab_size=250002, 262 | max_text_len=514, 263 | type_size=1, 264 | pad_id=1, 265 | text_dim=1024, 266 | text_heads=16, 267 | text_layers=24, 268 | text_post_norm=True, 269 | text_dropout=0.1, 270 | attn_dropout=0.0, 271 | proj_dropout=0.0, 272 | embedding_dropout=0.0, 273 | norm_eps=1e-5): 274 | super().__init__() 275 | self.embed_dim = embed_dim 276 | self.image_size = image_size 277 | self.patch_size = patch_size 278 | self.vision_dim = vision_dim 279 | self.vision_mlp_ratio = vision_mlp_ratio 280 | self.vision_heads = vision_heads 281 | self.vision_layers = vision_layers 282 | self.vision_pre_norm = vision_pre_norm 283 | self.vision_post_norm = vision_post_norm 284 | self.activation = activation 285 | self.vocab_size = vocab_size 286 | self.max_text_len = max_text_len 287 | self.type_size = type_size 288 | self.pad_id = pad_id 289 | self.text_dim = text_dim 290 | self.text_heads = text_heads 291 | self.text_layers = text_layers 292 | self.text_post_norm = text_post_norm 293 | self.norm_eps = norm_eps 294 | 295 | # models 296 | self.visual = VisionTransformer( 297 | image_size=image_size, 298 | patch_size=patch_size, 299 | dim=vision_dim, 300 | mlp_ratio=vision_mlp_ratio, 301 | out_dim=embed_dim, 302 | num_heads=vision_heads, 303 | num_layers=vision_layers, 304 | pool_type=vision_pool, 305 | pre_norm=vision_pre_norm, 306 | post_norm=vision_post_norm, 307 | activation=activation, 308 | attn_dropout=attn_dropout, 309 | proj_dropout=proj_dropout, 310 | embedding_dropout=embedding_dropout, 311 | norm_eps=norm_eps) 312 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) 313 | 314 | def forward(self, imgs, txt_ids): 315 | """ 316 | imgs: [B, 3, H, W] of torch.float32. 317 | - mean: [0.48145466, 0.4578275, 0.40821073] 318 | - std: [0.26862954, 0.26130258, 0.27577711] 319 | txt_ids: [B, L] of torch.long. 320 | Encoded by data.CLIPTokenizer. 321 | """ 322 | xi = self.visual(imgs) 323 | xt = self.textual(txt_ids) 324 | return xi, xt 325 | 326 | def param_groups(self): 327 | groups = [{ 328 | 'params': [ 329 | p for n, p in self.named_parameters() 330 | if 'norm' in n or n.endswith('bias') 331 | ], 332 | 'weight_decay': 0.0 333 | }, { 334 | 'params': [ 335 | p for n, p in self.named_parameters() 336 | if not ('norm' in n or n.endswith('bias')) 337 | ] 338 | }] 339 | return groups 340 | 341 | 342 | def _clip(pretrained=False, 343 | pretrained_name=None, 344 | model_cls=XLMRobertaCLIP, 345 | return_transforms=False, 346 | return_tokenizer=False, 347 | tokenizer_padding='eos', 348 | dtype=torch.float32, 349 | device='cpu', 350 | **kwargs): 351 | # init a model on device 352 | with torch.device(device): 353 | model = model_cls(**kwargs) 354 | 355 | # set device 356 | #model = model.to(dtype=dtype, device=device) 357 | output = (model,) 358 | 359 | # init transforms 360 | if return_transforms: 361 | # mean and std 362 | if 'siglip' in pretrained_name.lower(): 363 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 364 | else: 365 | mean = [0.48145466, 0.4578275, 0.40821073] 366 | std = [0.26862954, 0.26130258, 0.27577711] 367 | 368 | # transforms 369 | transforms = T.Compose([ 370 | T.Resize((model.image_size, model.image_size), 371 | interpolation=T.InterpolationMode.BICUBIC), 372 | T.ToTensor(), 373 | T.Normalize(mean=mean, std=std) 374 | ]) 375 | output += (transforms,) 376 | return output[0] if len(output) == 1 else output 377 | 378 | 379 | def clip_xlm_roberta_vit_h_14( 380 | pretrained=False, 381 | pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', 382 | **kwargs): 383 | cfg = dict( 384 | embed_dim=1024, 385 | image_size=224, 386 | patch_size=14, 387 | vision_dim=1280, 388 | vision_mlp_ratio=4, 389 | vision_heads=16, 390 | vision_layers=32, 391 | vision_pool='token', 392 | activation='gelu', 393 | vocab_size=250002, 394 | max_text_len=514, 395 | type_size=1, 396 | pad_id=1, 397 | text_dim=1024, 398 | text_heads=16, 399 | text_layers=24, 400 | text_post_norm=True, 401 | text_dropout=0.1, 402 | attn_dropout=0.0, 403 | proj_dropout=0.0, 404 | embedding_dropout=0.0) 405 | cfg.update(**kwargs) 406 | return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) 407 | 408 | 409 | class CLIPModel: 410 | 411 | def __init__(self, dtype, device, state_dict): 412 | self.dtype = dtype 413 | self.device = device 414 | 415 | # init model 416 | with init_empty_weights(): 417 | self.model, self.transforms = clip_xlm_roberta_vit_h_14( 418 | pretrained=False, 419 | return_transforms=True, 420 | return_tokenizer=False, 421 | dtype=dtype, 422 | device=device 423 | ) 424 | self.model = self.model.eval().requires_grad_(False) 425 | 426 | for name, param in self.model.named_parameters(): 427 | set_module_tensor_to_device(self.model, name, device=device, dtype=dtype, value=state_dict[name]) 428 | 429 | def visual(self, image): 430 | # forward 431 | with torch.autocast(device_type=mm.get_autocast_device(self.device), dtype=self.dtype): 432 | out = self.model.visual(image, use_31_block=True) 433 | return out 434 | -------------------------------------------------------------------------------- /wanvideo/modules/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models.modeling_utils import ModelMixin 8 | 9 | from ...enhance_a_video.enhance import get_feta_scores 10 | from ...enhance_a_video.globals import is_enhance_enabled 11 | 12 | from .attention import attention 13 | import numpy as np 14 | __all__ = ['WanModel'] 15 | 16 | from tqdm import tqdm 17 | import gc 18 | import comfy.model_management as mm 19 | from ...utils import log, get_module_memory_mb 20 | 21 | def poly1d(coefficients, x): 22 | result = torch.zeros_like(x) 23 | for i, coeff in enumerate(coefficients): 24 | result += coeff * (x ** (len(coefficients) - 1 - i)) 25 | return result.abs() 26 | 27 | def sinusoidal_embedding_1d(dim, position): 28 | # preprocess 29 | assert dim % 2 == 0 30 | half = dim // 2 31 | position = position.type(torch.float64) 32 | 33 | # calculation 34 | sinusoid = torch.outer( 35 | position, torch.pow(10000, -torch.arange(half).to(position).div(half))) 36 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) 37 | return x 38 | 39 | 40 | def rope_params(max_seq_len, dim, theta=10000, L_test=25, k=0): 41 | assert dim % 2 == 0 42 | exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim) 43 | inv_theta_pow = 1.0 / torch.pow(theta, exponents) 44 | 45 | if k > 0: 46 | print(f"RifleX: Using {k}th freq") 47 | inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test 48 | 49 | freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow) 50 | freqs = torch.polar(torch.ones_like(freqs), freqs) 51 | return freqs 52 | 53 | from comfy.model_management import get_torch_device, get_autocast_device 54 | @torch.autocast(device_type=get_autocast_device(get_torch_device()), enabled=False) 55 | @torch.compiler.disable() 56 | def rope_apply(x, grid_sizes, freqs): 57 | n, c = x.size(2), x.size(3) // 2 58 | 59 | # split freqs 60 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 61 | 62 | # loop over samples 63 | output = [] 64 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 65 | seq_len = f * h * w 66 | 67 | # precompute multipliers 68 | x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( 69 | seq_len, n, -1, 2)) 70 | freqs_i = torch.cat([ 71 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 72 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 73 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 74 | ], 75 | dim=-1).reshape(seq_len, 1, -1) 76 | 77 | # apply rotary embedding 78 | x_i = torch.view_as_real(x_i * freqs_i).flatten(2) 79 | x_i = torch.cat([x_i, x[i, seq_len:]]) 80 | 81 | # append to collection 82 | output.append(x_i) 83 | return torch.stack(output).float() 84 | 85 | 86 | class WanRMSNorm(nn.Module): 87 | 88 | def __init__(self, dim, eps=1e-5): 89 | super().__init__() 90 | self.dim = dim 91 | self.eps = eps 92 | self.weight = nn.Parameter(torch.ones(dim)) 93 | 94 | def forward(self, x): 95 | r""" 96 | Args: 97 | x(Tensor): Shape [B, L, C] 98 | """ 99 | return self._norm(x.float()).type_as(x) * self.weight 100 | 101 | def _norm(self, x): 102 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) 103 | 104 | 105 | class WanLayerNorm(nn.LayerNorm): 106 | 107 | def __init__(self, dim, eps=1e-6, elementwise_affine=False): 108 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) 109 | 110 | def forward(self, x): 111 | r""" 112 | Args: 113 | x(Tensor): Shape [B, L, C] 114 | """ 115 | return super().forward(x.float()).type_as(x) 116 | 117 | 118 | class WanSelfAttention(nn.Module): 119 | 120 | def __init__(self, 121 | dim, 122 | num_heads, 123 | window_size=(-1, -1), 124 | qk_norm=True, 125 | eps=1e-6, 126 | attention_mode='sdpa'): 127 | assert dim % num_heads == 0 128 | super().__init__() 129 | self.dim = dim 130 | self.num_heads = num_heads 131 | self.head_dim = dim // num_heads 132 | self.window_size = window_size 133 | self.qk_norm = qk_norm 134 | self.eps = eps 135 | self.attention_mode = attention_mode 136 | 137 | # layers 138 | self.q = nn.Linear(dim, dim) 139 | self.k = nn.Linear(dim, dim) 140 | self.v = nn.Linear(dim, dim) 141 | self.o = nn.Linear(dim, dim) 142 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 143 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 144 | 145 | def forward(self, x, seq_lens, grid_sizes, freqs): 146 | r""" 147 | Args: 148 | x(Tensor): Shape [B, L, num_heads, C / num_heads] 149 | seq_lens(Tensor): Shape [B] 150 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 151 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 152 | """ 153 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 154 | 155 | # query, key, value function 156 | def qkv_fn(x): 157 | q = self.norm_q(self.q(x)).view(b, s, n, d) 158 | k = self.norm_k(self.k(x)).view(b, s, n, d) 159 | v = self.v(x).view(b, s, n, d) 160 | return q, k, v 161 | 162 | q, k, v = qkv_fn(x) 163 | 164 | if self.attention_mode == 'spargeattn_tune' or self.attention_mode == 'spargeattn': 165 | tune_mode = False 166 | if self.attention_mode == 'spargeattn_tune': 167 | tune_mode = True 168 | 169 | if hasattr(self, 'inner_attention'): 170 | #print("has inner attention") 171 | q=rope_apply(q, grid_sizes, freqs) 172 | k=rope_apply(k, grid_sizes, freqs) 173 | q = q.permute(0, 2, 1, 3) 174 | k = k.permute(0, 2, 1, 3) 175 | v = v.permute(0, 2, 1, 3) 176 | x = self.inner_attention( 177 | q=q, 178 | k=k, 179 | v=v, 180 | is_causal=False, 181 | tune_mode=tune_mode 182 | ).permute(0, 2, 1, 3) 183 | #print("inner attention", x.shape) #inner attention torch.Size([1, 12, 32760, 128]) 184 | else: 185 | q=rope_apply(q, grid_sizes, freqs) 186 | k=rope_apply(k, grid_sizes, freqs) 187 | if is_enhance_enabled(): 188 | feta_scores = get_feta_scores(q, k) 189 | 190 | x = attention( 191 | q=q, 192 | k=k, 193 | v=v, 194 | k_lens=seq_lens, 195 | window_size=self.window_size, 196 | attention_mode=self.attention_mode) 197 | 198 | # output 199 | x = x.flatten(2) 200 | x = self.o(x) 201 | 202 | if is_enhance_enabled(): 203 | x *= feta_scores 204 | 205 | return x 206 | 207 | 208 | class WanT2VCrossAttention(WanSelfAttention): 209 | 210 | def forward(self, x, context, context_lens): 211 | r""" 212 | Args: 213 | x(Tensor): Shape [B, L1, C] 214 | context(Tensor): Shape [B, L2, C] 215 | context_lens(Tensor): Shape [B] 216 | """ 217 | b, n, d = x.size(0), self.num_heads, self.head_dim 218 | 219 | # compute query, key, value 220 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 221 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 222 | v = self.v(context).view(b, -1, n, d) 223 | 224 | # compute attention 225 | x = attention(q, k, v, k_lens=context_lens, attention_mode=self.attention_mode) 226 | 227 | # output 228 | x = x.flatten(2) 229 | x = self.o(x) 230 | return x 231 | 232 | 233 | class WanI2VCrossAttention(WanSelfAttention): 234 | 235 | def __init__(self, 236 | dim, 237 | num_heads, 238 | window_size=(-1, -1), 239 | qk_norm=True, 240 | eps=1e-6, 241 | attention_mode='sdpa'): 242 | super().__init__(dim, num_heads, window_size, qk_norm, eps) 243 | 244 | self.k_img = nn.Linear(dim, dim) 245 | self.v_img = nn.Linear(dim, dim) 246 | # self.alpha = nn.Parameter(torch.zeros((1, ))) 247 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 248 | self.attention_mode = attention_mode 249 | 250 | def forward(self, x, context, context_lens): 251 | r""" 252 | Args: 253 | x(Tensor): Shape [B, L1, C] 254 | context(Tensor): Shape [B, L2, C] 255 | context_lens(Tensor): Shape [B] 256 | """ 257 | context_img = context[:, :257] 258 | context = context[:, 257:] 259 | b, n, d = x.size(0), self.num_heads, self.head_dim 260 | 261 | # compute query, key, value 262 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 263 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 264 | v = self.v(context).view(b, -1, n, d) 265 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) 266 | v_img = self.v_img(context_img).view(b, -1, n, d) 267 | img_x = attention(q, k_img, v_img, k_lens=None, attention_mode=self.attention_mode) 268 | # compute attention 269 | x = attention(q, k, v, k_lens=context_lens, attention_mode=self.attention_mode) 270 | 271 | # output 272 | x = x.flatten(2) 273 | img_x = img_x.flatten(2) 274 | x = x + img_x 275 | x = self.o(x) 276 | return x 277 | 278 | 279 | WAN_CROSSATTENTION_CLASSES = { 280 | 't2v_cross_attn': WanT2VCrossAttention, 281 | 'i2v_cross_attn': WanI2VCrossAttention, 282 | } 283 | 284 | 285 | class WanAttentionBlock(nn.Module): 286 | 287 | def __init__(self, 288 | cross_attn_type, 289 | dim, 290 | ffn_dim, 291 | num_heads, 292 | window_size=(-1, -1), 293 | qk_norm=True, 294 | cross_attn_norm=False, 295 | eps=1e-6, 296 | attention_mode='sdpa'): 297 | super().__init__() 298 | self.dim = dim 299 | self.ffn_dim = ffn_dim 300 | self.num_heads = num_heads 301 | self.window_size = window_size 302 | self.qk_norm = qk_norm 303 | self.cross_attn_norm = cross_attn_norm 304 | self.eps = eps 305 | self.attention_mode = attention_mode 306 | 307 | # layers 308 | self.norm1 = WanLayerNorm(dim, eps) 309 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, 310 | eps, self.attention_mode) 311 | self.norm3 = WanLayerNorm( 312 | dim, eps, 313 | elementwise_affine=True) if cross_attn_norm else nn.Identity() 314 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, 315 | num_heads, 316 | (-1, -1), 317 | qk_norm, 318 | eps,#attention_mode=attention_mode sageattn doesn't seem faster here 319 | ) 320 | self.norm2 = WanLayerNorm(dim, eps) 321 | self.ffn = nn.Sequential( 322 | nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), 323 | nn.Linear(ffn_dim, dim)) 324 | 325 | # modulation 326 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 327 | 328 | def forward( 329 | self, 330 | x, 331 | e, 332 | seq_lens, 333 | grid_sizes, 334 | freqs, 335 | context, 336 | context_lens, 337 | ): 338 | r""" 339 | Args: 340 | x(Tensor): Shape [B, L, C] 341 | e(Tensor): Shape [B, 6, C] 342 | seq_lens(Tensor): Shape [B], length of each sequence in batch 343 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 344 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 345 | """ 346 | assert e.dtype == torch.float32 347 | e = (self.modulation.to(torch.float32).to(e.device) + e.to(torch.float32)).chunk(6, dim=1) 348 | assert e[0].dtype == torch.float32 349 | 350 | # self-attention 351 | y = self.self_attn( 352 | self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, 353 | freqs) 354 | x = x.to(torch.float32) + (y.to(torch.float32) * e[2].to(torch.float32)) 355 | 356 | # cross-attention & ffn function 357 | def cross_attn_ffn(x, context, context_lens, e): 358 | x = x + self.cross_attn(self.norm3(x), context, context_lens) 359 | y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) 360 | x = x.to(torch.float32) + (y.to(torch.float32) * e[5].to(torch.float32)) 361 | return x 362 | 363 | x = cross_attn_ffn(x, context, context_lens, e) 364 | return x 365 | 366 | 367 | class Head(nn.Module): 368 | 369 | def __init__(self, dim, out_dim, patch_size, eps=1e-6): 370 | super().__init__() 371 | self.dim = dim 372 | self.out_dim = out_dim 373 | self.patch_size = patch_size 374 | self.eps = eps 375 | 376 | # layers 377 | out_dim = math.prod(patch_size) * out_dim 378 | self.norm = WanLayerNorm(dim, eps) 379 | self.head = nn.Linear(dim, out_dim) 380 | 381 | # modulation 382 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) 383 | 384 | def forward(self, x, e): 385 | r""" 386 | Args: 387 | x(Tensor): Shape [B, L1, C] 388 | e(Tensor): Shape [B, C] 389 | """ 390 | assert e.dtype == torch.float32 391 | e_unsqueezed = e.unsqueeze(1).to(torch.float32) 392 | e = (self.modulation.to(torch.float32).to(e.device) + e_unsqueezed).chunk(2, dim=1) 393 | normed = self.norm(x).to(torch.float32) 394 | x = self.head(normed * (1 + e[1].to(torch.float32)) + e[0].to(torch.float32)) 395 | return x 396 | 397 | 398 | class MLPProj(torch.nn.Module): 399 | 400 | def __init__(self, in_dim, out_dim): 401 | super().__init__() 402 | 403 | self.proj = torch.nn.Sequential( 404 | torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), 405 | torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), 406 | torch.nn.LayerNorm(out_dim)) 407 | 408 | def forward(self, image_embeds): 409 | clip_extra_context_tokens = self.proj(image_embeds) 410 | return clip_extra_context_tokens 411 | 412 | 413 | class WanModel(ModelMixin, ConfigMixin): 414 | r""" 415 | Wan diffusion backbone supporting both text-to-video and image-to-video. 416 | """ 417 | 418 | ignore_for_config = [ 419 | 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' 420 | ] 421 | _no_split_modules = ['WanAttentionBlock'] 422 | 423 | @register_to_config 424 | def __init__(self, 425 | model_type='t2v', 426 | patch_size=(1, 2, 2), 427 | text_len=512, 428 | in_dim=16, 429 | dim=2048, 430 | ffn_dim=8192, 431 | freq_dim=256, 432 | text_dim=4096, 433 | out_dim=16, 434 | num_heads=16, 435 | num_layers=32, 436 | window_size=(-1, -1), 437 | qk_norm=True, 438 | cross_attn_norm=True, 439 | eps=1e-6, 440 | attention_mode='sdpa', 441 | main_device=torch.device('cuda'), 442 | offload_device=torch.device('cpu'), 443 | teacache_coefficients=[],): 444 | r""" 445 | Initialize the diffusion model backbone. 446 | 447 | Args: 448 | model_type (`str`, *optional*, defaults to 't2v'): 449 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) 450 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 451 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) 452 | text_len (`int`, *optional*, defaults to 512): 453 | Fixed length for text embeddings 454 | in_dim (`int`, *optional*, defaults to 16): 455 | Input video channels (C_in) 456 | dim (`int`, *optional*, defaults to 2048): 457 | Hidden dimension of the transformer 458 | ffn_dim (`int`, *optional*, defaults to 8192): 459 | Intermediate dimension in feed-forward network 460 | freq_dim (`int`, *optional*, defaults to 256): 461 | Dimension for sinusoidal time embeddings 462 | text_dim (`int`, *optional*, defaults to 4096): 463 | Input dimension for text embeddings 464 | out_dim (`int`, *optional*, defaults to 16): 465 | Output video channels (C_out) 466 | num_heads (`int`, *optional*, defaults to 16): 467 | Number of attention heads 468 | num_layers (`int`, *optional*, defaults to 32): 469 | Number of transformer blocks 470 | window_size (`tuple`, *optional*, defaults to (-1, -1)): 471 | Window size for local attention (-1 indicates global attention) 472 | qk_norm (`bool`, *optional*, defaults to True): 473 | Enable query/key normalization 474 | cross_attn_norm (`bool`, *optional*, defaults to False): 475 | Enable cross-attention normalization 476 | eps (`float`, *optional*, defaults to 1e-6): 477 | Epsilon value for normalization layers 478 | """ 479 | 480 | super().__init__() 481 | 482 | assert model_type in ['t2v', 'i2v'] 483 | self.model_type = model_type 484 | 485 | self.patch_size = patch_size 486 | self.text_len = text_len 487 | self.in_dim = in_dim 488 | self.dim = dim 489 | self.ffn_dim = ffn_dim 490 | self.freq_dim = freq_dim 491 | self.text_dim = text_dim 492 | self.out_dim = out_dim 493 | self.num_heads = num_heads 494 | self.num_layers = num_layers 495 | self.window_size = window_size 496 | self.qk_norm = qk_norm 497 | self.cross_attn_norm = cross_attn_norm 498 | self.eps = eps 499 | self.attention_mode = attention_mode 500 | self.main_device = main_device 501 | self.offload_device = offload_device 502 | 503 | self.blocks_to_swap = -1 504 | self.offload_txt_emb = False 505 | self.offload_img_emb = False 506 | 507 | #init TeaCache variables 508 | self.enable_teacache = False 509 | self.rel_l1_thresh = 0.15 510 | self.teacache_start_step= 0 511 | self.teacache_end_step = -1 512 | self.teacache_cache_device = main_device 513 | self.teacache_state = TeaCacheState() 514 | self.teacache_coefficients = teacache_coefficients 515 | self.teacache_use_coefficients = False 516 | 517 | 518 | self.slg_blocks = None 519 | self.slg_start_percent = 0.0 520 | self.slg_end_percent = 1.0 521 | # self.l1_history_x = [] 522 | # self.l1_history_temb = [] 523 | # self.l1_history_rescaled = [] 524 | 525 | # embeddings 526 | self.patch_embedding = nn.Conv3d( 527 | in_dim, dim, kernel_size=patch_size, stride=patch_size) 528 | self.text_embedding = nn.Sequential( 529 | nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), 530 | nn.Linear(dim, dim)) 531 | 532 | self.time_embedding = nn.Sequential( 533 | nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) 534 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) 535 | 536 | # blocks 537 | cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' 538 | self.blocks = nn.ModuleList([ 539 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, 540 | window_size, qk_norm, cross_attn_norm, eps, 541 | attention_mode=self.attention_mode) 542 | for _ in range(num_layers) 543 | ]) 544 | 545 | # head 546 | self.head = Head(dim, out_dim, patch_size, eps) 547 | 548 | # buffers (don't use register_buffer otherwise dtype will be changed in to()) 549 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 550 | 551 | 552 | if model_type == 'i2v': 553 | self.img_emb = MLPProj(1280, dim) 554 | 555 | # initialize weights 556 | #self.init_weights() 557 | 558 | def block_swap(self, blocks_to_swap, offload_txt_emb=False, offload_img_emb=False): 559 | print(f"Swapping {blocks_to_swap + 1} transformer blocks") 560 | self.blocks_to_swap = blocks_to_swap 561 | self.offload_img_emb = offload_img_emb 562 | self.offload_txt_emb = offload_txt_emb 563 | 564 | total_offload_memory = 0 565 | total_main_memory = 0 566 | 567 | for b, block in tqdm(enumerate(self.blocks), total=len(self.blocks), desc="Initializing block swap"): 568 | block_memory = get_module_memory_mb(block) 569 | 570 | if b > self.blocks_to_swap: 571 | block.to(self.main_device) 572 | total_main_memory += block_memory 573 | else: 574 | block.to(self.offload_device) 575 | total_offload_memory += block_memory 576 | 577 | mm.soft_empty_cache() 578 | gc.collect() 579 | 580 | #print(f"Block {b}: {block_memory:.2f}MB on {block.parameters().__next__().device}") 581 | log.info("----------------------") 582 | log.info(f"Block swap memory summary:") 583 | log.info(f"Transformer blocks on {self.offload_device}: {total_offload_memory:.2f}MB") 584 | log.info(f"Transformer blocks on {self.main_device}: {total_main_memory:.2f}MB") 585 | log.info(f"Total memory used by transformer blocks: {(total_offload_memory + total_main_memory):.2f}MB") 586 | log.info("----------------------") 587 | 588 | def forward( 589 | self, 590 | x, 591 | t, 592 | context, 593 | seq_len, 594 | is_uncond=False, 595 | current_step_percentage=0.0, 596 | clip_fea=None, 597 | y=None, 598 | device=torch.device('cuda'), 599 | freqs=None, 600 | current_step=0, 601 | pred_id=None 602 | ): 603 | r""" 604 | Forward pass through the diffusion model 605 | 606 | Args: 607 | x (List[Tensor]): 608 | List of input video tensors, each with shape [C_in, F, H, W] 609 | t (Tensor): 610 | Diffusion timesteps tensor of shape [B] 611 | context (List[Tensor]): 612 | List of text embeddings each with shape [L, C] 613 | seq_len (`int`): 614 | Maximum sequence length for positional encoding 615 | clip_fea (Tensor, *optional*): 616 | CLIP image features for image-to-video mode 617 | y (List[Tensor], *optional*): 618 | Conditional video inputs for image-to-video mode, same shape as x 619 | 620 | Returns: 621 | List[Tensor]: 622 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 623 | """ 624 | if self.model_type == 'i2v': 625 | assert clip_fea is not None and y is not None 626 | # params 627 | #device = self.patch_embedding.weight.device 628 | if freqs.device != device: 629 | freqs = freqs.to(device) 630 | 631 | if y is not None: 632 | #torch.Size([20, 17, 58, 104]) torch.Size([16, 17, 58, 104]) 633 | #c ,t,h,w 634 | x = torch.cat([x, y], dim=0) 635 | 636 | # embeddings 637 | x = [self.patch_embedding(x.unsqueeze(0))] 638 | grid_sizes = torch.stack( 639 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 640 | x = [u.flatten(2).transpose(1, 2) for u in x] 641 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 642 | assert seq_lens.max() <= seq_len 643 | x = torch.cat([ 644 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 645 | dim=1) for u in x 646 | ]) 647 | 648 | # time embeddings 649 | with torch.autocast(device_type='cuda', dtype=torch.float32): 650 | e = self.time_embedding( 651 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 652 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 653 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 654 | 655 | # context 656 | context_lens = None 657 | if self.offload_txt_emb: 658 | self.text_embedding.to(self.main_device) 659 | context = self.text_embedding( 660 | torch.stack([ 661 | torch.cat( 662 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 663 | for u in context 664 | ])) 665 | if self.offload_txt_emb: 666 | self.text_embedding.to(self.offload_device, non_blocking=True) 667 | 668 | if clip_fea is not None: 669 | if self.offload_img_emb: 670 | self.img_emb.to(self.main_device) 671 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 672 | context = torch.concat([context_clip, context], dim=1) 673 | if self.offload_img_emb: 674 | self.img_emb.to(self.offload_device, non_blocking=True) 675 | 676 | should_calc = True 677 | accumulated_rel_l1_distance = torch.tensor(0.0, dtype=torch.float32, device=device) 678 | if self.enable_teacache and self.teacache_start_step <= current_step <= self.teacache_end_step: 679 | if pred_id is None: 680 | pred_id = self.teacache_state.new_prediction() 681 | #log.info(current_step) 682 | #log.info(f"TeaCache: Initializing TeaCache variables for model pred: {pred_id}") 683 | should_calc = True 684 | else: 685 | previous_modulated_input = self.teacache_state.get(pred_id)['previous_modulated_input'] 686 | previous_modulated_input = previous_modulated_input.to(device) 687 | previous_residual = self.teacache_state.get(pred_id)['previous_residual'] 688 | accumulated_rel_l1_distance = self.teacache_state.get(pred_id)['accumulated_rel_l1_distance'] 689 | 690 | if self.teacache_use_coefficients: 691 | rescale_func = np.poly1d(self.teacache_coefficients) 692 | accumulated_rel_l1_distance += rescale_func(((e-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item()) 693 | else: 694 | temb_relative_l1 = relative_l1_distance(previous_modulated_input, e0) 695 | accumulated_rel_l1_distance = accumulated_rel_l1_distance.to(e0.device) + temb_relative_l1 696 | 697 | #print("accumulated_rel_l1_distance", accumulated_rel_l1_distance) 698 | 699 | if accumulated_rel_l1_distance < self.rel_l1_thresh: 700 | should_calc = False 701 | else: 702 | should_calc = True 703 | accumulated_rel_l1_distance = torch.tensor(0.0, dtype=torch.float32, device=device) 704 | 705 | previous_modulated_input = e.clone() if self.teacache_use_coefficients else e0.clone() 706 | if not should_calc: 707 | x += previous_residual.to(x.device) 708 | #log.info(f"TeaCache: Skipping uncond step {current_step+1}") 709 | self.teacache_state.update( 710 | pred_id, 711 | accumulated_rel_l1_distance=accumulated_rel_l1_distance, 712 | skipped_steps=self.teacache_state.get(pred_id)['skipped_steps'] + 1, 713 | ) 714 | 715 | if not self.enable_teacache or (self.enable_teacache and should_calc): 716 | if self.enable_teacache: 717 | original_x = x.clone() 718 | # arguments 719 | kwargs = dict( 720 | e=e0, 721 | seq_lens=seq_lens, 722 | grid_sizes=grid_sizes, 723 | freqs=freqs, 724 | context=context, 725 | context_lens=context_lens) 726 | 727 | for b, block in enumerate(self.blocks): 728 | if self.slg_blocks is not None: 729 | if b in self.slg_blocks and is_uncond: 730 | if self.slg_start_percent <= current_step_percentage <= self.slg_end_percent: 731 | continue 732 | if b <= self.blocks_to_swap and self.blocks_to_swap >= 0: 733 | block.to(self.main_device) 734 | x = block(x, **kwargs) 735 | if b <= self.blocks_to_swap and self.blocks_to_swap >= 0: 736 | block.to(self.offload_device, non_blocking=True) 737 | 738 | if self.enable_teacache and pred_id is not None: 739 | self.teacache_state.update( 740 | pred_id, 741 | previous_residual=(x - original_x), 742 | accumulated_rel_l1_distance=accumulated_rel_l1_distance, 743 | previous_modulated_input=previous_modulated_input 744 | ) 745 | #self.teacache_state.report() 746 | 747 | # head 748 | x = self.head(x, e) 749 | # unpatchify 750 | x = self.unpatchify(x, grid_sizes) 751 | return x, pred_id 752 | 753 | def unpatchify(self, x, grid_sizes): 754 | r""" 755 | Reconstruct video tensors from patch embeddings. 756 | 757 | Args: 758 | x (List[Tensor]): 759 | List of patchified features, each with shape [L, C_out * prod(patch_size)] 760 | grid_sizes (Tensor): 761 | Original spatial-temporal grid dimensions before patching, 762 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) 763 | 764 | Returns: 765 | List[Tensor]: 766 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] 767 | """ 768 | 769 | c = self.out_dim 770 | for v in grid_sizes.tolist(): 771 | x = x[:math.prod(v)].view(*v, *self.patch_size, c) 772 | x = torch.einsum('fhwpqrc->cfphqwr', x) 773 | x = x.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) 774 | return x 775 | 776 | class TeaCacheState: 777 | def __init__(self, cache_device='cpu'): 778 | self.cache_device = cache_device 779 | self.states = {} 780 | self._next_pred_id = 0 781 | 782 | def new_prediction(self): 783 | """Create new prediction state and return its ID""" 784 | pred_id = self._next_pred_id 785 | self._next_pred_id += 1 786 | self.states[pred_id] = { 787 | 'previous_residual': None, 788 | 'accumulated_rel_l1_distance': 0, 789 | 'previous_modulated_input': None, 790 | 'skipped_steps': 0 791 | } 792 | return pred_id 793 | 794 | def update(self, pred_id, **kwargs): 795 | """Update state for specific prediction""" 796 | if pred_id not in self.states: 797 | return None 798 | for key, value in kwargs.items(): 799 | if isinstance(value, torch.Tensor): 800 | value = value.to(self.cache_device) 801 | self.states[pred_id][key] = value 802 | 803 | def get(self, pred_id): 804 | return self.states.get(pred_id, {}) 805 | 806 | def report(self): 807 | for pred_id in self.states: 808 | log.info(f"Prediction {pred_id}: {self.states[pred_id]}") 809 | 810 | def clear_prediction(self, pred_id): 811 | if pred_id in self.states: 812 | del self.states[pred_id] 813 | 814 | def clear_all(self): 815 | self.states.clear() 816 | self._next_pred_id = 0 817 | 818 | def relative_l1_distance(last_tensor, current_tensor): 819 | l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean() 820 | norm = torch.abs(last_tensor).mean() 821 | relative_l1_distance = l1_distance / norm 822 | return relative_l1_distance.to(torch.float32).to(current_tensor.device) 823 | 824 | def normalize_values(values): 825 | min_val = min(values) 826 | max_val = max(values) 827 | if max_val == min_val: 828 | return [0.0] * len(values) 829 | return [(x - min_val) / (max_val - min_val) for x in values] 830 | 831 | def rescale_differences(input_diffs, output_diffs): 832 | """Polynomial fitting between input and output differences""" 833 | poly_degree = 4 834 | if len(input_diffs) < 2: 835 | return input_diffs 836 | 837 | x = np.array([x.item() for x in input_diffs]) 838 | y = np.array([y.item() for y in output_diffs]) 839 | print("x ", x) 840 | print("y ", y) 841 | 842 | # Fit polynomial 843 | coeffs = np.polyfit(x, y, poly_degree) 844 | 845 | # Apply polynomial transformation 846 | return np.polyval(coeffs, x) -------------------------------------------------------------------------------- /wanvideo/modules/t5.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.t5.modeling_t5 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .tokenizers import HuggingfaceTokenizer 11 | 12 | __all__ = [ 13 | 'T5Model', 14 | 'T5Encoder', 15 | 'T5Decoder', 16 | 'T5EncoderModel', 17 | ] 18 | 19 | from accelerate import init_empty_weights 20 | from accelerate.utils import set_module_tensor_to_device 21 | 22 | def fp16_clamp(x): 23 | if x.dtype == torch.float16 and torch.isinf(x).any(): 24 | clamp = torch.finfo(x.dtype).max - 1000 25 | x = torch.clamp(x, min=-clamp, max=clamp) 26 | return x 27 | 28 | 29 | def init_weights(m): 30 | if isinstance(m, T5LayerNorm): 31 | nn.init.ones_(m.weight) 32 | elif isinstance(m, T5Model): 33 | nn.init.normal_(m.token_embedding.weight, std=1.0) 34 | elif isinstance(m, T5FeedForward): 35 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) 36 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) 37 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) 38 | elif isinstance(m, T5Attention): 39 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) 40 | nn.init.normal_(m.k.weight, std=m.dim**-0.5) 41 | nn.init.normal_(m.v.weight, std=m.dim**-0.5) 42 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) 43 | elif isinstance(m, T5RelativeEmbedding): 44 | nn.init.normal_( 45 | m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) 46 | 47 | 48 | class GELU(nn.Module): 49 | 50 | def forward(self, x): 51 | return 0.5 * x * (1.0 + torch.tanh( 52 | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 53 | 54 | 55 | class T5LayerNorm(nn.Module): 56 | 57 | def __init__(self, dim, eps=1e-6): 58 | super(T5LayerNorm, self).__init__() 59 | self.dim = dim 60 | self.eps = eps 61 | self.weight = nn.Parameter(torch.ones(dim)) 62 | 63 | def forward(self, x): 64 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + 65 | self.eps) 66 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 67 | x = x.type_as(self.weight) 68 | return self.weight * x 69 | 70 | 71 | class T5Attention(nn.Module): 72 | 73 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1): 74 | assert dim_attn % num_heads == 0 75 | super(T5Attention, self).__init__() 76 | self.dim = dim 77 | self.dim_attn = dim_attn 78 | self.num_heads = num_heads 79 | self.head_dim = dim_attn // num_heads 80 | 81 | # layers 82 | self.q = nn.Linear(dim, dim_attn, bias=False) 83 | self.k = nn.Linear(dim, dim_attn, bias=False) 84 | self.v = nn.Linear(dim, dim_attn, bias=False) 85 | self.o = nn.Linear(dim_attn, dim, bias=False) 86 | self.dropout = nn.Dropout(dropout) 87 | 88 | def forward(self, x, context=None, mask=None, pos_bias=None): 89 | """ 90 | x: [B, L1, C]. 91 | context: [B, L2, C] or None. 92 | mask: [B, L2] or [B, L1, L2] or None. 93 | """ 94 | # check inputs 95 | context = x if context is None else context 96 | b, n, c = x.size(0), self.num_heads, self.head_dim 97 | 98 | # compute query, key, value 99 | q = self.q(x).view(b, -1, n, c) 100 | k = self.k(context).view(b, -1, n, c) 101 | v = self.v(context).view(b, -1, n, c) 102 | 103 | # attention bias 104 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) 105 | if pos_bias is not None: 106 | attn_bias += pos_bias 107 | if mask is not None: 108 | assert mask.ndim in [2, 3] 109 | mask = mask.view(b, 1, 1, 110 | -1) if mask.ndim == 2 else mask.unsqueeze(1) 111 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) 112 | 113 | # compute attention (T5 does not use scaling) 114 | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias 115 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 116 | x = torch.einsum('bnij,bjnc->binc', attn, v) 117 | 118 | # output 119 | x = x.reshape(b, -1, n * c) 120 | x = self.o(x) 121 | x = self.dropout(x) 122 | return x 123 | 124 | 125 | class T5FeedForward(nn.Module): 126 | 127 | def __init__(self, dim, dim_ffn, dropout=0.1): 128 | super(T5FeedForward, self).__init__() 129 | self.dim = dim 130 | self.dim_ffn = dim_ffn 131 | 132 | # layers 133 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) 134 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False) 135 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False) 136 | self.dropout = nn.Dropout(dropout) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) * self.gate(x) 140 | x = self.dropout(x) 141 | x = self.fc2(x) 142 | x = self.dropout(x) 143 | return x 144 | 145 | 146 | class T5SelfAttention(nn.Module): 147 | 148 | def __init__(self, 149 | dim, 150 | dim_attn, 151 | dim_ffn, 152 | num_heads, 153 | num_buckets, 154 | shared_pos=True, 155 | dropout=0.1): 156 | super(T5SelfAttention, self).__init__() 157 | self.dim = dim 158 | self.dim_attn = dim_attn 159 | self.dim_ffn = dim_ffn 160 | self.num_heads = num_heads 161 | self.num_buckets = num_buckets 162 | self.shared_pos = shared_pos 163 | 164 | # layers 165 | self.norm1 = T5LayerNorm(dim) 166 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout) 167 | self.norm2 = T5LayerNorm(dim) 168 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 169 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 170 | num_buckets, num_heads, bidirectional=True) 171 | 172 | def forward(self, x, mask=None, pos_bias=None): 173 | e = pos_bias if self.shared_pos else self.pos_embedding( 174 | x.size(1), x.size(1)) 175 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) 176 | x = fp16_clamp(x + self.ffn(self.norm2(x))) 177 | return x 178 | 179 | 180 | class T5CrossAttention(nn.Module): 181 | 182 | def __init__(self, 183 | dim, 184 | dim_attn, 185 | dim_ffn, 186 | num_heads, 187 | num_buckets, 188 | shared_pos=True, 189 | dropout=0.1): 190 | super(T5CrossAttention, self).__init__() 191 | self.dim = dim 192 | self.dim_attn = dim_attn 193 | self.dim_ffn = dim_ffn 194 | self.num_heads = num_heads 195 | self.num_buckets = num_buckets 196 | self.shared_pos = shared_pos 197 | 198 | # layers 199 | self.norm1 = T5LayerNorm(dim) 200 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) 201 | self.norm2 = T5LayerNorm(dim) 202 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) 203 | self.norm3 = T5LayerNorm(dim) 204 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 205 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 206 | num_buckets, num_heads, bidirectional=False) 207 | 208 | def forward(self, 209 | x, 210 | mask=None, 211 | encoder_states=None, 212 | encoder_mask=None, 213 | pos_bias=None): 214 | e = pos_bias if self.shared_pos else self.pos_embedding( 215 | x.size(1), x.size(1)) 216 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) 217 | x = fp16_clamp(x + self.cross_attn( 218 | self.norm2(x), context=encoder_states, mask=encoder_mask)) 219 | x = fp16_clamp(x + self.ffn(self.norm3(x))) 220 | return x 221 | 222 | 223 | class T5RelativeEmbedding(nn.Module): 224 | 225 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): 226 | super(T5RelativeEmbedding, self).__init__() 227 | self.num_buckets = num_buckets 228 | self.num_heads = num_heads 229 | self.bidirectional = bidirectional 230 | self.max_dist = max_dist 231 | 232 | # layers 233 | self.embedding = nn.Embedding(num_buckets, num_heads) 234 | 235 | def forward(self, lq, lk): 236 | device = self.embedding.weight.device 237 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ 238 | # torch.arange(lq).unsqueeze(1).to(device) 239 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ 240 | torch.arange(lq, device=device).unsqueeze(1) 241 | rel_pos = self._relative_position_bucket(rel_pos) 242 | rel_pos_embeds = self.embedding(rel_pos) 243 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( 244 | 0) # [1, N, Lq, Lk] 245 | return rel_pos_embeds.contiguous() 246 | 247 | def _relative_position_bucket(self, rel_pos): 248 | # preprocess 249 | if self.bidirectional: 250 | num_buckets = self.num_buckets // 2 251 | rel_buckets = (rel_pos > 0).long() * num_buckets 252 | rel_pos = torch.abs(rel_pos) 253 | else: 254 | num_buckets = self.num_buckets 255 | rel_buckets = 0 256 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) 257 | 258 | # embeddings for small and large positions 259 | max_exact = num_buckets // 2 260 | rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / 261 | math.log(self.max_dist / max_exact) * 262 | (num_buckets - max_exact)).long() 263 | rel_pos_large = torch.min( 264 | rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) 265 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) 266 | return rel_buckets 267 | 268 | 269 | class T5Encoder(nn.Module): 270 | 271 | def __init__(self, 272 | vocab, 273 | dim, 274 | dim_attn, 275 | dim_ffn, 276 | num_heads, 277 | num_layers, 278 | num_buckets, 279 | shared_pos=True, 280 | dropout=0.1): 281 | super(T5Encoder, self).__init__() 282 | self.dim = dim 283 | self.dim_attn = dim_attn 284 | self.dim_ffn = dim_ffn 285 | self.num_heads = num_heads 286 | self.num_layers = num_layers 287 | self.num_buckets = num_buckets 288 | self.shared_pos = shared_pos 289 | 290 | # layers 291 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 292 | else nn.Embedding(vocab, dim) 293 | self.pos_embedding = T5RelativeEmbedding( 294 | num_buckets, num_heads, bidirectional=True) if shared_pos else None 295 | self.dropout = nn.Dropout(dropout) 296 | self.blocks = nn.ModuleList([ 297 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 298 | shared_pos, dropout) for _ in range(num_layers) 299 | ]) 300 | self.norm = T5LayerNorm(dim) 301 | 302 | # initialize weights 303 | self.apply(init_weights) 304 | 305 | def forward(self, ids, mask=None): 306 | x = self.token_embedding(ids) 307 | x = self.dropout(x) 308 | e = self.pos_embedding(x.size(1), 309 | x.size(1)) if self.shared_pos else None 310 | for block in self.blocks: 311 | x = block(x, mask, pos_bias=e) 312 | x = self.norm(x) 313 | x = self.dropout(x) 314 | return x 315 | 316 | 317 | class T5Decoder(nn.Module): 318 | 319 | def __init__(self, 320 | vocab, 321 | dim, 322 | dim_attn, 323 | dim_ffn, 324 | num_heads, 325 | num_layers, 326 | num_buckets, 327 | shared_pos=True, 328 | dropout=0.1): 329 | super(T5Decoder, self).__init__() 330 | self.dim = dim 331 | self.dim_attn = dim_attn 332 | self.dim_ffn = dim_ffn 333 | self.num_heads = num_heads 334 | self.num_layers = num_layers 335 | self.num_buckets = num_buckets 336 | self.shared_pos = shared_pos 337 | 338 | # layers 339 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 340 | else nn.Embedding(vocab, dim) 341 | self.pos_embedding = T5RelativeEmbedding( 342 | num_buckets, num_heads, bidirectional=False) if shared_pos else None 343 | self.dropout = nn.Dropout(dropout) 344 | self.blocks = nn.ModuleList([ 345 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 346 | shared_pos, dropout) for _ in range(num_layers) 347 | ]) 348 | self.norm = T5LayerNorm(dim) 349 | 350 | # initialize weights 351 | self.apply(init_weights) 352 | 353 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): 354 | b, s = ids.size() 355 | 356 | # causal mask 357 | if mask is None: 358 | mask = torch.tril(torch.ones(1, s, s).to(ids.device)) 359 | elif mask.ndim == 2: 360 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) 361 | 362 | # layers 363 | x = self.token_embedding(ids) 364 | x = self.dropout(x) 365 | e = self.pos_embedding(x.size(1), 366 | x.size(1)) if self.shared_pos else None 367 | for block in self.blocks: 368 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) 369 | x = self.norm(x) 370 | x = self.dropout(x) 371 | return x 372 | 373 | 374 | class T5Model(nn.Module): 375 | 376 | def __init__(self, 377 | vocab_size, 378 | dim, 379 | dim_attn, 380 | dim_ffn, 381 | num_heads, 382 | encoder_layers, 383 | decoder_layers, 384 | num_buckets, 385 | shared_pos=True, 386 | dropout=0.1): 387 | super(T5Model, self).__init__() 388 | self.vocab_size = vocab_size 389 | self.dim = dim 390 | self.dim_attn = dim_attn 391 | self.dim_ffn = dim_ffn 392 | self.num_heads = num_heads 393 | self.encoder_layers = encoder_layers 394 | self.decoder_layers = decoder_layers 395 | self.num_buckets = num_buckets 396 | 397 | # layers 398 | self.token_embedding = nn.Embedding(vocab_size, dim) 399 | self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, 400 | num_heads, encoder_layers, num_buckets, 401 | shared_pos, dropout) 402 | self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, 403 | num_heads, decoder_layers, num_buckets, 404 | shared_pos, dropout) 405 | self.head = nn.Linear(dim, vocab_size, bias=False) 406 | 407 | # initialize weights 408 | self.apply(init_weights) 409 | 410 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): 411 | x = self.encoder(encoder_ids, encoder_mask) 412 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) 413 | x = self.head(x) 414 | return x 415 | 416 | 417 | def _t5(name, 418 | encoder_only=False, 419 | decoder_only=False, 420 | return_tokenizer=False, 421 | tokenizer_kwargs={}, 422 | dtype=torch.float32, 423 | device='cpu', 424 | **kwargs): 425 | # sanity check 426 | assert not (encoder_only and decoder_only) 427 | 428 | # params 429 | if encoder_only: 430 | model_cls = T5Encoder 431 | kwargs['vocab'] = kwargs.pop('vocab_size') 432 | kwargs['num_layers'] = kwargs.pop('encoder_layers') 433 | _ = kwargs.pop('decoder_layers') 434 | elif decoder_only: 435 | model_cls = T5Decoder 436 | kwargs['vocab'] = kwargs.pop('vocab_size') 437 | kwargs['num_layers'] = kwargs.pop('decoder_layers') 438 | _ = kwargs.pop('encoder_layers') 439 | else: 440 | model_cls = T5Model 441 | 442 | # init model 443 | with torch.device(device): 444 | model = model_cls(**kwargs) 445 | 446 | # set device 447 | #model = model.to(dtype=dtype, device=device) 448 | 449 | # init tokenizer 450 | if return_tokenizer: 451 | from .tokenizers import HuggingfaceTokenizer 452 | tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) 453 | return model, tokenizer 454 | else: 455 | return model 456 | 457 | 458 | def umt5_xxl(**kwargs): 459 | cfg = dict( 460 | vocab_size=256384, 461 | dim=4096, 462 | dim_attn=4096, 463 | dim_ffn=10240, 464 | num_heads=64, 465 | encoder_layers=24, 466 | decoder_layers=24, 467 | num_buckets=32, 468 | shared_pos=False, 469 | dropout=0.1) 470 | cfg.update(**kwargs) 471 | return _t5('umt5-xxl', **cfg) 472 | 473 | 474 | class T5EncoderModel: 475 | 476 | def __init__( 477 | self, 478 | text_len, 479 | dtype=torch.bfloat16, 480 | device=torch.device('cuda'), 481 | state_dict=None, 482 | tokenizer_path=None, 483 | quantization="disabled", 484 | ): 485 | self.text_len = text_len 486 | self.dtype = dtype 487 | self.device = device 488 | self.tokenizer_path = tokenizer_path 489 | 490 | # init model 491 | with init_empty_weights(): 492 | model = umt5_xxl( 493 | encoder_only=True, 494 | return_tokenizer=False, 495 | dtype=dtype, 496 | device=device).eval().requires_grad_(False) 497 | 498 | if quantization == "fp8_e4m3fn": 499 | cast_dtype = torch.float8_e4m3fn 500 | else: 501 | cast_dtype = dtype 502 | 503 | params_to_keep = {'norm', 'pos_embedding', 'token_embedding'} 504 | for name, param in model.named_parameters(): 505 | dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype 506 | set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name]) 507 | del state_dict 508 | self.model = model 509 | self.tokenizer = HuggingfaceTokenizer( 510 | name=tokenizer_path, seq_len=text_len, clean='whitespace') 511 | 512 | def __call__(self, texts, device): 513 | ids, mask = self.tokenizer( 514 | texts, return_mask=True, add_special_tokens=True) 515 | ids = ids.to(device) 516 | mask = mask.to(device) 517 | seq_lens = mask.gt(0).sum(dim=1).long() 518 | context = self.model(ids, mask) 519 | return [u[:v] for u, v in zip(context, seq_lens)] 520 | -------------------------------------------------------------------------------- /wanvideo/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /wanvideo/modules/vae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from comfy.model_management import get_torch_device, get_autocast_device 9 | __all__ = [ 10 | 'WanVAE', 11 | ] 12 | 13 | CACHE_T = 2 14 | 15 | 16 | class CausalConv3d(nn.Conv3d): 17 | """ 18 | Causal 3d convolusion. 19 | """ 20 | 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self._padding = (self.padding[2], self.padding[2], self.padding[1], 24 | self.padding[1], 2 * self.padding[0], 0) 25 | self.padding = (0, 0, 0) 26 | 27 | def forward(self, x, cache_x=None): 28 | padding = list(self._padding) 29 | if cache_x is not None and self._padding[4] > 0: 30 | cache_x = cache_x.to(x.device) 31 | x = torch.cat([cache_x, x], dim=2) 32 | padding[4] -= cache_x.shape[2] 33 | x = F.pad(x, padding) 34 | 35 | return super().forward(x) 36 | 37 | 38 | class RMS_norm(nn.Module): 39 | 40 | def __init__(self, dim, channel_first=True, images=True, bias=False): 41 | super().__init__() 42 | broadcastable_dims = (1, 1, 1) if not images else (1, 1) 43 | shape = (dim, *broadcastable_dims) if channel_first else (dim,) 44 | 45 | self.channel_first = channel_first 46 | self.scale = dim**0.5 47 | self.gamma = nn.Parameter(torch.ones(shape)) 48 | self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. 49 | 50 | def forward(self, x): 51 | return F.normalize( 52 | x, dim=(1 if self.channel_first else 53 | -1)) * self.scale * self.gamma + self.bias 54 | 55 | 56 | class Upsample(nn.Upsample): 57 | 58 | def forward(self, x): 59 | """ 60 | Fix bfloat16 support for nearest neighbor interpolation. 61 | """ 62 | return super().forward(x.float()).type_as(x) 63 | 64 | 65 | class Resample(nn.Module): 66 | 67 | def __init__(self, dim, mode): 68 | assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', 69 | 'downsample3d') 70 | super().__init__() 71 | self.dim = dim 72 | self.mode = mode 73 | 74 | # layers 75 | if mode == 'upsample2d': 76 | self.resample = nn.Sequential( 77 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), 78 | nn.Conv2d(dim, dim // 2, 3, padding=1)) 79 | elif mode == 'upsample3d': 80 | self.resample = nn.Sequential( 81 | Upsample(scale_factor=(2., 2.), mode='nearest-exact'), 82 | nn.Conv2d(dim, dim // 2, 3, padding=1)) 83 | self.time_conv = CausalConv3d( 84 | dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) 85 | 86 | elif mode == 'downsample2d': 87 | self.resample = nn.Sequential( 88 | nn.ZeroPad2d((0, 1, 0, 1)), 89 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) 90 | elif mode == 'downsample3d': 91 | self.resample = nn.Sequential( 92 | nn.ZeroPad2d((0, 1, 0, 1)), 93 | nn.Conv2d(dim, dim, 3, stride=(2, 2))) 94 | self.time_conv = CausalConv3d( 95 | dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) 96 | 97 | else: 98 | self.resample = nn.Identity() 99 | 100 | def forward(self, x, feat_cache=None, feat_idx=[0]): 101 | b, c, t, h, w = x.size() 102 | if self.mode == 'upsample3d': 103 | if feat_cache is not None: 104 | idx = feat_idx[0] 105 | if feat_cache[idx] is None: 106 | feat_cache[idx] = 'Rep' 107 | feat_idx[0] += 1 108 | else: 109 | 110 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 111 | if cache_x.shape[2] < 2 and feat_cache[ 112 | idx] is not None and feat_cache[idx] != 'Rep': 113 | # cache last frame of last two chunk 114 | cache_x = torch.cat([ 115 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 116 | cache_x.device), cache_x 117 | ], 118 | dim=2) 119 | if cache_x.shape[2] < 2 and feat_cache[ 120 | idx] is not None and feat_cache[idx] == 'Rep': 121 | cache_x = torch.cat([ 122 | torch.zeros_like(cache_x).to(cache_x.device), 123 | cache_x 124 | ], 125 | dim=2) 126 | if feat_cache[idx] == 'Rep': 127 | x = self.time_conv(x) 128 | else: 129 | x = self.time_conv(x, feat_cache[idx]) 130 | feat_cache[idx] = cache_x 131 | feat_idx[0] += 1 132 | 133 | x = x.reshape(b, 2, c, t, h, w) 134 | x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 135 | 3) 136 | x = x.reshape(b, c, t * 2, h, w) 137 | t = x.shape[2] 138 | x = rearrange(x, 'b c t h w -> (b t) c h w') 139 | x = self.resample(x) 140 | x = rearrange(x, '(b t) c h w -> b c t h w', t=t) 141 | 142 | if self.mode == 'downsample3d': 143 | if feat_cache is not None: 144 | idx = feat_idx[0] 145 | if feat_cache[idx] is None: 146 | feat_cache[idx] = x.clone() 147 | feat_idx[0] += 1 148 | else: 149 | 150 | cache_x = x[:, :, -1:, :, :].clone() 151 | # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': 152 | # # cache last frame of last two chunk 153 | # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) 154 | 155 | x = self.time_conv( 156 | torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) 157 | feat_cache[idx] = cache_x 158 | feat_idx[0] += 1 159 | return x 160 | 161 | def init_weight(self, conv): 162 | conv_weight = conv.weight 163 | nn.init.zeros_(conv_weight) 164 | c1, c2, t, h, w = conv_weight.size() 165 | one_matrix = torch.eye(c1, c2) 166 | init_matrix = one_matrix 167 | nn.init.zeros_(conv_weight) 168 | #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 169 | conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 170 | conv.weight.data.copy_(conv_weight) 171 | nn.init.zeros_(conv.bias.data) 172 | 173 | def init_weight2(self, conv): 174 | conv_weight = conv.weight.data 175 | nn.init.zeros_(conv_weight) 176 | c1, c2, t, h, w = conv_weight.size() 177 | init_matrix = torch.eye(c1 // 2, c2) 178 | #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) 179 | conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix 180 | conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix 181 | conv.weight.data.copy_(conv_weight) 182 | nn.init.zeros_(conv.bias.data) 183 | 184 | 185 | class ResidualBlock(nn.Module): 186 | 187 | def __init__(self, in_dim, out_dim, dropout=0.0): 188 | super().__init__() 189 | self.in_dim = in_dim 190 | self.out_dim = out_dim 191 | 192 | # layers 193 | self.residual = nn.Sequential( 194 | RMS_norm(in_dim, images=False), nn.SiLU(), 195 | CausalConv3d(in_dim, out_dim, 3, padding=1), 196 | RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), 197 | CausalConv3d(out_dim, out_dim, 3, padding=1)) 198 | self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ 199 | if in_dim != out_dim else nn.Identity() 200 | 201 | def forward(self, x, feat_cache=None, feat_idx=[0]): 202 | h = self.shortcut(x) 203 | for layer in self.residual: 204 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 205 | idx = feat_idx[0] 206 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 207 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 208 | # cache last frame of last two chunk 209 | cache_x = torch.cat([ 210 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 211 | cache_x.device), cache_x 212 | ], 213 | dim=2) 214 | x = layer(x, feat_cache[idx]) 215 | feat_cache[idx] = cache_x 216 | feat_idx[0] += 1 217 | else: 218 | x = layer(x) 219 | return x + h 220 | 221 | 222 | class AttentionBlock(nn.Module): 223 | """ 224 | Causal self-attention with a single head. 225 | """ 226 | 227 | def __init__(self, dim): 228 | super().__init__() 229 | self.dim = dim 230 | 231 | # layers 232 | self.norm = RMS_norm(dim) 233 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) 234 | self.proj = nn.Conv2d(dim, dim, 1) 235 | 236 | # zero out the last layer params 237 | nn.init.zeros_(self.proj.weight) 238 | 239 | def forward(self, x): 240 | identity = x 241 | b, c, t, h, w = x.size() 242 | x = rearrange(x, 'b c t h w -> (b t) c h w') 243 | x = self.norm(x) 244 | # compute query, key, value 245 | q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, 246 | -1).permute(0, 1, 3, 247 | 2).contiguous().chunk( 248 | 3, dim=-1) 249 | 250 | # apply attention 251 | x = F.scaled_dot_product_attention( 252 | q, 253 | k, 254 | v, 255 | ) 256 | x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) 257 | 258 | # output 259 | x = self.proj(x) 260 | x = rearrange(x, '(b t) c h w-> b c t h w', t=t) 261 | return x + identity 262 | 263 | 264 | class Encoder3d(nn.Module): 265 | 266 | def __init__(self, 267 | dim=128, 268 | z_dim=4, 269 | dim_mult=[1, 2, 4, 4], 270 | num_res_blocks=2, 271 | attn_scales=[], 272 | temperal_downsample=[True, True, False], 273 | dropout=0.0): 274 | super().__init__() 275 | self.dim = dim 276 | self.z_dim = z_dim 277 | self.dim_mult = dim_mult 278 | self.num_res_blocks = num_res_blocks 279 | self.attn_scales = attn_scales 280 | self.temperal_downsample = temperal_downsample 281 | 282 | # dimensions 283 | dims = [dim * u for u in [1] + dim_mult] 284 | scale = 1.0 285 | 286 | # init block 287 | self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) 288 | 289 | # downsample blocks 290 | downsamples = [] 291 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 292 | # residual (+attention) blocks 293 | for _ in range(num_res_blocks): 294 | downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) 295 | if scale in attn_scales: 296 | downsamples.append(AttentionBlock(out_dim)) 297 | in_dim = out_dim 298 | 299 | # downsample block 300 | if i != len(dim_mult) - 1: 301 | mode = 'downsample3d' if temperal_downsample[ 302 | i] else 'downsample2d' 303 | downsamples.append(Resample(out_dim, mode=mode)) 304 | scale /= 2.0 305 | self.downsamples = nn.Sequential(*downsamples) 306 | 307 | # middle blocks 308 | self.middle = nn.Sequential( 309 | ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), 310 | ResidualBlock(out_dim, out_dim, dropout)) 311 | 312 | # output blocks 313 | self.head = nn.Sequential( 314 | RMS_norm(out_dim, images=False), nn.SiLU(), 315 | CausalConv3d(out_dim, z_dim, 3, padding=1)) 316 | 317 | def forward(self, x, feat_cache=None, feat_idx=[0]): 318 | if feat_cache is not None: 319 | idx = feat_idx[0] 320 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 321 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 322 | # cache last frame of last two chunk 323 | cache_x = torch.cat([ 324 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 325 | cache_x.device), cache_x 326 | ], 327 | dim=2) 328 | x = self.conv1(x, feat_cache[idx]) 329 | feat_cache[idx] = cache_x 330 | feat_idx[0] += 1 331 | else: 332 | x = self.conv1(x) 333 | 334 | ## downsamples 335 | for layer in self.downsamples: 336 | if feat_cache is not None: 337 | x = layer(x, feat_cache, feat_idx) 338 | else: 339 | x = layer(x) 340 | 341 | ## middle 342 | for layer in self.middle: 343 | if isinstance(layer, ResidualBlock) and feat_cache is not None: 344 | x = layer(x, feat_cache, feat_idx) 345 | else: 346 | x = layer(x) 347 | 348 | ## head 349 | for layer in self.head: 350 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 351 | idx = feat_idx[0] 352 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 353 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 354 | # cache last frame of last two chunk 355 | cache_x = torch.cat([ 356 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 357 | cache_x.device), cache_x 358 | ], 359 | dim=2) 360 | x = layer(x, feat_cache[idx]) 361 | feat_cache[idx] = cache_x 362 | feat_idx[0] += 1 363 | else: 364 | x = layer(x) 365 | return x 366 | 367 | 368 | class Decoder3d(nn.Module): 369 | 370 | def __init__(self, 371 | dim=128, 372 | z_dim=4, 373 | dim_mult=[1, 2, 4, 4], 374 | num_res_blocks=2, 375 | attn_scales=[], 376 | temperal_upsample=[False, True, True], 377 | dropout=0.0): 378 | super().__init__() 379 | self.dim = dim 380 | self.z_dim = z_dim 381 | self.dim_mult = dim_mult 382 | self.num_res_blocks = num_res_blocks 383 | self.attn_scales = attn_scales 384 | self.temperal_upsample = temperal_upsample 385 | 386 | # dimensions 387 | dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] 388 | scale = 1.0 / 2**(len(dim_mult) - 2) 389 | 390 | # init block 391 | self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) 392 | 393 | # middle blocks 394 | self.middle = nn.Sequential( 395 | ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), 396 | ResidualBlock(dims[0], dims[0], dropout)) 397 | 398 | # upsample blocks 399 | upsamples = [] 400 | for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): 401 | # residual (+attention) blocks 402 | if i == 1 or i == 2 or i == 3: 403 | in_dim = in_dim // 2 404 | for _ in range(num_res_blocks + 1): 405 | upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) 406 | if scale in attn_scales: 407 | upsamples.append(AttentionBlock(out_dim)) 408 | in_dim = out_dim 409 | 410 | # upsample block 411 | if i != len(dim_mult) - 1: 412 | mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' 413 | upsamples.append(Resample(out_dim, mode=mode)) 414 | scale *= 2.0 415 | self.upsamples = nn.Sequential(*upsamples) 416 | 417 | # output blocks 418 | self.head = nn.Sequential( 419 | RMS_norm(out_dim, images=False), nn.SiLU(), 420 | CausalConv3d(out_dim, 3, 3, padding=1)) 421 | 422 | def forward(self, x, feat_cache=None, feat_idx=[0]): 423 | ## conv1 424 | if feat_cache is not None: 425 | idx = feat_idx[0] 426 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 427 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 428 | # cache last frame of last two chunk 429 | cache_x = torch.cat([ 430 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 431 | cache_x.device), cache_x 432 | ], 433 | dim=2) 434 | x = self.conv1(x, feat_cache[idx]) 435 | feat_cache[idx] = cache_x 436 | feat_idx[0] += 1 437 | else: 438 | x = self.conv1(x) 439 | 440 | ## middle 441 | for layer in self.middle: 442 | if isinstance(layer, ResidualBlock) and feat_cache is not None: 443 | x = layer(x, feat_cache, feat_idx) 444 | else: 445 | x = layer(x) 446 | 447 | ## upsamples 448 | for layer in self.upsamples: 449 | if feat_cache is not None: 450 | x = layer(x, feat_cache, feat_idx) 451 | else: 452 | x = layer(x) 453 | 454 | ## head 455 | for layer in self.head: 456 | if isinstance(layer, CausalConv3d) and feat_cache is not None: 457 | idx = feat_idx[0] 458 | cache_x = x[:, :, -CACHE_T:, :, :].clone() 459 | if cache_x.shape[2] < 2 and feat_cache[idx] is not None: 460 | # cache last frame of last two chunk 461 | cache_x = torch.cat([ 462 | feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( 463 | cache_x.device), cache_x 464 | ], 465 | dim=2) 466 | x = layer(x, feat_cache[idx]) 467 | feat_cache[idx] = cache_x 468 | feat_idx[0] += 1 469 | else: 470 | x = layer(x) 471 | return x 472 | 473 | 474 | def count_conv3d(model): 475 | count = 0 476 | for m in model.modules(): 477 | if isinstance(m, CausalConv3d): 478 | count += 1 479 | return count 480 | 481 | 482 | class WanVAE_(nn.Module): 483 | 484 | def __init__(self, 485 | dim=128, 486 | z_dim=4, 487 | dim_mult=[1, 2, 4, 4], 488 | num_res_blocks=2, 489 | attn_scales=[], 490 | temperal_downsample=[True, True, False], 491 | dropout=0.0): 492 | super().__init__() 493 | self.dim = dim 494 | self.z_dim = z_dim 495 | self.dim_mult = dim_mult 496 | self.num_res_blocks = num_res_blocks 497 | self.attn_scales = attn_scales 498 | self.temperal_downsample = temperal_downsample 499 | self.temperal_upsample = temperal_downsample[::-1] 500 | 501 | # modules 502 | self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, 503 | attn_scales, self.temperal_downsample, dropout) 504 | self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) 505 | self.conv2 = CausalConv3d(z_dim, z_dim, 1) 506 | self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, 507 | attn_scales, self.temperal_upsample, dropout) 508 | 509 | def forward(self, x): 510 | mu, log_var = self.encode(x) 511 | z = self.reparameterize(mu, log_var) 512 | x_recon = self.decode(z) 513 | return x_recon, mu, log_var 514 | 515 | def encode(self, x, scale): 516 | self.clear_cache() 517 | ## cache 518 | t = x.shape[2] 519 | iter_ = 1 + (t - 1) // 4 520 | ## 对encode输入的x,按时间拆分为1、4、4、4.... 521 | for i in range(iter_): 522 | self._enc_conv_idx = [0] 523 | if i == 0: 524 | out = self.encoder( 525 | x[:, :, :1, :, :], 526 | feat_cache=self._enc_feat_map, 527 | feat_idx=self._enc_conv_idx) 528 | else: 529 | out_ = self.encoder( 530 | x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], 531 | feat_cache=self._enc_feat_map, 532 | feat_idx=self._enc_conv_idx) 533 | out = torch.cat([out, out_], 2) 534 | mu, log_var = self.conv1(out).chunk(2, dim=1) 535 | print(mu[:,:,0,:,:]) 536 | if isinstance(scale[0], torch.Tensor): 537 | mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( 538 | 1, self.z_dim, 1, 1, 1) 539 | else: 540 | mu = (mu - scale[0]) * scale[1] 541 | self.clear_cache() 542 | return mu 543 | 544 | def decode(self, z, scale): 545 | self.clear_cache() 546 | # z: [b,c,t,h,w] 547 | if isinstance(scale[0], torch.Tensor): 548 | z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( 549 | 1, self.z_dim, 1, 1, 1) 550 | else: 551 | z = z / scale[1] + scale[0] 552 | iter_ = z.shape[2] 553 | x = self.conv2(z) 554 | for i in range(iter_): 555 | self._conv_idx = [0] 556 | if i == 0: 557 | out = self.decoder( 558 | x[:, :, i:i + 1, :, :], 559 | feat_cache=self._feat_map, 560 | feat_idx=self._conv_idx) 561 | else: 562 | out_ = self.decoder( 563 | x[:, :, i:i + 1, :, :], 564 | feat_cache=self._feat_map, 565 | feat_idx=self._conv_idx) 566 | out = torch.cat([out, out_], 2) 567 | self.clear_cache() 568 | return out 569 | 570 | def reparameterize(self, mu, log_var): 571 | std = torch.exp(0.5 * log_var) 572 | eps = torch.randn_like(std) 573 | return eps * std + mu 574 | 575 | def sample(self, imgs, deterministic=False): 576 | mu, log_var = self.encode(imgs) 577 | if deterministic: 578 | return mu 579 | std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) 580 | return mu + std * torch.randn_like(std) 581 | 582 | def clear_cache(self): 583 | self._conv_num = count_conv3d(self.decoder) 584 | self._conv_idx = [0] 585 | self._feat_map = [None] * self._conv_num 586 | #cache encode 587 | self._enc_conv_num = count_conv3d(self.encoder) 588 | self._enc_conv_idx = [0] 589 | self._enc_feat_map = [None] * self._enc_conv_num 590 | 591 | 592 | def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): 593 | """ 594 | Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. 595 | """ 596 | # params 597 | cfg = dict( 598 | dim=96, 599 | z_dim=z_dim, 600 | dim_mult=[1, 2, 4, 4], 601 | num_res_blocks=2, 602 | attn_scales=[], 603 | temperal_downsample=[False, True, True], 604 | dropout=0.0) 605 | cfg.update(**kwargs) 606 | 607 | # init model 608 | with torch.device('meta'): 609 | model = WanVAE_(**cfg) 610 | 611 | # load checkpoint 612 | logging.info(f'loading {pretrained_path}') 613 | model.load_state_dict( 614 | torch.load(pretrained_path, map_location=device), assign=True) 615 | 616 | return model 617 | 618 | 619 | class WanVAE: 620 | 621 | def __init__(self, 622 | z_dim=16, 623 | vae_pth='cache/vae_step_411000.pth', 624 | dtype=torch.float, 625 | device="cuda"): 626 | self.dtype = dtype 627 | self.device = device 628 | 629 | mean = [ 630 | -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 631 | 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 632 | ] 633 | std = [ 634 | 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 635 | 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 636 | ] 637 | self.mean = torch.tensor(mean, dtype=dtype, device=device) 638 | self.std = torch.tensor(std, dtype=dtype, device=device) 639 | self.scale = [self.mean, 1.0 / self.std] 640 | 641 | # init model 642 | self.model = _video_vae( 643 | pretrained_path=vae_pth, 644 | z_dim=z_dim, 645 | ).eval().requires_grad_(False).to(device) 646 | 647 | def encode(self, videos): 648 | """ 649 | videos: A list of videos each with shape [C, T, H, W]. 650 | """ 651 | with torch.autocast(device_type=get_autocast_device(get_torch_device()), enabled=False): 652 | return [ 653 | self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) 654 | for u in videos 655 | ] 656 | 657 | def decode(self, zs): 658 | with torch.autocast(device_type=get_autocast_device(get_torch_device()), enabled=False): 659 | return [ 660 | self.model.decode(u.unsqueeze(0), 661 | self.scale).float().clamp_(-1, 1).squeeze(0) 662 | for u in zs 663 | ] 664 | -------------------------------------------------------------------------------- /wanvideo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, 2 | retrieve_timesteps) 3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 4 | 5 | __all__ = [ 6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' 8 | ] 9 | -------------------------------------------------------------------------------- /wanvideo/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/utils/__pycache__/fm_solvers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/utils/__pycache__/fm_solvers.cpython-310.pyc -------------------------------------------------------------------------------- /wanvideo/utils/__pycache__/fm_solvers_unipc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raindrop313/ComfyUI-WanVideoStartEndFrames/340b4f071bccc54ccf69a22864050ff44a0ff8fb/wanvideo/utils/__pycache__/fm_solvers_unipc.cpython-310.pyc --------------------------------------------------------------------------------