├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── __init__.py ├── example ├── PuLID_with_FBcache.png ├── PuLID_with_teacache.png ├── workflow_base.png ├── workflow_hunyuanvideo.png └── workflow_ltxvideo.png ├── nodes ├── DitPatchNode.py ├── FluxPatchNode.py ├── ParaAttentionNode.py ├── TeaCacheNode.py ├── VideoPatchNode.py ├── __init__.py ├── node_utils.py ├── patch_lib │ ├── FluxPatch.py │ ├── HunYuanVideoPatch.py │ ├── LTXVideoPatch.py │ ├── MochiVideoPatch.py │ ├── WanVideoPatch.py │ ├── __init__.py │ └── old │ │ ├── HunYuanVideoPatch.py │ │ └── LTXVideoPatch.py └── patch_util.py ├── pyproject.toml └── requirements.txt /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'lldacing' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 lldacing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | The code and models of BiRefNet are released under the MIT License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [中文文档](README_CN.md) 2 | 3 | Add some hooks method. Such as: `TeaCache` and `First Block Cache` for `PuLID-Flux` `Flux` `HunYuanVideo` `LTXVideo` `MochiVideo` `WanVideo`. 4 | 5 | Need upgrade ComfyUI Version>=0.3.17 6 | 7 | ## Preview (Image with WorkFlow) 8 | ![save api extended](example/workflow_base.png) 9 | 10 | Working with `PuLID` (need my other custom nodes [ComfyUI_PuLID_Flux_ll](https://github.com/lldacing/ComfyUI_PuLID_Flux_ll)) 11 | ![save api extended](example/PuLID_with_teacache.png) 12 | 13 | 14 | ## Install 15 | 16 | - Manual 17 | ```shell 18 | cd custom_nodes 19 | git clone https://github.com/lldacing/ComfyUI_Patches_ll.git 20 | # restart ComfyUI 21 | ``` 22 | 23 | ## Nodes 24 | - FluxForwardOverrider 25 | - Add some hooks method support to the `Flux` model 26 | - VideoForwardOverrider 27 | - Add some hooks method support to the video model. Support `HunYuanVideo`, `LTXVideo`, `MochiVideo`, `WanVideo` 28 | - DitForwardOverrider 29 | - Auto add some hooks method for model (automatically identify model type). Support `Flux`, `HunYuanVideo`, `LTXVideo`, `MochiVideo`, `WanVideo` 30 | - ApplyTeaCachePatch 31 | - Use the `hooks` provided in `*ForwardOverrider` to support `TeaCache` acceleration. Support `Flux`, `HunYuanVideo`, `LTXVideo`, `MochiVideo`, `WanVideo` 32 | - In my test results, the video quality is not good after acceleration for `MochiVideo` 33 | - ApplyTeaCachePatchAdvanced 34 | - Support `start_at` and `end_at` 35 | - ApplyFirstBlockCachePatch 36 | - Use the `hooks` provided in `*ForwardOverrider` to support `First Block Cache` acceleration. Support `Flux`, `HunYuanVideo`, `LTXVideo`, `MochiVideo`, `WanVideo` 37 | - In my test results, the video quality is not good after acceleration for `MochiVideo` 38 | - ApplyFirstBlockCachePatchAdvanced 39 | - Support `start_at` and `end_at` 40 | 41 | ## SpeedUp reference 42 | ### TeaCache (rel_l1_thresh value) 43 | | | Original | 1.5x | 1.8x | 2.0x | 44 | |--------------|----------|------|------|------| 45 | | Flux | 0 | 0.25 | 0.4 | 0.6 | 46 | | HunYuanVideo | 0 | 0.1 | - | 0.15 | 47 | | LTXVideo | 0 | 0.03 | - | 0.05 | 48 | | MochiVideo | 0 | 0.06 | - | 0.09 | 49 | | WanVideo | 0 | - | - | - | 50 | 51 | Note: "-" indicates small speedup, low quality or untested. WanVideo's different models have different acceleration effects. 52 | 53 | ### First Block Cache (residual_diff_threshold value) 54 | | | Original | 1.2x | 1.5x | 1.8x | 55 | |--------------|----------|------|------|------| 56 | | Flux | 0 | - | - | 0.12 | 57 | | HunYuanVideo | 0 | - | 0.1 | - | 58 | | LTXVideo | 0 | 0.05 | - | - | 59 | | MochiVideo | 0 | - | 0.03 | - | 60 | | WanVideo | 0 | - | 0.05 | - | 61 | 62 | Note: "-" indicates small speedup, low quality or untested. 63 | 64 | 65 | ## Thanks 66 | 67 | [TeaCache](https://github.com/ali-vilab/TeaCache) 68 | [ParaAttention](https://github.com/chengzeyi/ParaAttention) 69 | [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed) 70 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [English](README.md) 2 | 3 | 添加一些钩子方法。例如支持`TeaCache`和`First Block Cache`加速`PulID-Flux`、`Flux`、`混元视频`、`LTXVideo`、`MochiVideo`、`WanVideo`。 4 | 5 | ComfyUI主体版本需要>=0.3.17 6 | 7 | ## 预览 (图片含工作流) 8 | ![save api extended](example/workflow_base.png) 9 | 10 | 加速`PuLID` (需要配合我的另一插件 [ComfyUI_PuLID_Flux_ll](https://github.com/lldacing/ComfyUI_PuLID_Flux_ll)使用) 11 | ![save api extended](example/PuLID_with_teacache.png) 12 | 13 | 14 | ## 安装 15 | 16 | - 手动安装 17 | ```shell 18 | cd custom_nodes 19 | git clone https://github.com/lldacing/ComfyUI_Patches_ll.git 20 | # restart ComfyUI 21 | ``` 22 | 23 | ## 节点 24 | - FluxForwardOverrider 25 | - 为`Flux`模型增加一些`hook`方法 26 | - VideoForwardOverrider 27 | - 为视频模型添加一些`hook`方法. 支持 `HunYuanVideo`、 `LTXVideo`、`MochiVideo`、`WanVideo` 28 | - DitForwardOverrider 29 | - 为Dit架构模型增加一些`hook`方法(自动识别模型类型). 支持 `Flux`、 `HunYuanVideo`、 `LTXVideo`、`MochiVideo`、`WanVideo` 30 | - ApplyTeaCachePatch 31 | - 使用`*ForwardOverrider`中支持的`hook`方法提供`TeaCache`加速,支持 `Flux`、 `HunYuanVideo`、 `LTXVideo`、`MochiVideo`、`WanVideo` 32 | - 我测试结果,`MochiVideo`可能加速失败,加速后视频质量不太好,可能出现全黑视频 33 | - ApplyTeaCachePatchAdvanced 34 | - 支持设置 `start_at` 和 `end_at` 35 | - ApplyFirstBlockCachePatch 36 | - 使用`*ForwardOverrider`中支持的`hook`方法提供`First Block Cache`加速,支持 `Flux`、 `HunYuanVideo`、 `LTXVideo`、`MochiVideo`、`WanVideo` 37 | - 我测试结果,`MochiVideo`可能加速失败,加速后视频质量不太好,可能出现全黑视频 38 | - ApplyFirstBlockCachePatchAdvanced 39 | - 支持设置 `start_at` 和 `end_at` 40 | 41 | ## 加速参考 42 | ### TeaCache (rel_l1_thresh值) 43 | | | 原始速度 | 1.5x | 1.8x | 2.0x | 44 | |--------------|------|------|------|------| 45 | | Flux | 0 | 0.25 | 0.4 | 0.6 | 46 | | HunYuanVideo | 0 | 0.1 | - | 0.15 | 47 | | LTXVideo | 0 | 0.03 | - | 0.05 | 48 | | MochiVideo | 0 | 0.06 | - | 0.09 | 49 | | WanVideo | 0 | - | - | - | 50 | 51 | 注: "-" 表示加速不明显、低质量或未测试。WanVideo的不同模型加速效果有差异。 52 | 53 | ### First Block Cache (residual_diff_threshold value) 54 | | | 原始速度 | 1.2x | 1.5x | 1.8x | 55 | |--------------|------|------|------|------| 56 | | Flux | 0 | - | - | 0.12 | 57 | | HunYuanVideo | 0 | - | 0.1 | - | 58 | | LTXVideo | 0 | 0.05 | - | - | 59 | | MochiVideo | 0 | - | 0.03 | - | 60 | | WanVideo | 0 | - | 0.05 | - | 61 | 62 | 注: "-" 表示加速不明显、低质量或未测试。 63 | 64 | ## 感谢 65 | 66 | [TeaCache](https://github.com/ali-vilab/TeaCache) 67 | [ParaAttention](https://github.com/chengzeyi/ParaAttention) 68 | [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed) 69 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib.util 3 | import os 4 | 5 | extension_folder = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | NODE_CLASS_MAPPINGS = {} 8 | NODE_DISPLAY_NAME_MAPPINGS = {} 9 | 10 | pyPath = os.path.join(extension_folder, 'nodes') 11 | 12 | def loadCustomNodes(): 13 | files = glob.glob(os.path.join(pyPath, "*Node.py"), recursive=True) 14 | for file in files: 15 | file_relative_path = file[len(extension_folder):] 16 | model_name = file_relative_path.replace(os.sep, '.') 17 | model_name = os.path.splitext(model_name)[0] 18 | module = importlib.import_module(model_name, __name__) 19 | if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: 20 | NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) 21 | if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: 22 | NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) 23 | 24 | loadCustomNodes() 25 | 26 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 27 | -------------------------------------------------------------------------------- /example/PuLID_with_FBcache.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/example/PuLID_with_FBcache.png -------------------------------------------------------------------------------- /example/PuLID_with_teacache.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/example/PuLID_with_teacache.png -------------------------------------------------------------------------------- /example/workflow_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/example/workflow_base.png -------------------------------------------------------------------------------- /example/workflow_hunyuanvideo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/example/workflow_hunyuanvideo.png -------------------------------------------------------------------------------- /example/workflow_ltxvideo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/example/workflow_ltxvideo.png -------------------------------------------------------------------------------- /nodes/DitPatchNode.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | from .VideoPatchNode import video_outer_sample_function_wrapper 3 | from .FluxPatchNode import flux_outer_sample_function_wrapper 4 | from .patch_util import is_flux_model, is_hunyuan_video_model, is_ltxv_video_model, is_mochi_video_model, \ 5 | is_wan_video_model 6 | 7 | 8 | class DitForwardOverrider: 9 | 10 | @classmethod 11 | def INPUT_TYPES(cls): 12 | return { 13 | "required": { 14 | "model": ("MODEL",), 15 | } 16 | } 17 | 18 | RETURN_TYPES = ("MODEL",) 19 | RETURN_NAMES = ("model",) 20 | FUNCTION = "apply_patch" 21 | CATEGORY = "patches/dit" 22 | DESCRIPTION = "Support Flux, HunYuanVideo, LTXVideo, MochiVideo" 23 | 24 | def apply_patch(self, model): 25 | 26 | model = model.clone() 27 | patch_key = "dit_forward_override_wrapper" 28 | diffusion_model = model.get_model_object('diffusion_model') 29 | if is_flux_model(diffusion_model): 30 | if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) == 0: 31 | # Just add it once when connecting in series 32 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 33 | patch_key, 34 | flux_outer_sample_function_wrapper 35 | ) 36 | elif is_hunyuan_video_model(diffusion_model) or is_ltxv_video_model(diffusion_model) or is_mochi_video_model(diffusion_model)\ 37 | or is_wan_video_model(diffusion_model): 38 | if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) == 0: 39 | # Just add it once when connecting in series 40 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 41 | patch_key, 42 | video_outer_sample_function_wrapper 43 | ) 44 | return (model, ) 45 | 46 | 47 | NODE_CLASS_MAPPINGS = { 48 | "DitForwardOverrider": DitForwardOverrider, 49 | } 50 | 51 | NODE_DISPLAY_NAME_MAPPINGS = { 52 | "DitForwardOverrider": "DitForwardOverrider", 53 | } 54 | -------------------------------------------------------------------------------- /nodes/FluxPatchNode.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | from .node_utils import get_old_method_name, get_new_forward_orig 3 | from .patch_util import set_hook, clean_hook, is_flux_model 4 | 5 | def flux_outer_sample_function_wrapper(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, 6 | callback=None, disable_pbar=False, seed=None): 7 | cfg_guider = wrapper_executor.class_obj 8 | diffusion_model = cfg_guider.model_patcher.model.diffusion_model 9 | # set hook 10 | set_hook(diffusion_model, 'flux_old_forward_orig', get_new_forward_orig(diffusion_model), get_old_method_name(diffusion_model)) 11 | 12 | try: 13 | out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=callback, 14 | disable_pbar=disable_pbar, seed=seed) 15 | finally: 16 | # cleanup hook 17 | clean_hook(diffusion_model, 'flux_old_forward_orig') 18 | return out 19 | 20 | 21 | class FluxForwardOverrider: 22 | 23 | @classmethod 24 | def INPUT_TYPES(cls): 25 | return { 26 | "required": { 27 | "model": ("MODEL",), 28 | } 29 | } 30 | 31 | RETURN_TYPES = ("MODEL",) 32 | RETURN_NAMES = ("model",) 33 | FUNCTION = "apply_patch" 34 | CATEGORY = "patches/dit" 35 | 36 | def apply_patch(self, model): 37 | 38 | model = model.clone() 39 | if is_flux_model(model.get_model_object('diffusion_model')): 40 | patch_key = "flux_forward_override_wrapper" 41 | if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) == 0: 42 | # Just add it once when connecting in series 43 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 44 | patch_key, 45 | flux_outer_sample_function_wrapper 46 | ) 47 | return (model, ) 48 | 49 | 50 | NODE_CLASS_MAPPINGS = { 51 | "FluxForwardOverrider": FluxForwardOverrider, 52 | } 53 | 54 | NODE_DISPLAY_NAME_MAPPINGS = { 55 | "FluxForwardOverrider": "FluxForwardOverrider", 56 | } 57 | -------------------------------------------------------------------------------- /nodes/ParaAttentionNode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import comfy 4 | from .patch_util import PatchKeys, add_model_patch_option, set_model_patch, set_model_patch_replace, \ 5 | is_hunyuan_video_model, is_flux_model, is_ltxv_video_model, is_mochi_video_model, is_wan_video_model 6 | 7 | fb_cache_key_attrs = "fb_cache_attr" 8 | fb_cache_model_temp = "flux_fb_cache" 9 | 10 | def get_fb_cache_global_cache(transformer_options, timesteps): 11 | diffusion_model = transformer_options.get(PatchKeys.running_net_model) 12 | if hasattr(diffusion_model, fb_cache_model_temp): 13 | tea_cache = getattr(diffusion_model, fb_cache_model_temp, {}) 14 | transformer_options[fb_cache_key_attrs] = tea_cache 15 | 16 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 17 | attrs['step_i'] = timesteps[0].detach().cpu().item() 18 | 19 | def fb_cache_enter_for_wanvideo(x, timestep, context, transformer_options): 20 | get_fb_cache_global_cache(transformer_options, timestep) 21 | return x, timestep, context 22 | 23 | def fb_cache_enter_for_mochivideo(x, timestep, context, attention_mask, num_tokens, transformer_options): 24 | get_fb_cache_global_cache(transformer_options, timestep) 25 | return x, timestep, context, attention_mask, num_tokens 26 | 27 | def fb_cache_enter_for_ltxvideo(x, timestep, context, attention_mask, frame_rate, guiding_latent, guiding_latent_noise_scale, transformer_options): 28 | get_fb_cache_global_cache(transformer_options, timestep) 29 | return x, timestep, context, attention_mask, frame_rate, guiding_latent, guiding_latent_noise_scale 30 | 31 | # For Flux and HunYuanVideo 32 | def fb_cache_enter(img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask, transformer_options): 33 | get_fb_cache_global_cache(transformer_options, timesteps) 34 | return img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask 35 | 36 | def are_two_tensors_similar(t1, t2, *, threshold): 37 | if t1.shape != t2.shape: 38 | return False 39 | mean_diff = (t1 - t2).abs().mean() 40 | mean_t1 = t1.abs().mean() 41 | diff = mean_diff / mean_t1 42 | return diff.item() < threshold 43 | 44 | def fb_cache_patch_double_block_with_control_replace(original_args, wrapper_options): 45 | transformer_options = wrapper_options.get('transformer_options', {}) 46 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 47 | step_i = attrs['step_i'] 48 | timestep_start = attrs['timestep_start'] 49 | timestep_end = attrs['timestep_end'] 50 | in_step = timestep_end <= step_i <=timestep_start 51 | if not in_step: 52 | attrs['should_calc'] = True 53 | return wrapper_options.get('original_func')(**original_args, transformer_options=transformer_options) 54 | 55 | block_i = original_args['i'] 56 | txt = original_args['txt'] 57 | if block_i == 0: 58 | # 与上一次采样中的first double block输出比较,绝对均值差值小于threshold则可以使用缓存 59 | img, txt = wrapper_options.get('original_func')(**original_args, transformer_options=transformer_options) 60 | 61 | previous_first_block_residual = attrs.get('previous_first_block_residual') 62 | if previous_first_block_residual is not None: 63 | should_calc = not are_two_tensors_similar(previous_first_block_residual, img, threshold=attrs['rel_diff_threshold']) 64 | else: 65 | # 需要计算,即:不使用缓存 66 | should_calc = True 67 | 68 | if should_calc: 69 | attrs['previous_first_block_residual'] = img.clone() 70 | else: 71 | # 上次非缓存采样值 72 | previous_residual = attrs.get('previous_residual') 73 | if previous_residual is not None: 74 | img += previous_residual 75 | 76 | attrs['should_calc'] = should_calc 77 | attrs['ori_img'] = None 78 | else: 79 | img = original_args['img'] 80 | should_calc = attrs['should_calc'] 81 | if should_calc: 82 | if attrs['ori_img'] is None: 83 | attrs['ori_img'] = original_args['img'].clone() 84 | if block_i > 0: 85 | img, txt = wrapper_options.get('original_func')(**original_args, transformer_options=transformer_options) 86 | 87 | del attrs, transformer_options 88 | return img, txt 89 | 90 | def fb_cache_patch_blocks_transition_replace(original_args, wrapper_options): 91 | img = original_args['img'] 92 | transformer_options = wrapper_options.get('transformer_options', {}) 93 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 94 | should_calc = attrs.get('should_calc', True) 95 | if should_calc: 96 | img = wrapper_options.get('original_func')(**original_args, transformer_options=transformer_options) 97 | return img 98 | 99 | def fb_cache_patch_single_blocks_replace(original_args, wrapper_options): 100 | img = original_args['img'] 101 | txt = original_args['txt'] 102 | transformer_options = wrapper_options.get('transformer_options', {}) 103 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 104 | should_calc = attrs.get('should_calc', True) 105 | if should_calc: 106 | img = wrapper_options.get('original_blocks')(**original_args, transformer_options=transformer_options) 107 | return img, txt 108 | 109 | def fb_cache_patch_blocks_after_replace(original_args, wrapper_options): 110 | img = original_args['img'] 111 | transformer_options = wrapper_options.get('transformer_options', {}) 112 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 113 | should_calc = attrs.get('should_calc', True) 114 | if should_calc: 115 | img = wrapper_options.get('original_func')(**original_args) 116 | return img 117 | 118 | def fb_cache_patch_final_transition_after(img, txt, transformer_options): 119 | attrs = transformer_options.get(fb_cache_key_attrs, {}) 120 | should_calc = attrs.get('should_calc', True) 121 | if should_calc: 122 | if attrs.get('ori_img') is not None: 123 | attrs['previous_residual'] = img - attrs['ori_img'] 124 | return img 125 | 126 | def fb_cache_patch_dit_exit(img, transformer_options): 127 | tea_cache = transformer_options.get(fb_cache_key_attrs, {}) 128 | setattr(transformer_options.get(PatchKeys.running_net_model), fb_cache_model_temp, tea_cache) 129 | return img 130 | 131 | def fb_cache_prepare_wrapper(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, 132 | callback=None, disable_pbar=False, seed=None): 133 | cfg_guider = wrapper_executor.class_obj 134 | 135 | try: 136 | out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=callback, 137 | disable_pbar=disable_pbar, seed=seed) 138 | finally: 139 | diffusion_model = cfg_guider.model_patcher.model.diffusion_model 140 | if hasattr(diffusion_model, fb_cache_model_temp): 141 | delattr(diffusion_model, fb_cache_model_temp) 142 | 143 | return out 144 | 145 | class ApplyFirstBlockCachePatchAdvanced: 146 | 147 | @classmethod 148 | def INPUT_TYPES(cls): 149 | return { 150 | "required": { 151 | "model": ("MODEL",), 152 | "residual_diff_threshold": ("FLOAT", 153 | { 154 | "default": 0.00, 155 | "min": 0.0, 156 | "max": 1.0, 157 | "step": 0.01, 158 | "tooltip": "Flux: 0 (original), 0.12 (1.8x speedup).\n" 159 | "HunYuanVideo: 0 (original), 0.1 (1.6x speedup).\n" 160 | "LTXVideo: 0 (original), 0.5 (1.2x speedup).\n" 161 | "MochiVideo: 0 (original), 0.03 (1.5x speedup).\n" 162 | "WanVideo: 0 (original), 0.05 (1.5x speedup)." 163 | }), 164 | "start_at": ("FLOAT", 165 | { 166 | "default": 0.0, 167 | "step": 0.01, 168 | "max": 1.0, 169 | "min": 0.0 170 | } 171 | ), 172 | "end_at": ("FLOAT", 173 | { 174 | "default": 1.0, 175 | "step": 0.01, 176 | "max": 1.0, 177 | "min": 0.0 178 | }) 179 | } 180 | } 181 | 182 | RETURN_TYPES = ("MODEL",) 183 | RETURN_NAMES = ("model",) 184 | FUNCTION = "apply_patch_advanced" 185 | CATEGORY = "patches/speed" 186 | DESCRIPTION = ("Apply the First Block Cache patch to accelerate the model. Use it together with nodes that have the suffix ForwardOverrider." 187 | "\nThis is effective only for Flux, HunYuanVideo, LTXVideo, WanVideo and MochiVideo.") 188 | 189 | def apply_patch_advanced(self, model, residual_diff_threshold, start_at=0.0, end_at=1.0): 190 | 191 | model = model.clone() 192 | patch_key = "fb_cache_wrapper" 193 | if residual_diff_threshold == 0 or len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) > 0: 194 | return (model,) 195 | 196 | diffusion_model = model.get_model_object('diffusion_model') 197 | if not is_flux_model(diffusion_model) and not is_hunyuan_video_model(diffusion_model) and not is_ltxv_video_model(diffusion_model)\ 198 | and not is_mochi_video_model(diffusion_model) and not is_wan_video_model(diffusion_model): 199 | logging.warning("First Block Cache patch is not applied because the model is not supported.") 200 | return (model,) 201 | 202 | fb_cache_attrs = add_model_patch_option(model, fb_cache_key_attrs) 203 | 204 | fb_cache_attrs['rel_diff_threshold'] = residual_diff_threshold 205 | model_sampling = model.get_model_object("model_sampling") 206 | sigma_start = model_sampling.percent_to_sigma(start_at) 207 | sigma_end = model_sampling.percent_to_sigma(end_at) 208 | fb_cache_attrs['timestep_start'] = model_sampling.timestep(sigma_start) 209 | fb_cache_attrs['timestep_end'] = model_sampling.timestep(sigma_end) 210 | 211 | if is_ltxv_video_model(diffusion_model): 212 | set_model_patch(model, PatchKeys.options_key, fb_cache_enter_for_ltxvideo, PatchKeys.dit_enter) 213 | elif is_mochi_video_model(diffusion_model): 214 | set_model_patch(model, PatchKeys.options_key, fb_cache_enter_for_mochivideo, PatchKeys.dit_enter) 215 | elif is_wan_video_model(diffusion_model): 216 | set_model_patch(model, PatchKeys.options_key, fb_cache_enter_for_wanvideo, PatchKeys.dit_enter) 217 | else: 218 | set_model_patch(model, PatchKeys.options_key, fb_cache_enter, PatchKeys.dit_enter) 219 | 220 | set_model_patch_replace(model, PatchKeys.options_key, fb_cache_patch_double_block_with_control_replace, PatchKeys.dit_double_block_with_control_replace) 221 | set_model_patch_replace(model, PatchKeys.options_key, fb_cache_patch_blocks_transition_replace, PatchKeys.dit_blocks_transition_replace) 222 | set_model_patch_replace(model, PatchKeys.options_key, fb_cache_patch_single_blocks_replace, PatchKeys.dit_single_blocks_replace) 223 | set_model_patch_replace(model, PatchKeys.options_key, fb_cache_patch_blocks_after_replace, PatchKeys.dit_blocks_after_transition_replace) 224 | 225 | set_model_patch(model, PatchKeys.options_key, fb_cache_patch_final_transition_after, PatchKeys.dit_final_layer_before) 226 | set_model_patch(model, PatchKeys.options_key, fb_cache_patch_dit_exit, PatchKeys.dit_exit) 227 | 228 | # Just add it once when connecting in series 229 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 230 | patch_key, 231 | fb_cache_prepare_wrapper 232 | ) 233 | return (model, ) 234 | 235 | class ApplyFirstBlockCachePatch(ApplyFirstBlockCachePatchAdvanced): 236 | 237 | @classmethod 238 | def INPUT_TYPES(cls): 239 | return { 240 | "required": { 241 | "model": ("MODEL",), 242 | "residual_diff_threshold": ("FLOAT", 243 | { 244 | "default": 0.00, 245 | "min": 0.0, 246 | "max": 1.0, 247 | "step": 0.01, 248 | "tooltip": "Flux: 0 (original), 0.12 (1.8x speedup).\n" 249 | "HunYuanVideo: 0 (original), 0.1 (1.6x speedup).\n" 250 | "LTXVideo: 0 (original), 0.05 (1.2x speedup).\n" 251 | "MochiVideo: 0 (original), 0.03 (1.5x speedup).\n" 252 | "WanVideo: 0 (original), 0.05 (1.5x speedup)." 253 | }) 254 | } 255 | } 256 | 257 | RETURN_TYPES = ("MODEL",) 258 | RETURN_NAMES = ("model",) 259 | FUNCTION = "apply_patch" 260 | CATEGORY = "patches/speed" 261 | DESCRIPTION = ("Apply the First Block Cache patch to accelerate the model. Use it together with nodes that have the suffix ForwardOverrider." 262 | "\nThis is effective only for Flux, HunYuanVideo, LTXVideo, WanVideo and MochiVideo.") 263 | 264 | def apply_patch(self, model, residual_diff_threshold): 265 | return super().apply_patch_advanced(model, residual_diff_threshold, start_at=0.0, end_at=1.0) 266 | 267 | NODE_CLASS_MAPPINGS = { 268 | "ApplyFirstBlockCachePatch": ApplyFirstBlockCachePatch, 269 | "ApplyFirstBlockCachePatchAdvanced": ApplyFirstBlockCachePatchAdvanced, 270 | } 271 | 272 | NODE_DISPLAY_NAME_MAPPINGS = { 273 | "ApplyFirstBlockCachePatch": "ApplyFirstBlockCachePatch", 274 | "ApplyFirstBlockCachePatchAdvanced": "ApplyFirstBlockCachePatchAdvanced", 275 | } 276 | -------------------------------------------------------------------------------- /nodes/TeaCacheNode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | import comfy 7 | from .patch_util import PatchKeys, add_model_patch_option, set_model_patch, set_model_patch_replace, \ 8 | is_hunyuan_video_model, is_flux_model, is_ltxv_video_model, is_mochi_video_model, is_wan_video_model 9 | 10 | tea_cache_key_attrs = "tea_cache_attr" 11 | # https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4FLUX/teacache_flux.py 12 | # https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4HunyuanVideo/teacache_sample_video.py 13 | # https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4LTX-Video/teacache_ltx.py 14 | # https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4Mochi/teacache_mochi.py 15 | # https://github.com/ali-vilab/TeaCache/blob/main/TeaCache4Wan2.1/teacache_generate.py 16 | coefficients_obj = { 17 | 'Flux': [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], 18 | 'HunYuanVideo': [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02], 19 | 'LTXVideo': [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03], 20 | 'MochiVideo': [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03], 21 | # Supports 480P 22 | 'WanVideo_t2v_1.3B': [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], 23 | # Supports both 480P and 720P 24 | 'WanVideo_t2v_14B': [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], 25 | # Supports 480P 26 | 'WanVideo_i2v_14B_480P': [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], 27 | # Supports 720P 28 | 'WanVideo_i2v_14B_720P': [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], 29 | 'WanVideo_disabled': [], 30 | } 31 | 32 | def get_teacache_global_cache(transformer_options, timesteps): 33 | diffusion_model = transformer_options.get(PatchKeys.running_net_model) 34 | if hasattr(diffusion_model, "flux_tea_cache"): 35 | tea_cache = getattr(diffusion_model, "flux_tea_cache", {}) 36 | transformer_options[tea_cache_key_attrs] = tea_cache 37 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 38 | attrs['step_i'] = timesteps[0].detach().cpu().item() 39 | # print(str(attrs['step_i'])) 40 | 41 | def tea_cache_enter_for_wanvideo(x, timestep, context, transformer_options, **kwargs): 42 | get_teacache_global_cache(transformer_options, timestep) 43 | return x, timestep, context 44 | 45 | def tea_cache_enter_for_mochivideo(x, timestep, context, attention_mask, num_tokens, transformer_options, **kwargs): 46 | get_teacache_global_cache(transformer_options, timestep) 47 | return x, timestep, context, attention_mask, num_tokens 48 | 49 | def tea_cache_enter_for_ltxvideo(x, timestep, context, attention_mask, frame_rate, guiding_latent, guiding_latent_noise_scale, transformer_options, **kwargs): 50 | get_teacache_global_cache(transformer_options, timestep) 51 | return x, timestep, context, attention_mask, frame_rate, guiding_latent, guiding_latent_noise_scale 52 | 53 | # For Flux and HunYuanVideo 54 | def tea_cache_enter(img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask, transformer_options, **kwargs): 55 | get_teacache_global_cache(transformer_options, timesteps) 56 | return img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask 57 | 58 | def tea_cache_patch_blocks_before(img, txt, vec, ids, pe, transformer_options, **kwargs): 59 | real_model = transformer_options[PatchKeys.running_net_model] 60 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 61 | step_i = attrs['step_i'] 62 | timestep_start = attrs['timestep_start'] 63 | timestep_end = attrs['timestep_end'] 64 | in_step = timestep_end <= step_i <= timestep_start 65 | # print(str(timestep_end)+' '+ str(step_i)+' '+str(timestep_start)) 66 | 67 | # kijai版本TeaCache和TeaCache官方实现相结合在质量和速度上是最好的(即KJ-Nodes中的实现) 68 | # TeaCache官方实现只计算了cond的accumulated_rel_l1_distance,没有计算uncond的accumulated_rel_l1_distance 69 | accumulated_state = attrs.get('accumulated_state', { 70 | "x": {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'skipped_steps': 0, 'previous_residual': None}, 71 | 'cond': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'skipped_steps': 0, 'previous_residual': None}, 72 | 'uncond': {'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'skipped_steps': 0, 'previous_residual': None} 73 | }) 74 | teacache_enabled = attrs['rel_l1_thresh'] > 0 and in_step 75 | attrs['cache_enabled'] = teacache_enabled 76 | current_state_type = 'x' 77 | should_calc = True 78 | if teacache_enabled: 79 | inp = img 80 | vec_ = vec 81 | rescale_func_flag = True 82 | # split_cnd_flag=True是生效 83 | coefficient_type = 'Flux' 84 | if is_ltxv_video_model(real_model): 85 | coefficient_type = 'LTXVideo' 86 | modulated_inp = comfy.ldm.common_dit.rms_norm(inp) 87 | double_block_0 = real_model.transformer_blocks[0] 88 | num_ada_params = double_block_0.scale_shift_table.shape[0] 89 | ada_values = double_block_0.scale_shift_table[None, None] + vec_.reshape(img.shape[0], vec_.shape[1], num_ada_params, -1) 90 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) 91 | modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa 92 | elif is_mochi_video_model(real_model): 93 | coefficient_type = 'MochiVideo' 94 | double_block_0 = real_model.blocks[0] 95 | mod_x = double_block_0.mod_x(F.silu(vec_)) 96 | scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1) 97 | # copied from comfy.ldm.genmo.joint_model.asymm_models_joint.modulated_rmsnorm 98 | modulated_inp = comfy.ldm.common_dit.rms_norm(inp) 99 | modulated_inp = modulated_inp * (1 + scale_msa_x.unsqueeze(1)) 100 | elif is_wan_video_model(real_model): 101 | coefficient_type = attrs.get("wan_coefficients_type", 'disabled') 102 | if coefficient_type == 'disabled': 103 | # e0 104 | rescale_func_flag = False 105 | else: 106 | vec_ = kwargs.get('e') 107 | modulated_inp = vec_ 108 | coefficient_type = 'WanVideo_' + coefficient_type 109 | is_cond_flag = True if transformer_options["cond_or_uncond"] == [0] else False 110 | current_state_type = 'cond' if is_cond_flag else 'uncond' 111 | else: 112 | double_block_0 = real_model.double_blocks[0] 113 | img_mod1, img_mod2 = double_block_0.img_mod(vec_) 114 | modulated_inp = double_block_0.img_norm1(inp) 115 | if is_hunyuan_video_model(real_model): 116 | coefficient_type = 'HunYuanVideo' 117 | # if img_mod1.scale is None and img_mod1.shift is None: 118 | # pass 119 | # elif img_mod1.shift is None: 120 | # modulated_inp = modulated_inp * (1 + img_mod1.scale) 121 | # elif img_mod1.scale is None: 122 | # modulated_inp = modulated_inp + img_mod1.shift 123 | # else: 124 | # modulated_inp = modulated_inp * (1 + img_mod1.scale) + img_mod1.shift 125 | if img_mod1.scale is not None: 126 | modulated_inp = modulated_inp * (1 + img_mod1.scale) 127 | if img_mod1.shift is not None: 128 | modulated_inp = modulated_inp + img_mod1.shift 129 | else: 130 | # Flux 131 | modulated_inp = (1 + img_mod1.scale) * modulated_inp + img_mod1.shift 132 | 133 | current_state = accumulated_state[current_state_type] 134 | 135 | if current_state.get('previous_modulated_input', None) is None or attrs['cnt'] == 0 or attrs['cnt'] == attrs['total_steps'] - 1: 136 | should_calc = True 137 | current_state['accumulated_rel_l1_distance'] = 0 138 | else: 139 | if rescale_func_flag: 140 | coefficients = coefficients_obj[coefficient_type] 141 | rescale_func = np.poly1d(coefficients) 142 | current_state['accumulated_rel_l1_distance'] += rescale_func(((modulated_inp - current_state['previous_modulated_input']).abs().mean() / current_state['previous_modulated_input'].abs().mean()).cpu().item()) 143 | else: 144 | current_state['accumulated_rel_l1_distance'] += ((modulated_inp - current_state['previous_modulated_input']).abs().mean() / current_state['previous_modulated_input'].abs().mean()).cpu().item() 145 | 146 | if current_state['accumulated_rel_l1_distance'] < attrs['rel_l1_thresh']: 147 | should_calc = False 148 | else: 149 | should_calc = True 150 | current_state['accumulated_rel_l1_distance'] = 0 151 | 152 | current_state['previous_modulated_input'] = modulated_inp.clone().detach() 153 | 154 | attrs['cnt'] += 1 155 | if attrs['cnt'] == attrs['total_steps']: 156 | attrs['cnt'] = 0 157 | del inp, vec_ 158 | else: 159 | # 设置了start_at场景需要初始化 160 | if is_wan_video_model(real_model): 161 | current_state_type = 'cond' if transformer_options["cond_or_uncond"] == [0] else 'uncond' 162 | 163 | attrs['should_calc'] = should_calc 164 | attrs['accumulated_state'] = accumulated_state 165 | attrs['current_state_type'] = current_state_type 166 | del real_model 167 | return img, txt, vec, ids, pe 168 | 169 | def tea_cache_patch_double_blocks_replace(original_args, wrapper_options): 170 | img = original_args['img'] 171 | txt = original_args['txt'] 172 | transformer_options = wrapper_options.get('transformer_options', {}) 173 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 174 | 175 | should_calc = not attrs.get('cache_enabled') or attrs.get('should_calc', True) 176 | if not should_calc: 177 | current_state = attrs['accumulated_state'][attrs['current_state_type']] 178 | img += current_state['previous_residual'].to(img.device) 179 | current_state['skipped_steps'] += 1 180 | else: 181 | # (b, seq_len, _) 182 | attrs['ori_img'] = img.clone().detach() 183 | img, txt = wrapper_options.get('original_blocks')(**original_args, transformer_options=transformer_options) 184 | return img, txt 185 | 186 | def tea_cache_patch_blocks_transition_replace(original_args, wrapper_options): 187 | img = original_args['img'] 188 | transformer_options = wrapper_options.get('transformer_options', {}) 189 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 190 | should_calc = not attrs.get('cache_enabled', False) or attrs.get('should_calc', True) 191 | if should_calc: 192 | img = wrapper_options.get('original_func')(**original_args, transformer_options=transformer_options) 193 | return img 194 | 195 | def tea_cache_patch_single_blocks_replace(original_args, wrapper_options): 196 | img = original_args['img'] 197 | txt = original_args['txt'] 198 | transformer_options = wrapper_options.get('transformer_options', {}) 199 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 200 | should_calc = not attrs.get('cache_enabled', False) or attrs.get('should_calc', True) 201 | if should_calc: 202 | img = wrapper_options.get('original_blocks')(**original_args, transformer_options=transformer_options) 203 | return img, txt 204 | 205 | def tea_cache_patch_blocks_after_replace(original_args, wrapper_options): 206 | img = original_args['img'] 207 | transformer_options = wrapper_options.get('transformer_options', {}) 208 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 209 | should_calc = not attrs.get('cache_enabled', False) or attrs.get('should_calc', True) 210 | if should_calc: 211 | img = wrapper_options.get('original_func')(**original_args) 212 | return img 213 | 214 | def tea_cache_patch_final_transition_after(img, txt, transformer_options): 215 | attrs = transformer_options.get(tea_cache_key_attrs, {}) 216 | should_calc = not attrs.get('cache_enabled', False) or attrs.get('should_calc', True) 217 | if should_calc: 218 | current_state = attrs['accumulated_state'][attrs['current_state_type']] 219 | current_state['previous_residual'] = (img - attrs['ori_img']).to(attrs['cache_device']) 220 | return img 221 | 222 | def tea_cache_patch_dit_exit(img, transformer_options): 223 | tea_cache = transformer_options.get(tea_cache_key_attrs, {}) 224 | setattr(transformer_options.get(PatchKeys.running_net_model), "flux_tea_cache", tea_cache) 225 | return img 226 | 227 | def tea_cache_prepare_wrapper(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, 228 | callback=None, disable_pbar=False, seed=None): 229 | cfg_guider = wrapper_executor.class_obj 230 | 231 | # Use cfd_guider.model_options, which is copied from modelPatcher.model_options and will be restored after execution without any unexpected contamination 232 | temp_options = add_model_patch_option(cfg_guider, tea_cache_key_attrs) 233 | temp_options['total_steps'] = len(sigmas) - 1 234 | temp_options['cnt'] = 0 235 | try: 236 | out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=callback, 237 | disable_pbar=disable_pbar, seed=seed) 238 | finally: 239 | diffusion_model = cfg_guider.model_patcher.model.diffusion_model 240 | if hasattr(diffusion_model, "flux_tea_cache"): 241 | print_tea_cache_executed_state(getattr(diffusion_model, "flux_tea_cache")) 242 | del diffusion_model.flux_tea_cache 243 | 244 | return out 245 | 246 | def print_tea_cache_executed_state(attrs): 247 | executed_state = attrs.get('accumulated_state', {}) 248 | for state_type, state in executed_state.items(): 249 | logging.info(f"skipped {state_type} steps: {state['skipped_steps']}") 250 | 251 | class ApplyTeaCachePatchAdvanced: 252 | 253 | @classmethod 254 | def INPUT_TYPES(cls): 255 | return { 256 | "required": { 257 | "model": ("MODEL",), 258 | "rel_l1_thresh": ("FLOAT", 259 | { 260 | "default": 0.25, 261 | "min": 0.0, 262 | "max": 5.0, 263 | "step": 0.001, 264 | "tooltip": "Flux: 0 (original), 0.25 (1.5x speedup), 0.4 (1.8x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).\n" 265 | "HunYuanVideo: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).\n" 266 | "LTXVideo: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).\n" 267 | "MochiVideo: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup).\n" 268 | "WanVideo: 0 (original), reference values\n" 269 | " Wan2.1 t2v 1.3B 0.05 0.07 0.08\n" 270 | " Wan2.1 t2v 14B 0.14 0.15 0.2\n" 271 | " Wan2.1 i2v 480P 0.13 0.19 0.26\n" 272 | " Wan2.1 i2v 720P 0.18 0.2 0.3" 273 | }), 274 | "start_at": ("FLOAT", 275 | { 276 | "default": 0.0, 277 | "step": 0.01, 278 | "max": 1.0, 279 | "min": 0.0, 280 | }, 281 | ), 282 | "end_at": ("FLOAT", { 283 | "default": 1.0, 284 | "step": 0.01, 285 | "max": 1.0, 286 | "min": 0.0, 287 | }), 288 | }, 289 | "optional": { 290 | "cache_device": (["main_device", "offload_device"], {"default": "offload_device"}), 291 | "wan_coefficients": (["disabled", "t2v_1.3B", "t2v_14B", "i2v_14B_480P", "i2v_14B_720P"], { 292 | "default": "disabled", 293 | "tooltip": "WanVideo coefficients." 294 | }), 295 | } 296 | } 297 | 298 | RETURN_TYPES = ("MODEL",) 299 | RETURN_NAMES = ("model",) 300 | FUNCTION = "apply_patch_advanced" 301 | CATEGORY = "patches/speed" 302 | DESCRIPTION = ("Apply the TeaCache patch to accelerate the model. Use it together with nodes that have the suffix ForwardOverrider." 303 | "\nThis is effective only for Flux, HunYuanVideo, LTXVideo, WanVideo and MochiVideo.") 304 | 305 | def apply_patch_advanced(self, model, rel_l1_thresh, start_at=0.0, end_at=1.0, cache_device="offload_device", wan_coefficients="disabled", from_simple=False): 306 | 307 | model = model.clone() 308 | patch_key = "tea_cache_wrapper" 309 | if rel_l1_thresh == 0 or len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) > 0: 310 | return (model,) 311 | 312 | diffusion_model = model.get_model_object('diffusion_model') 313 | if not is_flux_model(diffusion_model) and not is_hunyuan_video_model(diffusion_model) and not is_ltxv_video_model(diffusion_model)\ 314 | and not is_mochi_video_model(diffusion_model) and not is_wan_video_model(diffusion_model): 315 | logging.warning("TeaCache patch is not applied because the model is not supported.") 316 | return (model,) 317 | 318 | tea_cache_attrs = add_model_patch_option(model, tea_cache_key_attrs) 319 | 320 | tea_cache_attrs['rel_l1_thresh'] = rel_l1_thresh 321 | model_sampling = model.get_model_object("model_sampling") 322 | # For WanVideo, when wan_coefficients is disabled, the results of the first few steps are unstable? 323 | sigma_start = model_sampling.percent_to_sigma(max(start_at, 0.2) if from_simple and wan_coefficients == 'disabled' and is_wan_video_model(diffusion_model) else start_at) 324 | sigma_end = model_sampling.percent_to_sigma(end_at) 325 | tea_cache_attrs['timestep_start'] = model_sampling.timestep(sigma_start) 326 | tea_cache_attrs['timestep_end'] = model_sampling.timestep(sigma_end) 327 | tea_cache_attrs['cache_device'] = comfy.model_management.get_torch_device() if cache_device == "main_device" else comfy.model_management.unet_offload_device() 328 | 329 | if is_ltxv_video_model(diffusion_model): 330 | set_model_patch(model, PatchKeys.options_key, tea_cache_enter_for_ltxvideo, PatchKeys.dit_enter) 331 | elif is_mochi_video_model(diffusion_model): 332 | set_model_patch(model, PatchKeys.options_key, tea_cache_enter_for_mochivideo, PatchKeys.dit_enter) 333 | elif is_wan_video_model(diffusion_model): 334 | # i2v or t2v 335 | model_type = diffusion_model.model_type 336 | tea_cache_attrs['wan_coefficients_type'] = wan_coefficients 337 | if wan_coefficients != "disabled" and not wan_coefficients.startswith(model_type): 338 | logging.warning(f"The wan video's model type is {model_type}, but the selected wan_coefficients is {wan_coefficients}.") 339 | set_model_patch(model, PatchKeys.options_key, tea_cache_enter_for_wanvideo, PatchKeys.dit_enter) 340 | else: 341 | set_model_patch(model, PatchKeys.options_key, tea_cache_enter, PatchKeys.dit_enter) 342 | 343 | set_model_patch(model, PatchKeys.options_key, tea_cache_patch_blocks_before, PatchKeys.dit_blocks_before) 344 | 345 | set_model_patch_replace(model, PatchKeys.options_key, tea_cache_patch_double_blocks_replace, PatchKeys.dit_double_blocks_replace) 346 | set_model_patch_replace(model, PatchKeys.options_key, tea_cache_patch_blocks_transition_replace, PatchKeys.dit_blocks_transition_replace) 347 | set_model_patch_replace(model, PatchKeys.options_key, tea_cache_patch_single_blocks_replace, PatchKeys.dit_single_blocks_replace) 348 | set_model_patch_replace(model, PatchKeys.options_key, tea_cache_patch_blocks_after_replace, PatchKeys.dit_blocks_after_transition_replace) 349 | 350 | set_model_patch(model, PatchKeys.options_key, tea_cache_patch_final_transition_after, PatchKeys.dit_final_layer_before) 351 | set_model_patch(model, PatchKeys.options_key, tea_cache_patch_dit_exit, PatchKeys.dit_exit) 352 | 353 | # Just add it once when connecting in series 354 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 355 | patch_key, 356 | tea_cache_prepare_wrapper 357 | ) 358 | return (model, ) 359 | 360 | class ApplyTeaCachePatch(ApplyTeaCachePatchAdvanced): 361 | 362 | @classmethod 363 | def INPUT_TYPES(cls): 364 | return { 365 | "required": { 366 | "model": ("MODEL",), 367 | "rel_l1_thresh": ("FLOAT", 368 | { 369 | "default": 0.25, 370 | "min": 0.0, 371 | "max": 5.0, 372 | "step": 0.001, 373 | "tooltip": "Flux: 0 (original), 0.25 (1.5x speedup), 0.4 (1.8x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).\n" 374 | "HunYuanVideo: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).\n" 375 | "LTXVideo: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).\n" 376 | "MochiVideo: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup).\n" 377 | "WanVideo: 0 (original), reference values\n" 378 | " Wan2.1 t2v 1.3B 0.05 0.07 0.08\n" 379 | " Wan2.1 t2v 14B 0.14 0.15 0.2\n" 380 | " Wan2.1 i2v 480P 0.13 0.19 0.26\n" 381 | " Wan2.1 i2v 720P 0.18 0.2 0.3" 382 | }), 383 | }, 384 | "optional": { 385 | "cache_device": (["main_device", "offload_device"], {"default": "offload_device"}), 386 | "wan_coefficients": (["disabled", "t2v_1.3B", "t2v_14B", "i2v_14B_480P", "i2v_14B_720P"], { 387 | "default": "disabled", 388 | "tooltip": "WanVideo coefficients." 389 | }), 390 | } 391 | } 392 | 393 | RETURN_TYPES = ("MODEL",) 394 | RETURN_NAMES = ("model",) 395 | FUNCTION = "apply_patch" 396 | CATEGORY = "patches/speed" 397 | DESCRIPTION = ("Apply the TeaCache patch to accelerate the model. Use it together with nodes that have the suffix ForwardOverrider." 398 | "\nThis is effective only for Flux, HunYuanVideo, LTXVideo, WanVideo and MochiVideo.") 399 | 400 | def apply_patch(self, model, rel_l1_thresh, cache_device="offload_device", wan_coefficients="disabled"): 401 | 402 | return super().apply_patch_advanced(model, rel_l1_thresh, start_at=0.0, end_at=1.0, cache_device=cache_device, wan_coefficients=wan_coefficients, from_simple=False) 403 | 404 | NODE_CLASS_MAPPINGS = { 405 | "ApplyTeaCachePatch": ApplyTeaCachePatch, 406 | "ApplyTeaCachePatchAdvanced": ApplyTeaCachePatchAdvanced, 407 | } 408 | 409 | NODE_DISPLAY_NAME_MAPPINGS = { 410 | "ApplyTeaCachePatch": "ApplyTeaCachePatch", 411 | "ApplyTeaCachePatchAdvanced": "ApplyTeaCachePatchAdvanced", 412 | } 413 | -------------------------------------------------------------------------------- /nodes/VideoPatchNode.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | from .patch_util import set_hook, clean_hook, is_hunyuan_video_model, is_ltxv_video_model, is_mochi_video_model, is_wan_video_model 3 | from .node_utils import get_new_forward_orig, get_old_method_name 4 | from .patch_lib.WanVideoPatch import wan_forward 5 | 6 | 7 | def video_outer_sample_function_wrapper(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, 8 | callback=None, disable_pbar=False, seed=None): 9 | cfg_guider = wrapper_executor.class_obj 10 | diffusion_model = cfg_guider.model_patcher.model.diffusion_model 11 | # set hook 12 | set_hook(diffusion_model, 'video_old_forward_orig', get_new_forward_orig(diffusion_model), get_old_method_name(diffusion_model)) 13 | if is_wan_video_model(diffusion_model): 14 | # 原forward方法调用forward_origin时没有传transform_options,所以需要打补丁加上 15 | set_hook(diffusion_model, 'video_old_forward', wan_forward, "forward") 16 | 17 | try: 18 | out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=callback, 19 | disable_pbar=disable_pbar, seed=seed) 20 | finally: 21 | # cleanup hook 22 | clean_hook(diffusion_model, 'video_old_forward_orig', get_old_method_name(diffusion_model)) 23 | if is_wan_video_model(diffusion_model): 24 | clean_hook(diffusion_model, 'video_old_forward', "forward") 25 | return out 26 | 27 | 28 | class VideoForwardOverrider: 29 | 30 | @classmethod 31 | def INPUT_TYPES(cls): 32 | return { 33 | "required": { 34 | "model": ("MODEL",), 35 | } 36 | } 37 | 38 | RETURN_TYPES = ("MODEL",) 39 | RETURN_NAMES = ("model",) 40 | FUNCTION = "apply_patch" 41 | CATEGORY = "patches/dit" 42 | DESCRIPTION = "Support HunYuanVideo" 43 | 44 | def apply_patch(self, model): 45 | model = model.clone() 46 | diffusion_model = model.get_model_object('diffusion_model') 47 | if is_hunyuan_video_model(diffusion_model) or is_ltxv_video_model(diffusion_model) or is_mochi_video_model(diffusion_model)\ 48 | or is_wan_video_model(diffusion_model): 49 | patch_key = "video_forward_override_wrapper" 50 | if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, patch_key)) == 0: 51 | # Just add it once when connecting in series 52 | model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, 53 | patch_key, 54 | video_outer_sample_function_wrapper 55 | ) 56 | return (model,) 57 | 58 | 59 | NODE_CLASS_MAPPINGS = { 60 | "VideoForwardOverrider": VideoForwardOverrider, 61 | } 62 | 63 | NODE_DISPLAY_NAME_MAPPINGS = { 64 | "VideoForwardOverrider": "VideoForwardOverrider", 65 | } 66 | -------------------------------------------------------------------------------- /nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/nodes/__init__.py -------------------------------------------------------------------------------- /nodes/node_utils.py: -------------------------------------------------------------------------------- 1 | from packaging import version as version 2 | import comfyui_version 3 | from .patch_lib.FluxPatch import flux_forward_orig 4 | comfyui_ver = version.parse(comfyui_version.__version__) 5 | 6 | if comfyui_ver >= version.parse('0.3.25'): 7 | from .patch_lib.HunYuanVideoPatch import hunyuan_forward_orig 8 | else: 9 | from .patch_lib.old.HunYuanVideoPatch import hunyuan_forward_orig 10 | 11 | if comfyui_ver > version.parse('0.3.19'): 12 | # support LTXV 0.9.5 13 | from .patch_lib.LTXVideoPatch import ltx_forward_orig 14 | else: 15 | from .patch_lib.old.LTXVideoPatch import ltx_forward_orig 16 | from .patch_lib.MochiVideoPatch import mochi_forward 17 | from .patch_lib.WanVideoPatch import wan_forward_orig 18 | from .patch_util import is_hunyuan_video_model, is_ltxv_video_model, is_flux_model, is_mochi_video_model, \ 19 | is_wan_video_model 20 | 21 | 22 | def get_new_forward_orig(diffusion_model): 23 | if is_hunyuan_video_model(diffusion_model): 24 | return hunyuan_forward_orig 25 | if is_ltxv_video_model(diffusion_model): 26 | return ltx_forward_orig 27 | if is_flux_model(diffusion_model): 28 | return flux_forward_orig 29 | if is_mochi_video_model(diffusion_model): 30 | return mochi_forward 31 | if is_wan_video_model(diffusion_model): 32 | return wan_forward_orig 33 | return None 34 | 35 | def get_old_method_name(diffusion_model): 36 | if is_flux_model(diffusion_model) or is_hunyuan_video_model(diffusion_model) or is_wan_video_model(diffusion_model): 37 | return 'forward_orig' 38 | if is_ltxv_video_model(diffusion_model) or is_mochi_video_model(diffusion_model): 39 | return 'forward' 40 | return None -------------------------------------------------------------------------------- /nodes/patch_lib/FluxPatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ..patch_util import PatchKeys, set_hook, clean_hook 5 | from comfy.ldm.flux.layers import timestep_embedding 6 | 7 | def flux_forward_orig( 8 | self, 9 | img: Tensor, 10 | img_ids: Tensor, 11 | txt: Tensor, 12 | txt_ids: Tensor, 13 | timesteps: Tensor, 14 | y: Tensor, 15 | guidance: Tensor = None, 16 | control = None, 17 | transformer_options={}, 18 | attn_mask: Tensor = None, 19 | ) -> Tensor: 20 | patches_replace = transformer_options.get("patches_replace", {}) 21 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 22 | 23 | if img.ndim != 3 or txt.ndim != 3: 24 | raise ValueError("Input img and txt tensors must have 3 dimensions.") 25 | 26 | transformer_options[PatchKeys.running_net_model] = self 27 | 28 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 29 | if patches_enter is not None and len(patches_enter) > 0: 30 | for patch_enter in patches_enter: 31 | img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask = patch_enter(img, 32 | img_ids, 33 | txt, 34 | txt_ids, 35 | timesteps, 36 | y, 37 | guidance, 38 | control, 39 | attn_mask, 40 | transformer_options 41 | ) 42 | 43 | # running on sequences img 44 | img = self.img_in(img) 45 | vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) 46 | if self.params.guidance_embed: 47 | if guidance is not None: 48 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) 49 | 50 | vec = vec + self.vector_in(y) 51 | txt = self.txt_in(txt) 52 | 53 | ids = torch.cat((txt_ids, img_ids), dim=1) 54 | pe = self.pe_embedder(ids) 55 | 56 | blocks_replace = patches_replace.get("dit", {}) 57 | 58 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 59 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 60 | for blocks_before in patch_blocks_before: 61 | img, txt, vec, ids, pe = blocks_before(img, txt, vec, ids, pe, transformer_options) 62 | 63 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 64 | running_net_model = transformer_options[PatchKeys.running_net_model] 65 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 66 | for i, block in enumerate(running_net_model.double_blocks): 67 | # 0 -> 18 68 | if patch_double_blocks_with_control_replace is not None: 69 | img, txt = patch_double_blocks_with_control_replace({'i': i, 70 | 'block': block, 71 | 'img': img, 72 | 'txt': txt, 73 | 'vec': vec, 74 | 'pe': pe, 75 | 'control': control, 76 | 'attn_mask': attn_mask 77 | }, 78 | { 79 | "original_func": double_block_and_control_replace, 80 | "transformer_options": transformer_options 81 | }) 82 | else: 83 | img, txt = double_block_and_control_replace(i=i, 84 | block=block, 85 | img=img, 86 | txt=txt, 87 | vec=vec, 88 | pe=pe, 89 | control=control, 90 | attn_mask=attn_mask, 91 | transformer_options=transformer_options 92 | ) 93 | 94 | del patch_double_blocks_with_control_replace 95 | return img, txt 96 | 97 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 98 | 99 | if patch_double_blocks_replace is not None: 100 | img, txt = patch_double_blocks_replace({"img": img, 101 | "txt": txt, 102 | "vec": vec, 103 | "pe": pe, 104 | "control": control, 105 | "attn_mask": attn_mask, 106 | }, 107 | { 108 | "original_blocks": double_blocks_wrap, 109 | "transformer_options": transformer_options 110 | }) 111 | else: 112 | img, txt = double_blocks_wrap(img=img, 113 | txt=txt, 114 | vec=vec, 115 | pe=pe, 116 | control=control, 117 | attn_mask=attn_mask, 118 | transformer_options=transformer_options 119 | ) 120 | 121 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 122 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 123 | for patch_double_blocks_after in patches_double_blocks_after: 124 | img, txt = patch_double_blocks_after(img, txt, transformer_options) 125 | 126 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 127 | 128 | def blocks_transition_wrap(**kwargs): 129 | txt = kwargs["txt"] 130 | img = kwargs["img"] 131 | return torch.cat((txt, img), 1) 132 | 133 | if patch_blocks_transition is not None: 134 | img = patch_blocks_transition({"img": img, "txt": txt, "vec": vec, "pe": pe}, 135 | { 136 | "original_func": blocks_transition_wrap, 137 | "transformer_options": transformer_options 138 | }) 139 | else: 140 | img = blocks_transition_wrap(img=img, txt=txt) 141 | 142 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 143 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 144 | for patch_single_blocks_before in patches_single_blocks_before: 145 | img, txt = patch_single_blocks_before(img, txt, transformer_options) 146 | 147 | def single_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 148 | running_net_model = transformer_options[PatchKeys.running_net_model] 149 | for i, block in enumerate(running_net_model.single_blocks): 150 | # 0 -> 37 151 | if ("single_block", i) in blocks_replace: 152 | def block_wrap(args): 153 | out = {} 154 | out["img"] = block(args["img"], 155 | vec=args["vec"], 156 | pe=args["pe"], 157 | attn_mask=args.get("attn_mask")) 158 | return out 159 | 160 | out = blocks_replace[("single_block", i)]({"img": img, 161 | "vec": vec, 162 | "pe": pe, 163 | "attn_mask": attn_mask}, 164 | { 165 | "original_block": block_wrap, 166 | "transformer_options": transformer_options 167 | }) 168 | img = out["img"] 169 | else: 170 | img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) 171 | 172 | if control is not None: # Controlnet 173 | control_o = control.get("output") 174 | if i < len(control_o): 175 | add = control_o[i] 176 | if add is not None: 177 | img[:, txt.shape[1]:, ...] += add 178 | 179 | return img 180 | 181 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 182 | 183 | if patch_single_blocks_replace is not None: 184 | img, txt = patch_single_blocks_replace({"img": img, 185 | "txt": txt, 186 | "vec": vec, 187 | "pe": pe, 188 | "control": control, 189 | "attn_mask": attn_mask 190 | }, 191 | { 192 | "original_blocks": single_blocks_wrap, 193 | "transformer_options": transformer_options 194 | }) 195 | else: 196 | img = single_blocks_wrap(img=img, 197 | txt=txt, 198 | vec=vec, 199 | pe=pe, 200 | control=control, 201 | attn_mask=attn_mask, 202 | transformer_options=transformer_options 203 | ) 204 | 205 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 206 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 207 | for blocks_after in patch_blocks_exit: 208 | img, txt = blocks_after(img, txt, transformer_options) 209 | 210 | def final_transition_wrap(**kwargs): 211 | img = kwargs["img"] 212 | txt = kwargs["txt"] 213 | return img[:, txt.shape[1]:, ...] 214 | 215 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 216 | if patch_blocks_after_transition_replace is not None: 217 | img = patch_blocks_after_transition_replace({"img": img, "txt": txt, "vec": vec, "pe": pe}, 218 | { 219 | "original_func": final_transition_wrap, 220 | "transformer_options": transformer_options 221 | }) 222 | else: 223 | img = final_transition_wrap(img=img, txt=txt) 224 | 225 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 226 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 227 | for patch_final_layer_before in patches_final_layer_before: 228 | img = patch_final_layer_before(img, txt, transformer_options) 229 | 230 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 231 | 232 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 233 | if patches_exit is not None and len(patches_exit) > 0: 234 | for patch_exit in patches_exit: 235 | img = patch_exit(img, transformer_options) 236 | 237 | del transformer_options[PatchKeys.running_net_model] 238 | 239 | return img 240 | 241 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}): 242 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 243 | if ("double_block", i) in blocks_replace: 244 | def block_wrap(args): 245 | out = {} 246 | out["img"], out["txt"] = block(img=args["img"], 247 | txt=args["txt"], 248 | vec=args["vec"], 249 | pe=args["pe"], 250 | attn_mask=args.get("attn_mask")) 251 | return out 252 | 253 | out = blocks_replace[("double_block", i)]({"img": img, 254 | "txt": txt, 255 | "vec": vec, 256 | "pe": pe, 257 | "attn_mask": attn_mask 258 | }, 259 | { 260 | "original_block": block_wrap, 261 | "transformer_options": transformer_options 262 | }) 263 | txt = out["txt"] 264 | img = out["img"] 265 | else: 266 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) 267 | 268 | if control is not None: # Controlnet 269 | control_i = control.get("input") 270 | if i < len(control_i): 271 | add = control_i[i] 272 | if add is not None: 273 | img += add 274 | 275 | del blocks_replace 276 | return img, txt 277 | -------------------------------------------------------------------------------- /nodes/patch_lib/HunYuanVideoPatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ..patch_util import PatchKeys 5 | from comfy.ldm.flux.layers import timestep_embedding 6 | 7 | 8 | def hunyuan_forward_orig( 9 | self, 10 | img: Tensor, 11 | img_ids: Tensor, 12 | txt: Tensor, 13 | txt_ids: Tensor, 14 | txt_mask: Tensor, 15 | timesteps: Tensor, 16 | y: Tensor, 17 | guidance: Tensor = None, 18 | guiding_frame_index=None, 19 | control=None, 20 | transformer_options={}, 21 | ) -> Tensor: 22 | patches_replace = transformer_options.get("patches_replace", {}) 23 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 24 | 25 | transformer_options[PatchKeys.running_net_model] = self 26 | 27 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 28 | if patches_enter is not None and len(patches_enter) > 0: 29 | for patch_enter in patches_enter: 30 | img, img_ids, txt, txt_ids, timesteps, y, guidance, control, txt_mask = patch_enter(img, 31 | img_ids, 32 | txt, 33 | txt_ids, 34 | timesteps, 35 | y, 36 | guidance, 37 | control, 38 | attn_mask=txt_mask, 39 | transformer_options=transformer_options 40 | ) 41 | 42 | initial_shape = list(img.shape) 43 | # running on sequences img 44 | img = self.img_in(img) 45 | vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) 46 | 47 | if guiding_frame_index is not None: 48 | token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) 49 | vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) 50 | vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) 51 | frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) 52 | modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] 53 | modulation_dims_txt = [(0, None, 1)] 54 | else: 55 | vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) 56 | modulation_dims = None 57 | modulation_dims_txt = None 58 | 59 | if self.params.guidance_embed: 60 | if guidance is not None: 61 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) 62 | 63 | if txt_mask is not None and not torch.is_floating_point(txt_mask): 64 | txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max 65 | 66 | txt = self.txt_in(txt, timesteps, txt_mask) 67 | 68 | ids = torch.cat((img_ids, txt_ids), dim=1) 69 | pe = self.pe_embedder(ids) 70 | 71 | img_len = img.shape[1] 72 | if txt_mask is not None: 73 | attn_mask_len = img_len + txt.shape[1] 74 | attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) 75 | attn_mask[:, 0, img_len:] = txt_mask 76 | else: 77 | attn_mask = None 78 | 79 | blocks_replace = patches_replace.get("dit", {}) 80 | 81 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 82 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 83 | for blocks_before in patch_blocks_before: 84 | img, txt, vec, ids, pe = blocks_before(img, txt, vec, ids, pe, transformer_options) 85 | 86 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}, 87 | modulation_dims_img=None, modulation_dims_txt=None): 88 | running_net_model = transformer_options[PatchKeys.running_net_model] 89 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 90 | for i, block in enumerate(running_net_model.double_blocks): 91 | if patch_double_blocks_with_control_replace is not None: 92 | img, txt = patch_double_blocks_with_control_replace({'i': i, 93 | 'block': block, 94 | 'img': img, 95 | 'txt': txt, 96 | 'vec': vec, 97 | 'pe': pe, 98 | 'control': control, 99 | 'attn_mask': attn_mask, 100 | 'modulation_dims_img': modulation_dims_img, 101 | 'modulation_dims_txt': modulation_dims_txt 102 | }, 103 | { 104 | "original_func": double_block_and_control_replace, 105 | "transformer_options": transformer_options 106 | }) 107 | else: 108 | img, txt = double_block_and_control_replace(i=i, 109 | block=block, 110 | img=img, 111 | txt=txt, 112 | vec=vec, 113 | pe=pe, 114 | control=control, 115 | attn_mask=attn_mask, 116 | modulation_dims_img=modulation_dims_img, 117 | modulation_dims_txt=modulation_dims_txt, 118 | transformer_options=transformer_options 119 | ) 120 | 121 | del patch_double_blocks_with_control_replace 122 | return img, txt 123 | 124 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 125 | 126 | if patch_double_blocks_replace is not None: 127 | img, txt = patch_double_blocks_replace({"img": img, 128 | "txt": txt, 129 | "vec": vec, 130 | "pe": pe, 131 | "control": control, 132 | "attn_mask": attn_mask, 133 | "modulation_dims_img": modulation_dims, 134 | "modulation_dims_txt": modulation_dims_txt, 135 | }, 136 | { 137 | "original_blocks": double_blocks_wrap, 138 | "transformer_options": transformer_options 139 | }) 140 | else: 141 | img, txt = double_blocks_wrap(img=img, 142 | txt=txt, 143 | vec=vec, 144 | pe=pe, 145 | control=control, 146 | attn_mask=attn_mask, 147 | modulation_dims_img=modulation_dims, 148 | modulation_dims_txt=modulation_dims_txt, 149 | transformer_options=transformer_options 150 | ) 151 | 152 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 153 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 154 | for patch_double_blocks_after in patches_double_blocks_after: 155 | img, txt = patch_double_blocks_after(img, txt, transformer_options) 156 | 157 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 158 | 159 | def blocks_transition_wrap(**kwargs): 160 | txt = kwargs["txt"] 161 | img = kwargs["img"] 162 | return torch.cat((img, txt), 1) 163 | 164 | if patch_blocks_transition is not None: 165 | img = patch_blocks_transition({"img": img, "txt": txt, "vec": vec, "pe": pe}, 166 | { 167 | "original_func": blocks_transition_wrap, 168 | "transformer_options": transformer_options 169 | }) 170 | else: 171 | img = blocks_transition_wrap(img=img, txt=txt) 172 | 173 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 174 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 175 | for patch_single_blocks_before in patches_single_blocks_before: 176 | img, txt = patch_single_blocks_before(img, txt, transformer_options) 177 | 178 | def single_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}, modulation_dims=None): 179 | running_net_model = transformer_options[PatchKeys.running_net_model] 180 | for i, block in enumerate(running_net_model.single_blocks): 181 | if ("single_block", i) in blocks_replace: 182 | def block_wrap(args): 183 | out = {} 184 | out["img"] = block(args["img"], 185 | vec=args["vec"], 186 | pe=args["pe"], 187 | attn_mask=args.get("attention_mask"), 188 | modulation_dims=args.get("modulation_dims")) 189 | return out 190 | 191 | out = blocks_replace[("single_block", i)]({"img": img, 192 | "vec": vec, 193 | "pe": pe, 194 | "attention_mask": attn_mask, 195 | 'modulation_dims': modulation_dims}, 196 | { 197 | "original_block": block_wrap, 198 | "transformer_options": transformer_options 199 | }) 200 | img = out["img"] 201 | else: 202 | img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) 203 | 204 | if control is not None: # Controlnet 205 | control_o = control.get("output") 206 | if i < len(control_o): 207 | add = control_o[i] 208 | if add is not None: 209 | img[:, : img_len] += add 210 | 211 | return img 212 | 213 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 214 | 215 | if patch_single_blocks_replace is not None: 216 | img, txt = patch_single_blocks_replace({"img": img, 217 | "txt": txt, 218 | "vec": vec, 219 | "pe": pe, 220 | "control": control, 221 | "attn_mask": attn_mask, 222 | "modulation_dims": modulation_dims, 223 | }, 224 | { 225 | "original_blocks": single_blocks_wrap, 226 | "transformer_options": transformer_options 227 | }) 228 | else: 229 | img = single_blocks_wrap(img=img, 230 | txt=txt, 231 | vec=vec, 232 | pe=pe, 233 | control=control, 234 | attn_mask=attn_mask, 235 | modulation_dims=modulation_dims, 236 | transformer_options=transformer_options 237 | ) 238 | 239 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 240 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 241 | for blocks_after in patch_blocks_exit: 242 | img, txt = blocks_after(img, txt, transformer_options) 243 | 244 | def final_transition_wrap(**kwargs): 245 | img = kwargs["img"] 246 | img_len = kwargs["img_len"] 247 | return img[:, : img_len] 248 | 249 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 250 | if patch_blocks_after_transition_replace is not None: 251 | img = patch_blocks_after_transition_replace({"img": img, "txt": txt, "vec": vec, "pe": pe, "img_len": img_len}, 252 | { 253 | "original_func": final_transition_wrap, 254 | "transformer_options": transformer_options 255 | }) 256 | else: 257 | img = final_transition_wrap(img=img, img_len=img_len) 258 | 259 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 260 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 261 | for patch_final_layer_before in patches_final_layer_before: 262 | img = patch_final_layer_before(img, txt, transformer_options) 263 | 264 | img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) 265 | 266 | shape = initial_shape[-3:] 267 | for i in range(len(shape)): 268 | shape[i] = shape[i] // self.patch_size[i] 269 | img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) 270 | img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) 271 | img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) 272 | 273 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 274 | if patches_exit is not None and len(patches_exit) > 0: 275 | for patch_exit in patches_exit: 276 | img = patch_exit(img, transformer_options) 277 | 278 | del transformer_options[PatchKeys.running_net_model] 279 | 280 | return img 281 | 282 | 283 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}, modulation_dims_img=None, modulation_dims_txt=None): 284 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 285 | if ("double_block", i) in blocks_replace: 286 | def block_wrap(args): 287 | out = {} 288 | out["img"], out["txt"] = block(img=args["img"], 289 | txt=args["txt"], 290 | vec=args["vec"], 291 | pe=args["pe"], 292 | attn_mask=args.get("attention_mask"), 293 | modulation_dims_img=args["modulation_dims_img"], 294 | modulation_dims_txt=args["modulation_dims_txt"]) 295 | return out 296 | 297 | out = blocks_replace[("double_block", i)]({"img": img, 298 | "txt": txt, 299 | "vec": vec, 300 | "pe": pe, 301 | "attention_mask": attn_mask, 302 | 'modulation_dims_img': modulation_dims_img, 303 | 'modulation_dims_txt': modulation_dims_txt 304 | }, 305 | { 306 | "original_block": block_wrap, 307 | "transformer_options": transformer_options 308 | }) 309 | txt = out["txt"] 310 | img = out["img"] 311 | else: 312 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims_img, modulation_dims_txt=modulation_dims_txt) 313 | if control is not None: # Controlnet 314 | control_i = control.get("input") 315 | if i < len(control_i): 316 | add = control_i[i] 317 | if add is not None: 318 | img += add 319 | 320 | return img, txt 321 | -------------------------------------------------------------------------------- /nodes/patch_lib/LTXVideoPatch.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from comfy.ldm.lightricks.model import precompute_freqs_cis 7 | from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords 8 | from ..patch_util import PatchKeys 9 | 10 | 11 | def ltx_forward_orig( 12 | self, 13 | x, 14 | timestep, 15 | context, 16 | attention_mask, 17 | frame_rate=25, 18 | guiding_latent=None, 19 | transformer_options={}, 20 | keyframe_idxs=None, 21 | **kwargs 22 | ) -> Tensor: 23 | patches_replace = transformer_options.get("patches_replace", {}) 24 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 25 | 26 | transformer_options[PatchKeys.running_net_model] = self 27 | 28 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 29 | if patches_enter is not None and len(patches_enter) > 0: 30 | for patch_enter in patches_enter: 31 | x, timestep, context, attention_mask, frame_rate, guiding_latent, keyframe_idxs = patch_enter( 32 | x, 33 | timestep, 34 | context, 35 | attention_mask, 36 | frame_rate, 37 | guiding_latent, 38 | keyframe_idxs, 39 | transformer_options 40 | ) 41 | 42 | orig_shape = list(x.shape) 43 | 44 | x, latent_coords = self.patchifier.patchify(x) 45 | pixel_coords = latent_to_pixel_coords( 46 | latent_coords=latent_coords, 47 | scale_factors=self.vae_scale_factors, 48 | causal_fix=self.causal_temporal_positioning, 49 | ) 50 | 51 | if keyframe_idxs is not None: 52 | pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs 53 | 54 | fractional_coords = pixel_coords.to(torch.float32) 55 | fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) 56 | 57 | x = self.patchify_proj(x) 58 | timestep = timestep * 1000.0 59 | 60 | if attention_mask is not None and not torch.is_floating_point(attention_mask): 61 | attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max 62 | 63 | pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) 64 | 65 | batch_size = x.shape[0] 66 | timestep, embedded_timestep = self.adaln_single( 67 | timestep.flatten(), 68 | {"resolution": None, "aspect_ratio": None}, 69 | batch_size=batch_size, 70 | hidden_dtype=x.dtype, 71 | ) 72 | # Second dimension is 1 or number of tokens (if timestep_per_token) 73 | timestep = timestep.view(batch_size, -1, timestep.shape[-1]) 74 | embedded_timestep = embedded_timestep.view( 75 | batch_size, -1, embedded_timestep.shape[-1] 76 | ) 77 | 78 | # 2. Blocks 79 | if self.caption_projection is not None: 80 | batch_size = x.shape[0] 81 | context = self.caption_projection(context) 82 | context = context.view( 83 | batch_size, -1, x.shape[-1] 84 | ) 85 | 86 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 87 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 88 | for blocks_before in patch_blocks_before: 89 | x, context, timestep, ids, pe = blocks_before(img=x, txt=context, vec=timestep, ids=None, pe=pe, transformer_options=transformer_options) 90 | 91 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 92 | running_net_model = transformer_options[PatchKeys.running_net_model] 93 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 94 | for i, block in enumerate(running_net_model.transformer_blocks): 95 | if patch_double_blocks_with_control_replace is not None: 96 | img, txt = patch_double_blocks_with_control_replace({'i': i, 97 | 'block': block, 98 | 'img': img, 99 | 'txt': txt, 100 | 'vec': vec, 101 | 'pe': pe, 102 | 'control': control, 103 | 'attn_mask': attn_mask 104 | }, 105 | { 106 | "original_func": double_block_and_control_replace, 107 | "transformer_options": transformer_options 108 | }) 109 | else: 110 | img, txt = double_block_and_control_replace(i=i, 111 | block=block, 112 | img=img, 113 | txt=txt, 114 | vec=vec, 115 | pe=pe, 116 | control=control, 117 | attn_mask=attn_mask, 118 | transformer_options=transformer_options 119 | ) 120 | 121 | del patch_double_blocks_with_control_replace 122 | return img, txt 123 | 124 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 125 | 126 | if patch_double_blocks_replace is not None: 127 | x, context = patch_double_blocks_replace({"img": x, 128 | "txt": context, 129 | "vec": timestep, 130 | "pe": pe, 131 | "control": None, 132 | "attn_mask": attention_mask, 133 | }, 134 | { 135 | "original_blocks": double_blocks_wrap, 136 | "transformer_options": transformer_options 137 | }) 138 | else: 139 | x, context = double_blocks_wrap(img=x, 140 | txt=context, 141 | vec=timestep, 142 | pe=pe, 143 | control=None, 144 | attn_mask=attention_mask, 145 | transformer_options=transformer_options 146 | ) 147 | 148 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 149 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 150 | for patch_double_blocks_after in patches_double_blocks_after: 151 | x, context = patch_double_blocks_after(x, context, transformer_options) 152 | 153 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 154 | 155 | def blocks_transition_wrap(**kwargs): 156 | x = kwargs["img"] 157 | return x 158 | 159 | if patch_blocks_transition is not None: 160 | x = patch_blocks_transition({"img": x, "txt": context, "vec": timestep, "pe": pe}, 161 | { 162 | "original_func": blocks_transition_wrap, 163 | "transformer_options": transformer_options 164 | }) 165 | else: 166 | x = blocks_transition_wrap(img=x, txt=context) 167 | 168 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 169 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 170 | for patch_single_blocks_before in patches_single_blocks_before: 171 | x, context = patch_single_blocks_before(x, context, transformer_options) 172 | 173 | def single_blocks_wrap(img, **kwargs): 174 | return img 175 | 176 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 177 | 178 | if patch_single_blocks_replace is not None: 179 | x, context = patch_single_blocks_replace({"img": x, 180 | "txt": context, 181 | "vec": timestep, 182 | "pe": pe, 183 | "control": None, 184 | "attn_mask": attention_mask 185 | }, 186 | { 187 | "original_blocks": single_blocks_wrap, 188 | "transformer_options": transformer_options 189 | }) 190 | else: 191 | x = single_blocks_wrap(img=x, 192 | txt=context, 193 | vec=timestep, 194 | pe=pe, 195 | control=None, 196 | attn_mask=attention_mask, 197 | transformer_options=transformer_options 198 | ) 199 | 200 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 201 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 202 | for blocks_after in patch_blocks_exit: 203 | x, context = blocks_after(x, context, transformer_options) 204 | 205 | # 3. Output 206 | def final_transition_wrap(**kwargs): 207 | running_net_model = transformer_options[PatchKeys.running_net_model] 208 | x = kwargs["img"] 209 | embedded_timestep = kwargs["embedded_timestep"] 210 | scale_shift_values = ( 211 | running_net_model.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] 212 | ) 213 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 214 | x = running_net_model.norm_out(x) 215 | # Modulation 216 | x = x * (1 + scale) + shift 217 | return x 218 | 219 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 220 | if patch_blocks_after_transition_replace is not None: 221 | x = patch_blocks_after_transition_replace({"img": x, "txt": context, "vec": timestep, "pe": pe, "embedded_timestep": embedded_timestep}, 222 | { 223 | "original_func": final_transition_wrap, 224 | "transformer_options": transformer_options 225 | }) 226 | else: 227 | x = final_transition_wrap(img=x, embedded_timestep=embedded_timestep) 228 | 229 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 230 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 231 | for patch_final_layer_before in patches_final_layer_before: 232 | x = patch_final_layer_before(img=x, txt=context, transformer_options=transformer_options) 233 | 234 | x = self.proj_out(x) 235 | 236 | x = self.patchifier.unpatchify( 237 | latents=x, 238 | output_height=orig_shape[3], 239 | output_width=orig_shape[4], 240 | output_num_frames=orig_shape[2], 241 | out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), 242 | ) 243 | 244 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 245 | if patches_exit is not None and len(patches_exit) > 0: 246 | for patch_exit in patches_exit: 247 | x = patch_exit(x, transformer_options) 248 | 249 | del transformer_options[PatchKeys.running_net_model] 250 | 251 | return x 252 | 253 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}): 254 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 255 | if ("double_block", i) in blocks_replace: 256 | def block_wrap(args): 257 | out = {} 258 | out["img"] = block(x=args["img"], 259 | context=args["txt"], 260 | timestep=args["vec"], 261 | pe=args["pe"], 262 | attention_mask=args.get("attention_mask")) 263 | return out 264 | 265 | out = blocks_replace[("double_block", i)]({"img": img, 266 | "txt": txt, 267 | "vec": vec, 268 | "pe": pe, 269 | "attention_mask": attn_mask, 270 | }, 271 | { 272 | "original_block": block_wrap, 273 | "transformer_options": transformer_options 274 | }) 275 | img = out["img"] 276 | else: 277 | img = block(x=img, context=txt, timestep=vec, pe=pe, attention_mask=attn_mask) 278 | 279 | return img, txt 280 | -------------------------------------------------------------------------------- /nodes/patch_lib/MochiVideoPatch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | from einops import rearrange 5 | import torch.nn.functional as F 6 | 7 | from ..patch_util import PatchKeys 8 | 9 | # copied from comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint.forward 10 | def mochi_forward( 11 | self, 12 | x: torch.Tensor, 13 | timestep: torch.Tensor, 14 | context: List[torch.Tensor], 15 | attention_mask: List[torch.Tensor], 16 | num_tokens=256, 17 | packed_indices: Dict[str, torch.Tensor] = None, 18 | rope_cos: torch.Tensor = None, 19 | rope_sin: torch.Tensor = None, 20 | control=None, transformer_options={}, **kwargs 21 | ): 22 | patches_replace = transformer_options.get("patches_replace", {}) 23 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 24 | transformer_options[PatchKeys.running_net_model] = self 25 | 26 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 27 | if patches_enter is not None and len(patches_enter) > 0: 28 | for patch_enter in patches_enter: 29 | x, timestep, context, attention_mask, num_tokens = patch_enter( 30 | x, 31 | timestep, 32 | context, 33 | attention_mask, 34 | num_tokens, 35 | transformer_options 36 | ) 37 | 38 | y_feat = context 39 | y_mask = attention_mask 40 | sigma = timestep 41 | """Forward pass of DiT. 42 | 43 | Args: 44 | x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images) 45 | sigma: (B,) tensor of noise standard deviations 46 | y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048) 47 | y_mask: List((B, L) boolean tensor indicating which tokens are not padding) 48 | packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices. 49 | """ 50 | B, _, T, H, W = x.shape 51 | 52 | x, c, y_feat, rope_cos, rope_sin = self.prepare( 53 | x, sigma, y_feat, y_mask 54 | ) 55 | del y_mask 56 | 57 | pe = [rope_cos, rope_sin] 58 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 59 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 60 | for blocks_before in patch_blocks_before: 61 | x, y_feat, c, ids, pe = blocks_before(img=x, txt=y_feat, vec=c, ids=None, pe=pe, 62 | transformer_options=transformer_options) 63 | 64 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 65 | running_net_model = transformer_options[PatchKeys.running_net_model] 66 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 67 | for i, block in enumerate(running_net_model.blocks): 68 | if patch_double_blocks_with_control_replace is not None: 69 | img, txt = patch_double_blocks_with_control_replace({'i': i, 70 | 'block': block, 71 | 'img': img, 72 | 'txt': txt, 73 | 'vec': vec, 74 | 'pe': pe, 75 | 'control': control, 76 | 'attn_mask': attn_mask 77 | }, 78 | { 79 | "original_func": double_block_and_control_replace, 80 | "transformer_options": transformer_options 81 | }) 82 | else: 83 | img, txt = double_block_and_control_replace(i=i, 84 | block=block, 85 | img=img, 86 | txt=txt, 87 | vec=vec, 88 | pe=pe, 89 | control=control, 90 | attn_mask=attn_mask, 91 | transformer_options=transformer_options 92 | ) 93 | del patch_double_blocks_with_control_replace 94 | return img, txt 95 | 96 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 97 | 98 | transformer_options["db_blocks_num_tokens"] = num_tokens 99 | if patch_double_blocks_replace is not None: 100 | x, y_feat = patch_double_blocks_replace({"img": x, 101 | "txt": y_feat, 102 | "vec": c, 103 | "pe": pe, 104 | "control": None, 105 | "attn_mask": None, 106 | }, 107 | { 108 | "original_blocks": double_blocks_wrap, 109 | "transformer_options": transformer_options 110 | }) 111 | else: 112 | x, y_feat = double_blocks_wrap(img=x, 113 | txt=y_feat, 114 | vec=c, 115 | pe=pe, 116 | control=None, 117 | attn_mask=None, 118 | transformer_options=transformer_options 119 | ) 120 | del transformer_options["db_blocks_num_tokens"] 121 | 122 | # del y_feat # Final layers don't use dense text features. 123 | 124 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 125 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 126 | for patch_double_blocks_after in patches_double_blocks_after: 127 | x, y_feat = patch_double_blocks_after(x, y_feat, transformer_options) 128 | 129 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 130 | 131 | def blocks_transition_wrap(**kwargs): 132 | img = kwargs["img"] 133 | return img 134 | 135 | if patch_blocks_transition is not None: 136 | x = patch_blocks_transition({"img": x, "txt": y_feat, "vec": c, "pe": pe}, 137 | { 138 | "original_func": blocks_transition_wrap, 139 | "transformer_options": transformer_options 140 | }) 141 | else: 142 | x = blocks_transition_wrap(img=x, txt=y_feat) 143 | 144 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 145 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 146 | for patch_single_blocks_before in patches_single_blocks_before: 147 | x, y_feat = patch_single_blocks_before(x, y_feat, transformer_options) 148 | 149 | def single_blocks_wrap(img, **kwargs): 150 | return img 151 | 152 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 153 | 154 | if patch_single_blocks_replace is not None: 155 | x, y_feat = patch_single_blocks_replace({"img": x, 156 | "txt": y_feat, 157 | "vec": c, 158 | "pe": pe, 159 | "control": None, 160 | "attn_mask": None 161 | }, 162 | { 163 | "original_blocks": single_blocks_wrap, 164 | "transformer_options": transformer_options 165 | }) 166 | else: 167 | x = single_blocks_wrap(img=x, 168 | txt=y_feat, 169 | vec=c, 170 | pe=pe, 171 | control=None, 172 | attn_mask=None, 173 | transformer_options=transformer_options 174 | ) 175 | 176 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 177 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 178 | for blocks_after in patch_blocks_exit: 179 | x, y_feat = blocks_after(x, y_feat, transformer_options) 180 | 181 | def final_transition_wrap(**kwargs): 182 | # pipe => x = normal_out(x) 183 | img = kwargs["img"] 184 | _c = kwargs["vec"] 185 | temp_model = transformer_options[PatchKeys.running_net_model] 186 | shift, scale = temp_model.final_layer.mod(F.silu(_c)).chunk(2, dim=1) 187 | # comfyui => x = modulate(self.norm_final(x), shift, scale) 188 | img = temp_model.final_layer.norm_final(img) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 189 | del temp_model 190 | return img 191 | 192 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 193 | if patch_blocks_after_transition_replace is not None: 194 | x = patch_blocks_after_transition_replace({"img": x, "txt": y_feat, "vec": c, "pe": pe}, 195 | { 196 | "original_func": final_transition_wrap, 197 | "transformer_options": transformer_options 198 | }) 199 | else: 200 | x = final_transition_wrap(img=x, txt=y_feat, vec=c) 201 | 202 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 203 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 204 | for patch_final_layer_before in patches_final_layer_before: 205 | x = patch_final_layer_before(x, y_feat, transformer_options) 206 | 207 | del y_feat 208 | 209 | # pipe => x = proj_out(x) 210 | x = self.final_layer.linear(x) # (B, M, patch_size ** 2 * out_channels) 211 | # x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels) 212 | x = rearrange( 213 | x, 214 | "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)", 215 | T=T, 216 | hp=H // self.patch_size, 217 | wp=W // self.patch_size, 218 | p1=self.patch_size, 219 | p2=self.patch_size, 220 | c=self.out_channels, 221 | ) 222 | 223 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 224 | if patches_exit is not None and len(patches_exit) > 0: 225 | for patch_exit in patches_exit: 226 | x = patch_exit(x, transformer_options) 227 | 228 | del transformer_options[PatchKeys.running_net_model] 229 | 230 | return -x 231 | 232 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, 233 | transformer_options={}): 234 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 235 | _num_tokens = transformer_options["db_blocks_num_tokens"] 236 | if ("double_block", i) in blocks_replace: 237 | def block_wrap(args): 238 | out = {} 239 | out["img"], out["txt"] = block(args["img"], 240 | args["vec"], 241 | args["txt"], 242 | rope_cos=args["rope_cos"], 243 | rope_sin=args["rope_sin"], 244 | crop_y=args["num_tokens"]) 245 | return out 246 | 247 | out = blocks_replace[("double_block", i)]({"img": img, 248 | "txt": txt, 249 | "vec": vec, 250 | "pe": pe, 251 | "rope_cos": pe[0], 252 | "rope_sin": pe[1], 253 | "crop_y": _num_tokens 254 | }, 255 | { 256 | "original_block": block_wrap, 257 | "transformer_options": transformer_options 258 | }) 259 | txt = out["txt"] 260 | img = out["img"] 261 | else: 262 | img, txt = block(img, 263 | vec, 264 | txt, 265 | rope_cos=pe[0], 266 | rope_sin=pe[1], 267 | crop_y=_num_tokens) # (B, M, D), (B, L, D) 268 | 269 | return img, txt 270 | -------------------------------------------------------------------------------- /nodes/patch_lib/WanVideoPatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from comfy.ldm.wan.model import sinusoidal_embedding_1d 5 | from ..patch_util import PatchKeys 6 | from einops import repeat 7 | import comfy.ldm.common_dit 8 | 9 | 10 | def wan_forward(self, x, timestep, context, clip_fea=None, **kwargs): 11 | bs, c, t, h, w = x.shape 12 | x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) 13 | patch_size = self.patch_size 14 | t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) 15 | h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) 16 | w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) 17 | img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) 18 | img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, 19 | dtype=x.dtype).reshape(-1, 1, 1) 20 | img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, 21 | dtype=x.dtype).reshape(1, -1, 1) 22 | img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, 23 | dtype=x.dtype).reshape(1, 1, -1) 24 | img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) 25 | 26 | freqs = self.rope_embedder(img_ids).movedim(1, 2) 27 | return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, **kwargs)[:, :, :t, :h, :w] 28 | 29 | def wan_forward_orig( 30 | self, 31 | x, 32 | t, 33 | context, 34 | clip_fea=None, 35 | freqs=None, 36 | transformer_options={}, 37 | **kwargs 38 | ) -> Tensor: 39 | # patches_replace = transformer_options.get("patches_replace", {}) 40 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 41 | 42 | transformer_options[PatchKeys.running_net_model] = self 43 | 44 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 45 | if patches_enter is not None and len(patches_enter) > 0: 46 | for patch_enter in patches_enter: 47 | x, t, context = patch_enter( 48 | x, 49 | t, 50 | context, 51 | transformer_options 52 | ) 53 | 54 | # embeddings 55 | x = self.patch_embedding(x.float()).to(x.dtype) 56 | grid_sizes = x.shape[2:] 57 | x = x.flatten(2).transpose(1, 2) 58 | 59 | # time embeddings 60 | e = self.time_embedding( 61 | sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) 62 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 63 | 64 | # context 65 | context = self.text_embedding(context) 66 | 67 | if clip_fea is not None and self.img_emb is not None: 68 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 69 | context = torch.concat([context_clip, context], dim=1) 70 | 71 | attention_mask = None 72 | 73 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 74 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 75 | for blocks_before in patch_blocks_before: 76 | x, context, e0, ids, freqs = blocks_before(img=x, txt=context, vec=e0, ids=None, pe=freqs, transformer_options=transformer_options, e=e) 77 | 78 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 79 | running_net_model = transformer_options[PatchKeys.running_net_model] 80 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 81 | for i, block in enumerate(running_net_model.blocks): 82 | if patch_double_blocks_with_control_replace is not None: 83 | img, txt = patch_double_blocks_with_control_replace({'i': i, 84 | 'block': block, 85 | 'img': img, 86 | 'txt': txt, 87 | 'vec': vec, 88 | 'pe': pe, 89 | 'control': control, 90 | 'attn_mask': attn_mask 91 | }, 92 | { 93 | "original_func": double_block_and_control_replace, 94 | "transformer_options": transformer_options 95 | }) 96 | else: 97 | img, txt = double_block_and_control_replace(i=i, 98 | block=block, 99 | img=img, 100 | txt=txt, 101 | vec=vec, 102 | pe=pe, 103 | control=control, 104 | attn_mask=attn_mask, 105 | transformer_options=transformer_options 106 | ) 107 | 108 | del patch_double_blocks_with_control_replace 109 | return img, txt 110 | 111 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 112 | 113 | if patch_double_blocks_replace is not None: 114 | x, context = patch_double_blocks_replace({"img": x, 115 | "txt": context, 116 | "vec": e0, 117 | "pe": freqs, 118 | "control": None, 119 | "attn_mask": attention_mask, 120 | }, 121 | { 122 | "original_blocks": double_blocks_wrap, 123 | "transformer_options": transformer_options 124 | }) 125 | else: 126 | x, context = double_blocks_wrap(img=x, 127 | txt=context, 128 | vec=e0, 129 | pe=freqs, 130 | control=None, 131 | attn_mask=attention_mask, 132 | transformer_options=transformer_options 133 | ) 134 | 135 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 136 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 137 | for patch_double_blocks_after in patches_double_blocks_after: 138 | x, context = patch_double_blocks_after(x, context, transformer_options) 139 | 140 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 141 | 142 | def blocks_transition_wrap(**kwargs): 143 | x = kwargs["img"] 144 | return x 145 | 146 | if patch_blocks_transition is not None: 147 | x = patch_blocks_transition({"img": x, "txt": context, "vec": e0, "pe": freqs}, 148 | { 149 | "original_func": blocks_transition_wrap, 150 | "transformer_options": transformer_options 151 | }) 152 | else: 153 | x = blocks_transition_wrap(img=x, txt=context) 154 | 155 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 156 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 157 | for patch_single_blocks_before in patches_single_blocks_before: 158 | x, context = patch_single_blocks_before(x, context, transformer_options) 159 | 160 | def single_blocks_wrap(img, **kwargs): 161 | return img 162 | 163 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 164 | 165 | if patch_single_blocks_replace is not None: 166 | x, context = patch_single_blocks_replace({"img": x, 167 | "txt": context, 168 | "vec": e0, 169 | "pe": freqs, 170 | "control": None, 171 | "attn_mask": attention_mask 172 | }, 173 | { 174 | "original_blocks": single_blocks_wrap, 175 | "transformer_options": transformer_options 176 | }) 177 | else: 178 | x = single_blocks_wrap(img=x, 179 | txt=context, 180 | vec=e0, 181 | pe=freqs, 182 | control=None, 183 | attn_mask=attention_mask, 184 | transformer_options=transformer_options 185 | ) 186 | 187 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 188 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 189 | for blocks_after in patch_blocks_exit: 190 | x, context = blocks_after(x, context, transformer_options) 191 | 192 | def final_transition_wrap(**kwargs): 193 | img = kwargs["img"] 194 | # img = self.head.norm(img) 195 | return img 196 | 197 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 198 | if patch_blocks_after_transition_replace is not None: 199 | x = patch_blocks_after_transition_replace({"img": x, "txt": context, "vec": e0, "pe": freqs}, 200 | { 201 | "original_func": final_transition_wrap, 202 | "transformer_options": transformer_options 203 | }) 204 | else: 205 | x = final_transition_wrap(img=x) 206 | 207 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 208 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 209 | for patch_final_layer_before in patches_final_layer_before: 210 | x = patch_final_layer_before(img=x, txt=context, transformer_options=transformer_options) 211 | 212 | # e = (comfy.model_management.cast_to(self.head.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) 213 | # x = (self.head.head(x * (1 + e[1]) + e[0])) 214 | # head 215 | x = self.head(x, e) 216 | 217 | # unpatchify 218 | x = self.unpatchify(x, grid_sizes) 219 | 220 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 221 | if patches_exit is not None and len(patches_exit) > 0: 222 | for patch_exit in patches_exit: 223 | x = patch_exit(x, transformer_options) 224 | 225 | del transformer_options[PatchKeys.running_net_model] 226 | 227 | return x 228 | 229 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}): 230 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 231 | # arguments 232 | kwargs = dict( 233 | e=vec, 234 | freqs=pe, 235 | context=txt) 236 | if ("double_block", i) in blocks_replace: 237 | def block_wrap(args): 238 | out = {} 239 | out["img"] = block(x=args["img"], **kwargs) 240 | return out 241 | 242 | out = blocks_replace[("double_block", i)]({"img": img, 243 | "txt": txt, 244 | "vec": vec, 245 | "pe": pe, 246 | "attention_mask": attn_mask, 247 | }, 248 | { 249 | "original_block": block_wrap, 250 | "transformer_options": transformer_options 251 | }) 252 | img = out["img"] 253 | else: 254 | img = block(img, **kwargs) 255 | 256 | return img, txt 257 | -------------------------------------------------------------------------------- /nodes/patch_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_Patches_ll/4a9ea6e5d408768d5c2d666781ef756880eea0d0/nodes/patch_lib/__init__.py -------------------------------------------------------------------------------- /nodes/patch_lib/old/HunYuanVideoPatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from ...patch_util import PatchKeys 5 | from comfy.ldm.flux.layers import timestep_embedding 6 | 7 | 8 | def hunyuan_forward_orig( 9 | self, 10 | img: Tensor, 11 | img_ids: Tensor, 12 | txt: Tensor, 13 | txt_ids: Tensor, 14 | txt_mask: Tensor, 15 | timesteps: Tensor, 16 | y: Tensor, 17 | guidance: Tensor = None, 18 | control=None, 19 | transformer_options={}, 20 | **kwargs 21 | ) -> Tensor: 22 | patches_replace = transformer_options.get("patches_replace", {}) 23 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 24 | 25 | transformer_options[PatchKeys.running_net_model] = self 26 | 27 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 28 | if patches_enter is not None and len(patches_enter) > 0: 29 | for patch_enter in patches_enter: 30 | img, img_ids, txt, txt_ids, timesteps, y, guidance, control, txt_mask = patch_enter(img, 31 | img_ids, 32 | txt, 33 | txt_ids, 34 | timesteps, 35 | y, 36 | guidance, 37 | control, 38 | attn_mask=txt_mask, 39 | transformer_options=transformer_options 40 | ) 41 | 42 | initial_shape = list(img.shape) 43 | # running on sequences img 44 | img = self.img_in(img) 45 | vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) 46 | 47 | vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) 48 | 49 | if self.params.guidance_embed: 50 | if guidance is not None: 51 | vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) 52 | 53 | if txt_mask is not None and not torch.is_floating_point(txt_mask): 54 | txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max 55 | 56 | txt = self.txt_in(txt, timesteps, txt_mask) 57 | 58 | ids = torch.cat((img_ids, txt_ids), dim=1) 59 | pe = self.pe_embedder(ids) 60 | 61 | img_len = img.shape[1] 62 | if txt_mask is not None: 63 | attn_mask_len = img_len + txt.shape[1] 64 | attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) 65 | attn_mask[:, 0, img_len:] = txt_mask 66 | else: 67 | attn_mask = None 68 | 69 | blocks_replace = patches_replace.get("dit", {}) 70 | 71 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 72 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 73 | for blocks_before in patch_blocks_before: 74 | img, txt, vec, ids, pe = blocks_before(img, txt, vec, ids, pe, transformer_options) 75 | 76 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 77 | running_net_model = transformer_options[PatchKeys.running_net_model] 78 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 79 | for i, block in enumerate(running_net_model.double_blocks): 80 | if patch_double_blocks_with_control_replace is not None: 81 | img, txt = patch_double_blocks_with_control_replace({'i': i, 82 | 'block': block, 83 | 'img': img, 84 | 'txt': txt, 85 | 'vec': vec, 86 | 'pe': pe, 87 | 'control': control, 88 | 'attn_mask': attn_mask 89 | }, 90 | { 91 | "original_func": double_block_and_control_replace, 92 | "transformer_options": transformer_options 93 | }) 94 | else: 95 | img, txt = double_block_and_control_replace(i=i, 96 | block=block, 97 | img=img, 98 | txt=txt, 99 | vec=vec, 100 | pe=pe, 101 | control=control, 102 | attn_mask=attn_mask, 103 | transformer_options=transformer_options 104 | ) 105 | 106 | del patch_double_blocks_with_control_replace 107 | return img, txt 108 | 109 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 110 | 111 | if patch_double_blocks_replace is not None: 112 | img, txt = patch_double_blocks_replace({"img": img, 113 | "txt": txt, 114 | "vec": vec, 115 | "pe": pe, 116 | "control": control, 117 | "attn_mask": attn_mask, 118 | }, 119 | { 120 | "original_blocks": double_blocks_wrap, 121 | "transformer_options": transformer_options 122 | }) 123 | else: 124 | img, txt = double_blocks_wrap(img=img, 125 | txt=txt, 126 | vec=vec, 127 | pe=pe, 128 | control=control, 129 | attn_mask=attn_mask, 130 | transformer_options=transformer_options 131 | ) 132 | 133 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 134 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 135 | for patch_double_blocks_after in patches_double_blocks_after: 136 | img, txt = patch_double_blocks_after(img, txt, transformer_options) 137 | 138 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 139 | 140 | def blocks_transition_wrap(**kwargs): 141 | txt = kwargs["txt"] 142 | img = kwargs["img"] 143 | return torch.cat((img, txt), 1) 144 | 145 | if patch_blocks_transition is not None: 146 | img = patch_blocks_transition({"img": img, "txt": txt, "vec": vec, "pe": pe}, 147 | { 148 | "original_func": blocks_transition_wrap, 149 | "transformer_options": transformer_options 150 | }) 151 | else: 152 | img = blocks_transition_wrap(img=img, txt=txt) 153 | 154 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 155 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 156 | for patch_single_blocks_before in patches_single_blocks_before: 157 | img, txt = patch_single_blocks_before(img, txt, transformer_options) 158 | 159 | def single_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 160 | running_net_model = transformer_options[PatchKeys.running_net_model] 161 | for i, block in enumerate(running_net_model.single_blocks): 162 | if ("single_block", i) in blocks_replace: 163 | def block_wrap(args): 164 | out = {} 165 | out["img"] = block(args["img"], 166 | vec=args["vec"], 167 | pe=args["pe"], 168 | attn_mask=args.get("attention_mask")) 169 | return out 170 | 171 | out = blocks_replace[("single_block", i)]({"img": img, 172 | "vec": vec, 173 | "pe": pe, 174 | "attention_mask": attn_mask}, 175 | { 176 | "original_block": block_wrap, 177 | "transformer_options": transformer_options 178 | }) 179 | img = out["img"] 180 | else: 181 | img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) 182 | 183 | if control is not None: # Controlnet 184 | control_o = control.get("output") 185 | if i < len(control_o): 186 | add = control_o[i] 187 | if add is not None: 188 | img[:, : img_len] += add 189 | 190 | return img 191 | 192 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 193 | 194 | if patch_single_blocks_replace is not None: 195 | img, txt = patch_single_blocks_replace({"img": img, 196 | "txt": txt, 197 | "vec": vec, 198 | "pe": pe, 199 | "control": control, 200 | "attn_mask": attn_mask 201 | }, 202 | { 203 | "original_blocks": single_blocks_wrap, 204 | "transformer_options": transformer_options 205 | }) 206 | else: 207 | img = single_blocks_wrap(img=img, 208 | txt=txt, 209 | vec=vec, 210 | pe=pe, 211 | control=control, 212 | attn_mask=attn_mask, 213 | transformer_options=transformer_options 214 | ) 215 | 216 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 217 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 218 | for blocks_after in patch_blocks_exit: 219 | img, txt = blocks_after(img, txt, transformer_options) 220 | 221 | def final_transition_wrap(**kwargs): 222 | img = kwargs["img"] 223 | img_len = kwargs["img_len"] 224 | return img[:, : img_len] 225 | 226 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 227 | if patch_blocks_after_transition_replace is not None: 228 | img = patch_blocks_after_transition_replace({"img": img, "txt": txt, "vec": vec, "pe": pe, "img_len": img_len}, 229 | { 230 | "original_func": final_transition_wrap, 231 | "transformer_options": transformer_options 232 | }) 233 | else: 234 | img = final_transition_wrap(img=img, img_len=img_len) 235 | 236 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 237 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 238 | for patch_final_layer_before in patches_final_layer_before: 239 | img = patch_final_layer_before(img, txt, transformer_options) 240 | 241 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 242 | 243 | shape = initial_shape[-3:] 244 | for i in range(len(shape)): 245 | shape[i] = shape[i] // self.patch_size[i] 246 | img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) 247 | img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) 248 | img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) 249 | 250 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 251 | if patches_exit is not None and len(patches_exit) > 0: 252 | for patch_exit in patches_exit: 253 | img = patch_exit(img, transformer_options) 254 | 255 | del transformer_options[PatchKeys.running_net_model] 256 | 257 | return img 258 | 259 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}): 260 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 261 | if ("double_block", i) in blocks_replace: 262 | def block_wrap(args): 263 | out = {} 264 | out["img"], out["txt"] = block(img=args["img"], 265 | txt=args["txt"], 266 | vec=args["vec"], 267 | pe=args["pe"], 268 | attn_mask=args.get("attention_mask")) 269 | return out 270 | 271 | out = blocks_replace[("double_block", i)]({"img": img, 272 | "txt": txt, 273 | "vec": vec, 274 | "pe": pe, 275 | "attention_mask": attn_mask 276 | }, 277 | { 278 | "original_block": block_wrap, 279 | "transformer_options": transformer_options 280 | }) 281 | txt = out["txt"] 282 | img = out["img"] 283 | else: 284 | img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) 285 | if control is not None: # Controlnet 286 | control_i = control.get("input") 287 | if i < len(control_i): 288 | add = control_i[i] 289 | if add is not None: 290 | img += add 291 | 292 | return img, txt 293 | -------------------------------------------------------------------------------- /nodes/patch_lib/old/LTXVideoPatch.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from comfy.ldm.lightricks.model import precompute_freqs_cis 7 | from ...patch_util import PatchKeys 8 | 9 | # changed in comfyui hash commit 93fedd92fe0eb67a09e29069b05adebb40678639 (between comfyui version 0.3.19 and 1.3.20) 10 | def ltx_forward_orig( 11 | self, 12 | x, 13 | timestep, 14 | context, 15 | attention_mask, 16 | frame_rate=25, 17 | guiding_latent=None, 18 | guiding_latent_noise_scale=0, 19 | transformer_options={}, 20 | **kwargs 21 | ) -> Tensor: 22 | patches_point = transformer_options.get(PatchKeys.options_key, {}) 23 | 24 | transformer_options[PatchKeys.running_net_model] = self 25 | 26 | patches_enter = patches_point.get(PatchKeys.dit_enter, []) 27 | if patches_enter is not None and len(patches_enter) > 0: 28 | for patch_enter in patches_enter: 29 | x, timestep, context, attention_mask, frame_rate, guiding_latent, guiding_latent_noise_scale = patch_enter( 30 | x, 31 | timestep, 32 | context, 33 | attention_mask, 34 | frame_rate, 35 | guiding_latent, 36 | guiding_latent_noise_scale, 37 | transformer_options 38 | ) 39 | 40 | indices_grid = self.patchifier.get_grid( 41 | orig_num_frames=x.shape[2], 42 | orig_height=x.shape[3], 43 | orig_width=x.shape[4], 44 | batch_size=x.shape[0], 45 | scale_grid=((1 / frame_rate) * 8, 32, 32), 46 | device=x.device, 47 | ) 48 | 49 | if guiding_latent is not None: 50 | ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype) 51 | input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1)) 52 | ts *= input_ts 53 | ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2) 54 | timestep = self.patchifier.patchify(ts) 55 | input_x = x.clone() 56 | x[:, :, 0] = guiding_latent[:, :, 0] 57 | if guiding_latent_noise_scale > 0: 58 | if self.generator is None: 59 | self.generator = torch.Generator(device=x.device).manual_seed(42) 60 | elif self.generator.device != x.device: 61 | self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state()) 62 | 63 | noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]] 64 | scale = guiding_latent_noise_scale * (input_ts ** 2) 65 | guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator) 66 | 67 | x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0]) 68 | 69 | 70 | orig_shape = list(x.shape) 71 | 72 | x = self.patchifier.patchify(x) 73 | 74 | x = self.patchify_proj(x) 75 | timestep = timestep * 1000.0 76 | 77 | if attention_mask is not None and not torch.is_floating_point(attention_mask): 78 | attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max 79 | 80 | pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype) 81 | 82 | batch_size = x.shape[0] 83 | timestep, embedded_timestep = self.adaln_single( 84 | timestep.flatten(), 85 | {"resolution": None, "aspect_ratio": None}, 86 | batch_size=batch_size, 87 | hidden_dtype=x.dtype, 88 | ) 89 | # Second dimension is 1 or number of tokens (if timestep_per_token) 90 | timestep = timestep.view(batch_size, -1, timestep.shape[-1]) 91 | embedded_timestep = embedded_timestep.view( 92 | batch_size, -1, embedded_timestep.shape[-1] 93 | ) 94 | 95 | # 2. Blocks 96 | if self.caption_projection is not None: 97 | batch_size = x.shape[0] 98 | context = self.caption_projection(context) 99 | context = context.view( 100 | batch_size, -1, x.shape[-1] 101 | ) 102 | 103 | patch_blocks_before = patches_point.get(PatchKeys.dit_blocks_before, []) 104 | if patch_blocks_before is not None and len(patch_blocks_before) > 0: 105 | for blocks_before in patch_blocks_before: 106 | x, context, timestep, ids, pe = blocks_before(img=x, txt=context, vec=timestep, ids=None, pe=pe, transformer_options=transformer_options) 107 | 108 | def double_blocks_wrap(img, txt, vec, pe, control=None, attn_mask=None, transformer_options={}): 109 | running_net_model = transformer_options[PatchKeys.running_net_model] 110 | patch_double_blocks_with_control_replace = patches_point.get(PatchKeys.dit_double_block_with_control_replace) 111 | for i, block in enumerate(running_net_model.transformer_blocks): 112 | if patch_double_blocks_with_control_replace is not None: 113 | img, txt = patch_double_blocks_with_control_replace({'i': i, 114 | 'block': block, 115 | 'img': img, 116 | 'txt': txt, 117 | 'vec': vec, 118 | 'pe': pe, 119 | 'control': control, 120 | 'attn_mask': attn_mask 121 | }, 122 | { 123 | "original_func": double_block_and_control_replace, 124 | "transformer_options": transformer_options 125 | }) 126 | else: 127 | img, txt = double_block_and_control_replace(i=i, 128 | block=block, 129 | img=img, 130 | txt=txt, 131 | vec=vec, 132 | pe=pe, 133 | control=control, 134 | attn_mask=attn_mask, 135 | transformer_options=transformer_options 136 | ) 137 | 138 | del patch_double_blocks_with_control_replace 139 | return img, txt 140 | 141 | patch_double_blocks_replace = patches_point.get(PatchKeys.dit_double_blocks_replace) 142 | 143 | if patch_double_blocks_replace is not None: 144 | x, context = patch_double_blocks_replace({"img": x, 145 | "txt": context, 146 | "vec": timestep, 147 | "pe": pe, 148 | "control": None, 149 | "attn_mask": attention_mask, 150 | }, 151 | { 152 | "original_blocks": double_blocks_wrap, 153 | "transformer_options": transformer_options 154 | }) 155 | else: 156 | x, context = double_blocks_wrap(img=x, 157 | txt=context, 158 | vec=timestep, 159 | pe=pe, 160 | control=None, 161 | attn_mask=attention_mask, 162 | transformer_options=transformer_options 163 | ) 164 | 165 | patches_double_blocks_after = patches_point.get(PatchKeys.dit_double_blocks_after, []) 166 | if patches_double_blocks_after is not None and len(patches_double_blocks_after) > 0: 167 | for patch_double_blocks_after in patches_double_blocks_after: 168 | x, context = patch_double_blocks_after(x, context, transformer_options) 169 | 170 | patch_blocks_transition = patches_point.get(PatchKeys.dit_blocks_transition_replace) 171 | 172 | def blocks_transition_wrap(**kwargs): 173 | x = kwargs["img"] 174 | return x 175 | 176 | if patch_blocks_transition is not None: 177 | x = patch_blocks_transition({"img": x, "txt": context, "vec": timestep, "pe": pe}, 178 | { 179 | "original_func": blocks_transition_wrap, 180 | "transformer_options": transformer_options 181 | }) 182 | else: 183 | x = blocks_transition_wrap(img=x, txt=context) 184 | 185 | patches_single_blocks_before = patches_point.get(PatchKeys.dit_single_blocks_before, []) 186 | if patches_single_blocks_before is not None and len(patches_single_blocks_before) > 0: 187 | for patch_single_blocks_before in patches_single_blocks_before: 188 | x, context = patch_single_blocks_before(x, context, transformer_options) 189 | 190 | def single_blocks_wrap(img, **kwargs): 191 | return img 192 | 193 | patch_single_blocks_replace = patches_point.get(PatchKeys.dit_single_blocks_replace) 194 | 195 | if patch_single_blocks_replace is not None: 196 | x, context = patch_single_blocks_replace({"img": x, 197 | "txt": context, 198 | "vec": timestep, 199 | "pe": pe, 200 | "control": None, 201 | "attn_mask": attention_mask 202 | }, 203 | { 204 | "original_blocks": single_blocks_wrap, 205 | "transformer_options": transformer_options 206 | }) 207 | else: 208 | x = single_blocks_wrap(img=x, 209 | txt=context, 210 | vec=timestep, 211 | pe=pe, 212 | control=None, 213 | attn_mask=attention_mask, 214 | transformer_options=transformer_options 215 | ) 216 | 217 | patch_blocks_exit = patches_point.get(PatchKeys.dit_blocks_after, []) 218 | if patch_blocks_exit is not None and len(patch_blocks_exit) > 0: 219 | for blocks_after in patch_blocks_exit: 220 | x, context = blocks_after(x, context, transformer_options) 221 | 222 | # 3. Output 223 | def final_transition_wrap(**kwargs): 224 | running_net_model = transformer_options[PatchKeys.running_net_model] 225 | x = kwargs["img"] 226 | embedded_timestep = kwargs["embedded_timestep"] 227 | scale_shift_values = ( 228 | running_net_model.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] 229 | ) 230 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 231 | x = running_net_model.norm_out(x) 232 | # Modulation 233 | x = x * (1 + scale) + shift 234 | return x 235 | 236 | patch_blocks_after_transition_replace = patches_point.get(PatchKeys.dit_blocks_after_transition_replace) 237 | if patch_blocks_after_transition_replace is not None: 238 | x = patch_blocks_after_transition_replace({"img": x, "txt": context, "vec": timestep, "pe": pe, "embedded_timestep": embedded_timestep}, 239 | { 240 | "original_func": final_transition_wrap, 241 | "transformer_options": transformer_options 242 | }) 243 | else: 244 | x = final_transition_wrap(img=x, embedded_timestep=embedded_timestep) 245 | 246 | patches_final_layer_before = patches_point.get(PatchKeys.dit_final_layer_before, []) 247 | if patches_final_layer_before is not None and len(patches_final_layer_before) > 0: 248 | for patch_final_layer_before in patches_final_layer_before: 249 | x = patch_final_layer_before(img=x, txt=context, transformer_options=transformer_options) 250 | 251 | x = self.proj_out(x) 252 | 253 | x = self.patchifier.unpatchify( 254 | latents=x, 255 | output_height=orig_shape[3], 256 | output_width=orig_shape[4], 257 | output_num_frames=orig_shape[2], 258 | out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), 259 | ) 260 | 261 | if guiding_latent is not None: 262 | x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0] 263 | 264 | patches_exit = patches_point.get(PatchKeys.dit_exit, []) 265 | if patches_exit is not None and len(patches_exit) > 0: 266 | for patch_exit in patches_exit: 267 | x = patch_exit(x, transformer_options) 268 | 269 | del transformer_options[PatchKeys.running_net_model] 270 | 271 | return x 272 | 273 | def double_block_and_control_replace(i, block, img, txt=None, vec=None, pe=None, control=None, attn_mask=None, transformer_options={}): 274 | blocks_replace = transformer_options.get("patches_replace", {}).get("dit", {}) 275 | if ("double_block", i) in blocks_replace: 276 | def block_wrap(args): 277 | out = {} 278 | out["img"] = block(x=args["img"], 279 | context=args["txt"], 280 | timestep=args["vec"], 281 | pe=args["pe"], 282 | attention_mask=args.get("attention_mask")) 283 | return out 284 | 285 | out = blocks_replace[("double_block", i)]({"img": img, 286 | "txt": txt, 287 | "vec": vec, 288 | "pe": pe, 289 | "attention_mask": attn_mask, 290 | }, 291 | { 292 | "original_block": block_wrap, 293 | "transformer_options": transformer_options 294 | }) 295 | img = out["img"] 296 | else: 297 | img = block(x=img, context=txt, timestep=vec, pe=pe, attention_mask=attn_mask) 298 | 299 | return img, txt 300 | -------------------------------------------------------------------------------- /nodes/patch_util.py: -------------------------------------------------------------------------------- 1 | import types 2 | import comfy 3 | 4 | class PatchKeys: 5 | ################## transformer_options patches ################## 6 | options_key = "patches_point" 7 | running_net_model = "running_net_model" 8 | # patches_point下支持设置的补丁 9 | dit_enter = "patch_dit_enter" 10 | dit_blocks_before = "patch_dit_blocks_before" 11 | dit_double_blocks_replace = "patch_dit_double_blocks_replace" 12 | dit_double_block_with_control_replace = "patch_dit_double_block_with_control_replace" 13 | dit_double_blocks_after = "patch_dit_double_blocks_after" 14 | dit_blocks_transition_replace = "patch_dit_blocks_transition_replace" 15 | dit_single_blocks_before = "patch_dit_single_blocks_before" 16 | dit_single_blocks_replace = "patch_dit_single_blocks_replace" 17 | dit_blocks_after = "patch_dit_blocks_after" 18 | dit_blocks_after_transition_replace = "patch_dit_final_layer_before_replace" 19 | dit_final_layer_before = "patch_dit_final_layer_before" 20 | dit_exit = "patch_dit_exit" 21 | ################## transformer_options patches ################## 22 | 23 | 24 | def set_model_patch(model_patcher, options_key, patch, name): 25 | to = model_patcher.model_options["transformer_options"] 26 | if options_key not in to: 27 | to[options_key] = {} 28 | to[options_key][name] = to[options_key].get(name, []) + [patch] 29 | 30 | def set_model_patch_replace(model_patcher, options_key, patch, name): 31 | to = model_patcher.model_options["transformer_options"] 32 | if options_key not in to: 33 | to[options_key] = {} 34 | to[options_key][name] = patch 35 | 36 | def add_model_patch_option(model, patch_key): 37 | if 'transformer_options' not in model.model_options: 38 | model.model_options['transformer_options'] = {} 39 | to = model.model_options['transformer_options'] 40 | if patch_key not in to: 41 | to[patch_key] = {} 42 | return to[patch_key] 43 | 44 | 45 | def set_hook(diffusion_model, bak_method_name, new_method, old_method_name='forward_orig'): 46 | if new_method is not None: 47 | setattr(diffusion_model, bak_method_name, getattr(diffusion_model, old_method_name)); 48 | setattr(diffusion_model ,old_method_name, types.MethodType(new_method, diffusion_model)) 49 | 50 | 51 | def clean_hook(diffusion_model, bak_method_name, old_method_name='forward_orig'): 52 | if hasattr(diffusion_model, bak_method_name): 53 | setattr(diffusion_model, old_method_name, getattr(diffusion_model, bak_method_name)) 54 | delattr(diffusion_model, bak_method_name) 55 | 56 | 57 | def is_hunyuan_video_model(model): 58 | if isinstance(model, comfy.ldm.hunyuan_video.model.HunyuanVideo): 59 | return True 60 | return False 61 | 62 | def is_ltxv_video_model(model): 63 | if isinstance(model, comfy.ldm.lightricks.model.LTXVModel): 64 | return True 65 | return False 66 | 67 | def is_flux_model(model): 68 | if isinstance(model, comfy.ldm.flux.model.Flux): 69 | return True 70 | return False 71 | 72 | def is_mochi_video_model(model): 73 | if isinstance(model, comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint): 74 | return True 75 | return False 76 | 77 | def is_wan_video_model(model): 78 | if isinstance(model, comfy.ldm.wan.model.WanModel): 79 | return True 80 | return False -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_patches_ll" 3 | description = "Some patches for Flux|HunYuanVideo|LTXVideo|MochiVideo|WanVideo etc, support TeaCache, PuLID, First Block Cache." 4 | version = "1.1.1" 5 | license = {file = "LICENSE"} 6 | dependencies = [] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/lldacing/ComfyUI_Patches_ll" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "lldacing" 14 | DisplayName = "ComfyUI_Patches_ll" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | packaging --------------------------------------------------------------------------------