├── .gitattributes ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── __init__.py ├── decoder_noise.py ├── easy_samplers.py ├── example_workflows ├── 13b-distilled │ ├── ltxv-13b-dist-i2v-base-fp8.json │ ├── ltxv-13b-dist-i2v-base.json │ ├── ltxv-13b-dist-i2v-extend.json │ └── ltxv-13b-dist-i2v-keyframes.json ├── low_level │ ├── end.jpg │ ├── fox.jpg │ ├── jeep.mp4 │ ├── ltxvideo-first-sequence-conditioning.json │ ├── ltxvideo-frame-interpolation.json │ ├── ltxvideo-i2v-distilled.json │ ├── ltxvideo-i2v.json │ ├── ltxvideo-last-sequence-conditioning.json │ ├── ltxvideo-t2v.json │ ├── moto.png │ ├── shrek2.jpg │ └── start.jpg ├── ltxv-13b-i2v-base-fp8.json ├── ltxv-13b-i2v-base.json ├── ltxv-13b-i2v-extend.json ├── ltxv-13b-i2v-keyframes.json ├── ltxv-13b-i2v-mixed-multiscale.json └── tricks │ ├── ltxvideo-flow-edit.json │ ├── ltxvideo-flow-edit.png │ ├── ltxvideo-rf-edit.json │ ├── ltxvideo-rf-edit.png │ ├── ref.png │ ├── shot.mp4 │ ├── shot2.mp4 │ ├── shrek2.jpg │ └── shrek3.jpg ├── film_grain.py ├── guide.py ├── latent_adain.py ├── latent_upsampler.py ├── latents.py ├── nodes_registry.py ├── presets └── stg_advanced_presets.json ├── prompt_enhancer_nodes.py ├── prompt_enhancer_utils.py ├── pyproject.toml ├── q8_nodes.py ├── recurrent_sampler.py ├── requirements.txt ├── stg.py ├── tiled_sampler.py └── tricks ├── __init__.py ├── modules └── ltx_model.py ├── nodes ├── attn_bank_nodes.py ├── attn_override_node.py ├── latent_guide_node.py ├── ltx_feta_enhance_node.py ├── ltx_flowedit_nodes.py ├── ltx_inverse_model_pred_nodes.py ├── ltx_pag_node.py ├── modify_ltx_model_node.py ├── rectified_sampler_nodes.py └── rf_edit_sampler_nodes.py └── utils ├── attn_bank.py ├── feta_enhance_utils.py ├── latent_guide.py ├── module_utils.py └── noise_utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | **/*.png filter=lfs diff=lfs merge=lfs -text 2 | **/*.mp4 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # general formatting 2 | repos: 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v4.6.0 5 | hooks: 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | - id: requirements-txt-fixer 10 | - id: check-added-large-files 11 | - id: check-toml 12 | # Isort (import sorting) 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.13.2 15 | hooks: 16 | - id: isort 17 | args: [--profile, black] 18 | # Black (code formatting) 19 | - repo: https://github.com/psf/black 20 | rev: 24.4.2 # Replace by any tag/version: https://github.com/psf/black/tags 21 | hooks: 22 | - id: black 23 | language_version: python3.10 24 | # flake8 linter 25 | - repo: https://github.com/charliermarsh/ruff-pre-commit 26 | rev: 'v0.4.4' 27 | hooks: 28 | - id: ruff 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-LTXVideo 2 | 3 | ComfyUI-LTXVideo is a collection of custom nodes for ComfyUI, designed to provide useful tools for working with the LTXV model. 4 | The model itself is supported in the core ComfyUI [code](https://github.com/comfyanonymous/ComfyUI/tree/master/comfy/ldm/lightricks). 5 | The main LTXVideo repository can be found [here](https://github.com/Lightricks/LTX-Video). 6 | 7 | # ⭐ 14.05.2025 – LTXVideo 13B 0.9.7 Distilled Release ⭐ 8 | 9 | ### 🚀 What's New in LTXVideo 13B 0.9.7 Distilled 10 | 1. **LTXV 13B Distilled 🥳 0.9.7**
11 | Delivers cinematic-quality videos at fraction of steps needed to run full model. Only 4 or 8 steps needed for single generation.
12 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) 13 | 14 | 2. **LTXV 13B Distilled Quantized 0.9.7**
15 | Offers reduced memory requirements and even faster inference speeds. 16 | Ideal for consumer-grade GPUs (e.g., NVIDIA 4090, 5090).
17 | ***Important:*** In order to get the best performance with the quantized version please install [q8_kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) package and use dedicated flow below.
18 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors)
19 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section. 20 | 21 | 3. **Updated LTV 13B Quantized version**
22 | From now on all our 8bit quantized models are running natively in ComfyUI, still with our Q8 patcher node you will get the best inference speed.
23 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors)
24 | # ⭐ 06.05.2025 – LTXVideo 13B 0.9.7 Release ⭐ 25 | 26 | ### 🚀 What's New in LTXVideo 13B 0.9.7 27 | 28 | 1. **LTXV 13B 0.9.7** 29 | Delivers cinematic-quality videos at unprecedented speed.
30 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) 31 | 32 | 2. **LTXV 13B Quantized 0.9.7** 33 | Offers reduced memory requirements and even faster inference speeds. 34 | Ideal for consumer-grade GPUs (e.g., NVIDIA 4090, 5090). 35 | Delivers outstanding quality with improved performance.
36 | ***Important:*** In order to run the quantized version please install [LTXVideo-Q8-Kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) package and use dedicated flow below. Loading the model in Comfy with LoadCheckpoint node won't work.
37 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors)
38 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section. 39 | 40 | 3. **Latent Upscaling Models** 41 | Enables inference across multiple scales by upscaling latent tensors without decoding/encoding. 42 | Multiscale inference delivers high-quality results in a fraction of the time compared to similar models.
43 | ***Important:*** Make sure you put the models below in **models/upscale_models** folder.
44 | 👉 Spatial upscaling: [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors).
45 | 👉 Temporal upscaling: [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors).
46 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section. 47 | 48 | 49 | ### Technical Updates 50 | 51 | 1. ***New simplified flows and nodes***
52 | 1.1. Simplified image to video: [Download here](example_workflows/ltxv-13b-i2v-base.json).
53 | 1.2. Simplified image to video with extension: [Download here](example_workflows/ltxv-13b-i2v-extend.json).
54 | 1.3. Simplified image to video with keyframes: [Download here](example_workflows/ltxv-13b-i2v-keyframes.json).
55 | 56 | # 17.04.2025 ⭐ LTXVideo 0.9.6 Release ⭐ 57 | 58 | ### LTXVideo 0.9.6 introduces: 59 | 60 | 1. LTXV 0.9.6 – higher quality, faster, great for final output. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltxv-2b-0.9.6-dev-04-25.safetensors). 61 | 2. LTXV 0.9.6 Distilled – our fastest model yet (only 8 steps for generation), lighter, great for rapid iteration. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltxv-2b-0.9.6-distilled-04-25.safetensors). 62 | 63 | ### Technical Updates 64 | 65 | We introduce the __STGGuiderAdvanced__ node, which applies different CFG and STG parameters at various diffusion steps. All flows have been updated to use this node and are designed to provide optimal parameters for the best quality. 66 | See the [Example Workflows](#example-workflows) section. 67 | 68 | # 5.03.2025 ⭐ LTXVideo 0.9.5 Release ⭐ 69 | 70 | ### LTXVideo 0.9.5 introduces: 71 | 72 | 1. Improved quality with reduced artifacts. 73 | 2. Support for higher resolution and longer sequences. 74 | 3. Frame and sequence conditioning (beyond the first frame). 75 | 4. Enhanced prompt understanding. 76 | 5. Commercial license availability. 77 | 78 | ### Technical Updates 79 | 80 | Since LTXVideo is now fully supported in the ComfyUI core, we have removed the custom model implementation. Instead, we provide updated workflows to showcase the new features: 81 | 82 | 1. **Frame Conditioning** – Enables interpolation between given frames. 83 | 2. **Sequence Conditioning** – Allows motion interpolation from a given frame sequence, enabling video extension from the beginning, end, or middle of the original video. 84 | 3. **Prompt Enhancer** – A new node that helps generate prompts optimized for the best model performance. 85 | See the [Example Workflows](#example-workflows) section for more details. 86 | 87 | ### LTXTricks Update 88 | 89 | The LTXTricks code has been integrated into this repository (in the `/tricks` folder) and will be maintained here. The original [repo](https://github.com/logtd/ComfyUI-LTXTricks) is no longer maintained, but all existing workflows should continue to function as expected. 90 | 91 | ## 22.12.2024 92 | 93 | Fixed a bug which caused the model to produce artifacts on short negative prompts when using a native CLIP Loader node. 94 | 95 | ## 19.12.2024 ⭐ Update ⭐ 96 | 97 | 1. Improved model - removes "strobing texture" artifacts and generates better motion. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltx-video-2b-v0.9.1.safetensors). 98 | 2. STG support 99 | 3. Integrated image degradation system for improved motion generation. 100 | 4. Additional initial latent optional input to chain latents for high res generation. 101 | 5. Image captioning in image to video [flow](example_workflows/ltxvideo-i2v.json). 102 | 103 | ## Installation 104 | 105 | Installation via [ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) is preferred. Simply search for `ComfyUI-LTXVideo` in the list of nodes and follow installation instructions. 106 | 107 | ### Manual installation 108 | 109 | 1. Install ComfyUI 110 | 2. Clone this repository to `custom-nodes` folder in your ComfyUI installation directory. 111 | 3. Install the required packages: 112 | 113 | ```bash 114 | cd custom_nodes/ComfyUI-LTXVideo && pip install -r requirements.txt 115 | ``` 116 | 117 | For portable ComfyUI installations, run 118 | 119 | ``` 120 | .\python_embeded\python.exe -m pip install -r .\ComfyUI\custom_nodes\ComfyUI-LTXVideo\requirements.txt 121 | ``` 122 | 123 | ### Models 124 | 125 | 1. Download [ltx-video-2b-v0.9.1.safetensors](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) from Hugging Face and place it under `models/checkpoints`. 126 | 2. Install one of the t5 text encoders, for example [google_t5-v1_1-xxl_encoderonly](https://huggingface.co/mcmonkey/google_t5-v1_1-xxl_encoderonly/tree/main). You can install it using ComfyUI Model Manager. 127 | 128 | ## Example workflows 129 | 130 | Note that to run the example workflows, you need to have some additional custom nodes, like [ComfyUI-VideoHelperSuite](https://github.com/kosinkadink/ComfyUI-VideoHelperSuite) and others, installed. You can do it by pressing "Install Missing Custom Nodes" button in ComfyUI Manager. 131 | 132 | ### Easy to use multi scale generation workflows 133 | 134 | 🧩 [Image to video mixed](example_workflows/ltxv13b-i2v-mixed-multiscale.json): mixed flow with full and distilled model for best quality and speed trade-off.
135 | 136 | ### 13B model
137 | 🧩 [Image to video](example_workflows/ltxv-13b-i2v-base.json)
138 | 🧩 [Image to video with keyframes](example_workflows/ltxv-13b-i2v-keyframes.json)
139 | 🧩 [Image to video with duration extension](example_workflows/ltxv-13b-i2v-extend.json)
140 | 🧩 [Image to video 8b quantized](example_workflows/ltxv-13b-i2v-base-fp8.json) 141 | 142 | ### 13B distilled model
143 | 🧩 [Image to video](example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json)
144 | 🧩 [Image to video with keyframes](example_workflows/13b-distilled/ltxv-13b-dist-i2v-keyframes.json)
145 | 🧩 [Image to video with duration extension](example_workflows/13b-distilled/ltxv-13b-dist-i2v-extend.json)
146 | 🧩 [Image to video 8b quantized](example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) 147 | 148 | ### Inversion 149 | 150 | #### Flow Edit 151 | 152 | 🧩 [Download workflow](example_workflows/tricks/ltxvideo-flow-edit.json)
153 | ![workflow](example_workflows/tricks/ltxvideo-flow-edit.png) 154 | 155 | #### RF Edit 156 | 157 | 🧩 [Download workflow](example_workflows/tricks/ltxvideo-rf-edit.json)
158 | ![workflow](example_workflows/tricks/ltxvideo-rf-edit.png) 159 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder_noise import DecoderNoise 2 | from .easy_samplers import LTXVBaseSampler 3 | from .film_grain import LTXVFilmGrain 4 | from .guide import LTXVAddGuideAdvanced 5 | from .latent_adain import LTXVAdainLatent 6 | from .latent_upsampler import LTXVLatentUpsampler 7 | from .latents import LTXVSelectLatents, LTXVSetVideoLatentNoiseMasks 8 | from .nodes_registry import NODE_CLASS_MAPPINGS as RUNTIME_NODE_CLASS_MAPPINGS 9 | from .nodes_registry import ( 10 | NODE_DISPLAY_NAME_MAPPINGS as RUNTIME_NODE_DISPLAY_NAME_MAPPINGS, 11 | ) 12 | from .nodes_registry import NODES_DISPLAY_NAME_PREFIX, camel_case_to_spaces 13 | from .prompt_enhancer_nodes import LTXVPromptEnhancer, LTXVPromptEnhancerLoader 14 | from .q8_nodes import LTXVQ8Patch 15 | from .recurrent_sampler import LinearOverlapLatentTransition, LTXVRecurrentKSampler 16 | from .stg import ( 17 | LTXVApplySTG, 18 | STGAdvancedPresetsNode, 19 | STGGuiderAdvancedNode, 20 | STGGuiderNode, 21 | ) 22 | from .tiled_sampler import LTXVTiledSampler 23 | from .tricks import NODE_CLASS_MAPPINGS as TRICKS_NODE_CLASS_MAPPINGS 24 | from .tricks import NODE_DISPLAY_NAME_MAPPINGS as TRICKS_NODE_DISPLAY_NAME_MAPPINGS 25 | 26 | # Static node mappings, required for ComfyUI-Manager mapping to work 27 | NODE_CLASS_MAPPINGS = { 28 | "Set VAE Decoder Noise": DecoderNoise, 29 | "LinearOverlapLatentTransition": LinearOverlapLatentTransition, 30 | "LTXVAddGuideAdvanced": LTXVAddGuideAdvanced, 31 | "LTXVAdainLatent": LTXVAdainLatent, 32 | "LTXVApplySTG": LTXVApplySTG, 33 | "LTXVBaseSampler": LTXVBaseSampler, 34 | "LTXVFilmGrain": LTXVFilmGrain, 35 | "LTXVLatentUpsampler": LTXVLatentUpsampler, 36 | "LTXVPromptEnhancer": LTXVPromptEnhancer, 37 | "LTXVPromptEnhancerLoader": LTXVPromptEnhancerLoader, 38 | "LTXQ8Patch": LTXVQ8Patch, 39 | "LTXVRecurrentKSampler": LTXVRecurrentKSampler, 40 | "LTXVSelectLatents": LTXVSelectLatents, 41 | "LTXVSetVideoLatentNoiseMasks": LTXVSetVideoLatentNoiseMasks, 42 | "LTXVTiledSampler": LTXVTiledSampler, 43 | "STGAdvancedPresets": STGAdvancedPresetsNode, 44 | "STGGuiderAdvanced": STGGuiderAdvancedNode, 45 | "STGGuiderNode": STGGuiderNode, 46 | } 47 | 48 | # Consistent display names between static and dynamic node mappings in nodes_registry.py, 49 | # to prevent ComfyUI initializing them with default display names. 50 | NODE_DISPLAY_NAME_MAPPINGS = { 51 | name: f"{NODES_DISPLAY_NAME_PREFIX} {camel_case_to_spaces(name)}" 52 | for name in NODE_CLASS_MAPPINGS.keys() 53 | } 54 | 55 | # Merge the node mappings from tricks into the main mappings 56 | NODE_CLASS_MAPPINGS.update(TRICKS_NODE_CLASS_MAPPINGS) 57 | NODE_DISPLAY_NAME_MAPPINGS.update(TRICKS_NODE_DISPLAY_NAME_MAPPINGS) 58 | 59 | # Update with runtime mappings (these will override static mappings if there are any differences) 60 | NODE_CLASS_MAPPINGS.update(RUNTIME_NODE_CLASS_MAPPINGS) 61 | NODE_DISPLAY_NAME_MAPPINGS.update(RUNTIME_NODE_DISPLAY_NAME_MAPPINGS) 62 | 63 | # Export so that ComfyUI can pick them up. 64 | __all__ = [ 65 | "NODE_CLASS_MAPPINGS", 66 | "NODE_DISPLAY_NAME_MAPPINGS", 67 | ] 68 | -------------------------------------------------------------------------------- /decoder_noise.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | 3 | from .nodes_registry import comfy_node 4 | 5 | 6 | @comfy_node(name="Set VAE Decoder Noise") 7 | class DecoderNoise: 8 | @classmethod 9 | def INPUT_TYPES(cls): 10 | return { 11 | "required": { 12 | "vae": ("VAE",), 13 | "timestep": ( 14 | "FLOAT", 15 | { 16 | "default": 0.05, 17 | "min": 0.0, 18 | "max": 1.0, 19 | "step": 0.001, 20 | "tooltip": "The timestep used for decoding the noise.", 21 | }, 22 | ), 23 | "scale": ( 24 | "FLOAT", 25 | { 26 | "default": 0.025, 27 | "min": 0.0, 28 | "max": 1.0, 29 | "step": 0.001, 30 | "tooltip": "The scale of the noise added to the decoder.", 31 | }, 32 | ), 33 | "seed": ( 34 | "INT", 35 | { 36 | "default": 42, 37 | "min": 0, 38 | "max": 0xFFFFFFFFFFFFFFFF, 39 | "tooltip": "The random seed used for creating the noise.", 40 | }, 41 | ), 42 | } 43 | } 44 | 45 | FUNCTION = "add_noise" 46 | RETURN_TYPES = ("VAE",) 47 | CATEGORY = "lightricks/LTXV" 48 | 49 | def add_noise(self, vae, timestep, scale, seed): 50 | result = copy(vae) 51 | if hasattr(result, "first_stage_model"): 52 | result.first_stage_model.decode_timestep = timestep 53 | result.first_stage_model.decode_noise_scale = scale 54 | result._decode_timestep = timestep 55 | result.decode_noise_scale = scale 56 | result.seed = seed 57 | return (result,) 58 | -------------------------------------------------------------------------------- /easy_samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import comfy 4 | import comfy_extras 5 | import nodes 6 | from comfy_extras.nodes_custom_sampler import SamplerCustomAdvanced 7 | from comfy_extras.nodes_lt import ( 8 | EmptyLTXVLatentVideo, 9 | LTXVAddGuide, 10 | LTXVCropGuides, 11 | LTXVImgToVideo, 12 | ) 13 | 14 | from .guide import blur_internal 15 | from .latents import LTXVAddLatentGuide, LTXVSelectLatents 16 | from .nodes_registry import comfy_node 17 | from .recurrent_sampler import LinearOverlapLatentTransition 18 | 19 | 20 | @comfy_node( 21 | name="LTXVBaseSampler", 22 | ) 23 | class LTXVBaseSampler: 24 | 25 | @classmethod 26 | def INPUT_TYPES(s): 27 | return { 28 | "required": { 29 | "model": ("MODEL",), 30 | "vae": ("VAE",), 31 | "width": ( 32 | "INT", 33 | { 34 | "default": 768, 35 | "min": 64, 36 | "max": nodes.MAX_RESOLUTION, 37 | "step": 32, 38 | }, 39 | ), 40 | "height": ( 41 | "INT", 42 | { 43 | "default": 512, 44 | "min": 64, 45 | "max": nodes.MAX_RESOLUTION, 46 | "step": 32, 47 | }, 48 | ), 49 | "num_frames": ( 50 | "INT", 51 | {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}, 52 | ), 53 | "guider": ("GUIDER",), 54 | "sampler": ("SAMPLER",), 55 | "sigmas": ("SIGMAS",), 56 | "noise": ("NOISE",), 57 | }, 58 | "optional": { 59 | "optional_cond_images": ("IMAGE",), 60 | "optional_cond_indices": ("STRING",), 61 | "strength": ("FLOAT", {"default": 0.9, "min": 0, "max": 1}), 62 | "crop": (["center", "disabled"], {"default": "disabled"}), 63 | "crf": ("INT", {"default": 35, "min": 0, "max": 100}), 64 | "blur": ("INT", {"default": 0, "min": 0, "max": 10}), 65 | }, 66 | } 67 | 68 | RETURN_TYPES = ("LATENT",) 69 | RETURN_NAMES = ("denoised_output",) 70 | FUNCTION = "sample" 71 | CATEGORY = "sampling" 72 | 73 | def sample( 74 | self, 75 | model, 76 | vae, 77 | width, 78 | height, 79 | num_frames, 80 | guider, 81 | sampler, 82 | sigmas, 83 | noise, 84 | optional_cond_images=None, 85 | optional_cond_indices=None, 86 | strength=0.9, 87 | crop="disabled", 88 | crf=35, 89 | blur=0, 90 | ): 91 | 92 | if optional_cond_images is not None: 93 | optional_cond_images = ( 94 | comfy.utils.common_upscale( 95 | optional_cond_images.movedim(-1, 1), 96 | width, 97 | height, 98 | "bilinear", 99 | crop=crop, 100 | ) 101 | .movedim(1, -1) 102 | .clamp(0, 1) 103 | ) 104 | print("optional_cond_images shape", optional_cond_images.shape) 105 | optional_cond_images = comfy_extras.nodes_lt.LTXVPreprocess().preprocess( 106 | optional_cond_images, crf 107 | )[0] 108 | for i in range(optional_cond_images.shape[0]): 109 | optional_cond_images[i] = blur_internal( 110 | optional_cond_images[i].unsqueeze(0), blur 111 | ) 112 | 113 | if optional_cond_indices is not None and optional_cond_images is not None: 114 | optional_cond_indices = optional_cond_indices.split(",") 115 | optional_cond_indices = [int(i) for i in optional_cond_indices] 116 | assert len(optional_cond_indices) == len( 117 | optional_cond_images 118 | ), "Number of optional cond images must match number of optional cond indices" 119 | 120 | try: 121 | positive, negative = guider.raw_conds 122 | except AttributeError: 123 | raise ValueError( 124 | "Guider does not have raw conds, cannot use it as a guider. " 125 | "Please use STGGuiderAdvanced." 126 | ) 127 | 128 | if optional_cond_images is None: 129 | (latents,) = EmptyLTXVLatentVideo().generate(width, height, num_frames, 1) 130 | elif optional_cond_images.shape[0] == 1 and optional_cond_indices[0] == 0: 131 | ( 132 | positive, 133 | negative, 134 | latents, 135 | ) = LTXVImgToVideo().generate( 136 | positive=positive, 137 | negative=negative, 138 | vae=vae, 139 | image=optional_cond_images[0].unsqueeze(0), 140 | width=width, 141 | height=height, 142 | length=num_frames, 143 | batch_size=1, 144 | strength=strength, 145 | ) 146 | else: 147 | (latents,) = EmptyLTXVLatentVideo().generate(width, height, num_frames, 1) 148 | for cond_image, cond_idx in zip( 149 | optional_cond_images, optional_cond_indices 150 | ): 151 | ( 152 | positive, 153 | negative, 154 | latents, 155 | ) = LTXVAddGuide().generate( 156 | positive=positive, 157 | negative=negative, 158 | vae=vae, 159 | latent=latents, 160 | image=cond_image.unsqueeze(0), 161 | frame_idx=cond_idx, 162 | strength=strength, 163 | ) 164 | 165 | guider = copy.copy(guider) 166 | guider.set_conds(positive, negative) 167 | 168 | # Denoise the latent video 169 | (output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample( 170 | noise=noise, 171 | guider=guider, 172 | sampler=sampler, 173 | sigmas=sigmas, 174 | latent_image=latents, 175 | ) 176 | 177 | # Clean up guides if image conditioning was used 178 | print("before guide crop", denoised_output_latents["samples"].shape) 179 | positive, negative, denoised_output_latents = LTXVCropGuides().crop( 180 | positive=positive, 181 | negative=negative, 182 | latent=denoised_output_latents, 183 | ) 184 | print("after guide crop", denoised_output_latents["samples"].shape) 185 | 186 | return (denoised_output_latents,) 187 | 188 | 189 | @comfy_node( 190 | name="LTXVExtendSampler", 191 | ) 192 | class LTXVExtendSampler: 193 | 194 | @classmethod 195 | def INPUT_TYPES(s): 196 | return { 197 | "required": { 198 | "model": ("MODEL",), 199 | "vae": ("VAE",), 200 | "latents": ("LATENT",), 201 | "num_new_frames": ( 202 | "INT", 203 | {"default": 80, "min": 8, "max": nodes.MAX_RESOLUTION, "step": 8}, 204 | ), 205 | "frame_overlap": ( 206 | "INT", 207 | {"default": 16, "min": 16, "max": 128, "step": 8}, 208 | ), 209 | "guider": ("GUIDER",), 210 | "sampler": ("SAMPLER",), 211 | "sigmas": ("SIGMAS",), 212 | "noise": ("NOISE",), 213 | }, 214 | } 215 | 216 | RETURN_TYPES = ("LATENT",) 217 | RETURN_NAMES = ("denoised_output",) 218 | FUNCTION = "sample" 219 | CATEGORY = "sampling" 220 | 221 | def sample( 222 | self, 223 | model, 224 | vae, 225 | latents, 226 | num_new_frames, 227 | frame_overlap, 228 | guider, 229 | sampler, 230 | sigmas, 231 | noise, 232 | ): 233 | 234 | try: 235 | positive, negative = guider.raw_conds 236 | except AttributeError: 237 | raise ValueError( 238 | "Guider does not have raw conds, cannot use it as a guider. " 239 | "Please use STGGuiderAdvanced." 240 | ) 241 | 242 | samples = latents["samples"] 243 | batch, channels, frames, height, width = samples.shape 244 | time_scale_factor, width_scale_factor, height_scale_factor = ( 245 | vae.downscale_index_formula 246 | ) 247 | overlap = frame_overlap // time_scale_factor 248 | 249 | (last_overlap_latents,) = LTXVSelectLatents().select_latents( 250 | latents, -overlap, -1 251 | ) 252 | 253 | new_latents = EmptyLTXVLatentVideo().generate( 254 | width=width * width_scale_factor, 255 | height=height * height_scale_factor, 256 | length=overlap * time_scale_factor + num_new_frames, 257 | batch_size=1, 258 | )[0] 259 | print("new_latents shape: ", new_latents["samples"].shape) 260 | ( 261 | positive, 262 | negative, 263 | new_latents, 264 | ) = LTXVAddLatentGuide().generate( 265 | vae=vae, 266 | positive=positive, 267 | negative=negative, 268 | latent=new_latents, 269 | guiding_latent=last_overlap_latents, 270 | latent_idx=0, 271 | strength=1.0, 272 | ) 273 | 274 | guider = copy.copy(guider) 275 | guider.set_conds(positive, negative) 276 | 277 | # Denoise the latent video 278 | (output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample( 279 | noise=noise, 280 | guider=guider, 281 | sampler=sampler, 282 | sigmas=sigmas, 283 | latent_image=new_latents, 284 | ) 285 | 286 | # Clean up guides if image conditioning was used 287 | print("before guide crop", denoised_output_latents["samples"].shape) 288 | positive, negative, denoised_output_latents = LTXVCropGuides().crop( 289 | positive=positive, 290 | negative=negative, 291 | latent=denoised_output_latents, 292 | ) 293 | print("after guide crop", denoised_output_latents["samples"].shape) 294 | 295 | # drop first output latent as it's a reinterpreted 8-frame latent understood as a 1-frame latent 296 | truncated_denoised_output_latents = LTXVSelectLatents().select_latents( 297 | denoised_output_latents, 1, -1 298 | )[0] 299 | # Fuse new frames with old ones by calling LinearOverlapLatentTransition 300 | (latents,) = LinearOverlapLatentTransition().process( 301 | latents, truncated_denoised_output_latents, overlap - 1, axis=2 302 | ) 303 | print("latents shape after linear overlap blend: ", latents["samples"].shape) 304 | return (latents,) 305 | -------------------------------------------------------------------------------- /example_workflows/low_level/end.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/end.jpg -------------------------------------------------------------------------------- /example_workflows/low_level/fox.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/fox.jpg -------------------------------------------------------------------------------- /example_workflows/low_level/jeep.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5b8ac2148bdf1092f4342eb4f08961458a9e22743e8896f5dbcf749166f4f526 3 | size 1582922 4 | -------------------------------------------------------------------------------- /example_workflows/low_level/ltxvideo-i2v-distilled.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "d8be2e45-0fd2-4f4a-8065-469fdf4bc018", 3 | "revision": 0, 4 | "last_node_id": 109, 5 | "last_link_id": 268, 6 | "nodes": [ 7 | { 8 | "id": 38, 9 | "type": "CLIPLoader", 10 | "pos": [ 11 | 60, 12 | 190 13 | ], 14 | "size": [ 15 | 315, 16 | 98 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "name": "CLIP", 25 | "type": "CLIP", 26 | "slot_index": 0, 27 | "links": [ 28 | 74, 29 | 75 30 | ] 31 | } 32 | ], 33 | "properties": { 34 | "cnr_id": "comfy-core", 35 | "ver": "0.3.28", 36 | "Node name for S&R": "CLIPLoader" 37 | }, 38 | "widgets_values": [ 39 | "t5xxl_fp16.safetensors", 40 | "ltxv", 41 | "default" 42 | ] 43 | }, 44 | { 45 | "id": 8, 46 | "type": "VAEDecode", 47 | "pos": [ 48 | 1800, 49 | 430 50 | ], 51 | "size": [ 52 | 210, 53 | 46 54 | ], 55 | "flags": {}, 56 | "order": 16, 57 | "mode": 0, 58 | "inputs": [ 59 | { 60 | "name": "samples", 61 | "type": "LATENT", 62 | "link": 255 63 | }, 64 | { 65 | "name": "vae", 66 | "type": "VAE", 67 | "link": 87 68 | } 69 | ], 70 | "outputs": [ 71 | { 72 | "name": "IMAGE", 73 | "type": "IMAGE", 74 | "slot_index": 0, 75 | "links": [ 76 | 261 77 | ] 78 | } 79 | ], 80 | "properties": { 81 | "cnr_id": "comfy-core", 82 | "ver": "0.3.28", 83 | "Node name for S&R": "VAEDecode" 84 | }, 85 | "widgets_values": [] 86 | }, 87 | { 88 | "id": 103, 89 | "type": "VHS_VideoCombine", 90 | "pos": [ 91 | 1789.6087646484375, 92 | 548.7745361328125 93 | ], 94 | "size": [ 95 | 315, 96 | 545 97 | ], 98 | "flags": {}, 99 | "order": 17, 100 | "mode": 0, 101 | "inputs": [ 102 | { 103 | "name": "images", 104 | "shape": 7, 105 | "type": "IMAGE", 106 | "link": 261 107 | }, 108 | { 109 | "name": "audio", 110 | "shape": 7, 111 | "type": "AUDIO", 112 | "link": null 113 | }, 114 | { 115 | "name": "meta_batch", 116 | "shape": 7, 117 | "type": "VHS_BatchManager", 118 | "link": null 119 | }, 120 | { 121 | "name": "vae", 122 | "shape": 7, 123 | "type": "VAE", 124 | "link": null 125 | } 126 | ], 127 | "outputs": [ 128 | { 129 | "name": "Filenames", 130 | "type": "VHS_FILENAMES", 131 | "links": null 132 | } 133 | ], 134 | "properties": { 135 | "cnr_id": "comfyui-videohelpersuite", 136 | "ver": "972c87da577b47211c4e9aeed30dc38c7bae607f", 137 | "Node name for S&R": "VHS_VideoCombine" 138 | }, 139 | "widgets_values": { 140 | "frame_rate": 24, 141 | "loop_count": 0, 142 | "filename_prefix": "ltxv", 143 | "format": "video/h264-mp4", 144 | "pix_fmt": "yuv420p", 145 | "crf": 20, 146 | "save_metadata": false, 147 | "trim_to_audio": false, 148 | "pingpong": false, 149 | "save_output": true, 150 | "videopreview": { 151 | "hidden": false, 152 | "paused": false, 153 | "params": { 154 | "filename": "ltxv_00058.mp4", 155 | "subfolder": "", 156 | "type": "output", 157 | "format": "video/h264-mp4", 158 | "frame_rate": 24, 159 | "workflow": "ltxv_00058.png", 160 | "fullpath": "C:\\Users\\Zeev\\Projects\\ComfyUI\\output\\ltxv_00058.mp4" 161 | } 162 | } 163 | } 164 | }, 165 | { 166 | "id": 82, 167 | "type": "LTXVPreprocess", 168 | "pos": [ 169 | 488.92791748046875, 170 | 629.9364624023438 171 | ], 172 | "size": [ 173 | 275.9266662597656, 174 | 58 175 | ], 176 | "flags": {}, 177 | "order": 10, 178 | "mode": 0, 179 | "inputs": [ 180 | { 181 | "name": "image", 182 | "type": "IMAGE", 183 | "link": 226 184 | } 185 | ], 186 | "outputs": [ 187 | { 188 | "name": "output_image", 189 | "type": "IMAGE", 190 | "slot_index": 0, 191 | "links": [ 192 | 248 193 | ] 194 | } 195 | ], 196 | "properties": { 197 | "cnr_id": "comfy-core", 198 | "ver": "0.3.28", 199 | "Node name for S&R": "LTXVPreprocess" 200 | }, 201 | "widgets_values": [ 202 | 38 203 | ] 204 | }, 205 | { 206 | "id": 7, 207 | "type": "CLIPTextEncode", 208 | "pos": [ 209 | 420, 210 | 390 211 | ], 212 | "size": [ 213 | 425.27801513671875, 214 | 180.6060791015625 215 | ], 216 | "flags": {}, 217 | "order": 9, 218 | "mode": 0, 219 | "inputs": [ 220 | { 221 | "name": "clip", 222 | "type": "CLIP", 223 | "link": 75 224 | } 225 | ], 226 | "outputs": [ 227 | { 228 | "name": "CONDITIONING", 229 | "type": "CONDITIONING", 230 | "slot_index": 0, 231 | "links": [ 232 | 240 233 | ] 234 | } 235 | ], 236 | "title": "CLIP Text Encode (Negative Prompt)", 237 | "properties": { 238 | "cnr_id": "comfy-core", 239 | "ver": "0.3.28", 240 | "Node name for S&R": "CLIPTextEncode" 241 | }, 242 | "widgets_values": [ 243 | "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly" 244 | ], 245 | "color": "#322", 246 | "bgcolor": "#533" 247 | }, 248 | { 249 | "id": 78, 250 | "type": "LoadImage", 251 | "pos": [ 252 | 40, 253 | 630 254 | ], 255 | "size": [ 256 | 385.15606689453125, 257 | 333.3305358886719 258 | ], 259 | "flags": {}, 260 | "order": 1, 261 | "mode": 0, 262 | "inputs": [], 263 | "outputs": [ 264 | { 265 | "name": "IMAGE", 266 | "type": "IMAGE", 267 | "slot_index": 0, 268 | "links": [ 269 | 226 270 | ] 271 | }, 272 | { 273 | "name": "MASK", 274 | "type": "MASK", 275 | "links": null 276 | } 277 | ], 278 | "properties": { 279 | "cnr_id": "comfy-core", 280 | "ver": "0.3.28", 281 | "Node name for S&R": "LoadImage" 282 | }, 283 | "widgets_values": [ 284 | "fox.jpg", 285 | "image", 286 | "" 287 | ] 288 | }, 289 | { 290 | "id": 6, 291 | "type": "CLIPTextEncode", 292 | "pos": [ 293 | 420, 294 | 180 295 | ], 296 | "size": [ 297 | 422.84503173828125, 298 | 164.31304931640625 299 | ], 300 | "flags": {}, 301 | "order": 8, 302 | "mode": 0, 303 | "inputs": [ 304 | { 305 | "name": "clip", 306 | "type": "CLIP", 307 | "link": 74 308 | } 309 | ], 310 | "outputs": [ 311 | { 312 | "name": "CONDITIONING", 313 | "type": "CONDITIONING", 314 | "slot_index": 0, 315 | "links": [ 316 | 239 317 | ] 318 | } 319 | ], 320 | "title": "CLIP Text Encode (Positive Prompt)", 321 | "properties": { 322 | "cnr_id": "comfy-core", 323 | "ver": "0.3.28", 324 | "Node name for S&R": "CLIPTextEncode" 325 | }, 326 | "widgets_values": [ 327 | "A red fox moving gracefully, its russet coat vibrant against the white landscape, leaving perfect star-shaped prints behind as steam rises from its breath in the crisp winter air. The scene is wrapped in snow-muffled silence, broken only by the gentle murmur of water still flowing beneath the ice." 328 | ], 329 | "color": "#232", 330 | "bgcolor": "#353" 331 | }, 332 | { 333 | "id": 76, 334 | "type": "Note", 335 | "pos": [ 336 | 36.85359573364258, 337 | 360.48809814453125 338 | ], 339 | "size": [ 340 | 360, 341 | 200 342 | ], 343 | "flags": {}, 344 | "order": 2, 345 | "mode": 0, 346 | "inputs": [], 347 | "outputs": [], 348 | "properties": {}, 349 | "widgets_values": [ 350 | "While LTXV-2b model prefers long descriptive prompts, this version supports experimentation with broader prompting styles." 351 | ], 352 | "color": "#432", 353 | "bgcolor": "#653" 354 | }, 355 | { 356 | "id": 69, 357 | "type": "LTXVConditioning", 358 | "pos": [ 359 | 1183.358642578125, 360 | 292.5882873535156 361 | ], 362 | "size": [ 363 | 223.8660125732422, 364 | 78 365 | ], 366 | "flags": {}, 367 | "order": 13, 368 | "mode": 0, 369 | "inputs": [ 370 | { 371 | "name": "positive", 372 | "type": "CONDITIONING", 373 | "link": 245 374 | }, 375 | { 376 | "name": "negative", 377 | "type": "CONDITIONING", 378 | "link": 246 379 | } 380 | ], 381 | "outputs": [ 382 | { 383 | "name": "positive", 384 | "type": "CONDITIONING", 385 | "slot_index": 0, 386 | "links": [ 387 | 256 388 | ] 389 | }, 390 | { 391 | "name": "negative", 392 | "type": "CONDITIONING", 393 | "slot_index": 1, 394 | "links": [ 395 | 257 396 | ] 397 | } 398 | ], 399 | "properties": { 400 | "cnr_id": "comfy-core", 401 | "ver": "0.3.28", 402 | "Node name for S&R": "LTXVConditioning" 403 | }, 404 | "widgets_values": [ 405 | 24.000000000000004 406 | ] 407 | }, 408 | { 409 | "id": 101, 410 | "type": "SamplerCustomAdvanced", 411 | "pos": [ 412 | 1780, 413 | 270 414 | ], 415 | "size": [ 416 | 236.8000030517578, 417 | 106 418 | ], 419 | "flags": {}, 420 | "order": 15, 421 | "mode": 0, 422 | "inputs": [ 423 | { 424 | "name": "noise", 425 | "type": "NOISE", 426 | "link": 260 427 | }, 428 | { 429 | "name": "guider", 430 | "type": "GUIDER", 431 | "link": 252 432 | }, 433 | { 434 | "name": "sampler", 435 | "type": "SAMPLER", 436 | "link": 253 437 | }, 438 | { 439 | "name": "sigmas", 440 | "type": "SIGMAS", 441 | "link": 268 442 | }, 443 | { 444 | "name": "latent_image", 445 | "type": "LATENT", 446 | "link": 259 447 | } 448 | ], 449 | "outputs": [ 450 | { 451 | "name": "output", 452 | "type": "LATENT", 453 | "slot_index": 0, 454 | "links": [] 455 | }, 456 | { 457 | "name": "denoised_output", 458 | "type": "LATENT", 459 | "slot_index": 1, 460 | "links": [ 461 | 255 462 | ] 463 | } 464 | ], 465 | "properties": { 466 | "cnr_id": "comfy-core", 467 | "ver": "0.3.15", 468 | "Node name for S&R": "SamplerCustomAdvanced" 469 | }, 470 | "widgets_values": [] 471 | }, 472 | { 473 | "id": 95, 474 | "type": "LTXVImgToVideo", 475 | "pos": [ 476 | 900, 477 | 290 478 | ], 479 | "size": [ 480 | 210, 481 | 190 482 | ], 483 | "flags": {}, 484 | "order": 12, 485 | "mode": 0, 486 | "inputs": [ 487 | { 488 | "name": "positive", 489 | "type": "CONDITIONING", 490 | "link": 239 491 | }, 492 | { 493 | "name": "negative", 494 | "type": "CONDITIONING", 495 | "link": 240 496 | }, 497 | { 498 | "name": "vae", 499 | "type": "VAE", 500 | "link": 250 501 | }, 502 | { 503 | "name": "image", 504 | "type": "IMAGE", 505 | "link": 248 506 | } 507 | ], 508 | "outputs": [ 509 | { 510 | "name": "positive", 511 | "type": "CONDITIONING", 512 | "slot_index": 0, 513 | "links": [ 514 | 245 515 | ] 516 | }, 517 | { 518 | "name": "negative", 519 | "type": "CONDITIONING", 520 | "slot_index": 1, 521 | "links": [ 522 | 246 523 | ] 524 | }, 525 | { 526 | "name": "latent", 527 | "type": "LATENT", 528 | "slot_index": 2, 529 | "links": [ 530 | 259 531 | ] 532 | } 533 | ], 534 | "properties": { 535 | "cnr_id": "comfy-core", 536 | "ver": "0.3.28", 537 | "Node name for S&R": "LTXVImgToVideo" 538 | }, 539 | "widgets_values": [ 540 | 768, 541 | 512, 542 | 97, 543 | 1 544 | ] 545 | }, 546 | { 547 | "id": 100, 548 | "type": "StringToFloatList", 549 | "pos": [ 550 | 1116.0089111328125, 551 | 604.5989379882812 552 | ], 553 | "size": [ 554 | 395.74224853515625, 555 | 88 556 | ], 557 | "flags": {}, 558 | "order": 3, 559 | "mode": 0, 560 | "inputs": [], 561 | "outputs": [ 562 | { 563 | "name": "FLOAT", 564 | "type": "FLOAT", 565 | "links": [ 566 | 251 567 | ] 568 | } 569 | ], 570 | "properties": { 571 | "cnr_id": "comfyui-kjnodes", 572 | "ver": "0addfc6101f7a834c7fb6e0a1b26529360ab5350", 573 | "Node name for S&R": "StringToFloatList" 574 | }, 575 | "widgets_values": [ 576 | "1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0" 577 | ], 578 | "color": "#223", 579 | "bgcolor": "#335" 580 | }, 581 | { 582 | "id": 105, 583 | "type": "Note", 584 | "pos": [ 585 | 1126.7423095703125, 586 | 424.2294616699219 587 | ], 588 | "size": [ 589 | 335.8657531738281, 590 | 106.6832046508789 591 | ], 592 | "flags": {}, 593 | "order": 4, 594 | "mode": 0, 595 | "inputs": [], 596 | "outputs": [], 597 | "properties": {}, 598 | "widgets_values": [ 599 | "Distilled model expects the following sigma schedule:\n1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0\n\n\nEuler_ancestral is the recommended default sampler." 600 | ], 601 | "color": "#432", 602 | "bgcolor": "#653" 603 | }, 604 | { 605 | "id": 97, 606 | "type": "KSamplerSelect", 607 | "pos": [ 608 | 1477.55615234375, 609 | 427.04132080078125 610 | ], 611 | "size": [ 612 | 275.599365234375, 613 | 58 614 | ], 615 | "flags": {}, 616 | "order": 5, 617 | "mode": 0, 618 | "inputs": [], 619 | "outputs": [ 620 | { 621 | "name": "SAMPLER", 622 | "type": "SAMPLER", 623 | "slot_index": 0, 624 | "links": [ 625 | 253 626 | ] 627 | } 628 | ], 629 | "properties": { 630 | "cnr_id": "comfy-core", 631 | "ver": "0.3.15", 632 | "Node name for S&R": "KSamplerSelect" 633 | }, 634 | "widgets_values": [ 635 | "euler_ancestral" 636 | ] 637 | }, 638 | { 639 | "id": 99, 640 | "type": "CFGGuider", 641 | "pos": [ 642 | 1536.1513671875, 643 | 280 644 | ], 645 | "size": [ 646 | 210, 647 | 98 648 | ], 649 | "flags": {}, 650 | "order": 14, 651 | "mode": 0, 652 | "inputs": [ 653 | { 654 | "name": "model", 655 | "type": "MODEL", 656 | "link": 258 657 | }, 658 | { 659 | "name": "positive", 660 | "type": "CONDITIONING", 661 | "link": 256 662 | }, 663 | { 664 | "name": "negative", 665 | "type": "CONDITIONING", 666 | "link": 257 667 | } 668 | ], 669 | "outputs": [ 670 | { 671 | "name": "GUIDER", 672 | "type": "GUIDER", 673 | "slot_index": 0, 674 | "links": [ 675 | 252 676 | ] 677 | } 678 | ], 679 | "properties": { 680 | "cnr_id": "comfy-core", 681 | "ver": "0.3.26", 682 | "Node name for S&R": "CFGGuider" 683 | }, 684 | "widgets_values": [ 685 | 1 686 | ] 687 | }, 688 | { 689 | "id": 102, 690 | "type": "RandomNoise", 691 | "pos": [ 692 | 1524.6190185546875, 693 | 138.8462371826172 694 | ], 695 | "size": [ 696 | 210, 697 | 82 698 | ], 699 | "flags": {}, 700 | "order": 6, 701 | "mode": 0, 702 | "inputs": [], 703 | "outputs": [ 704 | { 705 | "name": "NOISE", 706 | "type": "NOISE", 707 | "links": [ 708 | 260 709 | ] 710 | } 711 | ], 712 | "properties": { 713 | "cnr_id": "comfy-core", 714 | "ver": "0.3.28", 715 | "Node name for S&R": "RandomNoise" 716 | }, 717 | "widgets_values": [ 718 | 45, 719 | "fixed" 720 | ] 721 | }, 722 | { 723 | "id": 98, 724 | "type": "FloatToSigmas", 725 | "pos": [ 726 | 1574.70166015625, 727 | 620.4625244140625 728 | ], 729 | "size": [ 730 | 210, 731 | 58 732 | ], 733 | "flags": { 734 | "collapsed": true 735 | }, 736 | "order": 11, 737 | "mode": 0, 738 | "inputs": [ 739 | { 740 | "name": "float_list", 741 | "type": "FLOAT", 742 | "widget": { 743 | "name": "float_list" 744 | }, 745 | "link": null 746 | }, 747 | { 748 | "name": "float_list", 749 | "type": "FLOAT", 750 | "widget": { 751 | "name": "float_list" 752 | }, 753 | "link": 251 754 | } 755 | ], 756 | "outputs": [ 757 | { 758 | "name": "SIGMAS", 759 | "type": "SIGMAS", 760 | "links": [ 761 | 268 762 | ] 763 | } 764 | ], 765 | "properties": { 766 | "cnr_id": "comfyui-kjnodes", 767 | "ver": "0addfc6101f7a834c7fb6e0a1b26529360ab5350", 768 | "Node name for S&R": "FloatToSigmas" 769 | }, 770 | "widgets_values": [ 771 | 0 772 | ] 773 | }, 774 | { 775 | "id": 44, 776 | "type": "CheckpointLoaderSimple", 777 | "pos": [ 778 | 899.0839233398438, 779 | 113.54173278808594 780 | ], 781 | "size": [ 782 | 315, 783 | 98 784 | ], 785 | "flags": {}, 786 | "order": 7, 787 | "mode": 0, 788 | "inputs": [], 789 | "outputs": [ 790 | { 791 | "name": "MODEL", 792 | "type": "MODEL", 793 | "slot_index": 0, 794 | "links": [ 795 | 258 796 | ] 797 | }, 798 | { 799 | "name": "CLIP", 800 | "type": "CLIP", 801 | "links": null 802 | }, 803 | { 804 | "name": "VAE", 805 | "type": "VAE", 806 | "slot_index": 2, 807 | "links": [ 808 | 87, 809 | 250 810 | ] 811 | } 812 | ], 813 | "properties": { 814 | "cnr_id": "comfy-core", 815 | "ver": "0.3.28", 816 | "Node name for S&R": "CheckpointLoaderSimple" 817 | }, 818 | "widgets_values": [ 819 | "ltxv-2b-0.9.6-distilled-04-25.safetensors" 820 | ] 821 | } 822 | ], 823 | "links": [ 824 | [ 825 | 74, 826 | 38, 827 | 0, 828 | 6, 829 | 0, 830 | "CLIP" 831 | ], 832 | [ 833 | 75, 834 | 38, 835 | 0, 836 | 7, 837 | 0, 838 | "CLIP" 839 | ], 840 | [ 841 | 87, 842 | 44, 843 | 2, 844 | 8, 845 | 1, 846 | "VAE" 847 | ], 848 | [ 849 | 226, 850 | 78, 851 | 0, 852 | 82, 853 | 0, 854 | "IMAGE" 855 | ], 856 | [ 857 | 239, 858 | 6, 859 | 0, 860 | 95, 861 | 0, 862 | "CONDITIONING" 863 | ], 864 | [ 865 | 240, 866 | 7, 867 | 0, 868 | 95, 869 | 1, 870 | "CONDITIONING" 871 | ], 872 | [ 873 | 245, 874 | 95, 875 | 0, 876 | 69, 877 | 0, 878 | "CONDITIONING" 879 | ], 880 | [ 881 | 246, 882 | 95, 883 | 1, 884 | 69, 885 | 1, 886 | "CONDITIONING" 887 | ], 888 | [ 889 | 248, 890 | 82, 891 | 0, 892 | 95, 893 | 3, 894 | "IMAGE" 895 | ], 896 | [ 897 | 250, 898 | 44, 899 | 2, 900 | 95, 901 | 2, 902 | "VAE" 903 | ], 904 | [ 905 | 251, 906 | 100, 907 | 0, 908 | 98, 909 | 1, 910 | "FLOAT" 911 | ], 912 | [ 913 | 252, 914 | 99, 915 | 0, 916 | 101, 917 | 1, 918 | "GUIDER" 919 | ], 920 | [ 921 | 253, 922 | 97, 923 | 0, 924 | 101, 925 | 2, 926 | "SAMPLER" 927 | ], 928 | [ 929 | 255, 930 | 101, 931 | 1, 932 | 8, 933 | 0, 934 | "LATENT" 935 | ], 936 | [ 937 | 256, 938 | 69, 939 | 0, 940 | 99, 941 | 1, 942 | "CONDITIONING" 943 | ], 944 | [ 945 | 257, 946 | 69, 947 | 1, 948 | 99, 949 | 2, 950 | "CONDITIONING" 951 | ], 952 | [ 953 | 258, 954 | 44, 955 | 0, 956 | 99, 957 | 0, 958 | "MODEL" 959 | ], 960 | [ 961 | 259, 962 | 95, 963 | 2, 964 | 101, 965 | 4, 966 | "LATENT" 967 | ], 968 | [ 969 | 260, 970 | 102, 971 | 0, 972 | 101, 973 | 0, 974 | "NOISE" 975 | ], 976 | [ 977 | 261, 978 | 8, 979 | 0, 980 | 103, 981 | 0, 982 | "IMAGE" 983 | ], 984 | [ 985 | 268, 986 | 98, 987 | 0, 988 | 101, 989 | 3, 990 | "SIGMAS" 991 | ] 992 | ], 993 | "groups": [], 994 | "config": {}, 995 | "extra": { 996 | "ds": { 997 | "scale": 0.6303940863129158, 998 | "offset": [ 999 | 627.9092558924768, 1000 | 478.72279841993054 1001 | ] 1002 | }, 1003 | "VHS_latentpreview": true, 1004 | "VHS_latentpreviewrate": 0, 1005 | "VHS_MetadataImage": true, 1006 | "VHS_KeepIntermediate": true 1007 | }, 1008 | "version": 0.4 1009 | } 1010 | -------------------------------------------------------------------------------- /example_workflows/low_level/moto.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8681ff86edb43d0a12472175d5503a590910b24d5ec5d5d5b78eed079960eac0 3 | size 998115 4 | -------------------------------------------------------------------------------- /example_workflows/low_level/shrek2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/shrek2.jpg -------------------------------------------------------------------------------- /example_workflows/low_level/start.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/start.jpg -------------------------------------------------------------------------------- /example_workflows/tricks/ltxvideo-flow-edit.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:041f3820e0ab4109137d743b3f6b8e2dea0cddcad194aac08d3baec7571200f9 3 | size 1841100 4 | -------------------------------------------------------------------------------- /example_workflows/tricks/ltxvideo-rf-edit.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9d2b002f12155d0b929366ad38561680a5bbaecea1925e5655658b46d281baf9 3 | size 2441562 4 | -------------------------------------------------------------------------------- /example_workflows/tricks/ref.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3726ed083a5c52873d29adab87e8fe0a65c7e390beaf8dfcf23e5c9091d6da73 3 | size 521574 4 | -------------------------------------------------------------------------------- /example_workflows/tricks/shot.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:01d44bb728e04879a5513a2c33bdd47d443bf901b33d3c1445918e75ba8b78b1 3 | size 678605 4 | -------------------------------------------------------------------------------- /example_workflows/tricks/shot2.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5432150d5477d19586b8d8049ca214038f886d3def9117830cf2af58879c48b5 3 | size 31584408 4 | -------------------------------------------------------------------------------- /example_workflows/tricks/shrek2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/tricks/shrek2.jpg -------------------------------------------------------------------------------- /example_workflows/tricks/shrek3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/tricks/shrek3.jpg -------------------------------------------------------------------------------- /film_grain.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import comfy 4 | import torch 5 | 6 | from .nodes_registry import comfy_node 7 | 8 | 9 | @comfy_node( 10 | name="LTXVFilmGrain", 11 | ) 12 | class LTXVFilmGrain: 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return { 16 | "required": { 17 | "images": ("IMAGE",), 18 | "grain_intensity": ( 19 | "FLOAT", 20 | {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}, 21 | ), 22 | "saturation": ( 23 | "FLOAT", 24 | {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}, 25 | ), 26 | }, 27 | "optional": {}, 28 | } 29 | 30 | RETURN_TYPES = ("IMAGE",) 31 | FUNCTION = "add_film_grain" 32 | CATEGORY = "effects" 33 | DESCRIPTION = "Adds film grain to the image." 34 | 35 | def add_film_grain( 36 | self, images: torch.Tensor, grain_intensity: float, saturation: float 37 | ) -> Tuple[torch.Tensor]: 38 | if grain_intensity < 0 or grain_intensity > 1: 39 | raise ValueError("Grain intensity must be between 0 and 1.") 40 | 41 | device = comfy.model_management.get_torch_device() 42 | images = images.to(device) 43 | 44 | grain = torch.randn_like(images, device=device) 45 | grain[:, :, :, 0] *= 2 46 | grain[:, :, :, 2] *= 3 47 | grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat( 48 | 1, 1, 1, 3 49 | ) * (1 - saturation) 50 | 51 | # Blend the grain with the image 52 | noised_images = images + grain_intensity * grain 53 | noised_images.clamp_(0, 1) 54 | noised_images = noised_images.to(comfy.model_management.intermediate_device()) 55 | 56 | return (noised_images,) 57 | -------------------------------------------------------------------------------- /guide.py: -------------------------------------------------------------------------------- 1 | import comfy 2 | import comfy_extras.nodes_lt as nodes_lt 3 | import comfy_extras.nodes_post_processing as post_processing 4 | import nodes 5 | 6 | from .nodes_registry import comfy_node 7 | 8 | 9 | def blur_internal(image, blur_radius): 10 | if blur_radius > 0: 11 | # https://docs.opencv.org/2.4/modules/imgproc/doc/filtering.html#getgaussiankernel 12 | # sigma = 0.3 * blur_radius + 0.5 is what is recommended in the OpenCV doc for the 13 | # relationship between sigma and kernel size 2*blur_radius + 1, however we want somewhat weaker 14 | # blurring, so we use 0.3 * blur_radius instead, reducing the sigma value by 0.5 15 | sigma = 0.3 * blur_radius 16 | image = post_processing.Blur().blur(image, blur_radius, sigma)[0] 17 | return image 18 | 19 | 20 | @comfy_node(name="LTXVAddGuideAdvanced") 21 | class LTXVAddGuideAdvanced: 22 | @classmethod 23 | def INPUT_TYPES(s): 24 | return { 25 | "required": { 26 | "positive": ("CONDITIONING",), 27 | "negative": ("CONDITIONING",), 28 | "vae": ("VAE",), 29 | "latent": ("LATENT",), 30 | "image": ("IMAGE",), 31 | "frame_idx": ( 32 | "INT", 33 | { 34 | "default": 0, 35 | "min": -9999, 36 | "max": 9999, 37 | "tooltip": "Frame index to start the conditioning at. For single-frame images or " 38 | "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " 39 | "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " 40 | "the nearest multiple of 8. Negative values are counted from the end of the video.", 41 | }, 42 | ), 43 | "strength": ( 44 | "FLOAT", 45 | { 46 | "default": 1.0, 47 | "min": 0.0, 48 | "max": 1.0, 49 | "tooltip": "Strength of the conditioning. Higher values will make the conditioning more exact.", 50 | }, 51 | ), 52 | "crf": ( 53 | "INT", 54 | { 55 | "default": 29, 56 | "min": 0, 57 | "max": 51, 58 | "step": 1, 59 | "tooltip": "CRF value for the video. Higher values mean more motion, lower values mean higher quality.", 60 | }, 61 | ), 62 | "blur_radius": ( 63 | "INT", 64 | { 65 | "default": 0, 66 | "min": 0, 67 | "max": 7, 68 | "step": 1, 69 | "tooltip": "Blur kernel radius size. Higher values mean more motion, lower values mean higher quality.", 70 | }, 71 | ), 72 | "interpolation": ( 73 | [ 74 | "lanczos", 75 | "bislerp", 76 | "nearest", 77 | "bilinear", 78 | "bicubic", 79 | "area", 80 | "nearest-exact", 81 | ], 82 | {"default": "lanczos"}, 83 | ), 84 | "crop": (["center", "disabled"], {"default": "disabled"}), 85 | } 86 | } 87 | 88 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") 89 | RETURN_NAMES = ("positive", "negative", "latent") 90 | 91 | CATEGORY = "conditioning/video_models" 92 | FUNCTION = "generate" 93 | 94 | DESCRIPTION = ( 95 | "Adds a conditioning frame or a video at a specific frame index. " 96 | "This node is used to add a keyframe or a video segment which should appear in the " 97 | "generated video at a specified index. It resizes the image to the correct size and " 98 | "applies preprocessing to it." 99 | ) 100 | 101 | def generate( 102 | self, 103 | positive, 104 | negative, 105 | vae, 106 | latent, 107 | image, 108 | frame_idx, 109 | strength, 110 | crf, 111 | blur_radius, 112 | interpolation, 113 | crop, 114 | ): 115 | _, width_scale_factor, height_scale_factor = vae.downscale_index_formula 116 | width, height = ( 117 | latent["samples"].shape[4] * width_scale_factor, 118 | latent["samples"].shape[3] * height_scale_factor, 119 | ) 120 | image = ( 121 | comfy.utils.common_upscale( 122 | image.movedim(-1, 1), width, height, interpolation, crop=crop 123 | ) 124 | .movedim(1, -1) 125 | .clamp(0, 1) 126 | ) 127 | image = nodes_lt.LTXVPreprocess().preprocess(image, crf)[0] 128 | image = blur_internal(image, blur_radius) 129 | return nodes_lt.LTXVAddGuide().generate( 130 | positive=positive, 131 | negative=negative, 132 | vae=vae, 133 | latent=latent, 134 | image=image, 135 | frame_idx=frame_idx, 136 | strength=strength, 137 | ) 138 | 139 | 140 | @comfy_node(name="LTXVImgToVideoAdvanced") 141 | class LTXVImgToVideoAdvanced: 142 | @classmethod 143 | def INPUT_TYPES(s): 144 | return { 145 | "required": { 146 | "positive": ("CONDITIONING",), 147 | "negative": ("CONDITIONING",), 148 | "vae": ("VAE",), 149 | "image": ("IMAGE",), 150 | "width": ( 151 | "INT", 152 | { 153 | "default": 768, 154 | "min": 64, 155 | "max": nodes.MAX_RESOLUTION, 156 | "step": 32, 157 | }, 158 | ), 159 | "height": ( 160 | "INT", 161 | { 162 | "default": 512, 163 | "min": 64, 164 | "max": nodes.MAX_RESOLUTION, 165 | "step": 32, 166 | }, 167 | ), 168 | "length": ( 169 | "INT", 170 | {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}, 171 | ), 172 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), 173 | "crf": ( 174 | "INT", 175 | { 176 | "default": 29, 177 | "min": 0, 178 | "max": 51, 179 | "step": 1, 180 | "tooltip": "CRF value for the video. Higher values mean more motion, lower values mean higher quality.", 181 | }, 182 | ), 183 | "blur_radius": ( 184 | "INT", 185 | { 186 | "default": 0, 187 | "min": 0, 188 | "max": 7, 189 | "step": 1, 190 | "tooltip": "Blur kernel radius size. Higher values mean more motion, lower values mean higher quality.", 191 | }, 192 | ), 193 | "interpolation": ( 194 | [ 195 | "lanczos", 196 | "bislerp", 197 | "nearest", 198 | "bilinear", 199 | "bicubic", 200 | "area", 201 | "nearest-exact", 202 | ], 203 | {"default": "lanczos"}, 204 | ), 205 | "crop": (["center", "disabled"], {"default": "disabled"}), 206 | "strength": ("FLOAT", {"default": 0.9, "min": 0, "max": 1}), 207 | } 208 | } 209 | 210 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") 211 | RETURN_NAMES = ("positive", "negative", "latent") 212 | 213 | CATEGORY = "conditioning/video_models" 214 | FUNCTION = "generate" 215 | 216 | DESCRIPTION = ( 217 | "Adds a conditioning frame or a video at index 0. " 218 | "This node is used to add a keyframe or a video segment which should appear in the " 219 | "generated video at index 0. It resizes the image to the correct size " 220 | "and applies preprocessing to it." 221 | ) 222 | 223 | def generate( 224 | self, 225 | positive, 226 | negative, 227 | vae, 228 | image, 229 | width, 230 | height, 231 | length, 232 | batch_size, 233 | crf, 234 | blur_radius, 235 | interpolation, 236 | crop, 237 | strength, 238 | ): 239 | image = comfy.utils.common_upscale( 240 | image.movedim(-1, 1), width, height, interpolation, crop=crop 241 | ).movedim(1, -1) 242 | image = nodes_lt.LTXVPreprocess().preprocess(image, crf)[0] 243 | image = blur_internal(image, blur_radius) 244 | return nodes_lt.LTXVImgToVideo().generate( 245 | positive=positive, 246 | negative=negative, 247 | vae=vae, 248 | image=image, 249 | width=width, 250 | height=height, 251 | length=length, 252 | batch_size=batch_size, 253 | strength=strength, 254 | ) 255 | -------------------------------------------------------------------------------- /latent_adain.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from .nodes_registry import comfy_node 6 | 7 | 8 | @comfy_node(name="LTXVAdainLatent") 9 | class LTXVAdainLatent: 10 | def __init__(self): 11 | pass 12 | 13 | @classmethod 14 | def INPUT_TYPES(s): 15 | return { 16 | "required": { 17 | "latents": ("LATENT",), 18 | "reference": ("LATENT",), 19 | "factor": ( 20 | "FLOAT", 21 | { 22 | "default": 1.0, 23 | "min": -10.0, 24 | "max": 10.0, 25 | "step": 0.01, 26 | "round": 0.01, 27 | }, 28 | ), 29 | }, 30 | } 31 | 32 | RETURN_TYPES = ("LATENT",) 33 | FUNCTION = "batch_normalize" 34 | 35 | CATEGORY = "Lightricks/latents" 36 | 37 | def batch_normalize(self, latents, reference, factor): 38 | latents_copy = copy.deepcopy(latents) 39 | t = latents_copy["samples"] # B x C x F x H x W 40 | 41 | for i in range(t.size(0)): # batch 42 | for c in range(t.size(1)): # channel 43 | r_sd, r_mean = torch.std_mean( 44 | reference["samples"][i, c], dim=None 45 | ) # index by original dim order 46 | i_sd, i_mean = torch.std_mean(t[i, c], dim=None) 47 | 48 | t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean 49 | 50 | latents_copy["samples"] = torch.lerp(latents["samples"], t, factor) 51 | return (latents_copy,) 52 | -------------------------------------------------------------------------------- /latent_upsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import folder_paths 4 | import torch 5 | import torch.nn as nn 6 | from comfy import model_management 7 | from diffusers import ConfigMixin, ModelMixin 8 | from einops import rearrange 9 | 10 | from .nodes_registry import comfy_node 11 | 12 | 13 | class PixelShuffle3D(nn.Module): 14 | def __init__(self, upscale_factor: int = 2): 15 | super().__init__() 16 | self.r = upscale_factor 17 | 18 | def forward(self, x): 19 | b, c, f, h, w = x.shape 20 | r = self.r 21 | out_c = c // (r**3) 22 | x = x.view(b, out_c, r, r, r, f, h, w) 23 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) # (b, out_c, f, r, h, r, w, r) 24 | x = x.reshape(b, out_c, f * r, h * r, w * r) 25 | return x 26 | 27 | 28 | class PixelShuffle2D(nn.Module): 29 | def __init__(self, upscale_factor: int = 2): 30 | super().__init__() 31 | self.r = upscale_factor 32 | 33 | def forward(self, x): 34 | b, c, h, w = x.size() 35 | r = self.r 36 | out_c = c // (r * r) 37 | x = x.view(b, out_c, r, r, h, w) 38 | x = x.permute(0, 1, 4, 2, 5, 3) # (b, out_c, h, r, w, r) 39 | x = x.reshape(b, out_c, h * r, w * r) 40 | return x 41 | 42 | 43 | class PixelShuffle1D(nn.Module): 44 | def __init__(self, upscale_factor: int = 2): 45 | super().__init__() 46 | self.r = upscale_factor 47 | 48 | def forward(self, x): 49 | b, c, f, h, w = x.shape 50 | r = self.r 51 | out_c = c // r 52 | x = x.view(b, out_c, r, f, h, w) # [B, C//r, r, F, H, W] 53 | x = x.permute(0, 1, 3, 2, 4, 5) # [B, C//r, F, r, H, W] 54 | x = x.reshape(b, out_c, f * r, h, w) 55 | return x 56 | 57 | 58 | class ResBlock(nn.Module): 59 | def __init__( 60 | self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 61 | ): 62 | super().__init__() 63 | if mid_channels is None: 64 | mid_channels = channels 65 | 66 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d 67 | 68 | self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) 69 | self.norm1 = nn.GroupNorm(32, mid_channels) 70 | self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) 71 | self.norm2 = nn.GroupNorm(32, channels) 72 | self.activation = nn.SiLU() 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | residual = x 76 | x = self.conv1(x) 77 | x = self.norm1(x) 78 | x = self.activation(x) 79 | x = self.conv2(x) 80 | x = self.norm2(x) 81 | x = self.activation(x + residual) 82 | return x 83 | 84 | 85 | class LatentUpsampler(ModelMixin, ConfigMixin): 86 | """ 87 | Model to spatially upsample VAE latents. 88 | 89 | Args: 90 | in_channels (`int`): Number of channels in the input latent 91 | mid_channels (`int`): Number of channels in the middle layers 92 | num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) 93 | dims (`int`): Number of dimensions for convolutions (2 or 3) 94 | spatial_upsample (`bool`): Whether to spatially upsample the latent 95 | temporal_upsample (`bool`): Whether to temporally upsample the latent 96 | """ 97 | 98 | def __init__( 99 | self, 100 | in_channels: int = 128, 101 | mid_channels: int = 512, 102 | num_blocks_per_stage: int = 4, 103 | dims: int = 3, 104 | spatial_upsample: bool = True, 105 | temporal_upsample: bool = False, 106 | ): 107 | super().__init__() 108 | 109 | self.in_channels = in_channels 110 | self.mid_channels = mid_channels 111 | self.num_blocks_per_stage = num_blocks_per_stage 112 | self.dims = dims 113 | self.spatial_upsample = spatial_upsample 114 | self.temporal_upsample = temporal_upsample 115 | 116 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d 117 | 118 | self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) 119 | self.initial_norm = nn.GroupNorm(32, mid_channels) 120 | self.initial_activation = nn.SiLU() 121 | 122 | self.res_blocks = nn.ModuleList( 123 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] 124 | ) 125 | 126 | if spatial_upsample and temporal_upsample: 127 | self.upsampler = nn.Sequential( 128 | nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), 129 | PixelShuffle3D(2), 130 | ) 131 | elif spatial_upsample: 132 | self.upsampler = nn.Sequential( 133 | nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), 134 | PixelShuffle2D(2), 135 | ) 136 | elif temporal_upsample: 137 | self.upsampler = nn.Sequential( 138 | nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), 139 | PixelShuffle1D(2), 140 | ) 141 | else: 142 | raise ValueError( 143 | "Either spatial_upsample or temporal_upsample must be True" 144 | ) 145 | 146 | self.post_upsample_res_blocks = nn.ModuleList( 147 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] 148 | ) 149 | 150 | self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) 151 | 152 | def forward(self, latent: torch.Tensor) -> torch.Tensor: 153 | b, c, f, h, w = latent.shape 154 | 155 | if self.dims == 2: 156 | x = rearrange(latent, "b c f h w -> (b f) c h w") 157 | x = self.initial_conv(x) 158 | x = self.initial_norm(x) 159 | x = self.initial_activation(x) 160 | 161 | for block in self.res_blocks: 162 | x = block(x) 163 | 164 | x = self.upsampler(x) 165 | 166 | for block in self.post_upsample_res_blocks: 167 | x = block(x) 168 | 169 | x = self.final_conv(x) 170 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) 171 | else: 172 | x = self.initial_conv(latent) 173 | x = self.initial_norm(x) 174 | x = self.initial_activation(x) 175 | 176 | for block in self.res_blocks: 177 | x = block(x) 178 | 179 | if self.temporal_upsample: 180 | x = self.upsampler(x) 181 | x = x[:, :, 1:, :, :] 182 | else: 183 | x = rearrange(x, "b c f h w -> (b f) c h w") 184 | x = self.upsampler(x) 185 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) 186 | 187 | for block in self.post_upsample_res_blocks: 188 | x = block(x) 189 | 190 | x = self.final_conv(x) 191 | 192 | return x 193 | 194 | @classmethod 195 | def from_config(cls, config): 196 | return cls( 197 | in_channels=config.get("in_channels", 4), 198 | mid_channels=config.get("mid_channels", 128), 199 | num_blocks_per_stage=config.get("num_blocks_per_stage", 4), 200 | dims=config.get("dims", 2), 201 | spatial_upsample=config.get("spatial_upsample", True), 202 | temporal_upsample=config.get("temporal_upsample", False), 203 | ) 204 | 205 | def config(self): 206 | return { 207 | "_class_name": "LatentUpsampler", 208 | "in_channels": self.in_channels, 209 | "mid_channels": self.mid_channels, 210 | "num_blocks_per_stage": self.num_blocks_per_stage, 211 | "dims": self.dims, 212 | "spatial_upsample": self.spatial_upsample, 213 | "temporal_upsample": self.temporal_upsample, 214 | } 215 | 216 | def load_weights(self, weights_path: str) -> None: 217 | """ 218 | Load model weights from a .safetensors file and switch to evaluation mode. 219 | 220 | Args: 221 | weights_path (str): Path to the .safetensors file containing the model weights 222 | 223 | Raises: 224 | RuntimeError: If there are missing or unexpected keys in the state dict 225 | """ 226 | import safetensors.torch 227 | 228 | sd = safetensors.torch.load_file(weights_path) 229 | self.load_state_dict(sd, strict=False, assign=True) 230 | # Switch to evaluation mode 231 | self.eval() 232 | 233 | 234 | @comfy_node(name="LTXVLatentUpsampler") 235 | class LTXVLatentUpsampler: 236 | """ 237 | Upsamples a video latent by a factor of 2. 238 | """ 239 | 240 | @classmethod 241 | def INPUT_TYPES(s): 242 | return { 243 | "required": { 244 | "samples": ("LATENT",), 245 | "upscale_model": ("UPSCALE_MODEL",), 246 | "vae": ("VAE",), 247 | } 248 | } 249 | 250 | RETURN_TYPES = ("LATENT",) 251 | FUNCTION = "upsample_latent" 252 | CATEGORY = "latent/video" 253 | 254 | def upsample_latent( 255 | self, samples: dict, upscale_model: LatentUpsampler, vae 256 | ) -> tuple: 257 | """ 258 | Upsample the input latent using the provided model. 259 | 260 | Args: 261 | samples (dict): Input latent samples 262 | upscale_model (LatentUpsampler): Loaded upscale model 263 | 264 | Returns: 265 | tuple: Tuple containing the upsampled latent 266 | """ 267 | latents = samples["samples"] 268 | 269 | # Ensure latents are on the same device as the model 270 | if latents.device != upscale_model.device: 271 | latents = latents.to(upscale_model.device) 272 | latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) 273 | upsampled_latents = upscale_model(latents) 274 | upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( 275 | upsampled_latents 276 | ) 277 | upsampled_latents = upsampled_latents.to(model_management.intermediate_device()) 278 | return_dict = samples.copy() 279 | return_dict["samples"] = upsampled_latents 280 | return (return_dict,) 281 | 282 | 283 | @comfy_node(name="LTXVLatentUpsamplerModelLoader") 284 | class LTXVLatentUpsamplerModelLoader: 285 | """ 286 | Loads a latent upsampler model from a .safetensors file. 287 | """ 288 | 289 | @classmethod 290 | def INPUT_TYPES(s): 291 | return { 292 | "required": { 293 | "upscale_model": (folder_paths.get_filename_list("upscale_models"),), 294 | "spatial_upsample": ("BOOLEAN", {"default": True}), 295 | "temporal_upsample": ("BOOLEAN", {"default": False}), 296 | } 297 | } 298 | 299 | RETURN_TYPES = ("UPSCALE_MODEL",) 300 | FUNCTION = "load_model" 301 | CATEGORY = "latent/video" 302 | 303 | def load_model( 304 | self, upscale_model: str, spatial_upsample: bool, temporal_upsample: bool 305 | ) -> tuple: 306 | """ 307 | Load the upscale model from the specified file. 308 | 309 | Args: 310 | upscale_model (str): Name of the upscale model file 311 | 312 | Returns: 313 | tuple: Tuple containing the loaded model 314 | """ 315 | upscale_model_path = folder_paths.get_full_path("upscale_models", upscale_model) 316 | if upscale_model_path is None: 317 | raise ValueError(f"Upscale model {upscale_model} not found") 318 | 319 | try: 320 | latent_upsampler = LatentUpsampler( 321 | num_blocks_per_stage=4, 322 | dims=3, 323 | spatial_upsample=spatial_upsample, 324 | temporal_upsample=temporal_upsample, 325 | ) 326 | latent_upsampler.load_weights(upscale_model_path) 327 | except Exception as e: 328 | raise ValueError( 329 | f"Failed to initialize LatentUpsampler with this configuration: {str(e)}" 330 | ) 331 | 332 | latent_upsampler.eval() 333 | # Move model to appropriate device 334 | device = model_management.get_torch_device() 335 | latent_upsampler.to(device) 336 | 337 | return (latent_upsampler,) 338 | -------------------------------------------------------------------------------- /latents.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import comfy_extras.nodes_lt as nodes_lt 4 | import torch 5 | 6 | from .nodes_registry import comfy_node 7 | 8 | 9 | @comfy_node(name="LTXVSelectLatents") 10 | class LTXVSelectLatents: 11 | """ 12 | Selects a range of frames from a video latent. 13 | 14 | Features: 15 | - Supports positive and negative indexing 16 | - Preserves batch processing capabilities 17 | - Handles noise masks if present 18 | - Maintains 5D tensor format 19 | """ 20 | 21 | @classmethod 22 | def INPUT_TYPES(s): 23 | return { 24 | "required": { 25 | "samples": ("LATENT",), 26 | "start_index": ( 27 | "INT", 28 | {"default": 0, "min": -9999, "max": 9999, "step": 1}, 29 | ), 30 | "end_index": ( 31 | "INT", 32 | {"default": -1, "min": -9999, "max": 9999, "step": 1}, 33 | ), 34 | } 35 | } 36 | 37 | RETURN_TYPES = ("LATENT",) 38 | FUNCTION = "select_latents" 39 | CATEGORY = "latent/video" 40 | DESCRIPTION = ( 41 | "Selects a range of frames from the video latent. " 42 | "start_index and end_index define a closed interval (inclusive of both endpoints)." 43 | ) 44 | 45 | def select_latents(self, samples: dict, start_index: int, end_index: int) -> tuple: 46 | """ 47 | Selects a range of frames from the video latent. 48 | 49 | Args: 50 | samples (dict): Video latent dictionary 51 | start_index (int): Starting frame index (supports negative indexing) 52 | end_index (int): Ending frame index (supports negative indexing) 53 | 54 | Returns: 55 | tuple: Contains modified latent dictionary with selected frames 56 | 57 | Raises: 58 | ValueError: If indices are invalid 59 | """ 60 | try: 61 | s = samples.copy() 62 | video_latent = s["samples"] 63 | batch, channels, frames, height, width = video_latent.shape 64 | 65 | # Handle negative indices 66 | start_idx = frames + start_index if start_index < 0 else start_index 67 | end_idx = frames + end_index if end_index < 0 else end_index 68 | 69 | # Validate and clamp indices 70 | start_idx = max(0, min(start_idx, frames - 1)) 71 | end_idx = max(0, min(end_idx, frames - 1)) 72 | if start_idx > end_idx: 73 | start_idx = min(start_idx, end_idx) 74 | 75 | # Select frames while maintaining 5D format 76 | s["samples"] = video_latent[:, :, start_idx : end_idx + 1, :, :] 77 | 78 | # Handle noise mask if present 79 | if "noise_mask" in s: 80 | s["noise_mask"] = s["noise_mask"][:, :, start_idx : end_idx + 1, :, :] 81 | 82 | return (s,) 83 | 84 | except Exception as e: 85 | print(f"[LTXVSelectLatents] Error: {str(e)}") 86 | raise 87 | 88 | 89 | @comfy_node(name="LTXVAddLatents") 90 | class LTXVAddLatents: 91 | """ 92 | Concatenates two video latents along the frames dimension. 93 | 94 | Features: 95 | - Validates dimension compatibility 96 | - Handles device placement 97 | - Preserves noise masks with proper handling 98 | - Supports batch processing 99 | """ 100 | 101 | @classmethod 102 | def INPUT_TYPES(s): 103 | return { 104 | "required": { 105 | "latents1": ("LATENT",), 106 | "latents2": ("LATENT",), 107 | } 108 | } 109 | 110 | RETURN_TYPES = ("LATENT",) 111 | FUNCTION = "add_latents" 112 | CATEGORY = "latent/video" 113 | DESCRIPTION = ( 114 | "Concatenates two video latents along the frames dimension. " 115 | "latents1 and latents2 must have the same dimensions except for the frames dimension." 116 | ) 117 | 118 | def add_latents( 119 | self, latents1: torch.Tensor, latents2: torch.Tensor 120 | ) -> torch.Tensor: 121 | """ 122 | Concatenates two video latents along the frames dimension. 123 | 124 | Args: 125 | latents1 (dict): First video latent dictionary 126 | latents2 (dict): Second video latent dictionary 127 | 128 | Returns: 129 | tuple: Contains concatenated latent dictionary 130 | 131 | Raises: 132 | ValueError: If latent dimensions don't match 133 | RuntimeError: If tensor operations fail 134 | """ 135 | try: 136 | s = latents1.copy() 137 | video_latent1 = latents1["samples"] 138 | video_latent2 = latents2["samples"] 139 | 140 | # Ensure tensors are on the same device 141 | target_device = video_latent1.device 142 | video_latent2 = video_latent2.to(target_device) 143 | 144 | # Validate dimensions 145 | self._validate_dimensions(video_latent1, video_latent2) 146 | 147 | # Concatenate along frames dimension 148 | s["samples"] = torch.cat([video_latent1, video_latent2], dim=2) 149 | 150 | # Handle noise masks 151 | s["noise_mask"] = self._merge_noise_masks( 152 | latents1, latents2, video_latent1.shape[2], video_latent2.shape[2] 153 | ) 154 | 155 | return (s,) 156 | 157 | except Exception as e: 158 | print(f"[LTXVAddLatents] Error: {str(e)}") 159 | raise 160 | 161 | def _validate_dimensions(self, latent1: torch.Tensor, latent2: torch.Tensor): 162 | """Validates that latent dimensions match except for frames.""" 163 | b1, c1, f1, h1, w1 = latent1.shape 164 | b2, c2, f2, h2, w2 = latent2.shape 165 | 166 | if not (b1 == b2 and c1 == c2 and h1 == h2 and w1 == w2): 167 | raise ValueError( 168 | f"Latent dimensions must match (except frames dimension).\n" 169 | f"Got shapes {latent1.shape} and {latent2.shape}" 170 | ) 171 | 172 | def _merge_noise_masks( 173 | self, latents1: torch.Tensor, latents2: torch.Tensor, frames1: int, frames2: int 174 | ) -> Optional[torch.Tensor]: 175 | """Merges noise masks from both latents with proper handling.""" 176 | if "noise_mask" in latents1 and "noise_mask" in latents2: 177 | return torch.cat([latents1["noise_mask"], latents2["noise_mask"]], dim=2) 178 | elif "noise_mask" in latents1: 179 | zeros = torch.zeros_like(latents1["noise_mask"][:, :, :frames2, :, :]) 180 | return torch.cat([latents1["noise_mask"], zeros], dim=2) 181 | elif "noise_mask" in latents2: 182 | zeros = torch.zeros_like(latents2["noise_mask"][:, :, :frames1, :, :]) 183 | return torch.cat([zeros, latents2["noise_mask"]], dim=2) 184 | return None 185 | 186 | 187 | @comfy_node(name="LTXVSetVideoLatentNoiseMasks") 188 | class LTXVSetVideoLatentNoiseMasks: 189 | """ 190 | Applies multiple masks to a video latent. 191 | 192 | Features: 193 | - Supports multiple input mask formats (2D, 3D, 4D) 194 | - Automatically handles fewer masks than frames by reusing the last mask 195 | - Resizes masks to match latent dimensions 196 | - Preserves batch processing capabilities 197 | 198 | Input Formats: 199 | - 2D mask: Single mask [H, W] 200 | - 3D mask: Multiple masks [M, H, W] 201 | - 4D mask: Multiple masks with channels [M, C, H, W] 202 | """ 203 | 204 | @classmethod 205 | def INPUT_TYPES(s): 206 | return { 207 | "required": { 208 | "samples": ("LATENT",), 209 | "masks": ("MASK",), 210 | } 211 | } 212 | 213 | RETURN_TYPES = ("LATENT",) 214 | FUNCTION = "set_mask" 215 | CATEGORY = "latent/video" 216 | DESCRIPTION = ( 217 | "Applies multiple masks to a video latent. " 218 | "masks can be 2D, 3D, or 4D tensors. " 219 | "If there are fewer masks than frames, the last mask will be reused." 220 | ) 221 | 222 | def set_mask(self, samples: dict, masks: torch.Tensor) -> tuple: 223 | """ 224 | Applies masks to video latent frames. 225 | 226 | Args: 227 | samples (dict): Video latent dictionary containing 'samples' tensor 228 | masks (torch.Tensor): Mask tensor in various possible formats 229 | - 2D: [H, W] single mask 230 | - 3D: [M, H, W] multiple masks 231 | - 4D: [M, C, H, W] multiple masks with channels 232 | 233 | Returns: 234 | tuple: Contains modified latent dictionary with applied masks 235 | 236 | Raises: 237 | ValueError: If mask dimensions are unsupported 238 | RuntimeError: If tensor operations fail 239 | """ 240 | try: 241 | s = samples.copy() 242 | video_latent = s["samples"] 243 | batch_size, channels, num_frames, height, width = video_latent.shape 244 | 245 | # Initialize noise_mask if not present 246 | if "noise_mask" not in s: 247 | s["noise_mask"] = torch.zeros( 248 | (batch_size, 1, num_frames, height, width), 249 | dtype=video_latent.dtype, 250 | device=video_latent.device, 251 | ) 252 | 253 | # Process masks 254 | masks_reshaped = self._reshape_masks(masks) 255 | M = masks_reshaped.shape[0] 256 | resized_masks = self._resize_masks(masks_reshaped, height, width) 257 | 258 | # Apply masks efficiently 259 | self._apply_masks(s["noise_mask"], resized_masks, num_frames, M) 260 | return (s,) 261 | 262 | except Exception as e: 263 | print(f"[LTXVSetVideoLatentNoiseMasks] Error: {str(e)}") 264 | raise 265 | 266 | def _reshape_masks(self, masks: torch.Tensor) -> torch.Tensor: 267 | """Reshapes input masks to consistent 4D format.""" 268 | original_shape = tuple(masks.shape) 269 | ndims = masks.ndim 270 | 271 | if ndims == 2: 272 | return masks.unsqueeze(0).unsqueeze(0) 273 | elif ndims == 3: 274 | return masks.reshape(masks.shape[0], 1, masks.shape[1], masks.shape[2]) 275 | elif ndims == 4: 276 | return masks.reshape(masks.shape[0], 1, masks.shape[2], masks.shape[3]) 277 | else: 278 | raise ValueError( 279 | f"Unsupported 'masks' dimension: {original_shape}. " 280 | "Must be 2D (H,W), 3D (M,H,W), or 4D (M,C,H,W)." 281 | ) 282 | 283 | def _resize_masks( 284 | self, masks: torch.Tensor, height: int, width: int 285 | ) -> torch.Tensor: 286 | """Resizes all masks to match latent dimensions.""" 287 | return torch.nn.functional.interpolate( 288 | masks, size=(height, width), mode="bilinear", align_corners=False 289 | ) 290 | 291 | def _apply_masks( 292 | self, 293 | noise_mask: torch.Tensor, 294 | resized_masks: torch.Tensor, 295 | num_frames: int, 296 | M: int, 297 | ) -> None: 298 | """Applies resized masks to all frames.""" 299 | for f in range(num_frames): 300 | mask_idx = min(f, M - 1) # Reuse last mask if we run out 301 | noise_mask[:, :, f] = resized_masks[mask_idx] 302 | 303 | 304 | @comfy_node(name="LTXVAddLatentGuide") 305 | class LTXVAddLatentGuide(nodes_lt.LTXVAddGuide): 306 | @classmethod 307 | def INPUT_TYPES(s): 308 | return { 309 | "required": { 310 | "vae": ("VAE",), 311 | "positive": ("CONDITIONING",), 312 | "negative": ("CONDITIONING",), 313 | "latent": ("LATENT",), 314 | "guiding_latent": ("LATENT",), 315 | "latent_idx": ( 316 | "INT", 317 | { 318 | "default": 0, 319 | "min": -9999, 320 | "max": 9999, 321 | "step": 1, 322 | "tooltip": "Latent index to start the conditioning at. Can be negative to" 323 | "indicate that the conditioning is on the frames before the latent.", 324 | }, 325 | ), 326 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), 327 | } 328 | } 329 | 330 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") 331 | RETURN_NAMES = ("positive", "negative", "latent") 332 | 333 | CATEGORY = "ltxtricks" 334 | FUNCTION = "generate" 335 | 336 | DESCRIPTION = "Adds a keyframe or a video segment at a specific frame index." 337 | 338 | def generate( 339 | self, vae, positive, negative, latent, guiding_latent, latent_idx, strength 340 | ): 341 | noise_mask = nodes_lt.get_noise_mask(latent) 342 | latent = latent["samples"] 343 | guiding_latent = guiding_latent["samples"] 344 | scale_factors = vae.downscale_index_formula 345 | 346 | if latent_idx <= 0: 347 | frame_idx = latent_idx * scale_factors[0] 348 | else: 349 | frame_idx = 1 + (latent_idx - 1) * scale_factors[0] 350 | 351 | positive, negative, latent, noise_mask = self.append_keyframe( 352 | positive=positive, 353 | negative=negative, 354 | frame_idx=frame_idx, 355 | latent_image=latent, 356 | noise_mask=noise_mask, 357 | guiding_latent=guiding_latent, 358 | strength=strength, 359 | scale_factors=scale_factors, 360 | ) 361 | 362 | return ( 363 | positive, 364 | negative, 365 | {"samples": latent, "noise_mask": noise_mask}, 366 | ) 367 | -------------------------------------------------------------------------------- /nodes_registry.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Callable, Optional, Type 3 | 4 | NODE_CLASS_MAPPINGS = {} 5 | NODE_DISPLAY_NAME_MAPPINGS = {} 6 | 7 | NODES_DISPLAY_NAME_PREFIX = "🅛🅣🅧" 8 | EXPERIMENTAL_DISPLAY_NAME_PREFIX = "(Experimental 🧪)" 9 | DEPRECATED_DISPLAY_NAME_PREFIX = "(Deprecated 🚫)" 10 | DEFAULT_CATEGORY_NAME = "Lightricks" 11 | 12 | 13 | def register_node(node_class: Type, name: str, description: str) -> None: 14 | """ 15 | Register a ComfyUI node class to ComfyUI's global nodes' registry. 16 | 17 | Args: 18 | node_class (Type): The class of the node to be registered. 19 | name (str): The name of the node. 20 | description (str): The short user-friendly description of the node. 21 | 22 | Raises: 23 | ValueError: If `node_class` is not a class, or `class_name` or `display_name` is not a string. 24 | """ 25 | 26 | if not isinstance(node_class, type): 27 | raise ValueError("`node_class` must be a class") 28 | 29 | if not isinstance(name, str): 30 | raise ValueError("`name` must be a string") 31 | 32 | if not isinstance(description, str): 33 | raise ValueError("`description` must be a string") 34 | 35 | NODE_CLASS_MAPPINGS[name] = node_class 36 | NODE_DISPLAY_NAME_MAPPINGS[name] = description 37 | 38 | 39 | def comfy_node( 40 | node_class: Optional[Type] = None, 41 | name: Optional[str] = None, 42 | description: Optional[str] = None, 43 | experimental: bool = False, 44 | deprecated: bool = False, 45 | skip: bool = False, 46 | ) -> Callable: 47 | """ 48 | Decorator for registering a node class with optional name, description, and status flags. 49 | 50 | Args: 51 | node_class (Type): The class of the node to be registered. 52 | name (str, optional): The name of the class. If not provided, the class name will be used. 53 | description (str, optional): The description of the class. 54 | If not provided, an auto-formatted description will be used based on the class name. 55 | experimental (bool): Flag indicating if the class is experimental. Defaults to False. 56 | deprecated (bool): Flag indicating if the class is deprecated. Defaults to False. 57 | skip (bool): Flag indicating if the node registration should be skipped. Defaults to False. 58 | This is useful for conditionally registering nodes based on certain conditions 59 | (e.g. unavailability of certain dependencies). 60 | 61 | Returns: 62 | Callable: The decorator function. 63 | 64 | Raises: 65 | ValueError: If `node_class` is not a class. 66 | """ 67 | 68 | def decorator(node_class: Type) -> Type: 69 | if skip: 70 | return node_class 71 | 72 | if not isinstance(node_class, type): 73 | raise ValueError("`node_class` must be a class") 74 | 75 | nonlocal name, description 76 | if name is None: 77 | name = node_class.__name__ 78 | 79 | # Remove possible "Node" suffix from the class name, e.g. "EditImageNode -> EditImage" 80 | if name is not None and name.endswith("Node"): 81 | name = name[:-4] 82 | 83 | description = _format_description(description, name, experimental, deprecated) 84 | 85 | register_node(node_class, name, description) 86 | return node_class 87 | 88 | # If the decorator is used without parentheses 89 | if node_class is None: 90 | return decorator 91 | else: 92 | return decorator(node_class) 93 | 94 | 95 | def _format_description( 96 | description: str, class_name: str, experimental: bool, deprecated: bool 97 | ) -> str: 98 | """Format nodes display name to a standard format""" 99 | 100 | # If description is not provided, auto-generate one based on the class name 101 | if description is None: 102 | description = camel_case_to_spaces(class_name) 103 | 104 | # Strip the prefix if it's already there 105 | prefix_len = len(NODES_DISPLAY_NAME_PREFIX) 106 | if description.startswith(NODES_DISPLAY_NAME_PREFIX): 107 | description = description[prefix_len:].lstrip() 108 | 109 | # Add the deprecated / experimental prefixes 110 | if deprecated: 111 | description = f"{DEPRECATED_DISPLAY_NAME_PREFIX} {description}" 112 | elif experimental: 113 | description = f"{EXPERIMENTAL_DISPLAY_NAME_PREFIX} {description}" 114 | 115 | # Add the prefix 116 | description = f"{NODES_DISPLAY_NAME_PREFIX} {description}" 117 | 118 | return description 119 | 120 | 121 | def camel_case_to_spaces(text: str) -> str: 122 | # Add space before each capital letter except the first one 123 | spaced_text = re.sub(r"(?<=[a-z])(?=[A-Z])", " ", text) 124 | # Handle sequences of uppercase letters followed by a lowercase letter 125 | spaced_text = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", " ", spaced_text) 126 | # Handle sequences of uppercase letters not followed by a lowercase letter 127 | spaced_text = re.sub(r"(?<=[A-Z])(?=[A-Z][A-Z][a-z])", " ", spaced_text) 128 | return spaced_text 129 | -------------------------------------------------------------------------------- /presets/stg_advanced_presets.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "Custom" 4 | }, 5 | { 6 | "name": "13b Dynamic", 7 | "skip_steps_sigma_threshold": 0.997, 8 | "cfg_star_rescale": true, 9 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180], 10 | "cfg_values": [1, 6, 8, 6, 1, 1], 11 | "stg_scale_values": [0, 4, 4, 4, 2, 1], 12 | "stg_rescale_values": [1, 0.5, 0.5, 1, 1, 1], 13 | "stg_layers_indices": [[11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] 14 | }, 15 | { 16 | "name": "13b Balanced", 17 | "skip_steps_sigma_threshold": 0.998, 18 | "cfg_star_rescale": true, 19 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180], 20 | "cfg_values": [1, 6, 8, 6, 1, 1], 21 | "stg_scale_values": [0, 4, 4, 4, 2, 1], 22 | "stg_rescale_values": [1, 0.5, 0.5, 1, 1, 1], 23 | "stg_layers_indices": [[12], [12], [5], [5], [28], [29]] 24 | }, 25 | { 26 | "name": "13b Upscale", 27 | "skip_steps_sigma_threshold": 0.997, 28 | "cfg_star_rescale": true, 29 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180], 30 | "cfg_values": [1, 1, 1, 1, 1, 1], 31 | "stg_scale_values": [1, 1, 1, 1, 1, 1], 32 | "stg_rescale_values": [1, 1, 1, 1, 1, 1], 33 | "stg_layers_indices": [[42], [42], [42], [42], [42], [42]] 34 | }, 35 | { 36 | "name": "2b", 37 | "skip_steps_sigma_threshold": 0.997, 38 | "cfg_star_rescale": true, 39 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180], 40 | "cfg_values": [4, 4, 4, 4, 1, 1], 41 | "stg_scale_values": [2, 2, 2, 2, 1, 0], 42 | "stg_rescale_values": [1, 1, 1, 1, 1, 1], 43 | "stg_layers_indices": [[14], [14], [14], [14], [14], [14]] 44 | } 45 | ] 46 | -------------------------------------------------------------------------------- /prompt_enhancer_nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import comfy.model_management 5 | import comfy.model_patcher 6 | import folder_paths 7 | import torch 8 | from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer 9 | 10 | from .nodes_registry import comfy_node 11 | from .prompt_enhancer_utils import generate_cinematic_prompt 12 | 13 | LLM_NAME = ["unsloth/Llama-3.2-3B-Instruct"] 14 | 15 | IMAGE_CAPTIONER = ["MiaoshouAI/Florence-2-large-PromptGen-v2.0"] 16 | 17 | MODELS_PATH_KEY = "LLM" 18 | 19 | 20 | class PromptEnhancer(torch.nn.Module): 21 | def __init__( 22 | self, 23 | image_caption_processor: AutoProcessor, 24 | image_caption_model: AutoModelForCausalLM, 25 | llm_model: AutoModelForCausalLM, 26 | llm_tokenizer: AutoTokenizer, 27 | ): 28 | super().__init__() 29 | self.image_caption_processor = image_caption_processor 30 | self.image_caption_model = image_caption_model 31 | self.llm_model = llm_model 32 | self.llm_tokenizer = llm_tokenizer 33 | self.device = image_caption_model.device 34 | # model parameters and buffer sizes plus some extra 1GB. 35 | self.model_size = ( 36 | self.get_model_size(self.image_caption_model) 37 | + self.get_model_size(self.llm_model) 38 | + 1073741824 39 | ) 40 | 41 | def forward(self, prompt, image_conditioning, max_resulting_tokens): 42 | enhanced_prompt = generate_cinematic_prompt( 43 | self.image_caption_model, 44 | self.image_caption_processor, 45 | self.llm_model, 46 | self.llm_tokenizer, 47 | prompt, 48 | image_conditioning, 49 | max_new_tokens=max_resulting_tokens, 50 | ) 51 | 52 | return enhanced_prompt 53 | 54 | @staticmethod 55 | def get_model_size(model): 56 | total_size = sum(p.numel() * p.element_size() for p in model.parameters()) 57 | total_size += sum(b.numel() * b.element_size() for b in model.buffers()) 58 | return total_size 59 | 60 | def memory_required(self, input_shape): 61 | return self.model_size 62 | 63 | 64 | @comfy_node(name="LTXVPromptEnhancerLoader") 65 | class LTXVPromptEnhancerLoader: 66 | @classmethod 67 | def INPUT_TYPES(s): 68 | return { 69 | "required": { 70 | "llm_name": ( 71 | "STRING", 72 | { 73 | "default": LLM_NAME, 74 | "tooltip": "The hugging face name of the llm model to load.", 75 | }, 76 | ), 77 | "image_captioner_name": ( 78 | "STRING", 79 | { 80 | "default": IMAGE_CAPTIONER, 81 | "tooltip": "The hugging face name of the image captioning model to load.", 82 | }, 83 | ), 84 | } 85 | } 86 | 87 | RETURN_TYPES = ("LTXV_PROMPT_ENHANCER",) 88 | RETURN_NAMES = ("prompt_enhancer",) 89 | FUNCTION = "load" 90 | CATEGORY = "lightricks/LTXV" 91 | TITLE = "LTXV Prompt Enhancer (Down)Loader" 92 | OUTPUT_NODE = False 93 | DESCRIPTION = "Downloads and initializes LLM and image captioning models from Hugging Face to enhance text prompts for image generation." 94 | 95 | def model_path_download_if_needed(self, model_name): 96 | model_directory = os.path.join(folder_paths.models_dir, MODELS_PATH_KEY) 97 | os.makedirs(model_directory, exist_ok=True) 98 | 99 | model_name_ = model_name.rsplit("/", 1)[-1] 100 | model_path = os.path.join(model_directory, model_name_) 101 | 102 | if not os.path.exists(model_path): 103 | from huggingface_hub import snapshot_download 104 | 105 | try: 106 | snapshot_download( 107 | repo_id=model_name, 108 | local_dir=model_path, 109 | local_dir_use_symlinks=False, 110 | ) 111 | except Exception: 112 | shutil.rmtree(model_path, ignore_errors=True) 113 | raise 114 | return model_path 115 | 116 | def down_load_llm_model(self, llm_name, load_device): 117 | model_path = self.model_path_download_if_needed(llm_name) 118 | llm_model = AutoModelForCausalLM.from_pretrained( 119 | model_path, 120 | torch_dtype=torch.bfloat16, 121 | ) 122 | 123 | llm_tokenizer = AutoTokenizer.from_pretrained( 124 | model_path, 125 | ) 126 | 127 | return llm_model, llm_tokenizer 128 | 129 | def down_load_image_captioner(self, image_captioner, load_device): 130 | model_path = self.model_path_download_if_needed(image_captioner) 131 | image_caption_model = AutoModelForCausalLM.from_pretrained( 132 | model_path, trust_remote_code=True 133 | ) 134 | 135 | image_caption_processor = AutoProcessor.from_pretrained( 136 | model_path, trust_remote_code=True 137 | ) 138 | 139 | return image_caption_model, image_caption_processor 140 | 141 | def load(self, llm_name, image_captioner_name): 142 | load_device = comfy.model_management.get_torch_device() 143 | offload_device = comfy.model_management.vae_offload_device() 144 | llm_model, llm_tokenizer = self.down_load_llm_model(llm_name, load_device) 145 | image_caption_model, image_caption_processor = self.down_load_image_captioner( 146 | image_captioner_name, load_device 147 | ) 148 | 149 | enhancer = PromptEnhancer( 150 | image_caption_processor, image_caption_model, llm_model, llm_tokenizer 151 | ) 152 | patcher = comfy.model_patcher.ModelPatcher( 153 | enhancer, 154 | load_device, 155 | offload_device, 156 | ) 157 | return (patcher,) 158 | 159 | 160 | @comfy_node(name="LTXVPromptEnhancer") 161 | class LTXVPromptEnhancer: 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return { 165 | "required": { 166 | "prompt": ("STRING",), 167 | "prompt_enhancer": ("LTXV_PROMPT_ENHANCER",), 168 | "max_resulting_tokens": ( 169 | "INT", 170 | {"default": 256, "min": 32, "max": 512}, 171 | ), 172 | }, 173 | "optional": { 174 | "image_prompt": ("IMAGE",), 175 | }, 176 | } 177 | 178 | RETURN_TYPES = ("STRING",) 179 | RETURN_NAMES = ("str",) 180 | FUNCTION = "enhance" 181 | CATEGORY = "lightricks/LTXV" 182 | TITLE = "LTXV Prompt Enhancer" 183 | OUTPUT_NODE = False 184 | DESCRIPTION = ( 185 | "Enhances text prompts for image generation using LLMs. " 186 | "Optionally incorporates reference images to create more contextually relevant descriptions." 187 | ) 188 | 189 | def enhance( 190 | self, 191 | prompt, 192 | prompt_enhancer: comfy.model_patcher.ModelPatcher, 193 | image_prompt: torch.Tensor = None, 194 | max_resulting_tokens=256, 195 | ): 196 | comfy.model_management.free_memory( 197 | prompt_enhancer.memory_required([]), 198 | comfy.model_management.get_torch_device(), 199 | ) 200 | comfy.model_management.load_model_gpu(prompt_enhancer) 201 | model = prompt_enhancer.model 202 | image_conditioning = None 203 | if image_prompt is not None: 204 | permuted_image = image_prompt.permute(3, 0, 1, 2)[None, :] 205 | image_conditioning = [(permuted_image, 0, 1.0)] 206 | 207 | enhanced_prompt = model(prompt, image_conditioning, max_resulting_tokens) 208 | return (enhanced_prompt[0],) 209 | -------------------------------------------------------------------------------- /prompt_enhancer_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from PIL import Image 7 | 8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 9 | 10 | T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. 11 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. 12 | Start directly with the action, and keep descriptions literal and precise. 13 | Think like a cinematographer describing a shot list. 14 | Do not change the user input intent, just enhance it. 15 | Keep within 150 words. 16 | For best results, build your prompts using this structure: 17 | Start with main action in a single sentence 18 | Add specific details about movements and gestures 19 | Describe character/object appearances precisely 20 | Include background and environment details 21 | Specify camera angles and movements 22 | Describe lighting and colors 23 | Note any changes or sudden events 24 | Do not exceed the 150 word limit! 25 | Output the enhanced prompt only. 26 | 27 | Examples: 28 | user prompt: A man drives a toyota car. 29 | enhanced prompt: A person is driving a car on a two-lane road, holding the steering wheel with both hands. The person's hands are light-skinned and they are wearing a black long-sleeved shirt. The steering wheel has a Toyota logo in the center and black leather around it. The car's dashboard is visible, showing a speedometer, tachometer, and navigation screen. The road ahead is straight and there are trees and fields visible on either side. The camera is positioned inside the car, providing a view from the driver's perspective. The lighting is natural and overcast, with a slightly cool tone. 30 | 31 | user prompt: A young woman is sitting on a chair. 32 | enhanced prompt: A young woman with dark, curly hair and pale skin sits on a chair; she wears a dark, intricately patterned dress with a high collar and long, dark gloves that extend past her elbows; the scene is dimly lit, with light streaming in from a large window behind the characters. 33 | 34 | user prompt: Aerial view of a city skyline. 35 | enhanced prompt: The camera pans across a cityscape of tall buildings with a circular building in the center. The camera moves from left to right, showing the tops of the buildings and the circular building in the center. The buildings are various shades of gray and white, and the circular building has a green roof. The camera angle is high, looking down at the city. The lighting is bright, with the sun shining from the upper left, casting shadows from the buildings. 36 | """ 37 | 38 | I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. 39 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. 40 | Start directly with the action, and keep descriptions literal and precise. 41 | Think like a cinematographer describing a shot list. 42 | Keep within 150 words. 43 | For best results, build your prompts using this structure: 44 | Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. 45 | Start with main action in a single sentence 46 | Add specific details about movements and gestures 47 | Describe character/object appearances precisely 48 | Include background and environment details 49 | Specify camera angles and movements 50 | Describe lighting and colors 51 | Note any changes or sudden events 52 | Align to the image caption if it contradicts the user text input. 53 | Do not exceed the 150 word limit! 54 | Output the enhanced prompt only. 55 | """ 56 | 57 | 58 | def tensor_to_pil(tensor): 59 | # Ensure tensor is in range [-1, 1] 60 | assert tensor.min() >= -1 and tensor.max() <= 1 61 | 62 | # Convert from [-1, 1] to [0, 1] 63 | tensor = (tensor + 1) / 2 64 | 65 | # Rearrange from [C, H, W] to [H, W, C] 66 | tensor = tensor.permute(1, 2, 0) 67 | 68 | # Convert to numpy array and then to uint8 range [0, 255] 69 | numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") 70 | 71 | # Convert to PIL Image 72 | return Image.fromarray(numpy_image) 73 | 74 | 75 | def generate_cinematic_prompt( 76 | image_caption_model, 77 | image_caption_processor, 78 | prompt_enhancer_model, 79 | prompt_enhancer_tokenizer, 80 | prompt: Union[str, List[str]], 81 | conditioning_items: Optional[List[Tuple[torch.Tensor, int, float]]] = None, 82 | max_new_tokens: int = 256, 83 | ) -> List[str]: 84 | prompts = [prompt] if isinstance(prompt, str) else prompt 85 | 86 | if conditioning_items is None: 87 | prompts = _generate_t2v_prompt( 88 | prompt_enhancer_model, 89 | prompt_enhancer_tokenizer, 90 | prompts, 91 | max_new_tokens, 92 | T2V_CINEMATIC_PROMPT, 93 | ) 94 | else: 95 | # if len(conditioning_items) > 1 or conditioning_items[0][1] != 0: 96 | # logger.warning( 97 | # "prompt enhancement does only support first frame of conditioning items, returning original prompts" 98 | # ) 99 | # return prompts 100 | 101 | first_frame_conditioning_item = conditioning_items[0] 102 | first_frames = _get_first_frames_from_conditioning_item( 103 | first_frame_conditioning_item 104 | ) 105 | 106 | assert len(first_frames) == len( 107 | prompts 108 | ), "Number of conditioning frames must match number of prompts" 109 | 110 | prompts = _generate_i2v_prompt( 111 | image_caption_model, 112 | image_caption_processor, 113 | prompt_enhancer_model, 114 | prompt_enhancer_tokenizer, 115 | prompts, 116 | first_frames, 117 | max_new_tokens, 118 | I2V_CINEMATIC_PROMPT, 119 | ) 120 | 121 | return prompts 122 | 123 | 124 | def _get_first_frames_from_conditioning_item( 125 | conditioning_item: Tuple[torch.Tensor, int, float] 126 | ) -> List[Image.Image]: 127 | frames_tensor = conditioning_item[0] 128 | # tensor shape: [batch_size, 3, num_frames, height, width], take first frame from each sample 129 | return [ 130 | tensor_to_pil(frames_tensor[i, :, 0, :, :]) 131 | for i in range(frames_tensor.shape[0]) 132 | ] 133 | 134 | 135 | def _generate_t2v_prompt( 136 | prompt_enhancer_model, 137 | prompt_enhancer_tokenizer, 138 | prompts: List[str], 139 | max_new_tokens: int, 140 | system_prompt: str, 141 | ) -> List[str]: 142 | messages = [ 143 | [ 144 | {"role": "system", "content": system_prompt}, 145 | {"role": "user", "content": f"user_prompt: {p}"}, 146 | ] 147 | for p in prompts 148 | ] 149 | 150 | texts = [ 151 | prompt_enhancer_tokenizer.apply_chat_template( 152 | m, tokenize=False, add_generation_prompt=True 153 | ) 154 | for m in messages 155 | ] 156 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( 157 | prompt_enhancer_model.device 158 | ) 159 | 160 | return _generate_and_decode_prompts( 161 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens 162 | ) 163 | 164 | 165 | def _generate_i2v_prompt( 166 | image_caption_model, 167 | image_caption_processor, 168 | prompt_enhancer_model, 169 | prompt_enhancer_tokenizer, 170 | prompts: List[str], 171 | first_frames: List[Image.Image], 172 | max_new_tokens: int, 173 | system_prompt: str, 174 | ) -> List[str]: 175 | image_captions = _generate_image_captions( 176 | image_caption_model, image_caption_processor, first_frames 177 | ) 178 | 179 | messages = [ 180 | [ 181 | {"role": "system", "content": system_prompt}, 182 | {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, 183 | ] 184 | for p, c in zip(prompts, image_captions) 185 | ] 186 | 187 | texts = [ 188 | prompt_enhancer_tokenizer.apply_chat_template( 189 | m, tokenize=False, add_generation_prompt=True 190 | ) 191 | for m in messages 192 | ] 193 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( 194 | prompt_enhancer_model.device 195 | ) 196 | 197 | return _generate_and_decode_prompts( 198 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens 199 | ) 200 | 201 | 202 | def _generate_image_captions( 203 | image_caption_model, 204 | image_caption_processor, 205 | images: List[Image.Image], 206 | system_prompt: str = "", 207 | ) -> List[str]: 208 | image_caption_prompts = [system_prompt] * len(images) 209 | inputs = image_caption_processor( 210 | image_caption_prompts, images, return_tensors="pt" 211 | ).to(image_caption_model.device) 212 | 213 | with torch.inference_mode(): 214 | generated_ids = image_caption_model.generate( 215 | input_ids=inputs["input_ids"], 216 | pixel_values=inputs["pixel_values"], 217 | max_new_tokens=1024, 218 | do_sample=False, 219 | num_beams=3, 220 | ) 221 | 222 | return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) 223 | 224 | 225 | def _get_random_scene_type(): 226 | """ 227 | Randomly select a scene type to add to the prompt. 228 | """ 229 | types = [ 230 | "The scene is captured in real-life footage.", 231 | "The scene is computer-generated imagery.", 232 | "The scene appears to be from a movie.", 233 | "The scene appears to be from a TV show.", 234 | "The scene is captured in a studio.", 235 | ] 236 | return random.choice(types) 237 | 238 | 239 | def _generate_and_decode_prompts( 240 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int 241 | ) -> List[str]: 242 | with torch.inference_mode(): 243 | outputs = prompt_enhancer_model.generate( 244 | **model_inputs, max_new_tokens=max_new_tokens 245 | ) 246 | generated_ids = [ 247 | output_ids[len(input_ids) :] 248 | for input_ids, output_ids in zip(model_inputs.input_ids, outputs) 249 | ] 250 | decoded_prompts = prompt_enhancer_tokenizer.batch_decode( 251 | generated_ids, skip_special_tokens=True 252 | ) 253 | 254 | decoded_prompts = [p + f" {_get_random_scene_type()}." for p in decoded_prompts] 255 | 256 | print(decoded_prompts) 257 | 258 | return decoded_prompts 259 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-ltxvideo" 3 | version = "0.1.0" 4 | description = "Custom nodes for LTX-Video support in ComfyUI" 5 | authors = [ 6 | { name = "Andrew Kvochko", email = "akvochko@lightricks.com" } 7 | ] 8 | requires-python = ">=3.10" 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | dependencies = [ 12 | "diffusers", 13 | "huggingface_hub>=0.25.2", 14 | "transformers[timm]>=4.45.0", 15 | "einops" 16 | ] 17 | 18 | [project.optional-dependencies] 19 | internal = [ 20 | "ltx-video@git+https://github.com/Lightricks/LTX-Video@ltx-video-0.9.7", 21 | "av>=10.0.0", 22 | "q8-kernels==0.0.5" 23 | ] 24 | dev = [ 25 | "pre-commit>=4.0.1", 26 | "pytest>=8.0.0", 27 | "websocket-client==1.6.1", 28 | "scikit-image==0.24.0" 29 | ] 30 | 31 | [tool.isort] 32 | profile = "black" 33 | line_length = 88 34 | force_single_line = false 35 | -------------------------------------------------------------------------------- /q8_nodes.py: -------------------------------------------------------------------------------- 1 | try: 2 | from q8_kernels.integration.patch_transformer import ( 3 | patch_comfyui_native_transformer, 4 | patch_comfyui_transformer, 5 | ) 6 | 7 | Q8_AVAILABLE = True 8 | except ImportError: 9 | Q8_AVAILABLE = False 10 | 11 | from .nodes_registry import comfy_node 12 | 13 | 14 | def check_q8_available(): 15 | if not Q8_AVAILABLE: 16 | raise ImportError( 17 | "Q8 kernels are not available. To use this feature install the q8_kernels package from here:." 18 | "https://github.com/Lightricks/LTX-Video-Q8-Kernels" 19 | ) 20 | 21 | 22 | @comfy_node(name="LTXQ8Patch") 23 | class LTXVQ8Patch: 24 | @classmethod 25 | def INPUT_TYPES(s): 26 | return { 27 | "required": { 28 | "model": ("MODEL",), 29 | "use_fp8_attention": ( 30 | "BOOLEAN", 31 | {"default": False, "tooltip": "Use FP8 attention."}, 32 | ), 33 | } 34 | } 35 | 36 | RETURN_TYPES = ("MODEL",) 37 | FUNCTION = "patch" 38 | CATEGORY = "lightricks/LTXV" 39 | TITLE = "LTXV Q8 Patcher" 40 | 41 | def patch(self, model, use_fp8_attention): 42 | check_q8_available() 43 | m = model.clone() 44 | diffusion_key = "diffusion_model" 45 | diffusion_model = m.get_model_object(diffusion_key) 46 | if diffusion_model.__class__.__name__ == "LTXVTransformer3D": 47 | transformer_key = "diffusion_model.transformer" 48 | patcher = patch_comfyui_transformer 49 | else: 50 | transformer_key = "diffusion_model" 51 | patcher = patch_comfyui_native_transformer 52 | transformer = m.get_model_object(transformer_key) 53 | patcher(transformer, use_fp8_attention, True) 54 | m.add_object_patch(transformer_key, transformer) 55 | return (m,) 56 | -------------------------------------------------------------------------------- /recurrent_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy_extras.nodes_custom_sampler import CFGGuider, SamplerCustomAdvanced 3 | from comfy_extras.nodes_lt import LTXVAddGuide, LTXVCropGuides 4 | 5 | from .latents import LTXVAddLatents, LTXVSelectLatents 6 | from .nodes_registry import comfy_node 7 | from .tricks import AddLatentGuideNode 8 | 9 | 10 | @comfy_node(description="Linear transition with overlap") 11 | class LinearOverlapLatentTransition: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | return { 15 | "required": { 16 | "samples1": ("LATENT",), 17 | "samples2": ("LATENT",), 18 | "overlap": ("INT", {"default": 1, "min": 1, "max": 256}), 19 | }, 20 | "optional": { 21 | "axis": ("INT", {"default": 0}), 22 | }, 23 | } 24 | 25 | RETURN_TYPES = ("LATENT",) 26 | FUNCTION = "process" 27 | 28 | CATEGORY = "Lightricks/latent" 29 | 30 | def get_subbatch(self, samples): 31 | s = samples.copy() 32 | samples = s["samples"] 33 | return samples 34 | 35 | def process(self, samples1, samples2, overlap, axis=0): 36 | samples1 = self.get_subbatch(samples1) 37 | samples2 = self.get_subbatch(samples2) 38 | 39 | # Create transition coefficients 40 | alpha = torch.linspace(1, 0, overlap + 2)[1:-1].to(samples1.device) 41 | 42 | # Create shape for broadcasting based on the axis 43 | shape = [1] * samples1.dim() 44 | shape[axis] = alpha.size(0) 45 | alpha = alpha.reshape(shape) 46 | 47 | # Create slices for the overlap regions 48 | slice_all = [slice(None)] * samples1.dim() 49 | slice_overlap1 = slice_all.copy() 50 | slice_overlap1[axis] = slice(-overlap, None) 51 | slice_overlap2 = slice_all.copy() 52 | slice_overlap2[axis] = slice(0, overlap) 53 | slice_rest1 = slice_all.copy() 54 | slice_rest1[axis] = slice(None, -overlap) 55 | slice_rest2 = slice_all.copy() 56 | slice_rest2[axis] = slice(overlap, None) 57 | 58 | # Combine samples 59 | parts = [ 60 | samples1[tuple(slice_rest1)], 61 | alpha * samples1[tuple(slice_overlap1)] 62 | + (1 - alpha) * samples2[tuple(slice_overlap2)], 63 | samples2[tuple(slice_rest2)], 64 | ] 65 | 66 | combined_samples = torch.cat(parts, dim=axis) 67 | combined_batch_index = torch.arange(0, combined_samples.shape[0]) 68 | 69 | return ( 70 | { 71 | "samples": combined_samples, 72 | "batch_index": combined_batch_index, 73 | }, 74 | ) 75 | 76 | 77 | @comfy_node( 78 | name="LTXVRecurrentKSampler", 79 | ) 80 | class LTXVRecurrentKSampler: 81 | 82 | @classmethod 83 | def INPUT_TYPES(s): 84 | return { 85 | "required": { 86 | "model": ("MODEL",), 87 | "vae": ("VAE",), 88 | "noise": ("NOISE",), 89 | "sampler": ("SAMPLER",), 90 | "sigmas": ("SIGMAS",), 91 | "latents": ("LATENT",), 92 | "chunk_sizes": ("STRING", {"default": "3", "multiline": False}), 93 | "overlaps": ("STRING", {"default": "1", "multiline": False}), 94 | "positive": ("CONDITIONING",), 95 | "negative": ("CONDITIONING",), 96 | "input_image": ("IMAGE",), 97 | "linear_blend_latents": ("BOOLEAN", {"default": True}), 98 | "conditioning_strength": ( 99 | "FLOAT", 100 | {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01}, 101 | ), 102 | }, 103 | "optional": { 104 | "guider": ("GUIDER",), 105 | }, 106 | } 107 | 108 | RETURN_TYPES = ( 109 | "LATENT", 110 | "LATENT", 111 | ) 112 | RETURN_NAMES = ( 113 | "output", 114 | "denoised_output", 115 | ) 116 | 117 | FUNCTION = "sample" 118 | 119 | CATEGORY = "sampling" 120 | 121 | def sample( 122 | self, 123 | model, 124 | vae, 125 | noise, 126 | sampler, 127 | sigmas, 128 | latents, 129 | chunk_sizes, 130 | overlaps, 131 | positive, 132 | negative, 133 | input_image, 134 | linear_blend_latents, 135 | conditioning_strength, 136 | guider=None, 137 | ): 138 | select_latents = LTXVSelectLatents().select_latents 139 | add_latent_guide = AddLatentGuideNode().generate 140 | add_latents = LTXVAddLatents().add_latents 141 | positive_orig = positive.copy() 142 | negative_orig = negative.copy() 143 | 144 | # Parse chunk sizes and overlaps from strings 145 | chunk_sizes = [int(x) for x in chunk_sizes.split(",")] 146 | overlaps = [int(x) for x in overlaps.split(",")] 147 | 148 | # Extend lists if shorter than number of sigma steps 149 | n_steps = len(sigmas) - 1 150 | if len(chunk_sizes) < n_steps: 151 | chunk_sizes.extend([chunk_sizes[-1]] * (n_steps - len(chunk_sizes))) 152 | if len(overlaps) < n_steps: 153 | overlaps.extend([overlaps[-1]] * (n_steps - len(overlaps))) 154 | 155 | # Initialize working latents 156 | current_latents = latents.copy() 157 | t_latents = None 158 | # Loop through sigma pairs for progressive denoising 159 | for i in range(n_steps): 160 | current_sigmas = sigmas[i : i + 2] 161 | current_chunk_size = chunk_sizes[i] 162 | current_overlap = overlaps[i] 163 | 164 | print(f"\nProcessing sigma step {i} with sigmas {current_sigmas}") 165 | print( 166 | f"Using chunk size {current_chunk_size} and overlap {current_overlap}" 167 | ) 168 | 169 | # Calculate valid chunk starts to ensure the last chunk isn't shorter than the overlap 170 | total_frames = current_latents["samples"].shape[2] 171 | chunk_stride = current_chunk_size - current_overlap 172 | valid_chunk_starts = list( 173 | range(0, total_frames - current_overlap, chunk_stride) 174 | ) 175 | 176 | # If the last chunk would be too short, remove the last start position 177 | if ( 178 | total_frames > chunk_stride 179 | and (total_frames - valid_chunk_starts[-1]) < current_chunk_size 180 | ): 181 | print( 182 | "last chunk is too short, it will only be of size", 183 | total_frames - valid_chunk_starts[-1], 184 | "frames", 185 | ) 186 | 187 | # Process each chunk for current sigma pair 188 | for i_chunk, chunk_start in enumerate(valid_chunk_starts): 189 | (latents_chunk,) = select_latents( 190 | current_latents, chunk_start, chunk_start + current_chunk_size - 1 191 | ) 192 | print(f"Processing chunk {i_chunk} starting at frame {chunk_start}") 193 | 194 | if i_chunk == 0: 195 | positive, negative, latents_chunk = LTXVAddGuide().generate( 196 | positive_orig, 197 | negative_orig, 198 | vae, 199 | latents_chunk, 200 | input_image, 201 | 0, 202 | 0.75, 203 | ) 204 | else: 205 | (cond_latent,) = select_latents(t_latents, -current_overlap, -1) 206 | model, latents_chunk = add_latent_guide( 207 | model, latents_chunk, cond_latent, 0, conditioning_strength 208 | ) 209 | 210 | if guider is None: 211 | (guider_obj,) = CFGGuider().get_guider( 212 | model, positive, negative, 1.0 213 | ) 214 | else: 215 | guider_obj = guider 216 | (_, denoised_latents) = SamplerCustomAdvanced().sample( 217 | noise, guider_obj, sampler, current_sigmas, latents_chunk 218 | ) 219 | (positive, negative, denoised_latents) = LTXVCropGuides().crop( 220 | positive, negative, denoised_latents 221 | ) 222 | 223 | if i_chunk == 0: 224 | t_latents = denoised_latents 225 | else: 226 | if linear_blend_latents and current_overlap > 1: 227 | # the first output latent is the result of a 1:8 latent 228 | # reinterpreted as a 1:1 latent, so we ignore it 229 | (denoised_latents_drop_first,) = select_latents( 230 | denoised_latents, 1, -1 231 | ) 232 | (t_latents,) = LinearOverlapLatentTransition().process( 233 | t_latents, 234 | denoised_latents_drop_first, 235 | current_overlap - 1, 236 | axis=2, 237 | ) 238 | else: 239 | (truncated_denoised_latents,) = select_latents( 240 | denoised_latents, current_overlap, -1 241 | ) 242 | (t_latents,) = add_latents( 243 | t_latents, truncated_denoised_latents 244 | ) 245 | 246 | print( 247 | f"Completed chunk {i_chunk}, current output shape: {t_latents['samples'].shape}" 248 | ) 249 | 250 | # Update current_latents for next sigma step 251 | current_latents = t_latents.copy() 252 | print(f"Completed sigma step {i}") 253 | 254 | return t_latents, t_latents 255 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | einops 3 | huggingface_hub>=0.25.2 4 | transformers[timm]>=4.45.0 5 | -------------------------------------------------------------------------------- /tiled_sampler.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import comfy 4 | import torch 5 | from comfy_extras.nodes_custom_sampler import SamplerCustomAdvanced 6 | from comfy_extras.nodes_lt import LTXVAddGuide, LTXVCropGuides 7 | 8 | from .latents import LTXVAddLatentGuide, LTXVSelectLatents 9 | from .nodes_registry import comfy_node 10 | 11 | 12 | @comfy_node( 13 | name="LTXVTiledSampler", 14 | ) 15 | class LTXVTiledSampler: 16 | 17 | @classmethod 18 | def INPUT_TYPES(s): 19 | return { 20 | "required": { 21 | "model": ("MODEL",), 22 | "vae": ("VAE",), 23 | "noise": ("NOISE",), 24 | "sampler": ("SAMPLER",), 25 | "sigmas": ("SIGMAS",), 26 | "guider": ("GUIDER",), 27 | "latents": ("LATENT",), 28 | "horizontal_tiles": ("INT", {"default": 1, "min": 1, "max": 6}), 29 | "vertical_tiles": ("INT", {"default": 1, "min": 1, "max": 6}), 30 | "overlap": ("INT", {"default": 1, "min": 1, "max": 8}), 31 | "latents_cond_strength": ( 32 | "FLOAT", 33 | {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01}, 34 | ), 35 | "boost_latent_similarity": ( 36 | "BOOLEAN", 37 | {"default": False}, 38 | ), 39 | "crop": (["center", "disabled"], {"default": "disabled"}), 40 | }, 41 | "optional": { 42 | "optional_cond_images": ("IMAGE",), 43 | "optional_cond_indices": ("STRING", {"default": "0"}), 44 | "images_cond_strengths": ("STRING", {"default": "0.9"}), 45 | }, 46 | } 47 | 48 | RETURN_TYPES = ( 49 | "LATENT", 50 | "LATENT", 51 | ) 52 | RETURN_NAMES = ( 53 | "output", 54 | "denoised_output", 55 | ) 56 | 57 | FUNCTION = "sample" 58 | 59 | CATEGORY = "sampling" 60 | 61 | def sample( 62 | self, 63 | model, 64 | vae, 65 | noise, 66 | sampler, 67 | sigmas, 68 | guider, 69 | latents, 70 | horizontal_tiles, 71 | vertical_tiles, 72 | overlap, 73 | latents_cond_strength, 74 | boost_latent_similarity, 75 | crop="disabled", 76 | optional_cond_images=None, 77 | optional_cond_indices="0", 78 | images_cond_strengths="0.9", 79 | ): 80 | 81 | # Get the latent samples 82 | samples = latents["samples"] 83 | 84 | batch, channels, frames, height, width = samples.shape 85 | time_scale_factor, width_scale_factor, height_scale_factor = ( 86 | vae.downscale_index_formula 87 | ) 88 | # Validate image dimensions if provided 89 | if optional_cond_images is not None: 90 | img_height = height * height_scale_factor 91 | img_width = width * width_scale_factor 92 | cond_images = comfy.utils.common_upscale( 93 | optional_cond_images.movedim(-1, 1), 94 | img_width, 95 | img_height, 96 | "bicubic", 97 | crop=crop, 98 | ).movedim(1, -1) 99 | print("cond_images shape after resize", cond_images.shape) 100 | img_batch, img_height, img_width, img_channels = cond_images.shape 101 | else: 102 | cond_images = None 103 | 104 | if optional_cond_indices is not None and optional_cond_images is not None: 105 | optional_cond_indices = optional_cond_indices.split(",") 106 | optional_cond_indices = [int(i) for i in optional_cond_indices] 107 | assert len(optional_cond_indices) == len( 108 | optional_cond_images 109 | ), "Number of optional cond images must match number of optional cond indices" 110 | 111 | images_cond_strengths = [float(i) for i in images_cond_strengths.split(",")] 112 | if optional_cond_images is not None and len(images_cond_strengths) < len( 113 | optional_cond_images 114 | ): 115 | # Repeat the last value to match the length of optional_cond_images 116 | images_cond_strengths = images_cond_strengths + [ 117 | images_cond_strengths[-1] 118 | ] * (len(optional_cond_images) - len(images_cond_strengths)) 119 | 120 | # Calculate tile sizes with overlap 121 | base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles 122 | base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles 123 | 124 | # Initialize output tensor and weight tensor 125 | output = torch.zeros_like(samples) 126 | denoised_output = torch.zeros_like(samples) 127 | weights = torch.zeros_like(samples) 128 | 129 | # Get positive and negative conditioning 130 | try: 131 | positive, negative = guider.raw_conds 132 | except AttributeError: 133 | raise ValueError( 134 | "Guider does not have raw conds, cannot use it as a guider. " 135 | "Please use STGGuiderAdvanced." 136 | ) 137 | 138 | # Process each tile 139 | for v in range(vertical_tiles): 140 | for h in range(horizontal_tiles): 141 | # Calculate tile boundaries 142 | h_start = h * (base_tile_width - overlap) 143 | v_start = v * (base_tile_height - overlap) 144 | 145 | # Adjust end positions for edge tiles 146 | h_end = ( 147 | min(h_start + base_tile_width, width) 148 | if h < horizontal_tiles - 1 149 | else width 150 | ) 151 | v_end = ( 152 | min(v_start + base_tile_height, height) 153 | if v < vertical_tiles - 1 154 | else height 155 | ) 156 | 157 | # Calculate actual tile dimensions 158 | tile_height = v_end - v_start 159 | tile_width = h_end - h_start 160 | 161 | print(f"Processing tile at row {v}, col {h}:") 162 | print(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})") 163 | print(f" Size: {tile_height}x{tile_width}") 164 | 165 | # Extract tile 166 | tile = samples[:, :, :, v_start:v_end, h_start:h_end] 167 | 168 | # Create tile latents dict 169 | tile_latents = {"samples": tile} 170 | unconditioned_tile_latents = tile_latents.copy() 171 | 172 | # Handle image conditioning if provided 173 | if cond_images is not None: 174 | # Scale coordinates for image 175 | img_h_start = v_start * height_scale_factor 176 | img_h_end = v_end * height_scale_factor 177 | img_w_start = h_start * width_scale_factor 178 | img_w_end = h_end * width_scale_factor 179 | 180 | # Create copies of conditioning for this tile 181 | tile_positive = positive.copy() 182 | tile_negative = negative.copy() 183 | 184 | for i_cond_image, ( 185 | cond_image, 186 | cond_image_idx, 187 | cond_image_strength, 188 | ) in enumerate( 189 | zip(cond_images, optional_cond_indices, images_cond_strengths) 190 | ): 191 | # Extract image tile 192 | img_tile = cond_image[ 193 | img_h_start:img_h_end, img_w_start:img_w_end, : 194 | ].unsqueeze(0) 195 | 196 | print( 197 | f"Applying image conditioning on cond image {i_cond_image} for tile at row {v}, col {h} with strength {cond_image_strength} at position {cond_image_idx}:" 198 | ) 199 | print( 200 | f" Image tile position: ({img_h_start}:{img_h_end}, {img_w_start}:{img_w_end})" 201 | ) 202 | print(f" Image tile size: {img_tile.shape}") 203 | 204 | # Add guide from image tile 205 | ( 206 | tile_positive, 207 | tile_negative, 208 | tile_latents, 209 | ) = LTXVAddGuide().generate( 210 | positive=tile_positive, 211 | negative=tile_negative, 212 | vae=vae, 213 | latent=tile_latents, 214 | image=img_tile, 215 | frame_idx=cond_image_idx, 216 | strength=cond_image_strength, 217 | ) 218 | if boost_latent_similarity: 219 | middle_latent_idx = (frames - 1) // 2 220 | middle_index_latent = LTXVSelectLatents().select_latents( 221 | samples=unconditioned_tile_latents, 222 | start_index=middle_latent_idx, 223 | end_index=middle_latent_idx, 224 | )[0] 225 | last_index_latent = LTXVSelectLatents().select_latents( 226 | samples=unconditioned_tile_latents, 227 | start_index=-1, 228 | end_index=-1, 229 | )[0] 230 | print( 231 | f"using LTXVAddLatentGuide on tiled latent with latent index {middle_latent_idx} and strength {latents_cond_strength}" 232 | ) 233 | ( 234 | tile_positive, 235 | tile_negative, 236 | tile_latents, 237 | ) = LTXVAddLatentGuide().generate( 238 | vae=vae, 239 | positive=tile_positive, 240 | negative=tile_negative, 241 | latent=tile_latents, 242 | guiding_latent=middle_index_latent, 243 | latent_idx=middle_latent_idx, 244 | strength=latents_cond_strength, 245 | ) 246 | print( 247 | f"using LTXVAddLatentGuide on tiled latent with latent index {frames-1} and strength {latents_cond_strength}" 248 | ) 249 | ( 250 | tile_positive, 251 | tile_negative, 252 | tile_latents, 253 | ) = LTXVAddLatentGuide().generate( 254 | vae=vae, 255 | positive=tile_positive, 256 | negative=tile_negative, 257 | latent=tile_latents, 258 | guiding_latent=last_index_latent, 259 | latent_idx=frames - 1, 260 | strength=latents_cond_strength, 261 | ) 262 | 263 | guider = copy.copy(guider) 264 | guider.set_conds(tile_positive, tile_negative) 265 | 266 | # Denoise the tile 267 | denoised_tile = SamplerCustomAdvanced().sample( 268 | noise=noise, 269 | guider=guider, 270 | sampler=sampler, 271 | sigmas=sigmas, 272 | latent_image=tile_latents, 273 | )[0] 274 | 275 | # Clean up guides if image conditioning was used 276 | if cond_images is not None: 277 | print("before guide crop", denoised_tile["samples"].shape) 278 | tile_positive, tile_negative, denoised_tile = LTXVCropGuides().crop( 279 | positive=tile_positive, 280 | negative=tile_negative, 281 | latent=denoised_tile, 282 | ) 283 | print("after guide crop", denoised_tile["samples"].shape) 284 | 285 | # Create weight mask for this tile 286 | tile_weights = torch.ones_like(tile) 287 | 288 | # Apply horizontal blending weights 289 | if h > 0: # Left overlap 290 | h_blend = torch.linspace(0, 1, overlap, device=tile.device) 291 | tile_weights[:, :, :, :, :overlap] *= h_blend.view(1, 1, 1, 1, -1) 292 | if h < horizontal_tiles - 1: # Right overlap 293 | h_blend = torch.linspace(1, 0, overlap, device=tile.device) 294 | tile_weights[:, :, :, :, -overlap:] *= h_blend.view(1, 1, 1, 1, -1) 295 | 296 | # Apply vertical blending weights 297 | if v > 0: # Top overlap 298 | v_blend = torch.linspace(0, 1, overlap, device=tile.device) 299 | tile_weights[:, :, :, :overlap, :] *= v_blend.view(1, 1, 1, -1, 1) 300 | if v < vertical_tiles - 1: # Bottom overlap 301 | v_blend = torch.linspace(1, 0, overlap, device=tile.device) 302 | tile_weights[:, :, :, -overlap:, :] *= v_blend.view(1, 1, 1, -1, 1) 303 | 304 | # Add weighted tile to output 305 | output[:, :, :, v_start:v_end, h_start:h_end] += ( 306 | denoised_tile["samples"] * tile_weights 307 | ) 308 | denoised_output[:, :, :, v_start:v_end, h_start:h_end] += ( 309 | denoised_tile["samples"] * tile_weights 310 | ) 311 | 312 | # Add weights to weight tensor 313 | weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights 314 | 315 | # Normalize by weights 316 | output = output / (weights + 1e-8) 317 | denoised_output = denoised_output / (weights + 1e-8) 318 | 319 | return {"samples": output}, {"samples": denoised_output} 320 | -------------------------------------------------------------------------------- /tricks/__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes.attn_bank_nodes import ( 2 | LTXAttentionBankNode, 3 | LTXAttentioOverrideNode, 4 | LTXPrepareAttnInjectionsNode, 5 | ) 6 | from .nodes.attn_override_node import LTXAttnOverrideNode 7 | from .nodes.latent_guide_node import AddLatentGuideNode 8 | from .nodes.ltx_feta_enhance_node import LTXFetaEnhanceNode 9 | from .nodes.ltx_flowedit_nodes import LTXFlowEditCFGGuiderNode, LTXFlowEditSamplerNode 10 | from .nodes.ltx_inverse_model_pred_nodes import ( 11 | LTXForwardModelSamplingPredNode, 12 | LTXReverseModelSamplingPredNode, 13 | ) 14 | from .nodes.ltx_pag_node import LTXPerturbedAttentionNode 15 | from .nodes.modify_ltx_model_node import ModifyLTXModelNode 16 | from .nodes.rectified_sampler_nodes import ( 17 | LTXRFForwardODESamplerNode, 18 | LTXRFReverseODESamplerNode, 19 | ) 20 | 21 | NODE_CLASS_MAPPINGS = { 22 | "ModifyLTXModel": ModifyLTXModelNode, 23 | "AddLatentGuide": AddLatentGuideNode, 24 | "LTXForwardModelSamplingPred": LTXForwardModelSamplingPredNode, 25 | "LTXReverseModelSamplingPred": LTXReverseModelSamplingPredNode, 26 | "LTXRFForwardODESampler": LTXRFForwardODESamplerNode, 27 | "LTXRFReverseODESampler": LTXRFReverseODESamplerNode, 28 | "LTXAttentionBank": LTXAttentionBankNode, 29 | "LTXPrepareAttnInjections": LTXPrepareAttnInjectionsNode, 30 | "LTXAttentioOverride": LTXAttentioOverrideNode, 31 | "LTXPerturbedAttention": LTXPerturbedAttentionNode, 32 | "LTXAttnOverride": LTXAttnOverrideNode, 33 | "LTXFlowEditCFGGuider": LTXFlowEditCFGGuiderNode, 34 | "LTXFlowEditSampler": LTXFlowEditSamplerNode, 35 | "LTXFetaEnhance": LTXFetaEnhanceNode, 36 | } 37 | 38 | NODE_DISPLAY_NAME_MAPPINGS = { 39 | "ModifyLTXModel": "Modify LTX Model", 40 | "AddLatentGuide": "Add LTX Latent Guide", 41 | "LTXAddImageGuide": "Add LTX Image Guide", 42 | "LTXForwardModelSamplingPred": "LTX Forward Model Pred", 43 | "LTXReverseModelSamplingPred": "LTX Reverse Model Pred", 44 | "LTXRFForwardODESampler": "LTX Rf-Inv Forward Sampler", 45 | "LTXRFReverseODESampler": "LTX Rf-Inv Reverse Sampler", 46 | "LTXAttentionBank": "LTX Attention Bank", 47 | "LTXPrepareAttnInjections": "LTX Prepare Attn Injection", 48 | "LTXAttentioOverride": "LTX Attn Block Override", 49 | "LTXPerturbedAttention": "LTX Apply Perturbed Attention", 50 | "LTXAttnOverride": "LTX Attention Override", 51 | "LTXFlowEditCFGGuider": "LTX Flow Edit CFG Guider", 52 | "LTXFlowEditSampler": "LTX Flow Edit Sampler", 53 | "LTXFetaEnhance": "LTX Feta Enhance", 54 | } 55 | -------------------------------------------------------------------------------- /tricks/modules/ltx_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import comfy.ldm.common_dit 4 | import comfy.ldm.modules.attention 5 | import torch 6 | from comfy.ldm.lightricks.model import ( 7 | BasicTransformerBlock, 8 | LTXVModel, 9 | apply_rotary_emb, 10 | precompute_freqs_cis, 11 | ) 12 | from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords 13 | from torch import nn 14 | 15 | from ..utils.feta_enhance_utils import get_feta_scores 16 | 17 | 18 | class LTXModifiedCrossAttention(nn.Module): 19 | def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): 20 | context = x if context is None else context 21 | context_v = x if context is None else context 22 | 23 | step = transformer_options.get("step", -1) 24 | total_steps = transformer_options.get("total_steps", 0) 25 | attn_bank = transformer_options.get("attn_bank", None) 26 | sample_mode = transformer_options.get("sample_mode", None) 27 | if attn_bank is not None and self.idx in attn_bank["block_map"]: 28 | len_conds = len(transformer_options["cond_or_uncond"]) 29 | pred_order = transformer_options["pred_order"] 30 | if ( 31 | sample_mode == "forward" 32 | and total_steps - step - 1 < attn_bank["save_steps"] 33 | ): 34 | step_idx = f"{pred_order}_{total_steps-step-1}" 35 | attn_bank["block_map"][self.idx][step_idx] = x.cpu() 36 | elif sample_mode == "reverse" and step < attn_bank["inject_steps"]: 37 | step_idx = f"{pred_order}_{step}" 38 | inject_settings = attn_bank.get("inject_settings", {}) 39 | if len(inject_settings) > 0: 40 | inj = ( 41 | attn_bank["block_map"][self.idx][step_idx] 42 | .to(x.device) 43 | .repeat(len_conds, 1, 1) 44 | ) 45 | if "q" in inject_settings: 46 | x = inj 47 | if "k" in inject_settings: 48 | context = inj 49 | if "v" in inject_settings: 50 | context_v = inj 51 | 52 | q = self.to_q(x) 53 | k = self.to_k(context) 54 | v = self.to_v(context_v) 55 | 56 | q = self.q_norm(q) 57 | k = self.k_norm(k) 58 | 59 | if pe is not None: 60 | q = apply_rotary_emb(q, pe) 61 | k = apply_rotary_emb(k, pe) 62 | 63 | feta_score = None 64 | if ( 65 | transformer_options.get("feta_weight", 0) > 0 66 | and self.idx in transformer_options["feta_layers"]["layers"] 67 | ): 68 | feta_score = get_feta_scores(q, k, self.heads, transformer_options) 69 | 70 | alt_attn_fn = ( 71 | transformer_options.get("patches_replace", {}) 72 | .get("layer", {}) 73 | .get(("self_attn", self.idx), None) 74 | ) 75 | if alt_attn_fn is not None: 76 | out = alt_attn_fn( 77 | q, 78 | k, 79 | v, 80 | self.heads, 81 | attn_precision=self.attn_precision, 82 | transformer_options=transformer_options, 83 | ) 84 | elif mask is None: 85 | out = comfy.ldm.modules.attention.optimized_attention( 86 | q, k, v, self.heads, attn_precision=self.attn_precision 87 | ) 88 | else: 89 | out = comfy.ldm.modules.attention.optimized_attention_masked( 90 | q, k, v, self.heads, mask, attn_precision=self.attn_precision 91 | ) 92 | 93 | if feta_score is not None: 94 | out *= feta_score 95 | 96 | return self.to_out(out) 97 | 98 | 99 | class LTXModifiedBasicTransformerBlock(BasicTransformerBlock): 100 | def forward( 101 | self, 102 | x, 103 | context=None, 104 | attention_mask=None, 105 | timestep=None, 106 | pe=None, 107 | transformer_options={}, 108 | ): 109 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 110 | self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) 111 | + timestep.reshape( 112 | x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1 113 | ) 114 | ).unbind(dim=2) 115 | x += ( 116 | self.attn1( 117 | comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, 118 | pe=pe, 119 | transformer_options=transformer_options, 120 | ) 121 | * gate_msa 122 | ) 123 | 124 | x += self.attn2(x, context=context, mask=attention_mask) 125 | 126 | y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp 127 | x += self.ff(y) * gate_mlp 128 | 129 | return x 130 | 131 | 132 | class LTXVModelModified(LTXVModel): 133 | 134 | def forward( 135 | self, 136 | x, 137 | timestep, 138 | context, 139 | attention_mask, 140 | frame_rate=25, 141 | transformer_options={}, 142 | keyframe_idxs=None, 143 | **kwargs, 144 | ): 145 | patches_replace = transformer_options.get("patches_replace", {}) 146 | 147 | orig_shape = list(x.shape) 148 | 149 | x, latent_coords = self.patchifier.patchify(x) 150 | pixel_coords = latent_to_pixel_coords( 151 | latent_coords=latent_coords, 152 | scale_factors=self.vae_scale_factors, 153 | causal_fix=self.causal_temporal_positioning, 154 | ) 155 | 156 | if keyframe_idxs is not None: 157 | pixel_coords[:, :, -keyframe_idxs.shape[2] :] = keyframe_idxs 158 | 159 | fractional_coords = pixel_coords.to(torch.float32) 160 | fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) 161 | 162 | x = self.patchify_proj(x) 163 | timestep = timestep * 1000.0 164 | 165 | if attention_mask is not None and not torch.is_floating_point(attention_mask): 166 | attention_mask = (attention_mask - 1).to(x.dtype).reshape( 167 | (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) 168 | ) * torch.finfo(x.dtype).max 169 | 170 | pe = precompute_freqs_cis( 171 | fractional_coords, dim=self.inner_dim, out_dtype=x.dtype 172 | ) 173 | 174 | batch_size = x.shape[0] 175 | timestep, embedded_timestep = self.adaln_single( 176 | timestep.flatten(), 177 | {"resolution": None, "aspect_ratio": None}, 178 | batch_size=batch_size, 179 | hidden_dtype=x.dtype, 180 | ) 181 | # Second dimension is 1 or number of tokens (if timestep_per_token) 182 | timestep = timestep.view(batch_size, -1, timestep.shape[-1]) 183 | embedded_timestep = embedded_timestep.view( 184 | batch_size, -1, embedded_timestep.shape[-1] 185 | ) 186 | 187 | # 2. Blocks 188 | if self.caption_projection is not None: 189 | batch_size = x.shape[0] 190 | context = self.caption_projection(context) 191 | context = context.view(batch_size, -1, x.shape[-1]) 192 | 193 | blocks_replace = patches_replace.get("dit", {}) 194 | for i, block in enumerate(self.transformer_blocks): 195 | if ("double_block", i) in blocks_replace: 196 | 197 | def block_wrap(args): 198 | out = {} 199 | out["img"] = block( 200 | args["img"], 201 | context=args["txt"], 202 | attention_mask=args["attention_mask"], 203 | timestep=args["vec"], 204 | pe=args["pe"], 205 | ) 206 | return out 207 | 208 | out = blocks_replace[("double_block", i)]( 209 | { 210 | "img": x, 211 | "txt": context, 212 | "attention_mask": attention_mask, 213 | "vec": timestep, 214 | "pe": pe, 215 | }, 216 | {"original_block": block_wrap}, 217 | ) 218 | x = out["img"] 219 | else: 220 | x = block( 221 | x, 222 | context=context, 223 | attention_mask=attention_mask, 224 | timestep=timestep, 225 | pe=pe, 226 | transformer_options=transformer_options, 227 | ) 228 | 229 | # 3. Output 230 | scale_shift_values = ( 231 | self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) 232 | + embedded_timestep[:, :, None] 233 | ) 234 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 235 | x = self.norm_out(x) 236 | # Modulation 237 | x = x * (1 + scale) + shift 238 | x = self.proj_out(x) 239 | 240 | x = self.patchifier.unpatchify( 241 | latents=x, 242 | output_height=orig_shape[3], 243 | output_width=orig_shape[4], 244 | output_num_frames=orig_shape[2], 245 | out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), 246 | ) 247 | 248 | return x 249 | 250 | 251 | def inject_model(diffusion_model): 252 | diffusion_model.__class__ = LTXVModelModified 253 | for idx, transformer_block in enumerate(diffusion_model.transformer_blocks): 254 | transformer_block.__class__ = LTXModifiedBasicTransformerBlock 255 | transformer_block.idx = idx 256 | transformer_block.attn1.__class__ = LTXModifiedCrossAttention 257 | transformer_block.attn1.idx = idx 258 | return diffusion_model 259 | -------------------------------------------------------------------------------- /tricks/nodes/attn_bank_nodes.py: -------------------------------------------------------------------------------- 1 | from ..utils.attn_bank import AttentionBank 2 | 3 | 4 | class LTXAttentionBankNode: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return { 8 | "required": { 9 | "save_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), 10 | "blocks": ("STRING", {"multiline": True}), 11 | } 12 | } 13 | 14 | RETURN_TYPES = ("ATTN_BANK",) 15 | FUNCTION = "build" 16 | 17 | CATEGORY = "ltxtricks" 18 | 19 | def build(self, save_steps, blocks=""): 20 | block_map = {} 21 | block_list = blocks.split(",") 22 | for block in block_list: 23 | block_idx = int(block) 24 | block_map[block_idx] = {} 25 | 26 | bank = AttentionBank(save_steps, block_map) 27 | return (bank,) 28 | 29 | 30 | class LTXPrepareAttnInjectionsNode: 31 | @classmethod 32 | def INPUT_TYPES(s): 33 | return { 34 | "required": { 35 | "latent": ("LATENT",), 36 | "attn_bank": ("ATTN_BANK",), 37 | "query": ("BOOLEAN", {"default": False}), 38 | "key": ("BOOLEAN", {"default": False}), 39 | "value": ("BOOLEAN", {"default": False}), 40 | "inject_steps": ( 41 | "INT", 42 | {"default": 0, "min": 0, "max": 1000, "step": 1}, 43 | ), 44 | }, 45 | "optional": {"blocks": ("LTX_BLOCKS",)}, 46 | } 47 | 48 | RETURN_TYPES = ("LATENT", "ATTN_INJ") 49 | FUNCTION = "prepare" 50 | 51 | CATEGORY = "fluxtapoz" 52 | 53 | def prepare(self, latent, attn_bank, query, key, value, inject_steps, blocks=None): 54 | if inject_steps > attn_bank["save_steps"]: 55 | raise ValueError("Can not inject more steps than were saved.") 56 | attn_bank = AttentionBank( 57 | attn_bank["save_steps"], attn_bank["block_map"], inject_steps 58 | ) 59 | attn_bank["inject_settings"] = set([]) 60 | if query: 61 | attn_bank["inject_settings"].add("q") 62 | if key: 63 | attn_bank["inject_settings"].add("k") 64 | if value: 65 | attn_bank["inject_settings"].add("v") 66 | 67 | if blocks is not None: 68 | attn_bank["block_map"] = {**attn_bank["block_map"]} 69 | for key in list(attn_bank["block_map"].keys()): 70 | if key not in blocks: 71 | del attn_bank["block_map"][key] 72 | 73 | # Hack to force order of operations in ComfyUI graph 74 | return (latent, attn_bank) 75 | 76 | 77 | class LTXAttentioOverrideNode: 78 | @classmethod 79 | def INPUT_TYPES(s): 80 | return {"required": {"blocks": ("STRING", {"multiline": True})}} 81 | 82 | RETURN_TYPES = ("LTX_BLOCKS",) 83 | FUNCTION = "build" 84 | 85 | CATEGORY = "ltxtricks" 86 | 87 | def build(self, blocks=""): 88 | block_set = set(list(int(block) for block in blocks.split(","))) 89 | 90 | return (block_set,) 91 | -------------------------------------------------------------------------------- /tricks/nodes/attn_override_node.py: -------------------------------------------------------------------------------- 1 | def is_integer(string): 2 | try: 3 | int(string) 4 | return True 5 | except ValueError: 6 | return False 7 | 8 | 9 | class LTXAttnOverrideNode: 10 | @classmethod 11 | def INPUT_TYPES(s): 12 | return { 13 | "required": { 14 | "layers": ("STRING", {"multiline": True}), 15 | } 16 | } 17 | 18 | RETURN_TYPES = ("ATTN_OVERRIDE",) 19 | FUNCTION = "build" 20 | 21 | CATEGORY = "ltxtricks/attn" 22 | 23 | def build(self, layers): 24 | layers_map = set([]) 25 | for block in layers.split(","): 26 | block = block.strip() 27 | if is_integer(block): 28 | layers_map.add(int(block)) 29 | 30 | return ({"layers": layers_map},) 31 | -------------------------------------------------------------------------------- /tricks/nodes/latent_guide_node.py: -------------------------------------------------------------------------------- 1 | import comfy_extras.nodes_lt as nodes_lt 2 | 3 | 4 | class AddLatentGuideNode(nodes_lt.LTXVAddGuide): 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return { 8 | "required": { 9 | "model": ("MODEL",), 10 | "latent": ("LATENT",), 11 | "image_latent": ("LATENT",), 12 | "index": ("INT", {"default": 0, "min": -1, "max": 9999, "step": 1}), 13 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}), 14 | } 15 | } 16 | 17 | RETURN_TYPES = ("MODEL", "LATENT") 18 | RETURN_NAMES = ("model", "latent") 19 | 20 | CATEGORY = "ltxtricks" 21 | FUNCTION = "generate" 22 | 23 | def generate(self, model, latent, image_latent, index, strength): 24 | noise_mask = nodes_lt.get_noise_mask(latent) 25 | latent = latent["samples"] 26 | 27 | image_latent = image_latent["samples"] 28 | 29 | latent, noise_mask = self.replace_latent_frames( 30 | latent, 31 | noise_mask, 32 | image_latent, 33 | index, 34 | strength, 35 | ) 36 | 37 | return ( 38 | model, 39 | {"samples": latent, "noise_mask": noise_mask}, 40 | ) 41 | -------------------------------------------------------------------------------- /tricks/nodes/ltx_feta_enhance_node.py: -------------------------------------------------------------------------------- 1 | DEFAULT_ATTN = { 2 | "layers": [i for i in range(0, 100, 1)], 3 | } 4 | 5 | 6 | class LTXFetaEnhanceNode: 7 | @classmethod 8 | def INPUT_TYPES(s): 9 | return { 10 | "required": { 11 | "model": ("MODEL",), 12 | "feta_weight": ( 13 | "FLOAT", 14 | {"default": 4, "min": -100.0, "max": 100.0, "step": 0.01}, 15 | ), 16 | }, 17 | "optional": {"attn_override": ("ATTN_OVERRIDE",)}, 18 | } 19 | 20 | RETURN_TYPES = ("MODEL",) 21 | 22 | CATEGORY = "ltxtricks" 23 | FUNCTION = "apply" 24 | 25 | def apply(self, model, feta_weight, attn_override=DEFAULT_ATTN): 26 | model = model.clone() 27 | 28 | model_options = model.model_options.copy() 29 | transformer_options = model_options["transformer_options"].copy() 30 | 31 | transformer_options["feta_weight"] = feta_weight 32 | transformer_options["feta_layers"] = attn_override 33 | model_options["transformer_options"] = transformer_options 34 | 35 | model.model_options = model_options 36 | return (model,) 37 | -------------------------------------------------------------------------------- /tricks/nodes/ltx_flowedit_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy.samplers import KSAMPLER, CFGGuider, sampling_function 3 | from tqdm import trange 4 | 5 | 6 | class FlowEditGuider(CFGGuider): 7 | def __init__(self, model_patcher): 8 | super().__init__(model_patcher) 9 | self.cfgs = {} 10 | 11 | def set_conds(self, **kwargs): 12 | self.inner_set_conds(kwargs) 13 | 14 | def set_cfgs(self, **kwargs): 15 | self.cfgs = {**kwargs} 16 | 17 | def predict_noise(self, x, timestep, model_options={}, seed=None): 18 | latent_type = model_options["transformer_options"]["latent_type"] 19 | positive = self.conds.get(f"{latent_type}_positive", None) 20 | negative = self.conds.get(f"{latent_type}_negative", None) 21 | cfg = self.cfgs.get(latent_type, self.cfg) 22 | return sampling_function( 23 | self.inner_model, 24 | x, 25 | timestep, 26 | negative, 27 | positive, 28 | cfg, 29 | model_options=model_options, 30 | seed=seed, 31 | ) 32 | 33 | 34 | class LTXFlowEditCFGGuiderNode: 35 | @classmethod 36 | def INPUT_TYPES(s): 37 | return { 38 | "required": { 39 | "model": ("MODEL",), 40 | "source_pos": ("CONDITIONING",), 41 | "source_neg": ("CONDITIONING",), 42 | "target_pos": ("CONDITIONING",), 43 | "target_neg": ("CONDITIONING",), 44 | "source_cfg": ( 45 | "FLOAT", 46 | {"default": 2, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "step": 0.01}, 47 | ), 48 | "target_cfg": ( 49 | "FLOAT", 50 | {"default": 4.5, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "step": 0.01}, 51 | ), 52 | } 53 | } 54 | 55 | RETURN_TYPES = ("GUIDER",) 56 | 57 | FUNCTION = "get_guider" 58 | CATEGORY = "ltxtricks" 59 | 60 | def get_guider( 61 | self, 62 | model, 63 | source_pos, 64 | source_neg, 65 | target_pos, 66 | target_neg, 67 | source_cfg, 68 | target_cfg, 69 | ): 70 | guider = FlowEditGuider(model) 71 | guider.set_conds( 72 | source_positive=source_pos, 73 | source_negative=source_neg, 74 | target_positive=target_pos, 75 | target_negative=target_neg, 76 | ) 77 | guider.set_cfgs(source=source_cfg, target=target_cfg) 78 | return (guider,) 79 | 80 | 81 | def get_flowedit_sample(skip_steps, refine_steps, seed): 82 | generator = torch.manual_seed(seed) 83 | 84 | @torch.no_grad() 85 | def flowedit_sample( 86 | model, x_init, sigmas, extra_args=None, callback=None, disable=None 87 | ): 88 | extra_args = {} if extra_args is None else extra_args 89 | 90 | model_options = extra_args.get("model_options", {}) 91 | transformer_options = model_options.get("transformer_options", {}) 92 | transformer_options = {**transformer_options} 93 | model_options["transformer_options"] = transformer_options 94 | extra_args["model_options"] = model_options 95 | denoise_mask = extra_args.get("denoise_mask", 1) 96 | if denoise_mask is None: 97 | denoise_mask = 1 98 | else: 99 | extra_args["denoise_mask"] = torch.ones_like(denoise_mask) 100 | 101 | source_extra_args = { 102 | **extra_args, 103 | "model_options": { 104 | "transformer_options": {**transformer_options, "latent_type": "source"} 105 | }, 106 | } 107 | 108 | sigmas = sigmas[skip_steps:] 109 | 110 | x_tgt = x_init.clone() 111 | N = len(sigmas) - 1 112 | s_in = x_init.new_ones([x_init.shape[0]]) 113 | 114 | for i in trange(N, disable=disable): 115 | sigma = sigmas[i] 116 | noise = torch.randn(x_init.shape, generator=generator).to(x_init.device) 117 | 118 | zt_src = (1 - sigma) * x_init + sigma * noise 119 | 120 | if i < N - refine_steps: 121 | zt_tgt = x_tgt + (zt_src - x_init) * denoise_mask 122 | transformer_options["latent_type"] = "source" 123 | source_extra_args["model_options"]["transformer_options"][ 124 | "latent_type" 125 | ] = "source" 126 | vt_src = model(zt_src, sigma * s_in, **source_extra_args) 127 | else: 128 | if i == N - refine_steps: 129 | x_tgt = x_tgt + (zt_src - x_init) * denoise_mask 130 | zt_tgt = x_tgt 131 | vt_src = 0 132 | 133 | transformer_options["latent_type"] = "target" 134 | vt_tgt = model(zt_tgt, sigma * s_in, **extra_args) 135 | 136 | v_delta = (vt_tgt - vt_src) * denoise_mask 137 | x_tgt += (sigmas[i + 1] - sigmas[i]) * v_delta 138 | 139 | if callback is not None: 140 | callback( 141 | { 142 | "x": x_tgt, 143 | "denoised": x_tgt, 144 | "i": i + skip_steps, 145 | "sigma": sigmas[i], 146 | "sigma_hat": sigmas[i], 147 | } 148 | ) 149 | 150 | return x_tgt 151 | 152 | return flowedit_sample 153 | 154 | 155 | class LTXFlowEditSamplerNode: 156 | @classmethod 157 | def INPUT_TYPES(s): 158 | return { 159 | "required": { 160 | "skip_steps": ( 161 | "INT", 162 | {"default": 4, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}, 163 | ), 164 | "refine_steps": ( 165 | "INT", 166 | {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}, 167 | ), 168 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), 169 | }, 170 | "optional": {}, 171 | } 172 | 173 | RETURN_TYPES = ("SAMPLER",) 174 | FUNCTION = "build" 175 | 176 | CATEGORY = "ltxtricks" 177 | 178 | def build(self, skip_steps, refine_steps, seed): 179 | sampler = KSAMPLER(get_flowedit_sample(skip_steps, refine_steps, seed)) 180 | return (sampler,) 181 | -------------------------------------------------------------------------------- /tricks/nodes/ltx_inverse_model_pred_nodes.py: -------------------------------------------------------------------------------- 1 | import comfy.latent_formats 2 | import comfy.model_sampling 3 | import comfy.sd 4 | 5 | 6 | class InverseCONST: 7 | def calculate_input(self, sigma, noise): 8 | return noise 9 | 10 | def calculate_denoised(self, sigma, model_output, model_input): 11 | sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) 12 | return model_output 13 | 14 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 15 | return latent_image 16 | 17 | def inverse_noise_scaling(self, sigma, latent): 18 | return latent 19 | 20 | 21 | class LTXForwardModelSamplingPredNode: 22 | @classmethod 23 | def INPUT_TYPES(s): 24 | return { 25 | "required": { 26 | "model": ("MODEL",), 27 | } 28 | } 29 | 30 | RETURN_TYPES = ("MODEL",) 31 | FUNCTION = "patch" 32 | 33 | CATEGORY = "ltxtricks" 34 | 35 | def patch(self, model): 36 | m = model.clone() 37 | 38 | sampling_base = comfy.model_sampling.ModelSamplingFlux 39 | sampling_type = InverseCONST 40 | 41 | class ModelSamplingAdvanced(sampling_base, sampling_type): 42 | pass 43 | 44 | model_sampling = ModelSamplingAdvanced(model.model.model_config) 45 | model_sampling.set_parameters(shift=1.15) 46 | m.add_object_patch("model_sampling", model_sampling) 47 | return (m,) 48 | 49 | 50 | class ReverseCONST: 51 | def calculate_input(self, sigma, noise): 52 | return noise 53 | 54 | def calculate_denoised(self, sigma, model_output, model_input): 55 | sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) 56 | return model_output # model_input - model_output * sigma 57 | 58 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): 59 | return latent_image 60 | 61 | def inverse_noise_scaling(self, sigma, latent): 62 | return latent / (1.0 - sigma) 63 | 64 | 65 | class LTXReverseModelSamplingPredNode: 66 | @classmethod 67 | def INPUT_TYPES(s): 68 | return { 69 | "required": { 70 | "model": ("MODEL",), 71 | } 72 | } 73 | 74 | RETURN_TYPES = ("MODEL",) 75 | FUNCTION = "patch" 76 | 77 | CATEGORY = "ltxtricks" 78 | 79 | def patch(self, model): 80 | m = model.clone() 81 | 82 | sampling_base = comfy.model_sampling.ModelSamplingFlux 83 | sampling_type = ReverseCONST 84 | 85 | class ModelSamplingAdvanced(sampling_base, sampling_type): 86 | pass 87 | 88 | model_sampling = ModelSamplingAdvanced(model.model.model_config) 89 | model_sampling.set_parameters(shift=1.15) 90 | m.add_object_patch("model_sampling", model_sampling) 91 | return (m,) 92 | -------------------------------------------------------------------------------- /tricks/nodes/ltx_pag_node.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import comfy.model_patcher 4 | import comfy.samplers 5 | import torch 6 | import torch.nn.functional as F 7 | from comfy.ldm.modules.attention import optimized_attention 8 | from einops import rearrange 9 | 10 | DEFAULT_PAG_LTX = {"layers": set([14])} 11 | 12 | 13 | def gaussian_blur_2d(img, kernel_size, sigma): 14 | height = img.shape[-1] 15 | kernel_size = min(kernel_size, height - (height % 2 - 1)) 16 | ksize_half = (kernel_size - 1) * 0.5 17 | 18 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) 19 | 20 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 21 | 22 | x_kernel = pdf / pdf.sum() 23 | x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) 24 | 25 | kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) 26 | kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) 27 | 28 | padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] 29 | 30 | img = F.pad(img, padding, mode="reflect") 31 | img = F.conv2d(img, kernel2d, groups=img.shape[-3]) 32 | 33 | return img 34 | 35 | 36 | class LTXPerturbedAttentionNode: 37 | @classmethod 38 | def INPUT_TYPES(s): 39 | return { 40 | "required": { 41 | "model": ("MODEL",), 42 | "scale": ( 43 | "FLOAT", 44 | { 45 | "default": 2.0, 46 | "min": 0.0, 47 | "max": 100.0, 48 | "step": 0.01, 49 | "round": 0.01, 50 | }, 51 | ), 52 | "rescale": ( 53 | "FLOAT", 54 | { 55 | "default": 0.5, 56 | "min": 0.0, 57 | "max": 100.0, 58 | "step": 0.01, 59 | "round": 0.01, 60 | }, 61 | ), 62 | "cfg": ( 63 | "FLOAT", 64 | { 65 | "default": 3.0, 66 | "min": 0.0, 67 | "max": 100.0, 68 | "step": 0.01, 69 | "round": 0.01, 70 | }, 71 | ), 72 | }, 73 | "optional": { 74 | "attn_override": ("ATTN_OVERRIDE",), 75 | # "attn_type": (["PAG", "SEG"],), 76 | }, 77 | } 78 | 79 | RETURN_TYPES = ("MODEL",) 80 | FUNCTION = "patch" 81 | 82 | CATEGORY = "ltxtricks/attn" 83 | 84 | def patch( 85 | self, model, scale, rescale, cfg, attn_override=DEFAULT_PAG_LTX, attn_type="PAG" 86 | ): 87 | m = model.clone() 88 | 89 | def pag_fn(q, k, v, heads, attn_precision=None, transformer_options=None): 90 | return v 91 | 92 | def seg_fn(q, k, v, heads, attn_precision=None, transformer_options=None): 93 | _, sequence_length, _ = q.shape 94 | b, c, f, h, w = transformer_options["original_shape"] 95 | 96 | q = rearrange(q, "b (f h w) d -> b (f d) w h", h=h, w=w) 97 | kernel_size = math.ceil(6 * scale) + 1 - math.ceil(6 * scale) % 2 98 | q = gaussian_blur_2d(q, kernel_size, scale) 99 | q = rearrange(q, "b (f d) w h -> b (f h w) d", f=f) 100 | return optimized_attention(q, k, v, heads, attn_precision=attn_precision) 101 | 102 | def post_cfg_function(args): 103 | model = args["model"] 104 | 105 | cond_pred = args["cond_denoised"] 106 | uncond_pred = args["uncond_denoised"] 107 | 108 | len_conds = 1 if args.get("uncond", None) is None else 2 109 | 110 | cond = args["cond"] 111 | sigma = args["sigma"] 112 | model_options = args["model_options"].copy() 113 | x = args["input"] 114 | 115 | if scale == 0: 116 | if len_conds == 1: 117 | return cond_pred 118 | return uncond_pred + (cond_pred - uncond_pred) 119 | 120 | attn_fn = pag_fn if attn_type == "PAG" else seg_fn 121 | for block_idx in attn_override["layers"]: 122 | model_options = comfy.model_patcher.set_model_options_patch_replace( 123 | model_options, attn_fn, "layer", "self_attn", int(block_idx) 124 | ) 125 | 126 | (perturbed,) = comfy.samplers.calc_cond_batch( 127 | model, [cond], x, sigma, model_options 128 | ) 129 | 130 | # if len_conds == 1: 131 | # output = cond_pred + scale * (cond_pred - pag) 132 | # else: 133 | # output = cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - pag) 134 | 135 | output = ( 136 | uncond_pred 137 | + cfg * (cond_pred - uncond_pred) 138 | + scale * (cond_pred - perturbed) 139 | ) 140 | if rescale > 0: 141 | factor = cond_pred.std() / output.std() 142 | factor = rescale * factor + (1 - rescale) 143 | output = output * factor 144 | 145 | return output 146 | 147 | m.set_model_sampler_post_cfg_function(post_cfg_function) 148 | 149 | return (m,) 150 | -------------------------------------------------------------------------------- /tricks/nodes/modify_ltx_model_node.py: -------------------------------------------------------------------------------- 1 | from ..modules.ltx_model import inject_model 2 | 3 | 4 | class ModifyLTXModelNode: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return { 8 | "required": { 9 | "model": ("MODEL",), 10 | } 11 | } 12 | 13 | RETURN_TYPES = ("MODEL",) 14 | 15 | CATEGORY = "ltxtricks" 16 | FUNCTION = "modify" 17 | 18 | def modify(self, model): 19 | model.model.diffusion_model = inject_model(model.model.diffusion_model) 20 | return (model,) 21 | -------------------------------------------------------------------------------- /tricks/nodes/rectified_sampler_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy.samplers import KSAMPLER 3 | from tqdm import trange 4 | 5 | 6 | def generate_trend_values(steps, start_time, end_time, eta, eta_trend): 7 | eta_values = [0] * steps 8 | 9 | if eta_trend == "constant": 10 | for i in range(start_time, end_time): 11 | eta_values[i] = eta 12 | elif eta_trend == "linear_increase": 13 | for i in range(start_time, end_time): 14 | progress = (i - start_time) / (end_time - start_time - 1) 15 | eta_values[i] = eta * progress 16 | elif eta_trend == "linear_decrease": 17 | for i in range(start_time, end_time): 18 | progress = 1 - (i - start_time) / (end_time - start_time - 1) 19 | eta_values[i] = eta * progress 20 | 21 | return eta_values 22 | 23 | 24 | def get_sample_forward( 25 | gamma, start_step, end_step, gamma_trend, seed, attn_bank=None, order="first" 26 | ): 27 | # Controlled Forward ODE (Algorithm 1) 28 | generator = torch.Generator() 29 | generator.manual_seed(seed) 30 | 31 | @torch.no_grad() 32 | def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None): 33 | if attn_bank is not None: 34 | for block_idx in attn_bank["block_map"]: 35 | attn_bank["block_map"][block_idx].clear() 36 | 37 | extra_args = {} if extra_args is None else extra_args 38 | model_options = extra_args.get("model_options", {}) 39 | model_options = {**model_options} 40 | transformer_options = model_options.get("transformer_options", {}) 41 | transformer_options = { 42 | **transformer_options, 43 | "total_steps": len(sigmas) - 1, 44 | "sample_mode": "forward", 45 | "attn_bank": attn_bank, 46 | } 47 | model_options["transformer_options"] = transformer_options 48 | extra_args["model_options"] = model_options 49 | 50 | Y = y0.clone() 51 | y1 = torch.randn(Y.shape, generator=generator).to(y0.device) 52 | N = len(sigmas) - 1 53 | s_in = y0.new_ones([y0.shape[0]]) 54 | gamma_values = generate_trend_values( 55 | N, start_step, end_step, gamma, gamma_trend 56 | ) 57 | for i in trange(N, disable=disable): 58 | transformer_options["step"] = i 59 | sigma = sigmas[i] 60 | sigma_next = sigmas[i + 1] 61 | t_i = model.inner_model.inner_model.model_sampling.timestep(sigmas[i]) 62 | 63 | conditional_vector_field = (y1 - Y) / (1 - t_i) 64 | 65 | transformer_options["pred_order"] = "first" 66 | pred = model( 67 | Y, s_in * sigmas[i], **extra_args 68 | ) # this implementation takes sigma instead of timestep 69 | 70 | if order == "second": 71 | transformer_options["pred_order"] = "second" 72 | img_mid = Y + (sigma_next - sigma) / 2 * pred 73 | sigma_mid = sigma + (sigma_next - sigma) / 2 74 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) 75 | 76 | first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2) 77 | pred = pred + gamma_values[i] * (conditional_vector_field - pred) 78 | # first_order = first_order + gamma_values[i] * (conditional_vector_field - first_order) 79 | Y = ( 80 | Y 81 | + (sigma_next - sigma) * pred 82 | + 0.5 * (sigma_next - sigma) ** 2 * first_order 83 | ) 84 | else: 85 | pred = pred + gamma_values[i] * (conditional_vector_field - pred) 86 | Y = Y + pred * (sigma_next - sigma) 87 | 88 | if callback is not None: 89 | callback( 90 | {"x": Y, "denoised": Y, "i": i, "sigma": sigma, "sigma_hat": sigma} 91 | ) 92 | 93 | return Y 94 | 95 | return sample_forward 96 | 97 | 98 | def get_sample_reverse( 99 | latent_image, eta, start_time, end_time, eta_trend, attn_bank=None, order="first" 100 | ): 101 | # Controlled Reverse ODE (Algorithm 2) 102 | @torch.no_grad() 103 | def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None): 104 | extra_args = {} if extra_args is None else extra_args 105 | model_options = extra_args.get("model_options", {}) 106 | model_options = {**model_options} 107 | transformer_options = model_options.get("transformer_options", {}) 108 | transformer_options = { 109 | **transformer_options, 110 | "total_steps": len(sigmas) - 1, 111 | "sample_mode": "reverse", 112 | "attn_bank": attn_bank, 113 | } 114 | model_options["transformer_options"] = transformer_options 115 | extra_args["model_options"] = model_options 116 | 117 | X = y1.clone() 118 | N = len(sigmas) - 1 119 | y0 = latent_image.clone().to(y1.device) 120 | s_in = y0.new_ones([y0.shape[0]]) 121 | eta_values = generate_trend_values(N, start_time, end_time, eta, eta_trend) 122 | for i in trange(N, disable=disable): 123 | transformer_options["step"] = i 124 | t_i = 1 - model.inner_model.inner_model.model_sampling.timestep(sigmas[i]) 125 | sigma = sigmas[i] 126 | sigma_prev = sigmas[i + 1] 127 | 128 | conditional_vector_field = (y0 - X) / (1 - t_i) 129 | 130 | transformer_options["pred_order"] = "first" 131 | pred = model( 132 | X, sigma * s_in, **extra_args 133 | ) # this implementation takes sigma instead of timestep 134 | 135 | if order == "second": 136 | transformer_options["pred_order"] = "second" 137 | img_mid = X + (sigma_prev - sigma) / 2 * pred 138 | sigma_mid = sigma + (sigma_prev - sigma) / 2 139 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) 140 | 141 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2) 142 | pred = -pred + eta_values[i] * (conditional_vector_field + pred) 143 | 144 | first_order = -first_order + eta_values[i] * ( 145 | conditional_vector_field + first_order 146 | ) 147 | X = ( 148 | X 149 | + (sigma - sigma_prev) * pred 150 | + 0.5 * (sigma - sigma_prev) ** 2 * first_order 151 | ) 152 | else: 153 | controlled_vector_field = -pred + eta_values[i] * ( 154 | conditional_vector_field + pred 155 | ) 156 | X = X + controlled_vector_field * (sigma - sigma_prev) 157 | 158 | if callback is not None: 159 | callback( 160 | { 161 | "x": X, 162 | "denoised": X, 163 | "i": i, 164 | "sigma": sigmas[i], 165 | "sigma_hat": sigmas[i], 166 | } 167 | ) 168 | 169 | return X 170 | 171 | return sample_reverse 172 | 173 | 174 | class LTXRFForwardODESamplerNode: 175 | @classmethod 176 | def INPUT_TYPES(s): 177 | return { 178 | "required": { 179 | "gamma": ( 180 | "FLOAT", 181 | {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01}, 182 | ), 183 | "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), 184 | "end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}), 185 | "gamma_trend": (["linear_decrease", "linear_increase", "constant"],), 186 | }, 187 | "optional": { 188 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}), 189 | "attn_bank": ("ATTN_BANK",), 190 | "order": (["first", "second"],), 191 | }, 192 | } 193 | 194 | RETURN_TYPES = ("SAMPLER",) 195 | FUNCTION = "build" 196 | 197 | CATEGORY = "ltxtricks" 198 | 199 | def build( 200 | self, 201 | gamma, 202 | start_step, 203 | end_step, 204 | gamma_trend, 205 | seed=0, 206 | attn_bank=None, 207 | order="first", 208 | ): 209 | sampler = KSAMPLER( 210 | get_sample_forward( 211 | gamma, 212 | start_step, 213 | end_step, 214 | gamma_trend, 215 | seed, 216 | attn_bank=attn_bank, 217 | order=order, 218 | ) 219 | ) 220 | 221 | return (sampler,) 222 | 223 | 224 | class LTXRFReverseODESamplerNode: 225 | @classmethod 226 | def INPUT_TYPES(s): 227 | return { 228 | "required": { 229 | "model": ("MODEL",), 230 | "latent_image": ("LATENT",), 231 | "eta": ( 232 | "FLOAT", 233 | {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}, 234 | ), 235 | "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), 236 | "end_step": ("INT", {"default": 15, "min": 0, "max": 1000, "step": 1}), 237 | }, 238 | "optional": { 239 | "eta_trend": (["linear_decrease", "linear_increase", "constant"],), 240 | "attn_inj": ("ATTN_INJ",), 241 | "order": (["first", "second"],), 242 | }, 243 | } 244 | 245 | RETURN_TYPES = ("SAMPLER",) 246 | FUNCTION = "build" 247 | 248 | CATEGORY = "ltxtricks" 249 | 250 | def build( 251 | self, 252 | model, 253 | latent_image, 254 | eta, 255 | start_step, 256 | end_step, 257 | eta_trend="constant", 258 | attn_inj=None, 259 | order="first", 260 | ): 261 | process_latent_in = model.get_model_object("process_latent_in") 262 | latent_image = process_latent_in(latent_image["samples"]) 263 | sampler = KSAMPLER( 264 | get_sample_reverse( 265 | latent_image, 266 | eta, 267 | start_step, 268 | end_step, 269 | eta_trend, 270 | attn_bank=attn_inj, 271 | order=order, 272 | ) 273 | ) 274 | 275 | return (sampler,) 276 | -------------------------------------------------------------------------------- /tricks/nodes/rf_edit_sampler_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from comfy.samplers import KSAMPLER 3 | from tqdm import trange 4 | 5 | 6 | def get_sample_forward(attn_bank, save_steps, single_layers, double_layers): 7 | @torch.no_grad() 8 | def sample_forward(model, x, sigmas, extra_args=None, callback=None, disable=None): 9 | attn_bank.clear() 10 | attn_bank["save_steps"] = save_steps 11 | 12 | extra_args = {} if extra_args is None else extra_args 13 | 14 | model_options = extra_args.get("model_options", {}) 15 | model_options = {**model_options} 16 | transformer_options = model_options.get("transformer_options", {}) 17 | transformer_options = {**transformer_options} 18 | model_options["transformer_options"] = transformer_options 19 | extra_args["model_options"] = model_options 20 | 21 | N = len(sigmas) - 1 22 | s_in = x.new_ones([x.shape[0]]) 23 | for i in trange(N, disable=disable): 24 | sigma = sigmas[i] 25 | sigma_next = sigmas[i + 1] 26 | 27 | if N - i - 1 < save_steps: 28 | attn_bank[N - i - 1] = {"first": {}, "mid": {}} 29 | 30 | transformer_options["rfedit"] = { 31 | "step": N - i - 1, 32 | "process": "forward" if N - i - 1 < save_steps else None, 33 | "pred": "first", 34 | "bank": attn_bank, 35 | "single_layers": single_layers, 36 | "double_layers": double_layers, 37 | } 38 | 39 | pred = model(x, s_in * sigma, **extra_args) 40 | 41 | transformer_options["rfedit"] = { 42 | "step": N - i - 1, 43 | "process": "forward" if N - i - 1 < save_steps else None, 44 | "pred": "mid", 45 | "bank": attn_bank, 46 | "single_layers": single_layers, 47 | "double_layers": double_layers, 48 | } 49 | 50 | img_mid = x + (sigma_next - sigma) / 2 * pred 51 | sigma_mid = sigma + (sigma_next - sigma) / 2 52 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) 53 | 54 | first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2) 55 | x = ( 56 | x 57 | + (sigma_next - sigma) * pred 58 | + 0.5 * (sigma_next - sigma) ** 2 * first_order 59 | ) 60 | 61 | if callback is not None: 62 | callback( 63 | { 64 | "x": x, 65 | "denoised": x, 66 | "i": i, 67 | "sigma": sigmas[i], 68 | "sigma_hat": sigmas[i], 69 | } 70 | ) 71 | 72 | return x 73 | 74 | return sample_forward 75 | 76 | 77 | def get_sample_reverse(attn_bank, inject_steps, single_layers, double_layers): 78 | @torch.no_grad() 79 | def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None): 80 | if inject_steps > attn_bank["save_steps"]: 81 | raise ValueError( 82 | f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}' 83 | ) 84 | 85 | extra_args = {} if extra_args is None else extra_args 86 | 87 | model_options = extra_args.get("model_options", {}) 88 | model_options = {**model_options} 89 | transformer_options = model_options.get("transformer_options", {}) 90 | transformer_options = {**transformer_options} 91 | model_options["transformer_options"] = transformer_options 92 | extra_args["model_options"] = model_options 93 | 94 | N = len(sigmas) - 1 95 | s_in = x.new_ones([x.shape[0]]) 96 | for i in trange(N, disable=disable): 97 | sigma = sigmas[i] 98 | sigma_prev = sigmas[i + 1] 99 | 100 | transformer_options["rfedit"] = { 101 | "step": i, 102 | "process": "reverse" if i < inject_steps else None, 103 | "pred": "first", 104 | "bank": attn_bank, 105 | "single_layers": single_layers, 106 | "double_layers": double_layers, 107 | } 108 | 109 | pred = model(x, s_in * sigma, **extra_args) 110 | 111 | transformer_options["rfedit"] = { 112 | "step": i, 113 | "process": "reverse" if i < inject_steps else None, 114 | "pred": "mid", 115 | "bank": attn_bank, 116 | "single_layers": single_layers, 117 | "double_layers": double_layers, 118 | } 119 | 120 | img_mid = x + (sigma_prev - sigma) / 2 * pred 121 | sigma_mid = sigma + (sigma_prev - sigma) / 2 122 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) 123 | 124 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2) 125 | x = ( 126 | x 127 | + (sigma_prev - sigma) * pred 128 | + 0.5 * (sigma_prev - sigma) ** 2 * first_order 129 | ) 130 | 131 | if callback is not None: 132 | callback( 133 | { 134 | "x": x, 135 | "denoised": x, 136 | "i": i, 137 | "sigma": sigmas[i], 138 | "sigma_hat": sigmas[i], 139 | } 140 | ) 141 | 142 | return x 143 | 144 | return sample_reverse 145 | 146 | 147 | DEFAULT_SINGLE_LAYERS = {} 148 | for i in range(38): 149 | DEFAULT_SINGLE_LAYERS[f"{i}"] = i > 19 150 | 151 | DEFAULT_DOUBLE_LAYERS = {} 152 | for i in range(19): 153 | DEFAULT_DOUBLE_LAYERS[f"{i}"] = False 154 | 155 | 156 | class FlowEditForwardSamplerNode: 157 | @classmethod 158 | def INPUT_TYPES(s): 159 | return { 160 | "required": { 161 | "save_steps": ( 162 | "INT", 163 | {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}, 164 | ), 165 | }, 166 | "optional": { 167 | "single_layers": ("SINGLE_LAYERS",), 168 | "double_layers": ("DOUBLE_LAYERS",), 169 | }, 170 | } 171 | 172 | RETURN_TYPES = ("SAMPLER", "ATTN_INJ") 173 | FUNCTION = "build" 174 | 175 | CATEGORY = "fluxtapoz" 176 | 177 | def build( 178 | self, 179 | save_steps, 180 | single_layers=DEFAULT_SINGLE_LAYERS, 181 | double_layers=DEFAULT_DOUBLE_LAYERS, 182 | ): 183 | attn_bank = {} 184 | sampler = KSAMPLER( 185 | get_sample_forward(attn_bank, save_steps, single_layers, double_layers) 186 | ) 187 | 188 | return (sampler, attn_bank) 189 | 190 | 191 | # class FlowEditReverseSamplerNode: 192 | # @classmethod 193 | # def INPUT_TYPES(s): 194 | # return { 195 | # "required": { 196 | # "attn_inj": ("ATTN_INJ",), 197 | # "latent_image": ("LATENT",), 198 | # "eta": ( 199 | # "FLOAT", 200 | # {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}, 201 | # ), 202 | # "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), 203 | # "end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}), 204 | # }, 205 | # "optional": {}, 206 | # } 207 | # 208 | # RETURN_TYPES = ("SAMPLER",) 209 | # FUNCTION = "build" 210 | # 211 | # CATEGORY = "fluxtapoz" 212 | # 213 | # def build(self, latent_image, eta, start_step, end_step): 214 | # sampler = KSAMPLER( 215 | # get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers) 216 | # ) 217 | # return (sampler,) 218 | 219 | 220 | def get_sample_reverse2(attn_bank, inject_steps, single_layers, double_layers): 221 | @torch.no_grad() 222 | def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None): 223 | if inject_steps > attn_bank["save_steps"]: 224 | raise ValueError( 225 | f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}' 226 | ) 227 | 228 | extra_args = {} if extra_args is None else extra_args 229 | 230 | model_options = extra_args.get("model_options", {}) 231 | model_options = {**model_options} 232 | transformer_options = model_options.get("transformer_options", {}) 233 | transformer_options = {**transformer_options} 234 | model_options["transformer_options"] = transformer_options 235 | extra_args["model_options"] = model_options 236 | 237 | N = len(sigmas) - 1 238 | s_in = x.new_ones([x.shape[0]]) 239 | for i in trange(N, disable=disable): 240 | sigma = sigmas[i] 241 | sigma_prev = sigmas[i + 1] 242 | 243 | transformer_options["rfedit"] = { 244 | "step": i, 245 | "process": "reverse" if i < inject_steps else None, 246 | "pred": "first", 247 | "bank": attn_bank, 248 | "single_layers": single_layers, 249 | "double_layers": double_layers, 250 | } 251 | 252 | pred = model(x, s_in * sigma, **extra_args) 253 | 254 | transformer_options["rfedit"] = { 255 | "step": i, 256 | "process": "reverse" if i < inject_steps else None, 257 | "pred": "mid", 258 | "bank": attn_bank, 259 | "single_layers": single_layers, 260 | "double_layers": double_layers, 261 | } 262 | 263 | img_mid = x + (sigma_prev - sigma) / 2 * pred 264 | sigma_mid = sigma + (sigma_prev - sigma) / 2 265 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args) 266 | 267 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2) 268 | x = ( 269 | x 270 | + (sigma_prev - sigma) * pred 271 | + 0.5 * (sigma_prev - sigma) ** 2 * first_order 272 | ) 273 | 274 | if callback is not None: 275 | callback( 276 | { 277 | "x": x, 278 | "denoised": x, 279 | "i": i, 280 | "sigma": sigmas[i], 281 | "sigma_hat": sigmas[i], 282 | } 283 | ) 284 | 285 | return x 286 | 287 | return sample_reverse 288 | 289 | 290 | class FlowEdit2ReverseSamplerNode: 291 | @classmethod 292 | def INPUT_TYPES(s): 293 | return { 294 | "required": { 295 | "attn_inj": ("ATTN_INJ",), 296 | "inject_steps": ( 297 | "INT", 298 | {"default": 0, "min": 0, "max": 1000, "step": 1}, 299 | ), 300 | }, 301 | "optional": { 302 | "single_layers": ("SINGLE_LAYERS",), 303 | "double_layers": ("DOUBLE_LAYERS",), 304 | }, 305 | } 306 | 307 | RETURN_TYPES = ("SAMPLER",) 308 | FUNCTION = "build" 309 | 310 | CATEGORY = "ltxtricks" 311 | 312 | def build( 313 | self, 314 | attn_inj, 315 | inject_steps, 316 | single_layers=DEFAULT_SINGLE_LAYERS, 317 | double_layers=DEFAULT_DOUBLE_LAYERS, 318 | ): 319 | sampler = KSAMPLER( 320 | get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers) 321 | ) 322 | return (sampler,) 323 | 324 | 325 | class PrepareAttnBankNode: 326 | @classmethod 327 | def INPUT_TYPES(s): 328 | return { 329 | "required": { 330 | "latent": ("LATENT",), 331 | "attn_inj": ("ATTN_INJ",), 332 | } 333 | } 334 | 335 | RETURN_TYPES = ("LATENT", "ATTN_INJ") 336 | FUNCTION = "prepare" 337 | 338 | CATEGORY = "ltxtricks" 339 | 340 | def prepare(self, latent, attn_inj): 341 | # Hack to force order of operations in ComfyUI graph 342 | return (latent, attn_inj) 343 | 344 | 345 | class RFSingleBlocksOverrideNode: 346 | @classmethod 347 | def INPUT_TYPES(s): 348 | layers = {} 349 | for i in range(38): 350 | layers[f"{i}"] = ("BOOLEAN", {"default": i > 19}) 351 | return {"required": layers} 352 | 353 | RETURN_TYPES = ("SINGLE_LAYERS",) 354 | FUNCTION = "build" 355 | 356 | CATEGORY = "ltxtricks" 357 | 358 | def build(self, *args, **kwargs): 359 | return (kwargs,) 360 | 361 | 362 | class RFDoubleBlocksOverrideNode: 363 | @classmethod 364 | def INPUT_TYPES(s): 365 | layers = {} 366 | for i in range(19): 367 | layers[f"{i}"] = ("BOOLEAN", {"default": False}) 368 | return {"required": layers} 369 | 370 | RETURN_TYPES = ("DOUBLE_LAYERS",) 371 | FUNCTION = "build" 372 | 373 | CATEGORY = "ltxtricks" 374 | 375 | def build(self, *args, **kwargs): 376 | return (kwargs,) 377 | -------------------------------------------------------------------------------- /tricks/utils/attn_bank.py: -------------------------------------------------------------------------------- 1 | class AttentionBank: 2 | def __init__(self, save_steps, block_map, inject_steps=None): 3 | self._data = { 4 | "save_steps": save_steps, 5 | "block_map": block_map, 6 | "inject_steps": inject_steps, 7 | } 8 | 9 | def __getitem__(self, key): 10 | return self._data[key] 11 | 12 | def __setitem__(self, key, value): 13 | self._data[key] = value 14 | 15 | def get(self, key, default=None): 16 | return self._data.get(key, default) 17 | -------------------------------------------------------------------------------- /tricks/utils/feta_enhance_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | 5 | def _feta_score(query_image, key_image, head_dim, num_frames, enhance_weight): 6 | scale = head_dim**-0.5 7 | query_image = query_image * scale 8 | attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 9 | attn_temp = attn_temp.to(torch.float32) 10 | attn_temp = attn_temp.softmax(dim=-1) 11 | 12 | # Reshape to [batch_size * num_tokens, num_frames, num_frames] 13 | attn_temp = attn_temp.reshape(-1, num_frames, num_frames) 14 | 15 | # Create a mask for diagonal elements 16 | diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() 17 | diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) 18 | 19 | # Zero out diagonal elements 20 | attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) 21 | 22 | # Calculate mean for each token's attention matrix 23 | # Number of off-diagonal elements per matrix is n*n - n 24 | num_off_diag = num_frames * num_frames - num_frames 25 | mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag 26 | 27 | enhance_scores = mean_scores.mean() * (num_frames + enhance_weight) 28 | enhance_scores = enhance_scores.clamp(min=1) 29 | return enhance_scores 30 | 31 | 32 | def get_feta_scores(img_q, img_k, num_heads, transformer_options): 33 | num_frames = transformer_options["original_shape"][2] 34 | _, ST, dim = img_q.shape 35 | head_dim = dim // num_heads 36 | spatial_dim = ST // num_frames 37 | 38 | query_image = rearrange( 39 | img_q, 40 | "B (T S) (N C) -> (B S) N T C", 41 | T=num_frames, 42 | S=spatial_dim, 43 | N=num_heads, 44 | C=head_dim, 45 | ) 46 | key_image = rearrange( 47 | img_k, 48 | "B (T S) (N C) -> (B S) N T C", 49 | T=num_frames, 50 | S=spatial_dim, 51 | N=num_heads, 52 | C=head_dim, 53 | ) 54 | weight = transformer_options.get("feta_weight", 0) 55 | return _feta_score(query_image, key_image, head_dim, num_frames, weight) 56 | -------------------------------------------------------------------------------- /tricks/utils/latent_guide.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LatentGuide(torch.nn.Module): 5 | def __init__(self, latent: torch.Tensor, index) -> None: 6 | super().__init__() 7 | self.index = index 8 | self.register_buffer("latent", latent) 9 | -------------------------------------------------------------------------------- /tricks/utils/module_utils.py: -------------------------------------------------------------------------------- 1 | def isinstance_str(x: object, cls_name: str): 2 | for _cls in x.__class__.__mro__: 3 | if _cls.__name__ == cls_name: 4 | return True 5 | 6 | return False 7 | -------------------------------------------------------------------------------- /tricks/utils/noise_utils.py: -------------------------------------------------------------------------------- 1 | def get_alphacumprod(sigma): 2 | return 1 / ((sigma * sigma) + 1) 3 | 4 | 5 | def add_noise(src_latent, noise, sigma): 6 | alphas_cumprod = get_alphacumprod(sigma) 7 | 8 | sqrt_alpha_prod = alphas_cumprod**0.5 9 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 10 | while len(sqrt_alpha_prod.shape) < len(src_latent.shape): 11 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 12 | 13 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5 14 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 15 | while len(sqrt_one_minus_alpha_prod.shape) < len(src_latent.shape): 16 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 17 | 18 | noisy_samples = sqrt_alpha_prod * src_latent + sqrt_one_minus_alpha_prod * noise 19 | return noisy_samples 20 | 21 | 22 | def add_noise_flux(src_latent, noise, sigma): 23 | return sigma * noise + (1.0 - sigma) * src_latent 24 | --------------------------------------------------------------------------------