├── LICENSE ├── README.md ├── __init__.py ├── modules ├── wan_blocks.py └── wan_model.py ├── nodes ├── modify_wan_model_node.py ├── wan_flowedit_nodes.py └── wan_model_pred_nodes.py └── utils └── debug_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-MagicWan 2 | Implementing FlowEdit, maybe other inversion techniques for the Wan video generation model 3 | 4 | Borrows very heavily from [HunyuanLoom](https://github.com/logtd/ComfyUI-HunyuanLoom) which implements [FlowEdit](https://github.com/fallenshock/FlowEdit). 5 | 6 | WIP, very experimental, and probably bugged in various ways. 7 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes.wan_flowedit_nodes import WanFlowEditGuiderNode, WanFlowEditGuiderAdvNode, WanFlowEditGuiderCFGNode, WanFlowEditGuiderCFGAdvNode, WanFlowEditSamplerNode 2 | from .nodes.modify_wan_model_node import ConfigureModifiedWanNode 3 | from .nodes.wan_model_pred_nodes import WanInverseModelSamplingPredNode, WanReverseModelSamplingPredNode 4 | 5 | NODE_CLASS_MAPPINGS = { 6 | "WanFlowEditGuider": WanFlowEditGuiderNode, 7 | "WanFlowEditGuiderAdv": WanFlowEditGuiderAdvNode, 8 | "WanFlowEditGuiderCFG": WanFlowEditGuiderCFGNode, 9 | "WanFlowEditGuiderCFGAdv": WanFlowEditGuiderCFGAdvNode, 10 | "WanFlowEditSampler": WanFlowEditSamplerNode, 11 | "ConfigureModifiedWan": ConfigureModifiedWanNode, 12 | "WanInverseModelSamplingPred": WanInverseModelSamplingPredNode, 13 | "WanReverseModelSamplingPred": WanReverseModelSamplingPredNode, 14 | } 15 | 16 | NODE_DISPLAY_NAME_MAPPINGS = { 17 | "WanFlowEditGuider": "Wan FlowEdit Guider", 18 | "WanFlowEditGuiderAdv": "Wan FlowEdit Guider Advanced", 19 | "WanFlowEditGuiderCFG": "Wan FlowEdit Guider CFG", 20 | "WanFlowEditGuiderCFGAdv": "Wan FlowEdit Guider CFG Advanced", 21 | "WanFlowEditSampler": "Wan FlowEdit Sampler", 22 | "ConfigureModifiedWan": "Configure Modified Wan Model", 23 | "WanInverseModelSamplingPred": "Wan Inverse Model Pred", 24 | "WanReverseModelSamplingPred": "Wan Reverse Model Pred", 25 | } -------------------------------------------------------------------------------- /modules/wan_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from comfy.ldm.flux.math import apply_rope 5 | from comfy.ldm.wan.model import WanAttentionBlock, WanT2VCrossAttention, WanI2VCrossAttention 6 | 7 | 8 | class ModifiedWanAttentionBlock(WanAttentionBlock): 9 | """ 10 | Modified Wan Attention Block that supports FlowEdit operations 11 | """ 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.idx = 0 # Will be set by inject_blocks 15 | 16 | def forward( 17 | self, 18 | x, 19 | e, 20 | freqs, 21 | context, 22 | transformer_options={}, 23 | ): 24 | # Get original forward implementation 25 | original_output = super().forward(x, e, freqs, context) 26 | 27 | # If we're in FlowEdit mode and need special handling 28 | latent_type = transformer_options.get('transformer_options', {}).get('latent_type', None) 29 | 30 | # Apply any custom modifications for FlowEdit if needed 31 | if latent_type is not None: 32 | # Could add specialized handling here if needed 33 | pass 34 | 35 | return original_output 36 | 37 | 38 | class ModifiedWanT2VCrossAttention(WanT2VCrossAttention): 39 | """ 40 | Modified T2V cross attention for FlowEdit support 41 | """ 42 | def forward(self, x, context, transformer_options={}): 43 | # Get original implementation 44 | return super().forward(x, context) 45 | 46 | 47 | class ModifiedWanI2VCrossAttention(WanI2VCrossAttention): 48 | """ 49 | Modified I2V cross attention for FlowEdit support 50 | """ 51 | def forward(self, x, context, transformer_options={}): 52 | # Get original implementation 53 | return super().forward(x, context) 54 | 55 | 56 | def inject_blocks(diffusion_model): 57 | """ 58 | Replace all attention blocks with our modified versions 59 | """ 60 | # Replace the attention blocks 61 | for i, block in enumerate(diffusion_model.blocks): 62 | block.__class__ = ModifiedWanAttentionBlock 63 | block.idx = i 64 | 65 | # Replace the cross attention mechanisms 66 | if hasattr(block, 'cross_attn'): 67 | if block.cross_attn.__class__ == WanT2VCrossAttention: 68 | block.cross_attn.__class__ = ModifiedWanT2VCrossAttention 69 | elif block.cross_attn.__class__ == WanI2VCrossAttention: 70 | block.cross_attn.__class__ = ModifiedWanI2VCrossAttention 71 | 72 | return diffusion_model 73 | -------------------------------------------------------------------------------- /modules/wan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat 3 | 4 | from comfy.ldm.wan.model import WanModel, sinusoidal_embedding_1d 5 | 6 | 7 | class ModifiedWanModel(WanModel): 8 | def forward_orig( 9 | self, 10 | x, 11 | t, 12 | context, 13 | clip_fea=None, 14 | freqs=None, 15 | transformer_options={}, 16 | ): 17 | """ 18 | Modified forward pass to support FlowEdit's dual conditioning 19 | """ 20 | # Store original shape for unpatchifying 21 | original_shape = list(x.shape) 22 | transformer_options['original_shape'] = original_shape 23 | 24 | # Embeddings 25 | x = self.patch_embedding(x.float()).to(x.dtype) 26 | grid_sizes = x.shape[2:] 27 | x = x.flatten(2).transpose(1, 2) 28 | 29 | # Time embeddings 30 | e = self.time_embedding( 31 | sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x.dtype)) 32 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 33 | 34 | # Context processing 35 | context = self.text_embedding(context) 36 | 37 | # Handle clip features for I2V if provided 38 | if clip_fea is not None and self.img_emb is not None: 39 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 40 | context = torch.concat([context_clip, context], dim=1) 41 | 42 | # Store context size for attention blocks 43 | if 'txt_size' not in transformer_options: 44 | transformer_options['txt_size'] = context.shape[1] 45 | if clip_fea is not None and self.img_emb is not None: 46 | transformer_options['txt_size'] = context.shape[1] - 257 # Account for clip tokens 47 | 48 | # Process through attention blocks 49 | kwargs = dict( 50 | e=e0, 51 | freqs=freqs, 52 | context=context, 53 | transformer_options=transformer_options) 54 | 55 | for block in self.blocks: 56 | x = block(x, **kwargs) 57 | 58 | # Final head 59 | x = self.head(x, e) 60 | 61 | # Unpatchify 62 | x = self.unpatchify(x, grid_sizes) 63 | 64 | # Crop to original dimensions if needed 65 | if list(x.shape[2:]) != original_shape[2:]: 66 | x = x[:, :, :original_shape[2], :original_shape[3], :original_shape[4]] 67 | 68 | return x 69 | 70 | def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): 71 | bs, c, t, h, w = x.shape 72 | 73 | # Handle padding to patch size if needed 74 | from comfy.ldm.common_dit import pad_to_patch_size 75 | x = pad_to_patch_size(x, self.patch_size) 76 | 77 | # Calculate positional embeddings 78 | patch_size = self.patch_size 79 | t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) 80 | h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) 81 | w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) 82 | img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) 83 | img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) 84 | img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) 85 | img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) 86 | img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) 87 | 88 | # Calculate rope frequencies 89 | freqs = self.rope_embedder(img_ids).movedim(1, 2) 90 | 91 | # Regional conditioning support (similar to HunyuanVideo) 92 | regional_conditioning = transformer_options.get('patches', {}).get('regional_conditioning', None) 93 | if regional_conditioning is not None: 94 | context = regional_conditioning[0](context, transformer_options) 95 | 96 | return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options) 97 | 98 | 99 | def inject_model(diffusion_model): 100 | """ 101 | Replace the model class with our modified version 102 | """ 103 | diffusion_model.__class__ = ModifiedWanModel 104 | return diffusion_model 105 | -------------------------------------------------------------------------------- /nodes/modify_wan_model_node.py: -------------------------------------------------------------------------------- 1 | from ..modules.wan_model import inject_model 2 | from ..modules.wan_blocks import inject_blocks 3 | 4 | 5 | class ConfigureModifiedWanNode: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | return {"required": { 9 | "model": ("MODEL",), 10 | }} 11 | RETURN_TYPES = ("MODEL",) 12 | 13 | CATEGORY = "magicwan" 14 | FUNCTION = "apply" 15 | 16 | def apply(self, model): 17 | # Inject modified model and block classes 18 | inject_model(model.model.diffusion_model) 19 | inject_blocks(model.model.diffusion_model) 20 | return (model,) 21 | -------------------------------------------------------------------------------- /nodes/wan_flowedit_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import trange 3 | 4 | from comfy.samplers import KSAMPLER, CFGGuider, sampling_function 5 | 6 | 7 | class FlowEditGuider(CFGGuider): 8 | def __init__(self, model_patcher): 9 | super().__init__(model_patcher) 10 | self.cfgs = {} 11 | self.num_repeats = 1 12 | 13 | def set_conds(self, **kwargs): 14 | self.inner_set_conds(kwargs) 15 | 16 | def set_cfgs(self, **kwargs): 17 | self.cfgs = {**kwargs} 18 | 19 | def set_num_repeats(self, num_repeats): 20 | self.num_repeats = num_repeats 21 | 22 | def predict_noise(self, x, timestep, model_options={}, seed=None): 23 | latent_type = model_options['transformer_options']['latent_type'] 24 | positive = self.conds.get(f'{latent_type}_positive', None) 25 | negative = self.conds.get(f'{latent_type}_negative', None) 26 | cfg = self.cfgs.get(latent_type, self.cfg) 27 | 28 | if self.num_repeats == 1: 29 | return sampling_function(self.inner_model, x, timestep, negative, positive, cfg, model_options=model_options, seed=seed) 30 | 31 | # Multiple samples case for more stable sampling 32 | predictions = None 33 | for i in range(self.num_repeats): 34 | current_seed = None if seed is None else seed + i 35 | current_pred = sampling_function( 36 | self.inner_model, x, timestep, negative, positive, cfg, 37 | model_options=model_options, seed=current_seed 38 | ) 39 | if predictions is None: 40 | predictions = current_pred 41 | else: 42 | predictions += current_pred 43 | 44 | return predictions / self.num_repeats 45 | 46 | 47 | # Basic node for model compatibility 48 | class WanFlowEditGuiderNode: 49 | @classmethod 50 | def INPUT_TYPES(s): 51 | return {"required": 52 | { 53 | "model": ("MODEL",), 54 | "source_cond": ("CONDITIONING", ), 55 | "target_cond": ("CONDITIONING", ), 56 | } 57 | } 58 | 59 | RETURN_TYPES = ("GUIDER",) 60 | FUNCTION = "get_guider" 61 | CATEGORY = "magicwan" 62 | 63 | def get_guider(self, model, source_cond, target_cond): 64 | guider = FlowEditGuider(model) 65 | guider.set_conds(source_positive=source_cond, target_positive=target_cond) 66 | guider.set_cfg(1.0) 67 | return (guider,) 68 | 69 | class WanFlowEditGuiderAdvNode: 70 | @classmethod 71 | def INPUT_TYPES(s): 72 | return {"required": 73 | { 74 | "model": ("MODEL",), 75 | "source_cond": ("CONDITIONING", ), 76 | "target_cond": ("CONDITIONING", ), 77 | "num_repeats": ("INT", {"default": 1, "min": 1, "max": 10}), 78 | } 79 | } 80 | 81 | RETURN_TYPES = ("GUIDER",) 82 | FUNCTION = "get_guider" 83 | CATEGORY = "magicwan" 84 | 85 | def get_guider(self, model, source_cond, target_cond, num_repeats): 86 | guider = FlowEditGuider(model) 87 | guider.set_conds(source_positive=source_cond, target_positive=target_cond) 88 | guider.set_cfg(1.0) 89 | guider.set_num_repeats(num_repeats) 90 | return (guider,) 91 | 92 | 93 | class WanFlowEditGuiderCFGNode: 94 | @classmethod 95 | def INPUT_TYPES(s): 96 | return {"required": 97 | { 98 | "model": ("MODEL",), 99 | "source_cond": ("CONDITIONING", ), 100 | "source_uncond": ("CONDITIONING", ), 101 | "target_cond": ("CONDITIONING", ), 102 | "target_uncond": ("CONDITIONING", ), 103 | "source_cfg": ("FLOAT", {"default": 2, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), 104 | "target_cfg": ("FLOAT", {"default": 4.5, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), 105 | } 106 | } 107 | 108 | RETURN_TYPES = ("GUIDER",) 109 | FUNCTION = "get_guider" 110 | CATEGORY = "magicwan" 111 | 112 | def get_guider(self, model, source_cond, source_uncond, target_cond, target_uncond, source_cfg, target_cfg): 113 | guider = FlowEditGuider(model) 114 | guider.set_conds(source_positive=source_cond, source_negative=source_uncond, 115 | target_positive=target_cond, target_negative=target_uncond) 116 | guider.set_cfgs(source=source_cfg, target=target_cfg) 117 | return (guider,) 118 | 119 | 120 | class WanFlowEditGuiderCFGAdvNode: 121 | @classmethod 122 | def INPUT_TYPES(s): 123 | return {"required": 124 | { 125 | "model": ("MODEL",), 126 | "source_cond": ("CONDITIONING", ), 127 | "source_uncond": ("CONDITIONING", ), 128 | "target_cond": ("CONDITIONING", ), 129 | "target_uncond": ("CONDITIONING", ), 130 | "source_cfg": ("FLOAT", {"default": 2, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), 131 | "target_cfg": ("FLOAT", {"default": 4.5, "min": 0, "max": 0xffffffffffffffff, "step": 0.01 }), 132 | "num_repeats": ("INT", {"default": 1, "min": 1, "max": 10}), 133 | } 134 | } 135 | 136 | RETURN_TYPES = ("GUIDER",) 137 | FUNCTION = "get_guider" 138 | CATEGORY = "magicwan" 139 | 140 | def get_guider(self, model, source_cond, source_uncond, target_cond, target_uncond, source_cfg, target_cfg, num_repeats): 141 | guider = FlowEditGuider(model) 142 | guider.set_conds(source_positive=source_cond, source_negative=source_uncond, 143 | target_positive=target_cond, target_negative=target_uncond) 144 | guider.set_cfgs(source=source_cfg, target=target_cfg) 145 | guider.set_num_repeats(num_repeats) 146 | return (guider,) 147 | 148 | def get_flowedit_sample(skip_steps, refine_steps, generator): 149 | @torch.no_grad() 150 | def flowedit_sample(model, x_init, sigmas, extra_args=None, callback=None, disable=None): 151 | extra_args = {} if extra_args is None else extra_args 152 | 153 | model_options = extra_args.get('model_options', {}) 154 | transformer_options = model_options.get('transformer_options', {}) 155 | transformer_options = {**transformer_options} 156 | model_options['transformer_options'] = transformer_options 157 | extra_args['model_options'] = model_options 158 | 159 | source_extra_args = {**extra_args, 'model_options': { 'transformer_options': { 'latent_type': 'source'} }} 160 | 161 | sigmas = sigmas[skip_steps:] 162 | 163 | x_tgt = x_init.clone() 164 | N = len(sigmas)-1 165 | s_in = x_init.new_ones([x_init.shape[0]]) 166 | noise_mask = extra_args.get('denoise_mask', None) 167 | if noise_mask is None: 168 | noise_mask = torch.ones_like(x_init) 169 | else: 170 | extra_args['denoise_mask'] = None 171 | source_extra_args['denoise_mask'] = None 172 | 173 | for i in trange(N, disable=disable): 174 | sigma = sigmas[i] 175 | noise = torch.randn(x_init.shape, generator=generator).to(x_init.device) 176 | 177 | zt_src = (1-sigma)*x_init + sigma*noise 178 | 179 | if i < N-refine_steps: 180 | zt_tgt = x_tgt + zt_src - x_init 181 | vt_src = model(zt_src, sigma*s_in, **source_extra_args) 182 | else: 183 | if i == N-refine_steps: 184 | zt_tgt = x_tgt + (zt_src - x_init) 185 | x_tgt = x_tgt + (zt_src - x_init) * noise_mask 186 | else: 187 | zt_tgt = x_tgt * (noise_mask) + (1-noise_mask) * ( (1-sigma)*x_tgt + sigma*noise ) 188 | vt_src = 0 189 | 190 | transformer_options['latent_type'] = 'target' 191 | vt_tgt = model(zt_tgt, sigma*s_in, **extra_args) 192 | 193 | v_delta = vt_tgt - vt_src 194 | x_tgt += (sigmas[i+1] - sigmas[i]) * v_delta * noise_mask 195 | 196 | if callback is not None: 197 | callback({'x': x_tgt, 'denoised': x_tgt, 'i': i+skip_steps, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]}) 198 | 199 | return x_tgt 200 | 201 | return flowedit_sample 202 | 203 | 204 | class WanFlowEditSamplerNode: 205 | @classmethod 206 | def INPUT_TYPES(s): 207 | return {"required": { 208 | "skip_steps": ("INT", {"default": 4, "min": 0, "max": 0xffffffffffffffff }), 209 | "drift_steps": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), 210 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }), 211 | }, "optional": { 212 | }} 213 | RETURN_TYPES = ("SAMPLER",) 214 | FUNCTION = "build" 215 | 216 | CATEGORY = "magicwan" 217 | 218 | def build(self, skip_steps, drift_steps, seed): 219 | generator = torch.manual_seed(seed) 220 | sampler = KSAMPLER(get_flowedit_sample(skip_steps, drift_steps, generator)) 221 | return (sampler, ) 222 | -------------------------------------------------------------------------------- /nodes/wan_model_pred_nodes.py: -------------------------------------------------------------------------------- 1 | import comfy.sd 2 | import comfy.model_sampling 3 | import comfy.latent_formats 4 | import torch 5 | import nodes 6 | 7 | 8 | class WanInverseModelSamplingPredNode: 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return {"required": { 12 | "model": ("MODEL",), 13 | "shift": ("FLOAT", {"default": 7, "min": 0.0, "max": 100.0, "step":0.01}), 14 | }} 15 | 16 | RETURN_TYPES = ("MODEL",) 17 | FUNCTION = "patch" 18 | 19 | CATEGORY = "magicwan" 20 | 21 | def patch(self, model, shift): 22 | m = model.clone() 23 | 24 | # Use the same InverseCONST class as for HunyuanVideo 25 | class InverseCONST: 26 | def calculate_input(self, sigma, noise): 27 | return noise 28 | 29 | def calculate_denoised(self, sigma, model_output, model_input): 30 | sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) 31 | return model_output 32 | 33 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 34 | return latent_image 35 | 36 | def inverse_noise_scaling(self, sigma, latent): 37 | return latent 38 | 39 | # Sampling base class depends on the model's sampling method 40 | sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow 41 | sampling_type = InverseCONST 42 | 43 | class ModelSamplingAdvanced(sampling_base, sampling_type): 44 | pass 45 | 46 | model_sampling = ModelSamplingAdvanced(model.model.model_config) 47 | model_sampling.set_parameters(shift=shift, multiplier=1000) 48 | m.add_object_patch("model_sampling", model_sampling) 49 | return (m, ) 50 | 51 | 52 | class WanReverseModelSamplingPredNode: 53 | @classmethod 54 | def INPUT_TYPES(s): 55 | return {"required": { 56 | "model": ("MODEL",), 57 | "shift": ("FLOAT", {"default": 7, "min": 0.0, "max": 100.0, "step":0.01}), 58 | }} 59 | 60 | RETURN_TYPES = ("MODEL",) 61 | FUNCTION = "patch" 62 | 63 | CATEGORY = "magicwan" 64 | 65 | def patch(self, model, shift): 66 | m = model.clone() 67 | 68 | # Use the same ReverseCONST class as for HunyuanVideo 69 | class ReverseCONST: 70 | def calculate_input(self, sigma, noise): 71 | return noise 72 | 73 | def calculate_denoised(self, sigma, model_output, model_input): 74 | return model_output # model_input - model_output * sigma 75 | 76 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 77 | return latent_image 78 | 79 | def inverse_noise_scaling(self, sigma, latent): 80 | return latent / (1.0 - sigma) 81 | 82 | # Sampling base class depends on the model's sampling method 83 | sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow 84 | sampling_type = ReverseCONST 85 | 86 | class ModelSamplingAdvanced(sampling_base, sampling_type): 87 | pass 88 | 89 | model_sampling = ModelSamplingAdvanced(model.model.model_config) 90 | model_sampling.set_parameters(shift=shift, multiplier=1000) 91 | m.add_object_patch("model_sampling", model_sampling) 92 | return (m, ) 93 | -------------------------------------------------------------------------------- /utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def debug_enabled(): 5 | """Check if debugging is enabled via environment variable""" 6 | return os.environ.get("MAGICWAN_DEBUG", "0") == "1" 7 | 8 | def debug_print(*args, **kwargs): 9 | """Print debug messages if debugging is enabled""" 10 | if debug_enabled(): 11 | print("[MagicWan Debug]", *args, **kwargs) 12 | 13 | def debug_tensor(name, tensor): 14 | """Print tensor info for debugging""" 15 | if debug_enabled(): 16 | if tensor is None: 17 | print(f"[MagicWan Debug] {name} is None") 18 | elif isinstance(tensor, torch.Tensor): 19 | print(f"[MagicWan Debug] {name}: shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device}") 20 | print(f" - min/max/mean: {tensor.min().item():.4f}/{tensor.max().item():.4f}/{tensor.mean().item():.4f}") 21 | if tensor.numel() < 10: 22 | print(f" - values: {tensor.tolist()}") 23 | elif tensor.dim() <= 2: 24 | print(f" - first few values: {tensor.flatten()[:5].tolist()}") 25 | else: 26 | print(f"[MagicWan Debug] {name}: {type(tensor)} (not a tensor)") 27 | --------------------------------------------------------------------------------