├── .gitattributes
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── __init__.py
├── decoder_noise.py
├── easy_samplers.py
├── example_workflows
├── 13b-distilled
│ ├── ltxv-13b-dist-i2v-base-fp8.json
│ ├── ltxv-13b-dist-i2v-base.json
│ ├── ltxv-13b-dist-i2v-extend.json
│ └── ltxv-13b-dist-i2v-keyframes.json
├── low_level
│ ├── end.jpg
│ ├── fox.jpg
│ ├── jeep.mp4
│ ├── ltxvideo-first-sequence-conditioning.json
│ ├── ltxvideo-frame-interpolation.json
│ ├── ltxvideo-i2v-distilled.json
│ ├── ltxvideo-i2v.json
│ ├── ltxvideo-last-sequence-conditioning.json
│ ├── ltxvideo-t2v.json
│ ├── moto.png
│ ├── shrek2.jpg
│ └── start.jpg
├── ltxv-13b-i2v-base-fp8.json
├── ltxv-13b-i2v-base.json
├── ltxv-13b-i2v-extend.json
├── ltxv-13b-i2v-keyframes.json
├── ltxv-13b-i2v-mixed-multiscale.json
└── tricks
│ ├── ltxvideo-flow-edit.json
│ ├── ltxvideo-flow-edit.png
│ ├── ltxvideo-rf-edit.json
│ ├── ltxvideo-rf-edit.png
│ ├── ref.png
│ ├── shot.mp4
│ ├── shot2.mp4
│ ├── shrek2.jpg
│ └── shrek3.jpg
├── film_grain.py
├── guide.py
├── latent_adain.py
├── latent_upsampler.py
├── latents.py
├── nodes_registry.py
├── presets
└── stg_advanced_presets.json
├── prompt_enhancer_nodes.py
├── prompt_enhancer_utils.py
├── pyproject.toml
├── q8_nodes.py
├── recurrent_sampler.py
├── requirements.txt
├── stg.py
├── tiled_sampler.py
└── tricks
├── __init__.py
├── modules
└── ltx_model.py
├── nodes
├── attn_bank_nodes.py
├── attn_override_node.py
├── latent_guide_node.py
├── ltx_feta_enhance_node.py
├── ltx_flowedit_nodes.py
├── ltx_inverse_model_pred_nodes.py
├── ltx_pag_node.py
├── modify_ltx_model_node.py
├── rectified_sampler_nodes.py
└── rf_edit_sampler_nodes.py
└── utils
├── attn_bank.py
├── feta_enhance_utils.py
├── latent_guide.py
├── module_utils.py
└── noise_utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | **/*.png filter=lfs diff=lfs merge=lfs -text
2 | **/*.mp4 filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # general formatting
2 | repos:
3 | - repo: https://github.com/pre-commit/pre-commit-hooks
4 | rev: v4.6.0
5 | hooks:
6 | - id: check-yaml
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 | - id: requirements-txt-fixer
10 | - id: check-added-large-files
11 | - id: check-toml
12 | # Isort (import sorting)
13 | - repo: https://github.com/PyCQA/isort
14 | rev: 5.13.2
15 | hooks:
16 | - id: isort
17 | args: [--profile, black]
18 | # Black (code formatting)
19 | - repo: https://github.com/psf/black
20 | rev: 24.4.2 # Replace by any tag/version: https://github.com/psf/black/tags
21 | hooks:
22 | - id: black
23 | language_version: python3.10
24 | # flake8 linter
25 | - repo: https://github.com/charliermarsh/ruff-pre-commit
26 | rev: 'v0.4.4'
27 | hooks:
28 | - id: ruff
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ComfyUI-LTXVideo
2 |
3 | ComfyUI-LTXVideo is a collection of custom nodes for ComfyUI, designed to provide useful tools for working with the LTXV model.
4 | The model itself is supported in the core ComfyUI [code](https://github.com/comfyanonymous/ComfyUI/tree/master/comfy/ldm/lightricks).
5 | The main LTXVideo repository can be found [here](https://github.com/Lightricks/LTX-Video).
6 |
7 | # ⭐ 14.05.2025 – LTXVideo 13B 0.9.7 Distilled Release ⭐
8 |
9 | ### 🚀 What's New in LTXVideo 13B 0.9.7 Distilled
10 | 1. **LTXV 13B Distilled 🥳 0.9.7**
11 | Delivers cinematic-quality videos at fraction of steps needed to run full model. Only 4 or 8 steps needed for single generation.
12 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
13 |
14 | 2. **LTXV 13B Distilled Quantized 0.9.7**
15 | Offers reduced memory requirements and even faster inference speeds.
16 | Ideal for consumer-grade GPUs (e.g., NVIDIA 4090, 5090).
17 | ***Important:*** In order to get the best performance with the quantized version please install [q8_kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) package and use dedicated flow below.
18 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors)
19 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section.
20 |
21 | 3. **Updated LTV 13B Quantized version**
22 | From now on all our 8bit quantized models are running natively in ComfyUI, still with our Q8 patcher node you will get the best inference speed.
23 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors)
24 | # ⭐ 06.05.2025 – LTXVideo 13B 0.9.7 Release ⭐
25 |
26 | ### 🚀 What's New in LTXVideo 13B 0.9.7
27 |
28 | 1. **LTXV 13B 0.9.7**
29 | Delivers cinematic-quality videos at unprecedented speed.
30 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
31 |
32 | 2. **LTXV 13B Quantized 0.9.7**
33 | Offers reduced memory requirements and even faster inference speeds.
34 | Ideal for consumer-grade GPUs (e.g., NVIDIA 4090, 5090).
35 | Delivers outstanding quality with improved performance.
36 | ***Important:*** In order to run the quantized version please install [LTXVideo-Q8-Kernels](https://github.com/Lightricks/LTXVideo-Q8-Kernels) package and use dedicated flow below. Loading the model in Comfy with LoadCheckpoint node won't work.
37 | 👉 [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors)
38 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section.
39 |
40 | 3. **Latent Upscaling Models**
41 | Enables inference across multiple scales by upscaling latent tensors without decoding/encoding.
42 | Multiscale inference delivers high-quality results in a fraction of the time compared to similar models.
43 | ***Important:*** Make sure you put the models below in **models/upscale_models** folder.
44 | 👉 Spatial upscaling: [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors).
45 | 👉 Temporal upscaling: [Download here](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors).
46 | 🧩 Example ComfyUI flow available in the [Example Workflows](#example-workflows) section.
47 |
48 |
49 | ### Technical Updates
50 |
51 | 1. ***New simplified flows and nodes***
52 | 1.1. Simplified image to video: [Download here](example_workflows/ltxv-13b-i2v-base.json).
53 | 1.2. Simplified image to video with extension: [Download here](example_workflows/ltxv-13b-i2v-extend.json).
54 | 1.3. Simplified image to video with keyframes: [Download here](example_workflows/ltxv-13b-i2v-keyframes.json).
55 |
56 | # 17.04.2025 ⭐ LTXVideo 0.9.6 Release ⭐
57 |
58 | ### LTXVideo 0.9.6 introduces:
59 |
60 | 1. LTXV 0.9.6 – higher quality, faster, great for final output. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltxv-2b-0.9.6-dev-04-25.safetensors).
61 | 2. LTXV 0.9.6 Distilled – our fastest model yet (only 8 steps for generation), lighter, great for rapid iteration. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltxv-2b-0.9.6-distilled-04-25.safetensors).
62 |
63 | ### Technical Updates
64 |
65 | We introduce the __STGGuiderAdvanced__ node, which applies different CFG and STG parameters at various diffusion steps. All flows have been updated to use this node and are designed to provide optimal parameters for the best quality.
66 | See the [Example Workflows](#example-workflows) section.
67 |
68 | # 5.03.2025 ⭐ LTXVideo 0.9.5 Release ⭐
69 |
70 | ### LTXVideo 0.9.5 introduces:
71 |
72 | 1. Improved quality with reduced artifacts.
73 | 2. Support for higher resolution and longer sequences.
74 | 3. Frame and sequence conditioning (beyond the first frame).
75 | 4. Enhanced prompt understanding.
76 | 5. Commercial license availability.
77 |
78 | ### Technical Updates
79 |
80 | Since LTXVideo is now fully supported in the ComfyUI core, we have removed the custom model implementation. Instead, we provide updated workflows to showcase the new features:
81 |
82 | 1. **Frame Conditioning** – Enables interpolation between given frames.
83 | 2. **Sequence Conditioning** – Allows motion interpolation from a given frame sequence, enabling video extension from the beginning, end, or middle of the original video.
84 | 3. **Prompt Enhancer** – A new node that helps generate prompts optimized for the best model performance.
85 | See the [Example Workflows](#example-workflows) section for more details.
86 |
87 | ### LTXTricks Update
88 |
89 | The LTXTricks code has been integrated into this repository (in the `/tricks` folder) and will be maintained here. The original [repo](https://github.com/logtd/ComfyUI-LTXTricks) is no longer maintained, but all existing workflows should continue to function as expected.
90 |
91 | ## 22.12.2024
92 |
93 | Fixed a bug which caused the model to produce artifacts on short negative prompts when using a native CLIP Loader node.
94 |
95 | ## 19.12.2024 ⭐ Update ⭐
96 |
97 | 1. Improved model - removes "strobing texture" artifacts and generates better motion. Download from [here](https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltx-video-2b-v0.9.1.safetensors).
98 | 2. STG support
99 | 3. Integrated image degradation system for improved motion generation.
100 | 4. Additional initial latent optional input to chain latents for high res generation.
101 | 5. Image captioning in image to video [flow](example_workflows/ltxvideo-i2v.json).
102 |
103 | ## Installation
104 |
105 | Installation via [ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) is preferred. Simply search for `ComfyUI-LTXVideo` in the list of nodes and follow installation instructions.
106 |
107 | ### Manual installation
108 |
109 | 1. Install ComfyUI
110 | 2. Clone this repository to `custom-nodes` folder in your ComfyUI installation directory.
111 | 3. Install the required packages:
112 |
113 | ```bash
114 | cd custom_nodes/ComfyUI-LTXVideo && pip install -r requirements.txt
115 | ```
116 |
117 | For portable ComfyUI installations, run
118 |
119 | ```
120 | .\python_embeded\python.exe -m pip install -r .\ComfyUI\custom_nodes\ComfyUI-LTXVideo\requirements.txt
121 | ```
122 |
123 | ### Models
124 |
125 | 1. Download [ltx-video-2b-v0.9.1.safetensors](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) from Hugging Face and place it under `models/checkpoints`.
126 | 2. Install one of the t5 text encoders, for example [google_t5-v1_1-xxl_encoderonly](https://huggingface.co/mcmonkey/google_t5-v1_1-xxl_encoderonly/tree/main). You can install it using ComfyUI Model Manager.
127 |
128 | ## Example workflows
129 |
130 | Note that to run the example workflows, you need to have some additional custom nodes, like [ComfyUI-VideoHelperSuite](https://github.com/kosinkadink/ComfyUI-VideoHelperSuite) and others, installed. You can do it by pressing "Install Missing Custom Nodes" button in ComfyUI Manager.
131 |
132 | ### Easy to use multi scale generation workflows
133 |
134 | 🧩 [Image to video mixed](example_workflows/ltxv13b-i2v-mixed-multiscale.json): mixed flow with full and distilled model for best quality and speed trade-off.
135 |
136 | ### 13B model
137 | 🧩 [Image to video](example_workflows/ltxv-13b-i2v-base.json)
138 | 🧩 [Image to video with keyframes](example_workflows/ltxv-13b-i2v-keyframes.json)
139 | 🧩 [Image to video with duration extension](example_workflows/ltxv-13b-i2v-extend.json)
140 | 🧩 [Image to video 8b quantized](example_workflows/ltxv-13b-i2v-base-fp8.json)
141 |
142 | ### 13B distilled model
143 | 🧩 [Image to video](example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json)
144 | 🧩 [Image to video with keyframes](example_workflows/13b-distilled/ltxv-13b-dist-i2v-keyframes.json)
145 | 🧩 [Image to video with duration extension](example_workflows/13b-distilled/ltxv-13b-dist-i2v-extend.json)
146 | 🧩 [Image to video 8b quantized](example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json)
147 |
148 | ### Inversion
149 |
150 | #### Flow Edit
151 |
152 | 🧩 [Download workflow](example_workflows/tricks/ltxvideo-flow-edit.json)
153 | 
154 |
155 | #### RF Edit
156 |
157 | 🧩 [Download workflow](example_workflows/tricks/ltxvideo-rf-edit.json)
158 | 
159 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder_noise import DecoderNoise
2 | from .easy_samplers import LTXVBaseSampler
3 | from .film_grain import LTXVFilmGrain
4 | from .guide import LTXVAddGuideAdvanced
5 | from .latent_adain import LTXVAdainLatent
6 | from .latent_upsampler import LTXVLatentUpsampler
7 | from .latents import LTXVSelectLatents, LTXVSetVideoLatentNoiseMasks
8 | from .nodes_registry import NODE_CLASS_MAPPINGS as RUNTIME_NODE_CLASS_MAPPINGS
9 | from .nodes_registry import (
10 | NODE_DISPLAY_NAME_MAPPINGS as RUNTIME_NODE_DISPLAY_NAME_MAPPINGS,
11 | )
12 | from .nodes_registry import NODES_DISPLAY_NAME_PREFIX, camel_case_to_spaces
13 | from .prompt_enhancer_nodes import LTXVPromptEnhancer, LTXVPromptEnhancerLoader
14 | from .q8_nodes import LTXVQ8Patch
15 | from .recurrent_sampler import LinearOverlapLatentTransition, LTXVRecurrentKSampler
16 | from .stg import (
17 | LTXVApplySTG,
18 | STGAdvancedPresetsNode,
19 | STGGuiderAdvancedNode,
20 | STGGuiderNode,
21 | )
22 | from .tiled_sampler import LTXVTiledSampler
23 | from .tricks import NODE_CLASS_MAPPINGS as TRICKS_NODE_CLASS_MAPPINGS
24 | from .tricks import NODE_DISPLAY_NAME_MAPPINGS as TRICKS_NODE_DISPLAY_NAME_MAPPINGS
25 |
26 | # Static node mappings, required for ComfyUI-Manager mapping to work
27 | NODE_CLASS_MAPPINGS = {
28 | "Set VAE Decoder Noise": DecoderNoise,
29 | "LinearOverlapLatentTransition": LinearOverlapLatentTransition,
30 | "LTXVAddGuideAdvanced": LTXVAddGuideAdvanced,
31 | "LTXVAdainLatent": LTXVAdainLatent,
32 | "LTXVApplySTG": LTXVApplySTG,
33 | "LTXVBaseSampler": LTXVBaseSampler,
34 | "LTXVFilmGrain": LTXVFilmGrain,
35 | "LTXVLatentUpsampler": LTXVLatentUpsampler,
36 | "LTXVPromptEnhancer": LTXVPromptEnhancer,
37 | "LTXVPromptEnhancerLoader": LTXVPromptEnhancerLoader,
38 | "LTXQ8Patch": LTXVQ8Patch,
39 | "LTXVRecurrentKSampler": LTXVRecurrentKSampler,
40 | "LTXVSelectLatents": LTXVSelectLatents,
41 | "LTXVSetVideoLatentNoiseMasks": LTXVSetVideoLatentNoiseMasks,
42 | "LTXVTiledSampler": LTXVTiledSampler,
43 | "STGAdvancedPresets": STGAdvancedPresetsNode,
44 | "STGGuiderAdvanced": STGGuiderAdvancedNode,
45 | "STGGuiderNode": STGGuiderNode,
46 | }
47 |
48 | # Consistent display names between static and dynamic node mappings in nodes_registry.py,
49 | # to prevent ComfyUI initializing them with default display names.
50 | NODE_DISPLAY_NAME_MAPPINGS = {
51 | name: f"{NODES_DISPLAY_NAME_PREFIX} {camel_case_to_spaces(name)}"
52 | for name in NODE_CLASS_MAPPINGS.keys()
53 | }
54 |
55 | # Merge the node mappings from tricks into the main mappings
56 | NODE_CLASS_MAPPINGS.update(TRICKS_NODE_CLASS_MAPPINGS)
57 | NODE_DISPLAY_NAME_MAPPINGS.update(TRICKS_NODE_DISPLAY_NAME_MAPPINGS)
58 |
59 | # Update with runtime mappings (these will override static mappings if there are any differences)
60 | NODE_CLASS_MAPPINGS.update(RUNTIME_NODE_CLASS_MAPPINGS)
61 | NODE_DISPLAY_NAME_MAPPINGS.update(RUNTIME_NODE_DISPLAY_NAME_MAPPINGS)
62 |
63 | # Export so that ComfyUI can pick them up.
64 | __all__ = [
65 | "NODE_CLASS_MAPPINGS",
66 | "NODE_DISPLAY_NAME_MAPPINGS",
67 | ]
68 |
--------------------------------------------------------------------------------
/decoder_noise.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 |
3 | from .nodes_registry import comfy_node
4 |
5 |
6 | @comfy_node(name="Set VAE Decoder Noise")
7 | class DecoderNoise:
8 | @classmethod
9 | def INPUT_TYPES(cls):
10 | return {
11 | "required": {
12 | "vae": ("VAE",),
13 | "timestep": (
14 | "FLOAT",
15 | {
16 | "default": 0.05,
17 | "min": 0.0,
18 | "max": 1.0,
19 | "step": 0.001,
20 | "tooltip": "The timestep used for decoding the noise.",
21 | },
22 | ),
23 | "scale": (
24 | "FLOAT",
25 | {
26 | "default": 0.025,
27 | "min": 0.0,
28 | "max": 1.0,
29 | "step": 0.001,
30 | "tooltip": "The scale of the noise added to the decoder.",
31 | },
32 | ),
33 | "seed": (
34 | "INT",
35 | {
36 | "default": 42,
37 | "min": 0,
38 | "max": 0xFFFFFFFFFFFFFFFF,
39 | "tooltip": "The random seed used for creating the noise.",
40 | },
41 | ),
42 | }
43 | }
44 |
45 | FUNCTION = "add_noise"
46 | RETURN_TYPES = ("VAE",)
47 | CATEGORY = "lightricks/LTXV"
48 |
49 | def add_noise(self, vae, timestep, scale, seed):
50 | result = copy(vae)
51 | if hasattr(result, "first_stage_model"):
52 | result.first_stage_model.decode_timestep = timestep
53 | result.first_stage_model.decode_noise_scale = scale
54 | result._decode_timestep = timestep
55 | result.decode_noise_scale = scale
56 | result.seed = seed
57 | return (result,)
58 |
--------------------------------------------------------------------------------
/easy_samplers.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import comfy
4 | import comfy_extras
5 | import nodes
6 | from comfy_extras.nodes_custom_sampler import SamplerCustomAdvanced
7 | from comfy_extras.nodes_lt import (
8 | EmptyLTXVLatentVideo,
9 | LTXVAddGuide,
10 | LTXVCropGuides,
11 | LTXVImgToVideo,
12 | )
13 |
14 | from .guide import blur_internal
15 | from .latents import LTXVAddLatentGuide, LTXVSelectLatents
16 | from .nodes_registry import comfy_node
17 | from .recurrent_sampler import LinearOverlapLatentTransition
18 |
19 |
20 | @comfy_node(
21 | name="LTXVBaseSampler",
22 | )
23 | class LTXVBaseSampler:
24 |
25 | @classmethod
26 | def INPUT_TYPES(s):
27 | return {
28 | "required": {
29 | "model": ("MODEL",),
30 | "vae": ("VAE",),
31 | "width": (
32 | "INT",
33 | {
34 | "default": 768,
35 | "min": 64,
36 | "max": nodes.MAX_RESOLUTION,
37 | "step": 32,
38 | },
39 | ),
40 | "height": (
41 | "INT",
42 | {
43 | "default": 512,
44 | "min": 64,
45 | "max": nodes.MAX_RESOLUTION,
46 | "step": 32,
47 | },
48 | ),
49 | "num_frames": (
50 | "INT",
51 | {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8},
52 | ),
53 | "guider": ("GUIDER",),
54 | "sampler": ("SAMPLER",),
55 | "sigmas": ("SIGMAS",),
56 | "noise": ("NOISE",),
57 | },
58 | "optional": {
59 | "optional_cond_images": ("IMAGE",),
60 | "optional_cond_indices": ("STRING",),
61 | "strength": ("FLOAT", {"default": 0.9, "min": 0, "max": 1}),
62 | "crop": (["center", "disabled"], {"default": "disabled"}),
63 | "crf": ("INT", {"default": 35, "min": 0, "max": 100}),
64 | "blur": ("INT", {"default": 0, "min": 0, "max": 10}),
65 | },
66 | }
67 |
68 | RETURN_TYPES = ("LATENT",)
69 | RETURN_NAMES = ("denoised_output",)
70 | FUNCTION = "sample"
71 | CATEGORY = "sampling"
72 |
73 | def sample(
74 | self,
75 | model,
76 | vae,
77 | width,
78 | height,
79 | num_frames,
80 | guider,
81 | sampler,
82 | sigmas,
83 | noise,
84 | optional_cond_images=None,
85 | optional_cond_indices=None,
86 | strength=0.9,
87 | crop="disabled",
88 | crf=35,
89 | blur=0,
90 | ):
91 |
92 | if optional_cond_images is not None:
93 | optional_cond_images = (
94 | comfy.utils.common_upscale(
95 | optional_cond_images.movedim(-1, 1),
96 | width,
97 | height,
98 | "bilinear",
99 | crop=crop,
100 | )
101 | .movedim(1, -1)
102 | .clamp(0, 1)
103 | )
104 | print("optional_cond_images shape", optional_cond_images.shape)
105 | optional_cond_images = comfy_extras.nodes_lt.LTXVPreprocess().preprocess(
106 | optional_cond_images, crf
107 | )[0]
108 | for i in range(optional_cond_images.shape[0]):
109 | optional_cond_images[i] = blur_internal(
110 | optional_cond_images[i].unsqueeze(0), blur
111 | )
112 |
113 | if optional_cond_indices is not None and optional_cond_images is not None:
114 | optional_cond_indices = optional_cond_indices.split(",")
115 | optional_cond_indices = [int(i) for i in optional_cond_indices]
116 | assert len(optional_cond_indices) == len(
117 | optional_cond_images
118 | ), "Number of optional cond images must match number of optional cond indices"
119 |
120 | try:
121 | positive, negative = guider.raw_conds
122 | except AttributeError:
123 | raise ValueError(
124 | "Guider does not have raw conds, cannot use it as a guider. "
125 | "Please use STGGuiderAdvanced."
126 | )
127 |
128 | if optional_cond_images is None:
129 | (latents,) = EmptyLTXVLatentVideo().generate(width, height, num_frames, 1)
130 | elif optional_cond_images.shape[0] == 1 and optional_cond_indices[0] == 0:
131 | (
132 | positive,
133 | negative,
134 | latents,
135 | ) = LTXVImgToVideo().generate(
136 | positive=positive,
137 | negative=negative,
138 | vae=vae,
139 | image=optional_cond_images[0].unsqueeze(0),
140 | width=width,
141 | height=height,
142 | length=num_frames,
143 | batch_size=1,
144 | strength=strength,
145 | )
146 | else:
147 | (latents,) = EmptyLTXVLatentVideo().generate(width, height, num_frames, 1)
148 | for cond_image, cond_idx in zip(
149 | optional_cond_images, optional_cond_indices
150 | ):
151 | (
152 | positive,
153 | negative,
154 | latents,
155 | ) = LTXVAddGuide().generate(
156 | positive=positive,
157 | negative=negative,
158 | vae=vae,
159 | latent=latents,
160 | image=cond_image.unsqueeze(0),
161 | frame_idx=cond_idx,
162 | strength=strength,
163 | )
164 |
165 | guider = copy.copy(guider)
166 | guider.set_conds(positive, negative)
167 |
168 | # Denoise the latent video
169 | (output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
170 | noise=noise,
171 | guider=guider,
172 | sampler=sampler,
173 | sigmas=sigmas,
174 | latent_image=latents,
175 | )
176 |
177 | # Clean up guides if image conditioning was used
178 | print("before guide crop", denoised_output_latents["samples"].shape)
179 | positive, negative, denoised_output_latents = LTXVCropGuides().crop(
180 | positive=positive,
181 | negative=negative,
182 | latent=denoised_output_latents,
183 | )
184 | print("after guide crop", denoised_output_latents["samples"].shape)
185 |
186 | return (denoised_output_latents,)
187 |
188 |
189 | @comfy_node(
190 | name="LTXVExtendSampler",
191 | )
192 | class LTXVExtendSampler:
193 |
194 | @classmethod
195 | def INPUT_TYPES(s):
196 | return {
197 | "required": {
198 | "model": ("MODEL",),
199 | "vae": ("VAE",),
200 | "latents": ("LATENT",),
201 | "num_new_frames": (
202 | "INT",
203 | {"default": 80, "min": 8, "max": nodes.MAX_RESOLUTION, "step": 8},
204 | ),
205 | "frame_overlap": (
206 | "INT",
207 | {"default": 16, "min": 16, "max": 128, "step": 8},
208 | ),
209 | "guider": ("GUIDER",),
210 | "sampler": ("SAMPLER",),
211 | "sigmas": ("SIGMAS",),
212 | "noise": ("NOISE",),
213 | },
214 | }
215 |
216 | RETURN_TYPES = ("LATENT",)
217 | RETURN_NAMES = ("denoised_output",)
218 | FUNCTION = "sample"
219 | CATEGORY = "sampling"
220 |
221 | def sample(
222 | self,
223 | model,
224 | vae,
225 | latents,
226 | num_new_frames,
227 | frame_overlap,
228 | guider,
229 | sampler,
230 | sigmas,
231 | noise,
232 | ):
233 |
234 | try:
235 | positive, negative = guider.raw_conds
236 | except AttributeError:
237 | raise ValueError(
238 | "Guider does not have raw conds, cannot use it as a guider. "
239 | "Please use STGGuiderAdvanced."
240 | )
241 |
242 | samples = latents["samples"]
243 | batch, channels, frames, height, width = samples.shape
244 | time_scale_factor, width_scale_factor, height_scale_factor = (
245 | vae.downscale_index_formula
246 | )
247 | overlap = frame_overlap // time_scale_factor
248 |
249 | (last_overlap_latents,) = LTXVSelectLatents().select_latents(
250 | latents, -overlap, -1
251 | )
252 |
253 | new_latents = EmptyLTXVLatentVideo().generate(
254 | width=width * width_scale_factor,
255 | height=height * height_scale_factor,
256 | length=overlap * time_scale_factor + num_new_frames,
257 | batch_size=1,
258 | )[0]
259 | print("new_latents shape: ", new_latents["samples"].shape)
260 | (
261 | positive,
262 | negative,
263 | new_latents,
264 | ) = LTXVAddLatentGuide().generate(
265 | vae=vae,
266 | positive=positive,
267 | negative=negative,
268 | latent=new_latents,
269 | guiding_latent=last_overlap_latents,
270 | latent_idx=0,
271 | strength=1.0,
272 | )
273 |
274 | guider = copy.copy(guider)
275 | guider.set_conds(positive, negative)
276 |
277 | # Denoise the latent video
278 | (output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
279 | noise=noise,
280 | guider=guider,
281 | sampler=sampler,
282 | sigmas=sigmas,
283 | latent_image=new_latents,
284 | )
285 |
286 | # Clean up guides if image conditioning was used
287 | print("before guide crop", denoised_output_latents["samples"].shape)
288 | positive, negative, denoised_output_latents = LTXVCropGuides().crop(
289 | positive=positive,
290 | negative=negative,
291 | latent=denoised_output_latents,
292 | )
293 | print("after guide crop", denoised_output_latents["samples"].shape)
294 |
295 | # drop first output latent as it's a reinterpreted 8-frame latent understood as a 1-frame latent
296 | truncated_denoised_output_latents = LTXVSelectLatents().select_latents(
297 | denoised_output_latents, 1, -1
298 | )[0]
299 | # Fuse new frames with old ones by calling LinearOverlapLatentTransition
300 | (latents,) = LinearOverlapLatentTransition().process(
301 | latents, truncated_denoised_output_latents, overlap - 1, axis=2
302 | )
303 | print("latents shape after linear overlap blend: ", latents["samples"].shape)
304 | return (latents,)
305 |
--------------------------------------------------------------------------------
/example_workflows/low_level/end.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/end.jpg
--------------------------------------------------------------------------------
/example_workflows/low_level/fox.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/fox.jpg
--------------------------------------------------------------------------------
/example_workflows/low_level/jeep.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5b8ac2148bdf1092f4342eb4f08961458a9e22743e8896f5dbcf749166f4f526
3 | size 1582922
4 |
--------------------------------------------------------------------------------
/example_workflows/low_level/ltxvideo-i2v-distilled.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "d8be2e45-0fd2-4f4a-8065-469fdf4bc018",
3 | "revision": 0,
4 | "last_node_id": 109,
5 | "last_link_id": 268,
6 | "nodes": [
7 | {
8 | "id": 38,
9 | "type": "CLIPLoader",
10 | "pos": [
11 | 60,
12 | 190
13 | ],
14 | "size": [
15 | 315,
16 | 98
17 | ],
18 | "flags": {},
19 | "order": 0,
20 | "mode": 0,
21 | "inputs": [],
22 | "outputs": [
23 | {
24 | "name": "CLIP",
25 | "type": "CLIP",
26 | "slot_index": 0,
27 | "links": [
28 | 74,
29 | 75
30 | ]
31 | }
32 | ],
33 | "properties": {
34 | "cnr_id": "comfy-core",
35 | "ver": "0.3.28",
36 | "Node name for S&R": "CLIPLoader"
37 | },
38 | "widgets_values": [
39 | "t5xxl_fp16.safetensors",
40 | "ltxv",
41 | "default"
42 | ]
43 | },
44 | {
45 | "id": 8,
46 | "type": "VAEDecode",
47 | "pos": [
48 | 1800,
49 | 430
50 | ],
51 | "size": [
52 | 210,
53 | 46
54 | ],
55 | "flags": {},
56 | "order": 16,
57 | "mode": 0,
58 | "inputs": [
59 | {
60 | "name": "samples",
61 | "type": "LATENT",
62 | "link": 255
63 | },
64 | {
65 | "name": "vae",
66 | "type": "VAE",
67 | "link": 87
68 | }
69 | ],
70 | "outputs": [
71 | {
72 | "name": "IMAGE",
73 | "type": "IMAGE",
74 | "slot_index": 0,
75 | "links": [
76 | 261
77 | ]
78 | }
79 | ],
80 | "properties": {
81 | "cnr_id": "comfy-core",
82 | "ver": "0.3.28",
83 | "Node name for S&R": "VAEDecode"
84 | },
85 | "widgets_values": []
86 | },
87 | {
88 | "id": 103,
89 | "type": "VHS_VideoCombine",
90 | "pos": [
91 | 1789.6087646484375,
92 | 548.7745361328125
93 | ],
94 | "size": [
95 | 315,
96 | 545
97 | ],
98 | "flags": {},
99 | "order": 17,
100 | "mode": 0,
101 | "inputs": [
102 | {
103 | "name": "images",
104 | "shape": 7,
105 | "type": "IMAGE",
106 | "link": 261
107 | },
108 | {
109 | "name": "audio",
110 | "shape": 7,
111 | "type": "AUDIO",
112 | "link": null
113 | },
114 | {
115 | "name": "meta_batch",
116 | "shape": 7,
117 | "type": "VHS_BatchManager",
118 | "link": null
119 | },
120 | {
121 | "name": "vae",
122 | "shape": 7,
123 | "type": "VAE",
124 | "link": null
125 | }
126 | ],
127 | "outputs": [
128 | {
129 | "name": "Filenames",
130 | "type": "VHS_FILENAMES",
131 | "links": null
132 | }
133 | ],
134 | "properties": {
135 | "cnr_id": "comfyui-videohelpersuite",
136 | "ver": "972c87da577b47211c4e9aeed30dc38c7bae607f",
137 | "Node name for S&R": "VHS_VideoCombine"
138 | },
139 | "widgets_values": {
140 | "frame_rate": 24,
141 | "loop_count": 0,
142 | "filename_prefix": "ltxv",
143 | "format": "video/h264-mp4",
144 | "pix_fmt": "yuv420p",
145 | "crf": 20,
146 | "save_metadata": false,
147 | "trim_to_audio": false,
148 | "pingpong": false,
149 | "save_output": true,
150 | "videopreview": {
151 | "hidden": false,
152 | "paused": false,
153 | "params": {
154 | "filename": "ltxv_00058.mp4",
155 | "subfolder": "",
156 | "type": "output",
157 | "format": "video/h264-mp4",
158 | "frame_rate": 24,
159 | "workflow": "ltxv_00058.png",
160 | "fullpath": "C:\\Users\\Zeev\\Projects\\ComfyUI\\output\\ltxv_00058.mp4"
161 | }
162 | }
163 | }
164 | },
165 | {
166 | "id": 82,
167 | "type": "LTXVPreprocess",
168 | "pos": [
169 | 488.92791748046875,
170 | 629.9364624023438
171 | ],
172 | "size": [
173 | 275.9266662597656,
174 | 58
175 | ],
176 | "flags": {},
177 | "order": 10,
178 | "mode": 0,
179 | "inputs": [
180 | {
181 | "name": "image",
182 | "type": "IMAGE",
183 | "link": 226
184 | }
185 | ],
186 | "outputs": [
187 | {
188 | "name": "output_image",
189 | "type": "IMAGE",
190 | "slot_index": 0,
191 | "links": [
192 | 248
193 | ]
194 | }
195 | ],
196 | "properties": {
197 | "cnr_id": "comfy-core",
198 | "ver": "0.3.28",
199 | "Node name for S&R": "LTXVPreprocess"
200 | },
201 | "widgets_values": [
202 | 38
203 | ]
204 | },
205 | {
206 | "id": 7,
207 | "type": "CLIPTextEncode",
208 | "pos": [
209 | 420,
210 | 390
211 | ],
212 | "size": [
213 | 425.27801513671875,
214 | 180.6060791015625
215 | ],
216 | "flags": {},
217 | "order": 9,
218 | "mode": 0,
219 | "inputs": [
220 | {
221 | "name": "clip",
222 | "type": "CLIP",
223 | "link": 75
224 | }
225 | ],
226 | "outputs": [
227 | {
228 | "name": "CONDITIONING",
229 | "type": "CONDITIONING",
230 | "slot_index": 0,
231 | "links": [
232 | 240
233 | ]
234 | }
235 | ],
236 | "title": "CLIP Text Encode (Negative Prompt)",
237 | "properties": {
238 | "cnr_id": "comfy-core",
239 | "ver": "0.3.28",
240 | "Node name for S&R": "CLIPTextEncode"
241 | },
242 | "widgets_values": [
243 | "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly"
244 | ],
245 | "color": "#322",
246 | "bgcolor": "#533"
247 | },
248 | {
249 | "id": 78,
250 | "type": "LoadImage",
251 | "pos": [
252 | 40,
253 | 630
254 | ],
255 | "size": [
256 | 385.15606689453125,
257 | 333.3305358886719
258 | ],
259 | "flags": {},
260 | "order": 1,
261 | "mode": 0,
262 | "inputs": [],
263 | "outputs": [
264 | {
265 | "name": "IMAGE",
266 | "type": "IMAGE",
267 | "slot_index": 0,
268 | "links": [
269 | 226
270 | ]
271 | },
272 | {
273 | "name": "MASK",
274 | "type": "MASK",
275 | "links": null
276 | }
277 | ],
278 | "properties": {
279 | "cnr_id": "comfy-core",
280 | "ver": "0.3.28",
281 | "Node name for S&R": "LoadImage"
282 | },
283 | "widgets_values": [
284 | "fox.jpg",
285 | "image",
286 | ""
287 | ]
288 | },
289 | {
290 | "id": 6,
291 | "type": "CLIPTextEncode",
292 | "pos": [
293 | 420,
294 | 180
295 | ],
296 | "size": [
297 | 422.84503173828125,
298 | 164.31304931640625
299 | ],
300 | "flags": {},
301 | "order": 8,
302 | "mode": 0,
303 | "inputs": [
304 | {
305 | "name": "clip",
306 | "type": "CLIP",
307 | "link": 74
308 | }
309 | ],
310 | "outputs": [
311 | {
312 | "name": "CONDITIONING",
313 | "type": "CONDITIONING",
314 | "slot_index": 0,
315 | "links": [
316 | 239
317 | ]
318 | }
319 | ],
320 | "title": "CLIP Text Encode (Positive Prompt)",
321 | "properties": {
322 | "cnr_id": "comfy-core",
323 | "ver": "0.3.28",
324 | "Node name for S&R": "CLIPTextEncode"
325 | },
326 | "widgets_values": [
327 | "A red fox moving gracefully, its russet coat vibrant against the white landscape, leaving perfect star-shaped prints behind as steam rises from its breath in the crisp winter air. The scene is wrapped in snow-muffled silence, broken only by the gentle murmur of water still flowing beneath the ice."
328 | ],
329 | "color": "#232",
330 | "bgcolor": "#353"
331 | },
332 | {
333 | "id": 76,
334 | "type": "Note",
335 | "pos": [
336 | 36.85359573364258,
337 | 360.48809814453125
338 | ],
339 | "size": [
340 | 360,
341 | 200
342 | ],
343 | "flags": {},
344 | "order": 2,
345 | "mode": 0,
346 | "inputs": [],
347 | "outputs": [],
348 | "properties": {},
349 | "widgets_values": [
350 | "While LTXV-2b model prefers long descriptive prompts, this version supports experimentation with broader prompting styles."
351 | ],
352 | "color": "#432",
353 | "bgcolor": "#653"
354 | },
355 | {
356 | "id": 69,
357 | "type": "LTXVConditioning",
358 | "pos": [
359 | 1183.358642578125,
360 | 292.5882873535156
361 | ],
362 | "size": [
363 | 223.8660125732422,
364 | 78
365 | ],
366 | "flags": {},
367 | "order": 13,
368 | "mode": 0,
369 | "inputs": [
370 | {
371 | "name": "positive",
372 | "type": "CONDITIONING",
373 | "link": 245
374 | },
375 | {
376 | "name": "negative",
377 | "type": "CONDITIONING",
378 | "link": 246
379 | }
380 | ],
381 | "outputs": [
382 | {
383 | "name": "positive",
384 | "type": "CONDITIONING",
385 | "slot_index": 0,
386 | "links": [
387 | 256
388 | ]
389 | },
390 | {
391 | "name": "negative",
392 | "type": "CONDITIONING",
393 | "slot_index": 1,
394 | "links": [
395 | 257
396 | ]
397 | }
398 | ],
399 | "properties": {
400 | "cnr_id": "comfy-core",
401 | "ver": "0.3.28",
402 | "Node name for S&R": "LTXVConditioning"
403 | },
404 | "widgets_values": [
405 | 24.000000000000004
406 | ]
407 | },
408 | {
409 | "id": 101,
410 | "type": "SamplerCustomAdvanced",
411 | "pos": [
412 | 1780,
413 | 270
414 | ],
415 | "size": [
416 | 236.8000030517578,
417 | 106
418 | ],
419 | "flags": {},
420 | "order": 15,
421 | "mode": 0,
422 | "inputs": [
423 | {
424 | "name": "noise",
425 | "type": "NOISE",
426 | "link": 260
427 | },
428 | {
429 | "name": "guider",
430 | "type": "GUIDER",
431 | "link": 252
432 | },
433 | {
434 | "name": "sampler",
435 | "type": "SAMPLER",
436 | "link": 253
437 | },
438 | {
439 | "name": "sigmas",
440 | "type": "SIGMAS",
441 | "link": 268
442 | },
443 | {
444 | "name": "latent_image",
445 | "type": "LATENT",
446 | "link": 259
447 | }
448 | ],
449 | "outputs": [
450 | {
451 | "name": "output",
452 | "type": "LATENT",
453 | "slot_index": 0,
454 | "links": []
455 | },
456 | {
457 | "name": "denoised_output",
458 | "type": "LATENT",
459 | "slot_index": 1,
460 | "links": [
461 | 255
462 | ]
463 | }
464 | ],
465 | "properties": {
466 | "cnr_id": "comfy-core",
467 | "ver": "0.3.15",
468 | "Node name for S&R": "SamplerCustomAdvanced"
469 | },
470 | "widgets_values": []
471 | },
472 | {
473 | "id": 95,
474 | "type": "LTXVImgToVideo",
475 | "pos": [
476 | 900,
477 | 290
478 | ],
479 | "size": [
480 | 210,
481 | 190
482 | ],
483 | "flags": {},
484 | "order": 12,
485 | "mode": 0,
486 | "inputs": [
487 | {
488 | "name": "positive",
489 | "type": "CONDITIONING",
490 | "link": 239
491 | },
492 | {
493 | "name": "negative",
494 | "type": "CONDITIONING",
495 | "link": 240
496 | },
497 | {
498 | "name": "vae",
499 | "type": "VAE",
500 | "link": 250
501 | },
502 | {
503 | "name": "image",
504 | "type": "IMAGE",
505 | "link": 248
506 | }
507 | ],
508 | "outputs": [
509 | {
510 | "name": "positive",
511 | "type": "CONDITIONING",
512 | "slot_index": 0,
513 | "links": [
514 | 245
515 | ]
516 | },
517 | {
518 | "name": "negative",
519 | "type": "CONDITIONING",
520 | "slot_index": 1,
521 | "links": [
522 | 246
523 | ]
524 | },
525 | {
526 | "name": "latent",
527 | "type": "LATENT",
528 | "slot_index": 2,
529 | "links": [
530 | 259
531 | ]
532 | }
533 | ],
534 | "properties": {
535 | "cnr_id": "comfy-core",
536 | "ver": "0.3.28",
537 | "Node name for S&R": "LTXVImgToVideo"
538 | },
539 | "widgets_values": [
540 | 768,
541 | 512,
542 | 97,
543 | 1
544 | ]
545 | },
546 | {
547 | "id": 100,
548 | "type": "StringToFloatList",
549 | "pos": [
550 | 1116.0089111328125,
551 | 604.5989379882812
552 | ],
553 | "size": [
554 | 395.74224853515625,
555 | 88
556 | ],
557 | "flags": {},
558 | "order": 3,
559 | "mode": 0,
560 | "inputs": [],
561 | "outputs": [
562 | {
563 | "name": "FLOAT",
564 | "type": "FLOAT",
565 | "links": [
566 | 251
567 | ]
568 | }
569 | ],
570 | "properties": {
571 | "cnr_id": "comfyui-kjnodes",
572 | "ver": "0addfc6101f7a834c7fb6e0a1b26529360ab5350",
573 | "Node name for S&R": "StringToFloatList"
574 | },
575 | "widgets_values": [
576 | "1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0"
577 | ],
578 | "color": "#223",
579 | "bgcolor": "#335"
580 | },
581 | {
582 | "id": 105,
583 | "type": "Note",
584 | "pos": [
585 | 1126.7423095703125,
586 | 424.2294616699219
587 | ],
588 | "size": [
589 | 335.8657531738281,
590 | 106.6832046508789
591 | ],
592 | "flags": {},
593 | "order": 4,
594 | "mode": 0,
595 | "inputs": [],
596 | "outputs": [],
597 | "properties": {},
598 | "widgets_values": [
599 | "Distilled model expects the following sigma schedule:\n1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250, 0.4219, 0.0\n\n\nEuler_ancestral is the recommended default sampler."
600 | ],
601 | "color": "#432",
602 | "bgcolor": "#653"
603 | },
604 | {
605 | "id": 97,
606 | "type": "KSamplerSelect",
607 | "pos": [
608 | 1477.55615234375,
609 | 427.04132080078125
610 | ],
611 | "size": [
612 | 275.599365234375,
613 | 58
614 | ],
615 | "flags": {},
616 | "order": 5,
617 | "mode": 0,
618 | "inputs": [],
619 | "outputs": [
620 | {
621 | "name": "SAMPLER",
622 | "type": "SAMPLER",
623 | "slot_index": 0,
624 | "links": [
625 | 253
626 | ]
627 | }
628 | ],
629 | "properties": {
630 | "cnr_id": "comfy-core",
631 | "ver": "0.3.15",
632 | "Node name for S&R": "KSamplerSelect"
633 | },
634 | "widgets_values": [
635 | "euler_ancestral"
636 | ]
637 | },
638 | {
639 | "id": 99,
640 | "type": "CFGGuider",
641 | "pos": [
642 | 1536.1513671875,
643 | 280
644 | ],
645 | "size": [
646 | 210,
647 | 98
648 | ],
649 | "flags": {},
650 | "order": 14,
651 | "mode": 0,
652 | "inputs": [
653 | {
654 | "name": "model",
655 | "type": "MODEL",
656 | "link": 258
657 | },
658 | {
659 | "name": "positive",
660 | "type": "CONDITIONING",
661 | "link": 256
662 | },
663 | {
664 | "name": "negative",
665 | "type": "CONDITIONING",
666 | "link": 257
667 | }
668 | ],
669 | "outputs": [
670 | {
671 | "name": "GUIDER",
672 | "type": "GUIDER",
673 | "slot_index": 0,
674 | "links": [
675 | 252
676 | ]
677 | }
678 | ],
679 | "properties": {
680 | "cnr_id": "comfy-core",
681 | "ver": "0.3.26",
682 | "Node name for S&R": "CFGGuider"
683 | },
684 | "widgets_values": [
685 | 1
686 | ]
687 | },
688 | {
689 | "id": 102,
690 | "type": "RandomNoise",
691 | "pos": [
692 | 1524.6190185546875,
693 | 138.8462371826172
694 | ],
695 | "size": [
696 | 210,
697 | 82
698 | ],
699 | "flags": {},
700 | "order": 6,
701 | "mode": 0,
702 | "inputs": [],
703 | "outputs": [
704 | {
705 | "name": "NOISE",
706 | "type": "NOISE",
707 | "links": [
708 | 260
709 | ]
710 | }
711 | ],
712 | "properties": {
713 | "cnr_id": "comfy-core",
714 | "ver": "0.3.28",
715 | "Node name for S&R": "RandomNoise"
716 | },
717 | "widgets_values": [
718 | 45,
719 | "fixed"
720 | ]
721 | },
722 | {
723 | "id": 98,
724 | "type": "FloatToSigmas",
725 | "pos": [
726 | 1574.70166015625,
727 | 620.4625244140625
728 | ],
729 | "size": [
730 | 210,
731 | 58
732 | ],
733 | "flags": {
734 | "collapsed": true
735 | },
736 | "order": 11,
737 | "mode": 0,
738 | "inputs": [
739 | {
740 | "name": "float_list",
741 | "type": "FLOAT",
742 | "widget": {
743 | "name": "float_list"
744 | },
745 | "link": null
746 | },
747 | {
748 | "name": "float_list",
749 | "type": "FLOAT",
750 | "widget": {
751 | "name": "float_list"
752 | },
753 | "link": 251
754 | }
755 | ],
756 | "outputs": [
757 | {
758 | "name": "SIGMAS",
759 | "type": "SIGMAS",
760 | "links": [
761 | 268
762 | ]
763 | }
764 | ],
765 | "properties": {
766 | "cnr_id": "comfyui-kjnodes",
767 | "ver": "0addfc6101f7a834c7fb6e0a1b26529360ab5350",
768 | "Node name for S&R": "FloatToSigmas"
769 | },
770 | "widgets_values": [
771 | 0
772 | ]
773 | },
774 | {
775 | "id": 44,
776 | "type": "CheckpointLoaderSimple",
777 | "pos": [
778 | 899.0839233398438,
779 | 113.54173278808594
780 | ],
781 | "size": [
782 | 315,
783 | 98
784 | ],
785 | "flags": {},
786 | "order": 7,
787 | "mode": 0,
788 | "inputs": [],
789 | "outputs": [
790 | {
791 | "name": "MODEL",
792 | "type": "MODEL",
793 | "slot_index": 0,
794 | "links": [
795 | 258
796 | ]
797 | },
798 | {
799 | "name": "CLIP",
800 | "type": "CLIP",
801 | "links": null
802 | },
803 | {
804 | "name": "VAE",
805 | "type": "VAE",
806 | "slot_index": 2,
807 | "links": [
808 | 87,
809 | 250
810 | ]
811 | }
812 | ],
813 | "properties": {
814 | "cnr_id": "comfy-core",
815 | "ver": "0.3.28",
816 | "Node name for S&R": "CheckpointLoaderSimple"
817 | },
818 | "widgets_values": [
819 | "ltxv-2b-0.9.6-distilled-04-25.safetensors"
820 | ]
821 | }
822 | ],
823 | "links": [
824 | [
825 | 74,
826 | 38,
827 | 0,
828 | 6,
829 | 0,
830 | "CLIP"
831 | ],
832 | [
833 | 75,
834 | 38,
835 | 0,
836 | 7,
837 | 0,
838 | "CLIP"
839 | ],
840 | [
841 | 87,
842 | 44,
843 | 2,
844 | 8,
845 | 1,
846 | "VAE"
847 | ],
848 | [
849 | 226,
850 | 78,
851 | 0,
852 | 82,
853 | 0,
854 | "IMAGE"
855 | ],
856 | [
857 | 239,
858 | 6,
859 | 0,
860 | 95,
861 | 0,
862 | "CONDITIONING"
863 | ],
864 | [
865 | 240,
866 | 7,
867 | 0,
868 | 95,
869 | 1,
870 | "CONDITIONING"
871 | ],
872 | [
873 | 245,
874 | 95,
875 | 0,
876 | 69,
877 | 0,
878 | "CONDITIONING"
879 | ],
880 | [
881 | 246,
882 | 95,
883 | 1,
884 | 69,
885 | 1,
886 | "CONDITIONING"
887 | ],
888 | [
889 | 248,
890 | 82,
891 | 0,
892 | 95,
893 | 3,
894 | "IMAGE"
895 | ],
896 | [
897 | 250,
898 | 44,
899 | 2,
900 | 95,
901 | 2,
902 | "VAE"
903 | ],
904 | [
905 | 251,
906 | 100,
907 | 0,
908 | 98,
909 | 1,
910 | "FLOAT"
911 | ],
912 | [
913 | 252,
914 | 99,
915 | 0,
916 | 101,
917 | 1,
918 | "GUIDER"
919 | ],
920 | [
921 | 253,
922 | 97,
923 | 0,
924 | 101,
925 | 2,
926 | "SAMPLER"
927 | ],
928 | [
929 | 255,
930 | 101,
931 | 1,
932 | 8,
933 | 0,
934 | "LATENT"
935 | ],
936 | [
937 | 256,
938 | 69,
939 | 0,
940 | 99,
941 | 1,
942 | "CONDITIONING"
943 | ],
944 | [
945 | 257,
946 | 69,
947 | 1,
948 | 99,
949 | 2,
950 | "CONDITIONING"
951 | ],
952 | [
953 | 258,
954 | 44,
955 | 0,
956 | 99,
957 | 0,
958 | "MODEL"
959 | ],
960 | [
961 | 259,
962 | 95,
963 | 2,
964 | 101,
965 | 4,
966 | "LATENT"
967 | ],
968 | [
969 | 260,
970 | 102,
971 | 0,
972 | 101,
973 | 0,
974 | "NOISE"
975 | ],
976 | [
977 | 261,
978 | 8,
979 | 0,
980 | 103,
981 | 0,
982 | "IMAGE"
983 | ],
984 | [
985 | 268,
986 | 98,
987 | 0,
988 | 101,
989 | 3,
990 | "SIGMAS"
991 | ]
992 | ],
993 | "groups": [],
994 | "config": {},
995 | "extra": {
996 | "ds": {
997 | "scale": 0.6303940863129158,
998 | "offset": [
999 | 627.9092558924768,
1000 | 478.72279841993054
1001 | ]
1002 | },
1003 | "VHS_latentpreview": true,
1004 | "VHS_latentpreviewrate": 0,
1005 | "VHS_MetadataImage": true,
1006 | "VHS_KeepIntermediate": true
1007 | },
1008 | "version": 0.4
1009 | }
1010 |
--------------------------------------------------------------------------------
/example_workflows/low_level/moto.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:8681ff86edb43d0a12472175d5503a590910b24d5ec5d5d5b78eed079960eac0
3 | size 998115
4 |
--------------------------------------------------------------------------------
/example_workflows/low_level/shrek2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/shrek2.jpg
--------------------------------------------------------------------------------
/example_workflows/low_level/start.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/low_level/start.jpg
--------------------------------------------------------------------------------
/example_workflows/tricks/ltxvideo-flow-edit.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:041f3820e0ab4109137d743b3f6b8e2dea0cddcad194aac08d3baec7571200f9
3 | size 1841100
4 |
--------------------------------------------------------------------------------
/example_workflows/tricks/ltxvideo-rf-edit.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:9d2b002f12155d0b929366ad38561680a5bbaecea1925e5655658b46d281baf9
3 | size 2441562
4 |
--------------------------------------------------------------------------------
/example_workflows/tricks/ref.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:3726ed083a5c52873d29adab87e8fe0a65c7e390beaf8dfcf23e5c9091d6da73
3 | size 521574
4 |
--------------------------------------------------------------------------------
/example_workflows/tricks/shot.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:01d44bb728e04879a5513a2c33bdd47d443bf901b33d3c1445918e75ba8b78b1
3 | size 678605
4 |
--------------------------------------------------------------------------------
/example_workflows/tricks/shot2.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5432150d5477d19586b8d8049ca214038f886d3def9117830cf2af58879c48b5
3 | size 31584408
4 |
--------------------------------------------------------------------------------
/example_workflows/tricks/shrek2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/tricks/shrek2.jpg
--------------------------------------------------------------------------------
/example_workflows/tricks/shrek3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/ComfyUI-LTXVideo/6e9e6de05624b0aab09b81a2f4a5f473fa97988a/example_workflows/tricks/shrek3.jpg
--------------------------------------------------------------------------------
/film_grain.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import comfy
4 | import torch
5 |
6 | from .nodes_registry import comfy_node
7 |
8 |
9 | @comfy_node(
10 | name="LTXVFilmGrain",
11 | )
12 | class LTXVFilmGrain:
13 | @classmethod
14 | def INPUT_TYPES(s):
15 | return {
16 | "required": {
17 | "images": ("IMAGE",),
18 | "grain_intensity": (
19 | "FLOAT",
20 | {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01},
21 | ),
22 | "saturation": (
23 | "FLOAT",
24 | {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01},
25 | ),
26 | },
27 | "optional": {},
28 | }
29 |
30 | RETURN_TYPES = ("IMAGE",)
31 | FUNCTION = "add_film_grain"
32 | CATEGORY = "effects"
33 | DESCRIPTION = "Adds film grain to the image."
34 |
35 | def add_film_grain(
36 | self, images: torch.Tensor, grain_intensity: float, saturation: float
37 | ) -> Tuple[torch.Tensor]:
38 | if grain_intensity < 0 or grain_intensity > 1:
39 | raise ValueError("Grain intensity must be between 0 and 1.")
40 |
41 | device = comfy.model_management.get_torch_device()
42 | images = images.to(device)
43 |
44 | grain = torch.randn_like(images, device=device)
45 | grain[:, :, :, 0] *= 2
46 | grain[:, :, :, 2] *= 3
47 | grain = grain * saturation + grain[:, :, :, 1].unsqueeze(3).repeat(
48 | 1, 1, 1, 3
49 | ) * (1 - saturation)
50 |
51 | # Blend the grain with the image
52 | noised_images = images + grain_intensity * grain
53 | noised_images.clamp_(0, 1)
54 | noised_images = noised_images.to(comfy.model_management.intermediate_device())
55 |
56 | return (noised_images,)
57 |
--------------------------------------------------------------------------------
/guide.py:
--------------------------------------------------------------------------------
1 | import comfy
2 | import comfy_extras.nodes_lt as nodes_lt
3 | import comfy_extras.nodes_post_processing as post_processing
4 | import nodes
5 |
6 | from .nodes_registry import comfy_node
7 |
8 |
9 | def blur_internal(image, blur_radius):
10 | if blur_radius > 0:
11 | # https://docs.opencv.org/2.4/modules/imgproc/doc/filtering.html#getgaussiankernel
12 | # sigma = 0.3 * blur_radius + 0.5 is what is recommended in the OpenCV doc for the
13 | # relationship between sigma and kernel size 2*blur_radius + 1, however we want somewhat weaker
14 | # blurring, so we use 0.3 * blur_radius instead, reducing the sigma value by 0.5
15 | sigma = 0.3 * blur_radius
16 | image = post_processing.Blur().blur(image, blur_radius, sigma)[0]
17 | return image
18 |
19 |
20 | @comfy_node(name="LTXVAddGuideAdvanced")
21 | class LTXVAddGuideAdvanced:
22 | @classmethod
23 | def INPUT_TYPES(s):
24 | return {
25 | "required": {
26 | "positive": ("CONDITIONING",),
27 | "negative": ("CONDITIONING",),
28 | "vae": ("VAE",),
29 | "latent": ("LATENT",),
30 | "image": ("IMAGE",),
31 | "frame_idx": (
32 | "INT",
33 | {
34 | "default": 0,
35 | "min": -9999,
36 | "max": 9999,
37 | "tooltip": "Frame index to start the conditioning at. For single-frame images or "
38 | "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ "
39 | "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to "
40 | "the nearest multiple of 8. Negative values are counted from the end of the video.",
41 | },
42 | ),
43 | "strength": (
44 | "FLOAT",
45 | {
46 | "default": 1.0,
47 | "min": 0.0,
48 | "max": 1.0,
49 | "tooltip": "Strength of the conditioning. Higher values will make the conditioning more exact.",
50 | },
51 | ),
52 | "crf": (
53 | "INT",
54 | {
55 | "default": 29,
56 | "min": 0,
57 | "max": 51,
58 | "step": 1,
59 | "tooltip": "CRF value for the video. Higher values mean more motion, lower values mean higher quality.",
60 | },
61 | ),
62 | "blur_radius": (
63 | "INT",
64 | {
65 | "default": 0,
66 | "min": 0,
67 | "max": 7,
68 | "step": 1,
69 | "tooltip": "Blur kernel radius size. Higher values mean more motion, lower values mean higher quality.",
70 | },
71 | ),
72 | "interpolation": (
73 | [
74 | "lanczos",
75 | "bislerp",
76 | "nearest",
77 | "bilinear",
78 | "bicubic",
79 | "area",
80 | "nearest-exact",
81 | ],
82 | {"default": "lanczos"},
83 | ),
84 | "crop": (["center", "disabled"], {"default": "disabled"}),
85 | }
86 | }
87 |
88 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
89 | RETURN_NAMES = ("positive", "negative", "latent")
90 |
91 | CATEGORY = "conditioning/video_models"
92 | FUNCTION = "generate"
93 |
94 | DESCRIPTION = (
95 | "Adds a conditioning frame or a video at a specific frame index. "
96 | "This node is used to add a keyframe or a video segment which should appear in the "
97 | "generated video at a specified index. It resizes the image to the correct size and "
98 | "applies preprocessing to it."
99 | )
100 |
101 | def generate(
102 | self,
103 | positive,
104 | negative,
105 | vae,
106 | latent,
107 | image,
108 | frame_idx,
109 | strength,
110 | crf,
111 | blur_radius,
112 | interpolation,
113 | crop,
114 | ):
115 | _, width_scale_factor, height_scale_factor = vae.downscale_index_formula
116 | width, height = (
117 | latent["samples"].shape[4] * width_scale_factor,
118 | latent["samples"].shape[3] * height_scale_factor,
119 | )
120 | image = (
121 | comfy.utils.common_upscale(
122 | image.movedim(-1, 1), width, height, interpolation, crop=crop
123 | )
124 | .movedim(1, -1)
125 | .clamp(0, 1)
126 | )
127 | image = nodes_lt.LTXVPreprocess().preprocess(image, crf)[0]
128 | image = blur_internal(image, blur_radius)
129 | return nodes_lt.LTXVAddGuide().generate(
130 | positive=positive,
131 | negative=negative,
132 | vae=vae,
133 | latent=latent,
134 | image=image,
135 | frame_idx=frame_idx,
136 | strength=strength,
137 | )
138 |
139 |
140 | @comfy_node(name="LTXVImgToVideoAdvanced")
141 | class LTXVImgToVideoAdvanced:
142 | @classmethod
143 | def INPUT_TYPES(s):
144 | return {
145 | "required": {
146 | "positive": ("CONDITIONING",),
147 | "negative": ("CONDITIONING",),
148 | "vae": ("VAE",),
149 | "image": ("IMAGE",),
150 | "width": (
151 | "INT",
152 | {
153 | "default": 768,
154 | "min": 64,
155 | "max": nodes.MAX_RESOLUTION,
156 | "step": 32,
157 | },
158 | ),
159 | "height": (
160 | "INT",
161 | {
162 | "default": 512,
163 | "min": 64,
164 | "max": nodes.MAX_RESOLUTION,
165 | "step": 32,
166 | },
167 | ),
168 | "length": (
169 | "INT",
170 | {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8},
171 | ),
172 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
173 | "crf": (
174 | "INT",
175 | {
176 | "default": 29,
177 | "min": 0,
178 | "max": 51,
179 | "step": 1,
180 | "tooltip": "CRF value for the video. Higher values mean more motion, lower values mean higher quality.",
181 | },
182 | ),
183 | "blur_radius": (
184 | "INT",
185 | {
186 | "default": 0,
187 | "min": 0,
188 | "max": 7,
189 | "step": 1,
190 | "tooltip": "Blur kernel radius size. Higher values mean more motion, lower values mean higher quality.",
191 | },
192 | ),
193 | "interpolation": (
194 | [
195 | "lanczos",
196 | "bislerp",
197 | "nearest",
198 | "bilinear",
199 | "bicubic",
200 | "area",
201 | "nearest-exact",
202 | ],
203 | {"default": "lanczos"},
204 | ),
205 | "crop": (["center", "disabled"], {"default": "disabled"}),
206 | "strength": ("FLOAT", {"default": 0.9, "min": 0, "max": 1}),
207 | }
208 | }
209 |
210 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
211 | RETURN_NAMES = ("positive", "negative", "latent")
212 |
213 | CATEGORY = "conditioning/video_models"
214 | FUNCTION = "generate"
215 |
216 | DESCRIPTION = (
217 | "Adds a conditioning frame or a video at index 0. "
218 | "This node is used to add a keyframe or a video segment which should appear in the "
219 | "generated video at index 0. It resizes the image to the correct size "
220 | "and applies preprocessing to it."
221 | )
222 |
223 | def generate(
224 | self,
225 | positive,
226 | negative,
227 | vae,
228 | image,
229 | width,
230 | height,
231 | length,
232 | batch_size,
233 | crf,
234 | blur_radius,
235 | interpolation,
236 | crop,
237 | strength,
238 | ):
239 | image = comfy.utils.common_upscale(
240 | image.movedim(-1, 1), width, height, interpolation, crop=crop
241 | ).movedim(1, -1)
242 | image = nodes_lt.LTXVPreprocess().preprocess(image, crf)[0]
243 | image = blur_internal(image, blur_radius)
244 | return nodes_lt.LTXVImgToVideo().generate(
245 | positive=positive,
246 | negative=negative,
247 | vae=vae,
248 | image=image,
249 | width=width,
250 | height=height,
251 | length=length,
252 | batch_size=batch_size,
253 | strength=strength,
254 | )
255 |
--------------------------------------------------------------------------------
/latent_adain.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 | from .nodes_registry import comfy_node
6 |
7 |
8 | @comfy_node(name="LTXVAdainLatent")
9 | class LTXVAdainLatent:
10 | def __init__(self):
11 | pass
12 |
13 | @classmethod
14 | def INPUT_TYPES(s):
15 | return {
16 | "required": {
17 | "latents": ("LATENT",),
18 | "reference": ("LATENT",),
19 | "factor": (
20 | "FLOAT",
21 | {
22 | "default": 1.0,
23 | "min": -10.0,
24 | "max": 10.0,
25 | "step": 0.01,
26 | "round": 0.01,
27 | },
28 | ),
29 | },
30 | }
31 |
32 | RETURN_TYPES = ("LATENT",)
33 | FUNCTION = "batch_normalize"
34 |
35 | CATEGORY = "Lightricks/latents"
36 |
37 | def batch_normalize(self, latents, reference, factor):
38 | latents_copy = copy.deepcopy(latents)
39 | t = latents_copy["samples"] # B x C x F x H x W
40 |
41 | for i in range(t.size(0)): # batch
42 | for c in range(t.size(1)): # channel
43 | r_sd, r_mean = torch.std_mean(
44 | reference["samples"][i, c], dim=None
45 | ) # index by original dim order
46 | i_sd, i_mean = torch.std_mean(t[i, c], dim=None)
47 |
48 | t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean
49 |
50 | latents_copy["samples"] = torch.lerp(latents["samples"], t, factor)
51 | return (latents_copy,)
52 |
--------------------------------------------------------------------------------
/latent_upsampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import folder_paths
4 | import torch
5 | import torch.nn as nn
6 | from comfy import model_management
7 | from diffusers import ConfigMixin, ModelMixin
8 | from einops import rearrange
9 |
10 | from .nodes_registry import comfy_node
11 |
12 |
13 | class PixelShuffle3D(nn.Module):
14 | def __init__(self, upscale_factor: int = 2):
15 | super().__init__()
16 | self.r = upscale_factor
17 |
18 | def forward(self, x):
19 | b, c, f, h, w = x.shape
20 | r = self.r
21 | out_c = c // (r**3)
22 | x = x.view(b, out_c, r, r, r, f, h, w)
23 | x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) # (b, out_c, f, r, h, r, w, r)
24 | x = x.reshape(b, out_c, f * r, h * r, w * r)
25 | return x
26 |
27 |
28 | class PixelShuffle2D(nn.Module):
29 | def __init__(self, upscale_factor: int = 2):
30 | super().__init__()
31 | self.r = upscale_factor
32 |
33 | def forward(self, x):
34 | b, c, h, w = x.size()
35 | r = self.r
36 | out_c = c // (r * r)
37 | x = x.view(b, out_c, r, r, h, w)
38 | x = x.permute(0, 1, 4, 2, 5, 3) # (b, out_c, h, r, w, r)
39 | x = x.reshape(b, out_c, h * r, w * r)
40 | return x
41 |
42 |
43 | class PixelShuffle1D(nn.Module):
44 | def __init__(self, upscale_factor: int = 2):
45 | super().__init__()
46 | self.r = upscale_factor
47 |
48 | def forward(self, x):
49 | b, c, f, h, w = x.shape
50 | r = self.r
51 | out_c = c // r
52 | x = x.view(b, out_c, r, f, h, w) # [B, C//r, r, F, H, W]
53 | x = x.permute(0, 1, 3, 2, 4, 5) # [B, C//r, F, r, H, W]
54 | x = x.reshape(b, out_c, f * r, h, w)
55 | return x
56 |
57 |
58 | class ResBlock(nn.Module):
59 | def __init__(
60 | self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
61 | ):
62 | super().__init__()
63 | if mid_channels is None:
64 | mid_channels = channels
65 |
66 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d
67 |
68 | self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
69 | self.norm1 = nn.GroupNorm(32, mid_channels)
70 | self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
71 | self.norm2 = nn.GroupNorm(32, channels)
72 | self.activation = nn.SiLU()
73 |
74 | def forward(self, x: torch.Tensor) -> torch.Tensor:
75 | residual = x
76 | x = self.conv1(x)
77 | x = self.norm1(x)
78 | x = self.activation(x)
79 | x = self.conv2(x)
80 | x = self.norm2(x)
81 | x = self.activation(x + residual)
82 | return x
83 |
84 |
85 | class LatentUpsampler(ModelMixin, ConfigMixin):
86 | """
87 | Model to spatially upsample VAE latents.
88 |
89 | Args:
90 | in_channels (`int`): Number of channels in the input latent
91 | mid_channels (`int`): Number of channels in the middle layers
92 | num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
93 | dims (`int`): Number of dimensions for convolutions (2 or 3)
94 | spatial_upsample (`bool`): Whether to spatially upsample the latent
95 | temporal_upsample (`bool`): Whether to temporally upsample the latent
96 | """
97 |
98 | def __init__(
99 | self,
100 | in_channels: int = 128,
101 | mid_channels: int = 512,
102 | num_blocks_per_stage: int = 4,
103 | dims: int = 3,
104 | spatial_upsample: bool = True,
105 | temporal_upsample: bool = False,
106 | ):
107 | super().__init__()
108 |
109 | self.in_channels = in_channels
110 | self.mid_channels = mid_channels
111 | self.num_blocks_per_stage = num_blocks_per_stage
112 | self.dims = dims
113 | self.spatial_upsample = spatial_upsample
114 | self.temporal_upsample = temporal_upsample
115 |
116 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d
117 |
118 | self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
119 | self.initial_norm = nn.GroupNorm(32, mid_channels)
120 | self.initial_activation = nn.SiLU()
121 |
122 | self.res_blocks = nn.ModuleList(
123 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
124 | )
125 |
126 | if spatial_upsample and temporal_upsample:
127 | self.upsampler = nn.Sequential(
128 | nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
129 | PixelShuffle3D(2),
130 | )
131 | elif spatial_upsample:
132 | self.upsampler = nn.Sequential(
133 | nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
134 | PixelShuffle2D(2),
135 | )
136 | elif temporal_upsample:
137 | self.upsampler = nn.Sequential(
138 | nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
139 | PixelShuffle1D(2),
140 | )
141 | else:
142 | raise ValueError(
143 | "Either spatial_upsample or temporal_upsample must be True"
144 | )
145 |
146 | self.post_upsample_res_blocks = nn.ModuleList(
147 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
148 | )
149 |
150 | self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
151 |
152 | def forward(self, latent: torch.Tensor) -> torch.Tensor:
153 | b, c, f, h, w = latent.shape
154 |
155 | if self.dims == 2:
156 | x = rearrange(latent, "b c f h w -> (b f) c h w")
157 | x = self.initial_conv(x)
158 | x = self.initial_norm(x)
159 | x = self.initial_activation(x)
160 |
161 | for block in self.res_blocks:
162 | x = block(x)
163 |
164 | x = self.upsampler(x)
165 |
166 | for block in self.post_upsample_res_blocks:
167 | x = block(x)
168 |
169 | x = self.final_conv(x)
170 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
171 | else:
172 | x = self.initial_conv(latent)
173 | x = self.initial_norm(x)
174 | x = self.initial_activation(x)
175 |
176 | for block in self.res_blocks:
177 | x = block(x)
178 |
179 | if self.temporal_upsample:
180 | x = self.upsampler(x)
181 | x = x[:, :, 1:, :, :]
182 | else:
183 | x = rearrange(x, "b c f h w -> (b f) c h w")
184 | x = self.upsampler(x)
185 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
186 |
187 | for block in self.post_upsample_res_blocks:
188 | x = block(x)
189 |
190 | x = self.final_conv(x)
191 |
192 | return x
193 |
194 | @classmethod
195 | def from_config(cls, config):
196 | return cls(
197 | in_channels=config.get("in_channels", 4),
198 | mid_channels=config.get("mid_channels", 128),
199 | num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
200 | dims=config.get("dims", 2),
201 | spatial_upsample=config.get("spatial_upsample", True),
202 | temporal_upsample=config.get("temporal_upsample", False),
203 | )
204 |
205 | def config(self):
206 | return {
207 | "_class_name": "LatentUpsampler",
208 | "in_channels": self.in_channels,
209 | "mid_channels": self.mid_channels,
210 | "num_blocks_per_stage": self.num_blocks_per_stage,
211 | "dims": self.dims,
212 | "spatial_upsample": self.spatial_upsample,
213 | "temporal_upsample": self.temporal_upsample,
214 | }
215 |
216 | def load_weights(self, weights_path: str) -> None:
217 | """
218 | Load model weights from a .safetensors file and switch to evaluation mode.
219 |
220 | Args:
221 | weights_path (str): Path to the .safetensors file containing the model weights
222 |
223 | Raises:
224 | RuntimeError: If there are missing or unexpected keys in the state dict
225 | """
226 | import safetensors.torch
227 |
228 | sd = safetensors.torch.load_file(weights_path)
229 | self.load_state_dict(sd, strict=False, assign=True)
230 | # Switch to evaluation mode
231 | self.eval()
232 |
233 |
234 | @comfy_node(name="LTXVLatentUpsampler")
235 | class LTXVLatentUpsampler:
236 | """
237 | Upsamples a video latent by a factor of 2.
238 | """
239 |
240 | @classmethod
241 | def INPUT_TYPES(s):
242 | return {
243 | "required": {
244 | "samples": ("LATENT",),
245 | "upscale_model": ("UPSCALE_MODEL",),
246 | "vae": ("VAE",),
247 | }
248 | }
249 |
250 | RETURN_TYPES = ("LATENT",)
251 | FUNCTION = "upsample_latent"
252 | CATEGORY = "latent/video"
253 |
254 | def upsample_latent(
255 | self, samples: dict, upscale_model: LatentUpsampler, vae
256 | ) -> tuple:
257 | """
258 | Upsample the input latent using the provided model.
259 |
260 | Args:
261 | samples (dict): Input latent samples
262 | upscale_model (LatentUpsampler): Loaded upscale model
263 |
264 | Returns:
265 | tuple: Tuple containing the upsampled latent
266 | """
267 | latents = samples["samples"]
268 |
269 | # Ensure latents are on the same device as the model
270 | if latents.device != upscale_model.device:
271 | latents = latents.to(upscale_model.device)
272 | latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
273 | upsampled_latents = upscale_model(latents)
274 | upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
275 | upsampled_latents
276 | )
277 | upsampled_latents = upsampled_latents.to(model_management.intermediate_device())
278 | return_dict = samples.copy()
279 | return_dict["samples"] = upsampled_latents
280 | return (return_dict,)
281 |
282 |
283 | @comfy_node(name="LTXVLatentUpsamplerModelLoader")
284 | class LTXVLatentUpsamplerModelLoader:
285 | """
286 | Loads a latent upsampler model from a .safetensors file.
287 | """
288 |
289 | @classmethod
290 | def INPUT_TYPES(s):
291 | return {
292 | "required": {
293 | "upscale_model": (folder_paths.get_filename_list("upscale_models"),),
294 | "spatial_upsample": ("BOOLEAN", {"default": True}),
295 | "temporal_upsample": ("BOOLEAN", {"default": False}),
296 | }
297 | }
298 |
299 | RETURN_TYPES = ("UPSCALE_MODEL",)
300 | FUNCTION = "load_model"
301 | CATEGORY = "latent/video"
302 |
303 | def load_model(
304 | self, upscale_model: str, spatial_upsample: bool, temporal_upsample: bool
305 | ) -> tuple:
306 | """
307 | Load the upscale model from the specified file.
308 |
309 | Args:
310 | upscale_model (str): Name of the upscale model file
311 |
312 | Returns:
313 | tuple: Tuple containing the loaded model
314 | """
315 | upscale_model_path = folder_paths.get_full_path("upscale_models", upscale_model)
316 | if upscale_model_path is None:
317 | raise ValueError(f"Upscale model {upscale_model} not found")
318 |
319 | try:
320 | latent_upsampler = LatentUpsampler(
321 | num_blocks_per_stage=4,
322 | dims=3,
323 | spatial_upsample=spatial_upsample,
324 | temporal_upsample=temporal_upsample,
325 | )
326 | latent_upsampler.load_weights(upscale_model_path)
327 | except Exception as e:
328 | raise ValueError(
329 | f"Failed to initialize LatentUpsampler with this configuration: {str(e)}"
330 | )
331 |
332 | latent_upsampler.eval()
333 | # Move model to appropriate device
334 | device = model_management.get_torch_device()
335 | latent_upsampler.to(device)
336 |
337 | return (latent_upsampler,)
338 |
--------------------------------------------------------------------------------
/latents.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import comfy_extras.nodes_lt as nodes_lt
4 | import torch
5 |
6 | from .nodes_registry import comfy_node
7 |
8 |
9 | @comfy_node(name="LTXVSelectLatents")
10 | class LTXVSelectLatents:
11 | """
12 | Selects a range of frames from a video latent.
13 |
14 | Features:
15 | - Supports positive and negative indexing
16 | - Preserves batch processing capabilities
17 | - Handles noise masks if present
18 | - Maintains 5D tensor format
19 | """
20 |
21 | @classmethod
22 | def INPUT_TYPES(s):
23 | return {
24 | "required": {
25 | "samples": ("LATENT",),
26 | "start_index": (
27 | "INT",
28 | {"default": 0, "min": -9999, "max": 9999, "step": 1},
29 | ),
30 | "end_index": (
31 | "INT",
32 | {"default": -1, "min": -9999, "max": 9999, "step": 1},
33 | ),
34 | }
35 | }
36 |
37 | RETURN_TYPES = ("LATENT",)
38 | FUNCTION = "select_latents"
39 | CATEGORY = "latent/video"
40 | DESCRIPTION = (
41 | "Selects a range of frames from the video latent. "
42 | "start_index and end_index define a closed interval (inclusive of both endpoints)."
43 | )
44 |
45 | def select_latents(self, samples: dict, start_index: int, end_index: int) -> tuple:
46 | """
47 | Selects a range of frames from the video latent.
48 |
49 | Args:
50 | samples (dict): Video latent dictionary
51 | start_index (int): Starting frame index (supports negative indexing)
52 | end_index (int): Ending frame index (supports negative indexing)
53 |
54 | Returns:
55 | tuple: Contains modified latent dictionary with selected frames
56 |
57 | Raises:
58 | ValueError: If indices are invalid
59 | """
60 | try:
61 | s = samples.copy()
62 | video_latent = s["samples"]
63 | batch, channels, frames, height, width = video_latent.shape
64 |
65 | # Handle negative indices
66 | start_idx = frames + start_index if start_index < 0 else start_index
67 | end_idx = frames + end_index if end_index < 0 else end_index
68 |
69 | # Validate and clamp indices
70 | start_idx = max(0, min(start_idx, frames - 1))
71 | end_idx = max(0, min(end_idx, frames - 1))
72 | if start_idx > end_idx:
73 | start_idx = min(start_idx, end_idx)
74 |
75 | # Select frames while maintaining 5D format
76 | s["samples"] = video_latent[:, :, start_idx : end_idx + 1, :, :]
77 |
78 | # Handle noise mask if present
79 | if "noise_mask" in s:
80 | s["noise_mask"] = s["noise_mask"][:, :, start_idx : end_idx + 1, :, :]
81 |
82 | return (s,)
83 |
84 | except Exception as e:
85 | print(f"[LTXVSelectLatents] Error: {str(e)}")
86 | raise
87 |
88 |
89 | @comfy_node(name="LTXVAddLatents")
90 | class LTXVAddLatents:
91 | """
92 | Concatenates two video latents along the frames dimension.
93 |
94 | Features:
95 | - Validates dimension compatibility
96 | - Handles device placement
97 | - Preserves noise masks with proper handling
98 | - Supports batch processing
99 | """
100 |
101 | @classmethod
102 | def INPUT_TYPES(s):
103 | return {
104 | "required": {
105 | "latents1": ("LATENT",),
106 | "latents2": ("LATENT",),
107 | }
108 | }
109 |
110 | RETURN_TYPES = ("LATENT",)
111 | FUNCTION = "add_latents"
112 | CATEGORY = "latent/video"
113 | DESCRIPTION = (
114 | "Concatenates two video latents along the frames dimension. "
115 | "latents1 and latents2 must have the same dimensions except for the frames dimension."
116 | )
117 |
118 | def add_latents(
119 | self, latents1: torch.Tensor, latents2: torch.Tensor
120 | ) -> torch.Tensor:
121 | """
122 | Concatenates two video latents along the frames dimension.
123 |
124 | Args:
125 | latents1 (dict): First video latent dictionary
126 | latents2 (dict): Second video latent dictionary
127 |
128 | Returns:
129 | tuple: Contains concatenated latent dictionary
130 |
131 | Raises:
132 | ValueError: If latent dimensions don't match
133 | RuntimeError: If tensor operations fail
134 | """
135 | try:
136 | s = latents1.copy()
137 | video_latent1 = latents1["samples"]
138 | video_latent2 = latents2["samples"]
139 |
140 | # Ensure tensors are on the same device
141 | target_device = video_latent1.device
142 | video_latent2 = video_latent2.to(target_device)
143 |
144 | # Validate dimensions
145 | self._validate_dimensions(video_latent1, video_latent2)
146 |
147 | # Concatenate along frames dimension
148 | s["samples"] = torch.cat([video_latent1, video_latent2], dim=2)
149 |
150 | # Handle noise masks
151 | s["noise_mask"] = self._merge_noise_masks(
152 | latents1, latents2, video_latent1.shape[2], video_latent2.shape[2]
153 | )
154 |
155 | return (s,)
156 |
157 | except Exception as e:
158 | print(f"[LTXVAddLatents] Error: {str(e)}")
159 | raise
160 |
161 | def _validate_dimensions(self, latent1: torch.Tensor, latent2: torch.Tensor):
162 | """Validates that latent dimensions match except for frames."""
163 | b1, c1, f1, h1, w1 = latent1.shape
164 | b2, c2, f2, h2, w2 = latent2.shape
165 |
166 | if not (b1 == b2 and c1 == c2 and h1 == h2 and w1 == w2):
167 | raise ValueError(
168 | f"Latent dimensions must match (except frames dimension).\n"
169 | f"Got shapes {latent1.shape} and {latent2.shape}"
170 | )
171 |
172 | def _merge_noise_masks(
173 | self, latents1: torch.Tensor, latents2: torch.Tensor, frames1: int, frames2: int
174 | ) -> Optional[torch.Tensor]:
175 | """Merges noise masks from both latents with proper handling."""
176 | if "noise_mask" in latents1 and "noise_mask" in latents2:
177 | return torch.cat([latents1["noise_mask"], latents2["noise_mask"]], dim=2)
178 | elif "noise_mask" in latents1:
179 | zeros = torch.zeros_like(latents1["noise_mask"][:, :, :frames2, :, :])
180 | return torch.cat([latents1["noise_mask"], zeros], dim=2)
181 | elif "noise_mask" in latents2:
182 | zeros = torch.zeros_like(latents2["noise_mask"][:, :, :frames1, :, :])
183 | return torch.cat([zeros, latents2["noise_mask"]], dim=2)
184 | return None
185 |
186 |
187 | @comfy_node(name="LTXVSetVideoLatentNoiseMasks")
188 | class LTXVSetVideoLatentNoiseMasks:
189 | """
190 | Applies multiple masks to a video latent.
191 |
192 | Features:
193 | - Supports multiple input mask formats (2D, 3D, 4D)
194 | - Automatically handles fewer masks than frames by reusing the last mask
195 | - Resizes masks to match latent dimensions
196 | - Preserves batch processing capabilities
197 |
198 | Input Formats:
199 | - 2D mask: Single mask [H, W]
200 | - 3D mask: Multiple masks [M, H, W]
201 | - 4D mask: Multiple masks with channels [M, C, H, W]
202 | """
203 |
204 | @classmethod
205 | def INPUT_TYPES(s):
206 | return {
207 | "required": {
208 | "samples": ("LATENT",),
209 | "masks": ("MASK",),
210 | }
211 | }
212 |
213 | RETURN_TYPES = ("LATENT",)
214 | FUNCTION = "set_mask"
215 | CATEGORY = "latent/video"
216 | DESCRIPTION = (
217 | "Applies multiple masks to a video latent. "
218 | "masks can be 2D, 3D, or 4D tensors. "
219 | "If there are fewer masks than frames, the last mask will be reused."
220 | )
221 |
222 | def set_mask(self, samples: dict, masks: torch.Tensor) -> tuple:
223 | """
224 | Applies masks to video latent frames.
225 |
226 | Args:
227 | samples (dict): Video latent dictionary containing 'samples' tensor
228 | masks (torch.Tensor): Mask tensor in various possible formats
229 | - 2D: [H, W] single mask
230 | - 3D: [M, H, W] multiple masks
231 | - 4D: [M, C, H, W] multiple masks with channels
232 |
233 | Returns:
234 | tuple: Contains modified latent dictionary with applied masks
235 |
236 | Raises:
237 | ValueError: If mask dimensions are unsupported
238 | RuntimeError: If tensor operations fail
239 | """
240 | try:
241 | s = samples.copy()
242 | video_latent = s["samples"]
243 | batch_size, channels, num_frames, height, width = video_latent.shape
244 |
245 | # Initialize noise_mask if not present
246 | if "noise_mask" not in s:
247 | s["noise_mask"] = torch.zeros(
248 | (batch_size, 1, num_frames, height, width),
249 | dtype=video_latent.dtype,
250 | device=video_latent.device,
251 | )
252 |
253 | # Process masks
254 | masks_reshaped = self._reshape_masks(masks)
255 | M = masks_reshaped.shape[0]
256 | resized_masks = self._resize_masks(masks_reshaped, height, width)
257 |
258 | # Apply masks efficiently
259 | self._apply_masks(s["noise_mask"], resized_masks, num_frames, M)
260 | return (s,)
261 |
262 | except Exception as e:
263 | print(f"[LTXVSetVideoLatentNoiseMasks] Error: {str(e)}")
264 | raise
265 |
266 | def _reshape_masks(self, masks: torch.Tensor) -> torch.Tensor:
267 | """Reshapes input masks to consistent 4D format."""
268 | original_shape = tuple(masks.shape)
269 | ndims = masks.ndim
270 |
271 | if ndims == 2:
272 | return masks.unsqueeze(0).unsqueeze(0)
273 | elif ndims == 3:
274 | return masks.reshape(masks.shape[0], 1, masks.shape[1], masks.shape[2])
275 | elif ndims == 4:
276 | return masks.reshape(masks.shape[0], 1, masks.shape[2], masks.shape[3])
277 | else:
278 | raise ValueError(
279 | f"Unsupported 'masks' dimension: {original_shape}. "
280 | "Must be 2D (H,W), 3D (M,H,W), or 4D (M,C,H,W)."
281 | )
282 |
283 | def _resize_masks(
284 | self, masks: torch.Tensor, height: int, width: int
285 | ) -> torch.Tensor:
286 | """Resizes all masks to match latent dimensions."""
287 | return torch.nn.functional.interpolate(
288 | masks, size=(height, width), mode="bilinear", align_corners=False
289 | )
290 |
291 | def _apply_masks(
292 | self,
293 | noise_mask: torch.Tensor,
294 | resized_masks: torch.Tensor,
295 | num_frames: int,
296 | M: int,
297 | ) -> None:
298 | """Applies resized masks to all frames."""
299 | for f in range(num_frames):
300 | mask_idx = min(f, M - 1) # Reuse last mask if we run out
301 | noise_mask[:, :, f] = resized_masks[mask_idx]
302 |
303 |
304 | @comfy_node(name="LTXVAddLatentGuide")
305 | class LTXVAddLatentGuide(nodes_lt.LTXVAddGuide):
306 | @classmethod
307 | def INPUT_TYPES(s):
308 | return {
309 | "required": {
310 | "vae": ("VAE",),
311 | "positive": ("CONDITIONING",),
312 | "negative": ("CONDITIONING",),
313 | "latent": ("LATENT",),
314 | "guiding_latent": ("LATENT",),
315 | "latent_idx": (
316 | "INT",
317 | {
318 | "default": 0,
319 | "min": -9999,
320 | "max": 9999,
321 | "step": 1,
322 | "tooltip": "Latent index to start the conditioning at. Can be negative to"
323 | "indicate that the conditioning is on the frames before the latent.",
324 | },
325 | ),
326 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
327 | }
328 | }
329 |
330 | RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
331 | RETURN_NAMES = ("positive", "negative", "latent")
332 |
333 | CATEGORY = "ltxtricks"
334 | FUNCTION = "generate"
335 |
336 | DESCRIPTION = "Adds a keyframe or a video segment at a specific frame index."
337 |
338 | def generate(
339 | self, vae, positive, negative, latent, guiding_latent, latent_idx, strength
340 | ):
341 | noise_mask = nodes_lt.get_noise_mask(latent)
342 | latent = latent["samples"]
343 | guiding_latent = guiding_latent["samples"]
344 | scale_factors = vae.downscale_index_formula
345 |
346 | if latent_idx <= 0:
347 | frame_idx = latent_idx * scale_factors[0]
348 | else:
349 | frame_idx = 1 + (latent_idx - 1) * scale_factors[0]
350 |
351 | positive, negative, latent, noise_mask = self.append_keyframe(
352 | positive=positive,
353 | negative=negative,
354 | frame_idx=frame_idx,
355 | latent_image=latent,
356 | noise_mask=noise_mask,
357 | guiding_latent=guiding_latent,
358 | strength=strength,
359 | scale_factors=scale_factors,
360 | )
361 |
362 | return (
363 | positive,
364 | negative,
365 | {"samples": latent, "noise_mask": noise_mask},
366 | )
367 |
--------------------------------------------------------------------------------
/nodes_registry.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Callable, Optional, Type
3 |
4 | NODE_CLASS_MAPPINGS = {}
5 | NODE_DISPLAY_NAME_MAPPINGS = {}
6 |
7 | NODES_DISPLAY_NAME_PREFIX = "🅛🅣🅧"
8 | EXPERIMENTAL_DISPLAY_NAME_PREFIX = "(Experimental 🧪)"
9 | DEPRECATED_DISPLAY_NAME_PREFIX = "(Deprecated 🚫)"
10 | DEFAULT_CATEGORY_NAME = "Lightricks"
11 |
12 |
13 | def register_node(node_class: Type, name: str, description: str) -> None:
14 | """
15 | Register a ComfyUI node class to ComfyUI's global nodes' registry.
16 |
17 | Args:
18 | node_class (Type): The class of the node to be registered.
19 | name (str): The name of the node.
20 | description (str): The short user-friendly description of the node.
21 |
22 | Raises:
23 | ValueError: If `node_class` is not a class, or `class_name` or `display_name` is not a string.
24 | """
25 |
26 | if not isinstance(node_class, type):
27 | raise ValueError("`node_class` must be a class")
28 |
29 | if not isinstance(name, str):
30 | raise ValueError("`name` must be a string")
31 |
32 | if not isinstance(description, str):
33 | raise ValueError("`description` must be a string")
34 |
35 | NODE_CLASS_MAPPINGS[name] = node_class
36 | NODE_DISPLAY_NAME_MAPPINGS[name] = description
37 |
38 |
39 | def comfy_node(
40 | node_class: Optional[Type] = None,
41 | name: Optional[str] = None,
42 | description: Optional[str] = None,
43 | experimental: bool = False,
44 | deprecated: bool = False,
45 | skip: bool = False,
46 | ) -> Callable:
47 | """
48 | Decorator for registering a node class with optional name, description, and status flags.
49 |
50 | Args:
51 | node_class (Type): The class of the node to be registered.
52 | name (str, optional): The name of the class. If not provided, the class name will be used.
53 | description (str, optional): The description of the class.
54 | If not provided, an auto-formatted description will be used based on the class name.
55 | experimental (bool): Flag indicating if the class is experimental. Defaults to False.
56 | deprecated (bool): Flag indicating if the class is deprecated. Defaults to False.
57 | skip (bool): Flag indicating if the node registration should be skipped. Defaults to False.
58 | This is useful for conditionally registering nodes based on certain conditions
59 | (e.g. unavailability of certain dependencies).
60 |
61 | Returns:
62 | Callable: The decorator function.
63 |
64 | Raises:
65 | ValueError: If `node_class` is not a class.
66 | """
67 |
68 | def decorator(node_class: Type) -> Type:
69 | if skip:
70 | return node_class
71 |
72 | if not isinstance(node_class, type):
73 | raise ValueError("`node_class` must be a class")
74 |
75 | nonlocal name, description
76 | if name is None:
77 | name = node_class.__name__
78 |
79 | # Remove possible "Node" suffix from the class name, e.g. "EditImageNode -> EditImage"
80 | if name is not None and name.endswith("Node"):
81 | name = name[:-4]
82 |
83 | description = _format_description(description, name, experimental, deprecated)
84 |
85 | register_node(node_class, name, description)
86 | return node_class
87 |
88 | # If the decorator is used without parentheses
89 | if node_class is None:
90 | return decorator
91 | else:
92 | return decorator(node_class)
93 |
94 |
95 | def _format_description(
96 | description: str, class_name: str, experimental: bool, deprecated: bool
97 | ) -> str:
98 | """Format nodes display name to a standard format"""
99 |
100 | # If description is not provided, auto-generate one based on the class name
101 | if description is None:
102 | description = camel_case_to_spaces(class_name)
103 |
104 | # Strip the prefix if it's already there
105 | prefix_len = len(NODES_DISPLAY_NAME_PREFIX)
106 | if description.startswith(NODES_DISPLAY_NAME_PREFIX):
107 | description = description[prefix_len:].lstrip()
108 |
109 | # Add the deprecated / experimental prefixes
110 | if deprecated:
111 | description = f"{DEPRECATED_DISPLAY_NAME_PREFIX} {description}"
112 | elif experimental:
113 | description = f"{EXPERIMENTAL_DISPLAY_NAME_PREFIX} {description}"
114 |
115 | # Add the prefix
116 | description = f"{NODES_DISPLAY_NAME_PREFIX} {description}"
117 |
118 | return description
119 |
120 |
121 | def camel_case_to_spaces(text: str) -> str:
122 | # Add space before each capital letter except the first one
123 | spaced_text = re.sub(r"(?<=[a-z])(?=[A-Z])", " ", text)
124 | # Handle sequences of uppercase letters followed by a lowercase letter
125 | spaced_text = re.sub(r"(?<=[A-Z])(?=[A-Z][a-z])", " ", spaced_text)
126 | # Handle sequences of uppercase letters not followed by a lowercase letter
127 | spaced_text = re.sub(r"(?<=[A-Z])(?=[A-Z][A-Z][a-z])", " ", spaced_text)
128 | return spaced_text
129 |
--------------------------------------------------------------------------------
/presets/stg_advanced_presets.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "name": "Custom"
4 | },
5 | {
6 | "name": "13b Dynamic",
7 | "skip_steps_sigma_threshold": 0.997,
8 | "cfg_star_rescale": true,
9 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180],
10 | "cfg_values": [1, 6, 8, 6, 1, 1],
11 | "stg_scale_values": [0, 4, 4, 4, 2, 1],
12 | "stg_rescale_values": [1, 0.5, 0.5, 1, 1, 1],
13 | "stg_layers_indices": [[11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
14 | },
15 | {
16 | "name": "13b Balanced",
17 | "skip_steps_sigma_threshold": 0.998,
18 | "cfg_star_rescale": true,
19 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180],
20 | "cfg_values": [1, 6, 8, 6, 1, 1],
21 | "stg_scale_values": [0, 4, 4, 4, 2, 1],
22 | "stg_rescale_values": [1, 0.5, 0.5, 1, 1, 1],
23 | "stg_layers_indices": [[12], [12], [5], [5], [28], [29]]
24 | },
25 | {
26 | "name": "13b Upscale",
27 | "skip_steps_sigma_threshold": 0.997,
28 | "cfg_star_rescale": true,
29 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180],
30 | "cfg_values": [1, 1, 1, 1, 1, 1],
31 | "stg_scale_values": [1, 1, 1, 1, 1, 1],
32 | "stg_rescale_values": [1, 1, 1, 1, 1, 1],
33 | "stg_layers_indices": [[42], [42], [42], [42], [42], [42]]
34 | },
35 | {
36 | "name": "2b",
37 | "skip_steps_sigma_threshold": 0.997,
38 | "cfg_star_rescale": true,
39 | "sigmas": [1.0, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180],
40 | "cfg_values": [4, 4, 4, 4, 1, 1],
41 | "stg_scale_values": [2, 2, 2, 2, 1, 0],
42 | "stg_rescale_values": [1, 1, 1, 1, 1, 1],
43 | "stg_layers_indices": [[14], [14], [14], [14], [14], [14]]
44 | }
45 | ]
46 |
--------------------------------------------------------------------------------
/prompt_enhancer_nodes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import comfy.model_management
5 | import comfy.model_patcher
6 | import folder_paths
7 | import torch
8 | from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
9 |
10 | from .nodes_registry import comfy_node
11 | from .prompt_enhancer_utils import generate_cinematic_prompt
12 |
13 | LLM_NAME = ["unsloth/Llama-3.2-3B-Instruct"]
14 |
15 | IMAGE_CAPTIONER = ["MiaoshouAI/Florence-2-large-PromptGen-v2.0"]
16 |
17 | MODELS_PATH_KEY = "LLM"
18 |
19 |
20 | class PromptEnhancer(torch.nn.Module):
21 | def __init__(
22 | self,
23 | image_caption_processor: AutoProcessor,
24 | image_caption_model: AutoModelForCausalLM,
25 | llm_model: AutoModelForCausalLM,
26 | llm_tokenizer: AutoTokenizer,
27 | ):
28 | super().__init__()
29 | self.image_caption_processor = image_caption_processor
30 | self.image_caption_model = image_caption_model
31 | self.llm_model = llm_model
32 | self.llm_tokenizer = llm_tokenizer
33 | self.device = image_caption_model.device
34 | # model parameters and buffer sizes plus some extra 1GB.
35 | self.model_size = (
36 | self.get_model_size(self.image_caption_model)
37 | + self.get_model_size(self.llm_model)
38 | + 1073741824
39 | )
40 |
41 | def forward(self, prompt, image_conditioning, max_resulting_tokens):
42 | enhanced_prompt = generate_cinematic_prompt(
43 | self.image_caption_model,
44 | self.image_caption_processor,
45 | self.llm_model,
46 | self.llm_tokenizer,
47 | prompt,
48 | image_conditioning,
49 | max_new_tokens=max_resulting_tokens,
50 | )
51 |
52 | return enhanced_prompt
53 |
54 | @staticmethod
55 | def get_model_size(model):
56 | total_size = sum(p.numel() * p.element_size() for p in model.parameters())
57 | total_size += sum(b.numel() * b.element_size() for b in model.buffers())
58 | return total_size
59 |
60 | def memory_required(self, input_shape):
61 | return self.model_size
62 |
63 |
64 | @comfy_node(name="LTXVPromptEnhancerLoader")
65 | class LTXVPromptEnhancerLoader:
66 | @classmethod
67 | def INPUT_TYPES(s):
68 | return {
69 | "required": {
70 | "llm_name": (
71 | "STRING",
72 | {
73 | "default": LLM_NAME,
74 | "tooltip": "The hugging face name of the llm model to load.",
75 | },
76 | ),
77 | "image_captioner_name": (
78 | "STRING",
79 | {
80 | "default": IMAGE_CAPTIONER,
81 | "tooltip": "The hugging face name of the image captioning model to load.",
82 | },
83 | ),
84 | }
85 | }
86 |
87 | RETURN_TYPES = ("LTXV_PROMPT_ENHANCER",)
88 | RETURN_NAMES = ("prompt_enhancer",)
89 | FUNCTION = "load"
90 | CATEGORY = "lightricks/LTXV"
91 | TITLE = "LTXV Prompt Enhancer (Down)Loader"
92 | OUTPUT_NODE = False
93 | DESCRIPTION = "Downloads and initializes LLM and image captioning models from Hugging Face to enhance text prompts for image generation."
94 |
95 | def model_path_download_if_needed(self, model_name):
96 | model_directory = os.path.join(folder_paths.models_dir, MODELS_PATH_KEY)
97 | os.makedirs(model_directory, exist_ok=True)
98 |
99 | model_name_ = model_name.rsplit("/", 1)[-1]
100 | model_path = os.path.join(model_directory, model_name_)
101 |
102 | if not os.path.exists(model_path):
103 | from huggingface_hub import snapshot_download
104 |
105 | try:
106 | snapshot_download(
107 | repo_id=model_name,
108 | local_dir=model_path,
109 | local_dir_use_symlinks=False,
110 | )
111 | except Exception:
112 | shutil.rmtree(model_path, ignore_errors=True)
113 | raise
114 | return model_path
115 |
116 | def down_load_llm_model(self, llm_name, load_device):
117 | model_path = self.model_path_download_if_needed(llm_name)
118 | llm_model = AutoModelForCausalLM.from_pretrained(
119 | model_path,
120 | torch_dtype=torch.bfloat16,
121 | )
122 |
123 | llm_tokenizer = AutoTokenizer.from_pretrained(
124 | model_path,
125 | )
126 |
127 | return llm_model, llm_tokenizer
128 |
129 | def down_load_image_captioner(self, image_captioner, load_device):
130 | model_path = self.model_path_download_if_needed(image_captioner)
131 | image_caption_model = AutoModelForCausalLM.from_pretrained(
132 | model_path, trust_remote_code=True
133 | )
134 |
135 | image_caption_processor = AutoProcessor.from_pretrained(
136 | model_path, trust_remote_code=True
137 | )
138 |
139 | return image_caption_model, image_caption_processor
140 |
141 | def load(self, llm_name, image_captioner_name):
142 | load_device = comfy.model_management.get_torch_device()
143 | offload_device = comfy.model_management.vae_offload_device()
144 | llm_model, llm_tokenizer = self.down_load_llm_model(llm_name, load_device)
145 | image_caption_model, image_caption_processor = self.down_load_image_captioner(
146 | image_captioner_name, load_device
147 | )
148 |
149 | enhancer = PromptEnhancer(
150 | image_caption_processor, image_caption_model, llm_model, llm_tokenizer
151 | )
152 | patcher = comfy.model_patcher.ModelPatcher(
153 | enhancer,
154 | load_device,
155 | offload_device,
156 | )
157 | return (patcher,)
158 |
159 |
160 | @comfy_node(name="LTXVPromptEnhancer")
161 | class LTXVPromptEnhancer:
162 | @classmethod
163 | def INPUT_TYPES(s):
164 | return {
165 | "required": {
166 | "prompt": ("STRING",),
167 | "prompt_enhancer": ("LTXV_PROMPT_ENHANCER",),
168 | "max_resulting_tokens": (
169 | "INT",
170 | {"default": 256, "min": 32, "max": 512},
171 | ),
172 | },
173 | "optional": {
174 | "image_prompt": ("IMAGE",),
175 | },
176 | }
177 |
178 | RETURN_TYPES = ("STRING",)
179 | RETURN_NAMES = ("str",)
180 | FUNCTION = "enhance"
181 | CATEGORY = "lightricks/LTXV"
182 | TITLE = "LTXV Prompt Enhancer"
183 | OUTPUT_NODE = False
184 | DESCRIPTION = (
185 | "Enhances text prompts for image generation using LLMs. "
186 | "Optionally incorporates reference images to create more contextually relevant descriptions."
187 | )
188 |
189 | def enhance(
190 | self,
191 | prompt,
192 | prompt_enhancer: comfy.model_patcher.ModelPatcher,
193 | image_prompt: torch.Tensor = None,
194 | max_resulting_tokens=256,
195 | ):
196 | comfy.model_management.free_memory(
197 | prompt_enhancer.memory_required([]),
198 | comfy.model_management.get_torch_device(),
199 | )
200 | comfy.model_management.load_model_gpu(prompt_enhancer)
201 | model = prompt_enhancer.model
202 | image_conditioning = None
203 | if image_prompt is not None:
204 | permuted_image = image_prompt.permute(3, 0, 1, 2)[None, :]
205 | image_conditioning = [(permuted_image, 0, 1.0)]
206 |
207 | enhanced_prompt = model(prompt, image_conditioning, max_resulting_tokens)
208 | return (enhanced_prompt[0],)
209 |
--------------------------------------------------------------------------------
/prompt_enhancer_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import torch
6 | from PIL import Image
7 |
8 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
9 |
10 | T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
11 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
12 | Start directly with the action, and keep descriptions literal and precise.
13 | Think like a cinematographer describing a shot list.
14 | Do not change the user input intent, just enhance it.
15 | Keep within 150 words.
16 | For best results, build your prompts using this structure:
17 | Start with main action in a single sentence
18 | Add specific details about movements and gestures
19 | Describe character/object appearances precisely
20 | Include background and environment details
21 | Specify camera angles and movements
22 | Describe lighting and colors
23 | Note any changes or sudden events
24 | Do not exceed the 150 word limit!
25 | Output the enhanced prompt only.
26 |
27 | Examples:
28 | user prompt: A man drives a toyota car.
29 | enhanced prompt: A person is driving a car on a two-lane road, holding the steering wheel with both hands. The person's hands are light-skinned and they are wearing a black long-sleeved shirt. The steering wheel has a Toyota logo in the center and black leather around it. The car's dashboard is visible, showing a speedometer, tachometer, and navigation screen. The road ahead is straight and there are trees and fields visible on either side. The camera is positioned inside the car, providing a view from the driver's perspective. The lighting is natural and overcast, with a slightly cool tone.
30 |
31 | user prompt: A young woman is sitting on a chair.
32 | enhanced prompt: A young woman with dark, curly hair and pale skin sits on a chair; she wears a dark, intricately patterned dress with a high collar and long, dark gloves that extend past her elbows; the scene is dimly lit, with light streaming in from a large window behind the characters.
33 |
34 | user prompt: Aerial view of a city skyline.
35 | enhanced prompt: The camera pans across a cityscape of tall buildings with a circular building in the center. The camera moves from left to right, showing the tops of the buildings and the circular building in the center. The buildings are various shades of gray and white, and the circular building has a green roof. The camera angle is high, looking down at the city. The lighting is bright, with the sun shining from the upper left, casting shadows from the buildings.
36 | """
37 |
38 | I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
39 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
40 | Start directly with the action, and keep descriptions literal and precise.
41 | Think like a cinematographer describing a shot list.
42 | Keep within 150 words.
43 | For best results, build your prompts using this structure:
44 | Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
45 | Start with main action in a single sentence
46 | Add specific details about movements and gestures
47 | Describe character/object appearances precisely
48 | Include background and environment details
49 | Specify camera angles and movements
50 | Describe lighting and colors
51 | Note any changes or sudden events
52 | Align to the image caption if it contradicts the user text input.
53 | Do not exceed the 150 word limit!
54 | Output the enhanced prompt only.
55 | """
56 |
57 |
58 | def tensor_to_pil(tensor):
59 | # Ensure tensor is in range [-1, 1]
60 | assert tensor.min() >= -1 and tensor.max() <= 1
61 |
62 | # Convert from [-1, 1] to [0, 1]
63 | tensor = (tensor + 1) / 2
64 |
65 | # Rearrange from [C, H, W] to [H, W, C]
66 | tensor = tensor.permute(1, 2, 0)
67 |
68 | # Convert to numpy array and then to uint8 range [0, 255]
69 | numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
70 |
71 | # Convert to PIL Image
72 | return Image.fromarray(numpy_image)
73 |
74 |
75 | def generate_cinematic_prompt(
76 | image_caption_model,
77 | image_caption_processor,
78 | prompt_enhancer_model,
79 | prompt_enhancer_tokenizer,
80 | prompt: Union[str, List[str]],
81 | conditioning_items: Optional[List[Tuple[torch.Tensor, int, float]]] = None,
82 | max_new_tokens: int = 256,
83 | ) -> List[str]:
84 | prompts = [prompt] if isinstance(prompt, str) else prompt
85 |
86 | if conditioning_items is None:
87 | prompts = _generate_t2v_prompt(
88 | prompt_enhancer_model,
89 | prompt_enhancer_tokenizer,
90 | prompts,
91 | max_new_tokens,
92 | T2V_CINEMATIC_PROMPT,
93 | )
94 | else:
95 | # if len(conditioning_items) > 1 or conditioning_items[0][1] != 0:
96 | # logger.warning(
97 | # "prompt enhancement does only support first frame of conditioning items, returning original prompts"
98 | # )
99 | # return prompts
100 |
101 | first_frame_conditioning_item = conditioning_items[0]
102 | first_frames = _get_first_frames_from_conditioning_item(
103 | first_frame_conditioning_item
104 | )
105 |
106 | assert len(first_frames) == len(
107 | prompts
108 | ), "Number of conditioning frames must match number of prompts"
109 |
110 | prompts = _generate_i2v_prompt(
111 | image_caption_model,
112 | image_caption_processor,
113 | prompt_enhancer_model,
114 | prompt_enhancer_tokenizer,
115 | prompts,
116 | first_frames,
117 | max_new_tokens,
118 | I2V_CINEMATIC_PROMPT,
119 | )
120 |
121 | return prompts
122 |
123 |
124 | def _get_first_frames_from_conditioning_item(
125 | conditioning_item: Tuple[torch.Tensor, int, float]
126 | ) -> List[Image.Image]:
127 | frames_tensor = conditioning_item[0]
128 | # tensor shape: [batch_size, 3, num_frames, height, width], take first frame from each sample
129 | return [
130 | tensor_to_pil(frames_tensor[i, :, 0, :, :])
131 | for i in range(frames_tensor.shape[0])
132 | ]
133 |
134 |
135 | def _generate_t2v_prompt(
136 | prompt_enhancer_model,
137 | prompt_enhancer_tokenizer,
138 | prompts: List[str],
139 | max_new_tokens: int,
140 | system_prompt: str,
141 | ) -> List[str]:
142 | messages = [
143 | [
144 | {"role": "system", "content": system_prompt},
145 | {"role": "user", "content": f"user_prompt: {p}"},
146 | ]
147 | for p in prompts
148 | ]
149 |
150 | texts = [
151 | prompt_enhancer_tokenizer.apply_chat_template(
152 | m, tokenize=False, add_generation_prompt=True
153 | )
154 | for m in messages
155 | ]
156 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
157 | prompt_enhancer_model.device
158 | )
159 |
160 | return _generate_and_decode_prompts(
161 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
162 | )
163 |
164 |
165 | def _generate_i2v_prompt(
166 | image_caption_model,
167 | image_caption_processor,
168 | prompt_enhancer_model,
169 | prompt_enhancer_tokenizer,
170 | prompts: List[str],
171 | first_frames: List[Image.Image],
172 | max_new_tokens: int,
173 | system_prompt: str,
174 | ) -> List[str]:
175 | image_captions = _generate_image_captions(
176 | image_caption_model, image_caption_processor, first_frames
177 | )
178 |
179 | messages = [
180 | [
181 | {"role": "system", "content": system_prompt},
182 | {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
183 | ]
184 | for p, c in zip(prompts, image_captions)
185 | ]
186 |
187 | texts = [
188 | prompt_enhancer_tokenizer.apply_chat_template(
189 | m, tokenize=False, add_generation_prompt=True
190 | )
191 | for m in messages
192 | ]
193 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
194 | prompt_enhancer_model.device
195 | )
196 |
197 | return _generate_and_decode_prompts(
198 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
199 | )
200 |
201 |
202 | def _generate_image_captions(
203 | image_caption_model,
204 | image_caption_processor,
205 | images: List[Image.Image],
206 | system_prompt: str = "",
207 | ) -> List[str]:
208 | image_caption_prompts = [system_prompt] * len(images)
209 | inputs = image_caption_processor(
210 | image_caption_prompts, images, return_tensors="pt"
211 | ).to(image_caption_model.device)
212 |
213 | with torch.inference_mode():
214 | generated_ids = image_caption_model.generate(
215 | input_ids=inputs["input_ids"],
216 | pixel_values=inputs["pixel_values"],
217 | max_new_tokens=1024,
218 | do_sample=False,
219 | num_beams=3,
220 | )
221 |
222 | return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
223 |
224 |
225 | def _get_random_scene_type():
226 | """
227 | Randomly select a scene type to add to the prompt.
228 | """
229 | types = [
230 | "The scene is captured in real-life footage.",
231 | "The scene is computer-generated imagery.",
232 | "The scene appears to be from a movie.",
233 | "The scene appears to be from a TV show.",
234 | "The scene is captured in a studio.",
235 | ]
236 | return random.choice(types)
237 |
238 |
239 | def _generate_and_decode_prompts(
240 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
241 | ) -> List[str]:
242 | with torch.inference_mode():
243 | outputs = prompt_enhancer_model.generate(
244 | **model_inputs, max_new_tokens=max_new_tokens
245 | )
246 | generated_ids = [
247 | output_ids[len(input_ids) :]
248 | for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
249 | ]
250 | decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
251 | generated_ids, skip_special_tokens=True
252 | )
253 |
254 | decoded_prompts = [p + f" {_get_random_scene_type()}." for p in decoded_prompts]
255 |
256 | print(decoded_prompts)
257 |
258 | return decoded_prompts
259 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui-ltxvideo"
3 | version = "0.1.0"
4 | description = "Custom nodes for LTX-Video support in ComfyUI"
5 | authors = [
6 | { name = "Andrew Kvochko", email = "akvochko@lightricks.com" }
7 | ]
8 | requires-python = ">=3.10"
9 | readme = "README.md"
10 | license = { file = "LICENSE" }
11 | dependencies = [
12 | "diffusers",
13 | "huggingface_hub>=0.25.2",
14 | "transformers[timm]>=4.45.0",
15 | "einops"
16 | ]
17 |
18 | [project.optional-dependencies]
19 | internal = [
20 | "ltx-video@git+https://github.com/Lightricks/LTX-Video@ltx-video-0.9.7",
21 | "av>=10.0.0",
22 | "q8-kernels==0.0.5"
23 | ]
24 | dev = [
25 | "pre-commit>=4.0.1",
26 | "pytest>=8.0.0",
27 | "websocket-client==1.6.1",
28 | "scikit-image==0.24.0"
29 | ]
30 |
31 | [tool.isort]
32 | profile = "black"
33 | line_length = 88
34 | force_single_line = false
35 |
--------------------------------------------------------------------------------
/q8_nodes.py:
--------------------------------------------------------------------------------
1 | try:
2 | from q8_kernels.integration.patch_transformer import (
3 | patch_comfyui_native_transformer,
4 | patch_comfyui_transformer,
5 | )
6 |
7 | Q8_AVAILABLE = True
8 | except ImportError:
9 | Q8_AVAILABLE = False
10 |
11 | from .nodes_registry import comfy_node
12 |
13 |
14 | def check_q8_available():
15 | if not Q8_AVAILABLE:
16 | raise ImportError(
17 | "Q8 kernels are not available. To use this feature install the q8_kernels package from here:."
18 | "https://github.com/Lightricks/LTX-Video-Q8-Kernels"
19 | )
20 |
21 |
22 | @comfy_node(name="LTXQ8Patch")
23 | class LTXVQ8Patch:
24 | @classmethod
25 | def INPUT_TYPES(s):
26 | return {
27 | "required": {
28 | "model": ("MODEL",),
29 | "use_fp8_attention": (
30 | "BOOLEAN",
31 | {"default": False, "tooltip": "Use FP8 attention."},
32 | ),
33 | }
34 | }
35 |
36 | RETURN_TYPES = ("MODEL",)
37 | FUNCTION = "patch"
38 | CATEGORY = "lightricks/LTXV"
39 | TITLE = "LTXV Q8 Patcher"
40 |
41 | def patch(self, model, use_fp8_attention):
42 | check_q8_available()
43 | m = model.clone()
44 | diffusion_key = "diffusion_model"
45 | diffusion_model = m.get_model_object(diffusion_key)
46 | if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
47 | transformer_key = "diffusion_model.transformer"
48 | patcher = patch_comfyui_transformer
49 | else:
50 | transformer_key = "diffusion_model"
51 | patcher = patch_comfyui_native_transformer
52 | transformer = m.get_model_object(transformer_key)
53 | patcher(transformer, use_fp8_attention, True)
54 | m.add_object_patch(transformer_key, transformer)
55 | return (m,)
56 |
--------------------------------------------------------------------------------
/recurrent_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy_extras.nodes_custom_sampler import CFGGuider, SamplerCustomAdvanced
3 | from comfy_extras.nodes_lt import LTXVAddGuide, LTXVCropGuides
4 |
5 | from .latents import LTXVAddLatents, LTXVSelectLatents
6 | from .nodes_registry import comfy_node
7 | from .tricks import AddLatentGuideNode
8 |
9 |
10 | @comfy_node(description="Linear transition with overlap")
11 | class LinearOverlapLatentTransition:
12 | @classmethod
13 | def INPUT_TYPES(s):
14 | return {
15 | "required": {
16 | "samples1": ("LATENT",),
17 | "samples2": ("LATENT",),
18 | "overlap": ("INT", {"default": 1, "min": 1, "max": 256}),
19 | },
20 | "optional": {
21 | "axis": ("INT", {"default": 0}),
22 | },
23 | }
24 |
25 | RETURN_TYPES = ("LATENT",)
26 | FUNCTION = "process"
27 |
28 | CATEGORY = "Lightricks/latent"
29 |
30 | def get_subbatch(self, samples):
31 | s = samples.copy()
32 | samples = s["samples"]
33 | return samples
34 |
35 | def process(self, samples1, samples2, overlap, axis=0):
36 | samples1 = self.get_subbatch(samples1)
37 | samples2 = self.get_subbatch(samples2)
38 |
39 | # Create transition coefficients
40 | alpha = torch.linspace(1, 0, overlap + 2)[1:-1].to(samples1.device)
41 |
42 | # Create shape for broadcasting based on the axis
43 | shape = [1] * samples1.dim()
44 | shape[axis] = alpha.size(0)
45 | alpha = alpha.reshape(shape)
46 |
47 | # Create slices for the overlap regions
48 | slice_all = [slice(None)] * samples1.dim()
49 | slice_overlap1 = slice_all.copy()
50 | slice_overlap1[axis] = slice(-overlap, None)
51 | slice_overlap2 = slice_all.copy()
52 | slice_overlap2[axis] = slice(0, overlap)
53 | slice_rest1 = slice_all.copy()
54 | slice_rest1[axis] = slice(None, -overlap)
55 | slice_rest2 = slice_all.copy()
56 | slice_rest2[axis] = slice(overlap, None)
57 |
58 | # Combine samples
59 | parts = [
60 | samples1[tuple(slice_rest1)],
61 | alpha * samples1[tuple(slice_overlap1)]
62 | + (1 - alpha) * samples2[tuple(slice_overlap2)],
63 | samples2[tuple(slice_rest2)],
64 | ]
65 |
66 | combined_samples = torch.cat(parts, dim=axis)
67 | combined_batch_index = torch.arange(0, combined_samples.shape[0])
68 |
69 | return (
70 | {
71 | "samples": combined_samples,
72 | "batch_index": combined_batch_index,
73 | },
74 | )
75 |
76 |
77 | @comfy_node(
78 | name="LTXVRecurrentKSampler",
79 | )
80 | class LTXVRecurrentKSampler:
81 |
82 | @classmethod
83 | def INPUT_TYPES(s):
84 | return {
85 | "required": {
86 | "model": ("MODEL",),
87 | "vae": ("VAE",),
88 | "noise": ("NOISE",),
89 | "sampler": ("SAMPLER",),
90 | "sigmas": ("SIGMAS",),
91 | "latents": ("LATENT",),
92 | "chunk_sizes": ("STRING", {"default": "3", "multiline": False}),
93 | "overlaps": ("STRING", {"default": "1", "multiline": False}),
94 | "positive": ("CONDITIONING",),
95 | "negative": ("CONDITIONING",),
96 | "input_image": ("IMAGE",),
97 | "linear_blend_latents": ("BOOLEAN", {"default": True}),
98 | "conditioning_strength": (
99 | "FLOAT",
100 | {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01},
101 | ),
102 | },
103 | "optional": {
104 | "guider": ("GUIDER",),
105 | },
106 | }
107 |
108 | RETURN_TYPES = (
109 | "LATENT",
110 | "LATENT",
111 | )
112 | RETURN_NAMES = (
113 | "output",
114 | "denoised_output",
115 | )
116 |
117 | FUNCTION = "sample"
118 |
119 | CATEGORY = "sampling"
120 |
121 | def sample(
122 | self,
123 | model,
124 | vae,
125 | noise,
126 | sampler,
127 | sigmas,
128 | latents,
129 | chunk_sizes,
130 | overlaps,
131 | positive,
132 | negative,
133 | input_image,
134 | linear_blend_latents,
135 | conditioning_strength,
136 | guider=None,
137 | ):
138 | select_latents = LTXVSelectLatents().select_latents
139 | add_latent_guide = AddLatentGuideNode().generate
140 | add_latents = LTXVAddLatents().add_latents
141 | positive_orig = positive.copy()
142 | negative_orig = negative.copy()
143 |
144 | # Parse chunk sizes and overlaps from strings
145 | chunk_sizes = [int(x) for x in chunk_sizes.split(",")]
146 | overlaps = [int(x) for x in overlaps.split(",")]
147 |
148 | # Extend lists if shorter than number of sigma steps
149 | n_steps = len(sigmas) - 1
150 | if len(chunk_sizes) < n_steps:
151 | chunk_sizes.extend([chunk_sizes[-1]] * (n_steps - len(chunk_sizes)))
152 | if len(overlaps) < n_steps:
153 | overlaps.extend([overlaps[-1]] * (n_steps - len(overlaps)))
154 |
155 | # Initialize working latents
156 | current_latents = latents.copy()
157 | t_latents = None
158 | # Loop through sigma pairs for progressive denoising
159 | for i in range(n_steps):
160 | current_sigmas = sigmas[i : i + 2]
161 | current_chunk_size = chunk_sizes[i]
162 | current_overlap = overlaps[i]
163 |
164 | print(f"\nProcessing sigma step {i} with sigmas {current_sigmas}")
165 | print(
166 | f"Using chunk size {current_chunk_size} and overlap {current_overlap}"
167 | )
168 |
169 | # Calculate valid chunk starts to ensure the last chunk isn't shorter than the overlap
170 | total_frames = current_latents["samples"].shape[2]
171 | chunk_stride = current_chunk_size - current_overlap
172 | valid_chunk_starts = list(
173 | range(0, total_frames - current_overlap, chunk_stride)
174 | )
175 |
176 | # If the last chunk would be too short, remove the last start position
177 | if (
178 | total_frames > chunk_stride
179 | and (total_frames - valid_chunk_starts[-1]) < current_chunk_size
180 | ):
181 | print(
182 | "last chunk is too short, it will only be of size",
183 | total_frames - valid_chunk_starts[-1],
184 | "frames",
185 | )
186 |
187 | # Process each chunk for current sigma pair
188 | for i_chunk, chunk_start in enumerate(valid_chunk_starts):
189 | (latents_chunk,) = select_latents(
190 | current_latents, chunk_start, chunk_start + current_chunk_size - 1
191 | )
192 | print(f"Processing chunk {i_chunk} starting at frame {chunk_start}")
193 |
194 | if i_chunk == 0:
195 | positive, negative, latents_chunk = LTXVAddGuide().generate(
196 | positive_orig,
197 | negative_orig,
198 | vae,
199 | latents_chunk,
200 | input_image,
201 | 0,
202 | 0.75,
203 | )
204 | else:
205 | (cond_latent,) = select_latents(t_latents, -current_overlap, -1)
206 | model, latents_chunk = add_latent_guide(
207 | model, latents_chunk, cond_latent, 0, conditioning_strength
208 | )
209 |
210 | if guider is None:
211 | (guider_obj,) = CFGGuider().get_guider(
212 | model, positive, negative, 1.0
213 | )
214 | else:
215 | guider_obj = guider
216 | (_, denoised_latents) = SamplerCustomAdvanced().sample(
217 | noise, guider_obj, sampler, current_sigmas, latents_chunk
218 | )
219 | (positive, negative, denoised_latents) = LTXVCropGuides().crop(
220 | positive, negative, denoised_latents
221 | )
222 |
223 | if i_chunk == 0:
224 | t_latents = denoised_latents
225 | else:
226 | if linear_blend_latents and current_overlap > 1:
227 | # the first output latent is the result of a 1:8 latent
228 | # reinterpreted as a 1:1 latent, so we ignore it
229 | (denoised_latents_drop_first,) = select_latents(
230 | denoised_latents, 1, -1
231 | )
232 | (t_latents,) = LinearOverlapLatentTransition().process(
233 | t_latents,
234 | denoised_latents_drop_first,
235 | current_overlap - 1,
236 | axis=2,
237 | )
238 | else:
239 | (truncated_denoised_latents,) = select_latents(
240 | denoised_latents, current_overlap, -1
241 | )
242 | (t_latents,) = add_latents(
243 | t_latents, truncated_denoised_latents
244 | )
245 |
246 | print(
247 | f"Completed chunk {i_chunk}, current output shape: {t_latents['samples'].shape}"
248 | )
249 |
250 | # Update current_latents for next sigma step
251 | current_latents = t_latents.copy()
252 | print(f"Completed sigma step {i}")
253 |
254 | return t_latents, t_latents
255 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | einops
3 | huggingface_hub>=0.25.2
4 | transformers[timm]>=4.45.0
5 |
--------------------------------------------------------------------------------
/tiled_sampler.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import comfy
4 | import torch
5 | from comfy_extras.nodes_custom_sampler import SamplerCustomAdvanced
6 | from comfy_extras.nodes_lt import LTXVAddGuide, LTXVCropGuides
7 |
8 | from .latents import LTXVAddLatentGuide, LTXVSelectLatents
9 | from .nodes_registry import comfy_node
10 |
11 |
12 | @comfy_node(
13 | name="LTXVTiledSampler",
14 | )
15 | class LTXVTiledSampler:
16 |
17 | @classmethod
18 | def INPUT_TYPES(s):
19 | return {
20 | "required": {
21 | "model": ("MODEL",),
22 | "vae": ("VAE",),
23 | "noise": ("NOISE",),
24 | "sampler": ("SAMPLER",),
25 | "sigmas": ("SIGMAS",),
26 | "guider": ("GUIDER",),
27 | "latents": ("LATENT",),
28 | "horizontal_tiles": ("INT", {"default": 1, "min": 1, "max": 6}),
29 | "vertical_tiles": ("INT", {"default": 1, "min": 1, "max": 6}),
30 | "overlap": ("INT", {"default": 1, "min": 1, "max": 8}),
31 | "latents_cond_strength": (
32 | "FLOAT",
33 | {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01},
34 | ),
35 | "boost_latent_similarity": (
36 | "BOOLEAN",
37 | {"default": False},
38 | ),
39 | "crop": (["center", "disabled"], {"default": "disabled"}),
40 | },
41 | "optional": {
42 | "optional_cond_images": ("IMAGE",),
43 | "optional_cond_indices": ("STRING", {"default": "0"}),
44 | "images_cond_strengths": ("STRING", {"default": "0.9"}),
45 | },
46 | }
47 |
48 | RETURN_TYPES = (
49 | "LATENT",
50 | "LATENT",
51 | )
52 | RETURN_NAMES = (
53 | "output",
54 | "denoised_output",
55 | )
56 |
57 | FUNCTION = "sample"
58 |
59 | CATEGORY = "sampling"
60 |
61 | def sample(
62 | self,
63 | model,
64 | vae,
65 | noise,
66 | sampler,
67 | sigmas,
68 | guider,
69 | latents,
70 | horizontal_tiles,
71 | vertical_tiles,
72 | overlap,
73 | latents_cond_strength,
74 | boost_latent_similarity,
75 | crop="disabled",
76 | optional_cond_images=None,
77 | optional_cond_indices="0",
78 | images_cond_strengths="0.9",
79 | ):
80 |
81 | # Get the latent samples
82 | samples = latents["samples"]
83 |
84 | batch, channels, frames, height, width = samples.shape
85 | time_scale_factor, width_scale_factor, height_scale_factor = (
86 | vae.downscale_index_formula
87 | )
88 | # Validate image dimensions if provided
89 | if optional_cond_images is not None:
90 | img_height = height * height_scale_factor
91 | img_width = width * width_scale_factor
92 | cond_images = comfy.utils.common_upscale(
93 | optional_cond_images.movedim(-1, 1),
94 | img_width,
95 | img_height,
96 | "bicubic",
97 | crop=crop,
98 | ).movedim(1, -1)
99 | print("cond_images shape after resize", cond_images.shape)
100 | img_batch, img_height, img_width, img_channels = cond_images.shape
101 | else:
102 | cond_images = None
103 |
104 | if optional_cond_indices is not None and optional_cond_images is not None:
105 | optional_cond_indices = optional_cond_indices.split(",")
106 | optional_cond_indices = [int(i) for i in optional_cond_indices]
107 | assert len(optional_cond_indices) == len(
108 | optional_cond_images
109 | ), "Number of optional cond images must match number of optional cond indices"
110 |
111 | images_cond_strengths = [float(i) for i in images_cond_strengths.split(",")]
112 | if optional_cond_images is not None and len(images_cond_strengths) < len(
113 | optional_cond_images
114 | ):
115 | # Repeat the last value to match the length of optional_cond_images
116 | images_cond_strengths = images_cond_strengths + [
117 | images_cond_strengths[-1]
118 | ] * (len(optional_cond_images) - len(images_cond_strengths))
119 |
120 | # Calculate tile sizes with overlap
121 | base_tile_height = (height + (vertical_tiles - 1) * overlap) // vertical_tiles
122 | base_tile_width = (width + (horizontal_tiles - 1) * overlap) // horizontal_tiles
123 |
124 | # Initialize output tensor and weight tensor
125 | output = torch.zeros_like(samples)
126 | denoised_output = torch.zeros_like(samples)
127 | weights = torch.zeros_like(samples)
128 |
129 | # Get positive and negative conditioning
130 | try:
131 | positive, negative = guider.raw_conds
132 | except AttributeError:
133 | raise ValueError(
134 | "Guider does not have raw conds, cannot use it as a guider. "
135 | "Please use STGGuiderAdvanced."
136 | )
137 |
138 | # Process each tile
139 | for v in range(vertical_tiles):
140 | for h in range(horizontal_tiles):
141 | # Calculate tile boundaries
142 | h_start = h * (base_tile_width - overlap)
143 | v_start = v * (base_tile_height - overlap)
144 |
145 | # Adjust end positions for edge tiles
146 | h_end = (
147 | min(h_start + base_tile_width, width)
148 | if h < horizontal_tiles - 1
149 | else width
150 | )
151 | v_end = (
152 | min(v_start + base_tile_height, height)
153 | if v < vertical_tiles - 1
154 | else height
155 | )
156 |
157 | # Calculate actual tile dimensions
158 | tile_height = v_end - v_start
159 | tile_width = h_end - h_start
160 |
161 | print(f"Processing tile at row {v}, col {h}:")
162 | print(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})")
163 | print(f" Size: {tile_height}x{tile_width}")
164 |
165 | # Extract tile
166 | tile = samples[:, :, :, v_start:v_end, h_start:h_end]
167 |
168 | # Create tile latents dict
169 | tile_latents = {"samples": tile}
170 | unconditioned_tile_latents = tile_latents.copy()
171 |
172 | # Handle image conditioning if provided
173 | if cond_images is not None:
174 | # Scale coordinates for image
175 | img_h_start = v_start * height_scale_factor
176 | img_h_end = v_end * height_scale_factor
177 | img_w_start = h_start * width_scale_factor
178 | img_w_end = h_end * width_scale_factor
179 |
180 | # Create copies of conditioning for this tile
181 | tile_positive = positive.copy()
182 | tile_negative = negative.copy()
183 |
184 | for i_cond_image, (
185 | cond_image,
186 | cond_image_idx,
187 | cond_image_strength,
188 | ) in enumerate(
189 | zip(cond_images, optional_cond_indices, images_cond_strengths)
190 | ):
191 | # Extract image tile
192 | img_tile = cond_image[
193 | img_h_start:img_h_end, img_w_start:img_w_end, :
194 | ].unsqueeze(0)
195 |
196 | print(
197 | f"Applying image conditioning on cond image {i_cond_image} for tile at row {v}, col {h} with strength {cond_image_strength} at position {cond_image_idx}:"
198 | )
199 | print(
200 | f" Image tile position: ({img_h_start}:{img_h_end}, {img_w_start}:{img_w_end})"
201 | )
202 | print(f" Image tile size: {img_tile.shape}")
203 |
204 | # Add guide from image tile
205 | (
206 | tile_positive,
207 | tile_negative,
208 | tile_latents,
209 | ) = LTXVAddGuide().generate(
210 | positive=tile_positive,
211 | negative=tile_negative,
212 | vae=vae,
213 | latent=tile_latents,
214 | image=img_tile,
215 | frame_idx=cond_image_idx,
216 | strength=cond_image_strength,
217 | )
218 | if boost_latent_similarity:
219 | middle_latent_idx = (frames - 1) // 2
220 | middle_index_latent = LTXVSelectLatents().select_latents(
221 | samples=unconditioned_tile_latents,
222 | start_index=middle_latent_idx,
223 | end_index=middle_latent_idx,
224 | )[0]
225 | last_index_latent = LTXVSelectLatents().select_latents(
226 | samples=unconditioned_tile_latents,
227 | start_index=-1,
228 | end_index=-1,
229 | )[0]
230 | print(
231 | f"using LTXVAddLatentGuide on tiled latent with latent index {middle_latent_idx} and strength {latents_cond_strength}"
232 | )
233 | (
234 | tile_positive,
235 | tile_negative,
236 | tile_latents,
237 | ) = LTXVAddLatentGuide().generate(
238 | vae=vae,
239 | positive=tile_positive,
240 | negative=tile_negative,
241 | latent=tile_latents,
242 | guiding_latent=middle_index_latent,
243 | latent_idx=middle_latent_idx,
244 | strength=latents_cond_strength,
245 | )
246 | print(
247 | f"using LTXVAddLatentGuide on tiled latent with latent index {frames-1} and strength {latents_cond_strength}"
248 | )
249 | (
250 | tile_positive,
251 | tile_negative,
252 | tile_latents,
253 | ) = LTXVAddLatentGuide().generate(
254 | vae=vae,
255 | positive=tile_positive,
256 | negative=tile_negative,
257 | latent=tile_latents,
258 | guiding_latent=last_index_latent,
259 | latent_idx=frames - 1,
260 | strength=latents_cond_strength,
261 | )
262 |
263 | guider = copy.copy(guider)
264 | guider.set_conds(tile_positive, tile_negative)
265 |
266 | # Denoise the tile
267 | denoised_tile = SamplerCustomAdvanced().sample(
268 | noise=noise,
269 | guider=guider,
270 | sampler=sampler,
271 | sigmas=sigmas,
272 | latent_image=tile_latents,
273 | )[0]
274 |
275 | # Clean up guides if image conditioning was used
276 | if cond_images is not None:
277 | print("before guide crop", denoised_tile["samples"].shape)
278 | tile_positive, tile_negative, denoised_tile = LTXVCropGuides().crop(
279 | positive=tile_positive,
280 | negative=tile_negative,
281 | latent=denoised_tile,
282 | )
283 | print("after guide crop", denoised_tile["samples"].shape)
284 |
285 | # Create weight mask for this tile
286 | tile_weights = torch.ones_like(tile)
287 |
288 | # Apply horizontal blending weights
289 | if h > 0: # Left overlap
290 | h_blend = torch.linspace(0, 1, overlap, device=tile.device)
291 | tile_weights[:, :, :, :, :overlap] *= h_blend.view(1, 1, 1, 1, -1)
292 | if h < horizontal_tiles - 1: # Right overlap
293 | h_blend = torch.linspace(1, 0, overlap, device=tile.device)
294 | tile_weights[:, :, :, :, -overlap:] *= h_blend.view(1, 1, 1, 1, -1)
295 |
296 | # Apply vertical blending weights
297 | if v > 0: # Top overlap
298 | v_blend = torch.linspace(0, 1, overlap, device=tile.device)
299 | tile_weights[:, :, :, :overlap, :] *= v_blend.view(1, 1, 1, -1, 1)
300 | if v < vertical_tiles - 1: # Bottom overlap
301 | v_blend = torch.linspace(1, 0, overlap, device=tile.device)
302 | tile_weights[:, :, :, -overlap:, :] *= v_blend.view(1, 1, 1, -1, 1)
303 |
304 | # Add weighted tile to output
305 | output[:, :, :, v_start:v_end, h_start:h_end] += (
306 | denoised_tile["samples"] * tile_weights
307 | )
308 | denoised_output[:, :, :, v_start:v_end, h_start:h_end] += (
309 | denoised_tile["samples"] * tile_weights
310 | )
311 |
312 | # Add weights to weight tensor
313 | weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
314 |
315 | # Normalize by weights
316 | output = output / (weights + 1e-8)
317 | denoised_output = denoised_output / (weights + 1e-8)
318 |
319 | return {"samples": output}, {"samples": denoised_output}
320 |
--------------------------------------------------------------------------------
/tricks/__init__.py:
--------------------------------------------------------------------------------
1 | from .nodes.attn_bank_nodes import (
2 | LTXAttentionBankNode,
3 | LTXAttentioOverrideNode,
4 | LTXPrepareAttnInjectionsNode,
5 | )
6 | from .nodes.attn_override_node import LTXAttnOverrideNode
7 | from .nodes.latent_guide_node import AddLatentGuideNode
8 | from .nodes.ltx_feta_enhance_node import LTXFetaEnhanceNode
9 | from .nodes.ltx_flowedit_nodes import LTXFlowEditCFGGuiderNode, LTXFlowEditSamplerNode
10 | from .nodes.ltx_inverse_model_pred_nodes import (
11 | LTXForwardModelSamplingPredNode,
12 | LTXReverseModelSamplingPredNode,
13 | )
14 | from .nodes.ltx_pag_node import LTXPerturbedAttentionNode
15 | from .nodes.modify_ltx_model_node import ModifyLTXModelNode
16 | from .nodes.rectified_sampler_nodes import (
17 | LTXRFForwardODESamplerNode,
18 | LTXRFReverseODESamplerNode,
19 | )
20 |
21 | NODE_CLASS_MAPPINGS = {
22 | "ModifyLTXModel": ModifyLTXModelNode,
23 | "AddLatentGuide": AddLatentGuideNode,
24 | "LTXForwardModelSamplingPred": LTXForwardModelSamplingPredNode,
25 | "LTXReverseModelSamplingPred": LTXReverseModelSamplingPredNode,
26 | "LTXRFForwardODESampler": LTXRFForwardODESamplerNode,
27 | "LTXRFReverseODESampler": LTXRFReverseODESamplerNode,
28 | "LTXAttentionBank": LTXAttentionBankNode,
29 | "LTXPrepareAttnInjections": LTXPrepareAttnInjectionsNode,
30 | "LTXAttentioOverride": LTXAttentioOverrideNode,
31 | "LTXPerturbedAttention": LTXPerturbedAttentionNode,
32 | "LTXAttnOverride": LTXAttnOverrideNode,
33 | "LTXFlowEditCFGGuider": LTXFlowEditCFGGuiderNode,
34 | "LTXFlowEditSampler": LTXFlowEditSamplerNode,
35 | "LTXFetaEnhance": LTXFetaEnhanceNode,
36 | }
37 |
38 | NODE_DISPLAY_NAME_MAPPINGS = {
39 | "ModifyLTXModel": "Modify LTX Model",
40 | "AddLatentGuide": "Add LTX Latent Guide",
41 | "LTXAddImageGuide": "Add LTX Image Guide",
42 | "LTXForwardModelSamplingPred": "LTX Forward Model Pred",
43 | "LTXReverseModelSamplingPred": "LTX Reverse Model Pred",
44 | "LTXRFForwardODESampler": "LTX Rf-Inv Forward Sampler",
45 | "LTXRFReverseODESampler": "LTX Rf-Inv Reverse Sampler",
46 | "LTXAttentionBank": "LTX Attention Bank",
47 | "LTXPrepareAttnInjections": "LTX Prepare Attn Injection",
48 | "LTXAttentioOverride": "LTX Attn Block Override",
49 | "LTXPerturbedAttention": "LTX Apply Perturbed Attention",
50 | "LTXAttnOverride": "LTX Attention Override",
51 | "LTXFlowEditCFGGuider": "LTX Flow Edit CFG Guider",
52 | "LTXFlowEditSampler": "LTX Flow Edit Sampler",
53 | "LTXFetaEnhance": "LTX Feta Enhance",
54 | }
55 |
--------------------------------------------------------------------------------
/tricks/modules/ltx_model.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import comfy.ldm.common_dit
4 | import comfy.ldm.modules.attention
5 | import torch
6 | from comfy.ldm.lightricks.model import (
7 | BasicTransformerBlock,
8 | LTXVModel,
9 | apply_rotary_emb,
10 | precompute_freqs_cis,
11 | )
12 | from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
13 | from torch import nn
14 |
15 | from ..utils.feta_enhance_utils import get_feta_scores
16 |
17 |
18 | class LTXModifiedCrossAttention(nn.Module):
19 | def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
20 | context = x if context is None else context
21 | context_v = x if context is None else context
22 |
23 | step = transformer_options.get("step", -1)
24 | total_steps = transformer_options.get("total_steps", 0)
25 | attn_bank = transformer_options.get("attn_bank", None)
26 | sample_mode = transformer_options.get("sample_mode", None)
27 | if attn_bank is not None and self.idx in attn_bank["block_map"]:
28 | len_conds = len(transformer_options["cond_or_uncond"])
29 | pred_order = transformer_options["pred_order"]
30 | if (
31 | sample_mode == "forward"
32 | and total_steps - step - 1 < attn_bank["save_steps"]
33 | ):
34 | step_idx = f"{pred_order}_{total_steps-step-1}"
35 | attn_bank["block_map"][self.idx][step_idx] = x.cpu()
36 | elif sample_mode == "reverse" and step < attn_bank["inject_steps"]:
37 | step_idx = f"{pred_order}_{step}"
38 | inject_settings = attn_bank.get("inject_settings", {})
39 | if len(inject_settings) > 0:
40 | inj = (
41 | attn_bank["block_map"][self.idx][step_idx]
42 | .to(x.device)
43 | .repeat(len_conds, 1, 1)
44 | )
45 | if "q" in inject_settings:
46 | x = inj
47 | if "k" in inject_settings:
48 | context = inj
49 | if "v" in inject_settings:
50 | context_v = inj
51 |
52 | q = self.to_q(x)
53 | k = self.to_k(context)
54 | v = self.to_v(context_v)
55 |
56 | q = self.q_norm(q)
57 | k = self.k_norm(k)
58 |
59 | if pe is not None:
60 | q = apply_rotary_emb(q, pe)
61 | k = apply_rotary_emb(k, pe)
62 |
63 | feta_score = None
64 | if (
65 | transformer_options.get("feta_weight", 0) > 0
66 | and self.idx in transformer_options["feta_layers"]["layers"]
67 | ):
68 | feta_score = get_feta_scores(q, k, self.heads, transformer_options)
69 |
70 | alt_attn_fn = (
71 | transformer_options.get("patches_replace", {})
72 | .get("layer", {})
73 | .get(("self_attn", self.idx), None)
74 | )
75 | if alt_attn_fn is not None:
76 | out = alt_attn_fn(
77 | q,
78 | k,
79 | v,
80 | self.heads,
81 | attn_precision=self.attn_precision,
82 | transformer_options=transformer_options,
83 | )
84 | elif mask is None:
85 | out = comfy.ldm.modules.attention.optimized_attention(
86 | q, k, v, self.heads, attn_precision=self.attn_precision
87 | )
88 | else:
89 | out = comfy.ldm.modules.attention.optimized_attention_masked(
90 | q, k, v, self.heads, mask, attn_precision=self.attn_precision
91 | )
92 |
93 | if feta_score is not None:
94 | out *= feta_score
95 |
96 | return self.to_out(out)
97 |
98 |
99 | class LTXModifiedBasicTransformerBlock(BasicTransformerBlock):
100 | def forward(
101 | self,
102 | x,
103 | context=None,
104 | attention_mask=None,
105 | timestep=None,
106 | pe=None,
107 | transformer_options={},
108 | ):
109 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
110 | self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
111 | + timestep.reshape(
112 | x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1
113 | )
114 | ).unbind(dim=2)
115 | x += (
116 | self.attn1(
117 | comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa,
118 | pe=pe,
119 | transformer_options=transformer_options,
120 | )
121 | * gate_msa
122 | )
123 |
124 | x += self.attn2(x, context=context, mask=attention_mask)
125 |
126 | y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
127 | x += self.ff(y) * gate_mlp
128 |
129 | return x
130 |
131 |
132 | class LTXVModelModified(LTXVModel):
133 |
134 | def forward(
135 | self,
136 | x,
137 | timestep,
138 | context,
139 | attention_mask,
140 | frame_rate=25,
141 | transformer_options={},
142 | keyframe_idxs=None,
143 | **kwargs,
144 | ):
145 | patches_replace = transformer_options.get("patches_replace", {})
146 |
147 | orig_shape = list(x.shape)
148 |
149 | x, latent_coords = self.patchifier.patchify(x)
150 | pixel_coords = latent_to_pixel_coords(
151 | latent_coords=latent_coords,
152 | scale_factors=self.vae_scale_factors,
153 | causal_fix=self.causal_temporal_positioning,
154 | )
155 |
156 | if keyframe_idxs is not None:
157 | pixel_coords[:, :, -keyframe_idxs.shape[2] :] = keyframe_idxs
158 |
159 | fractional_coords = pixel_coords.to(torch.float32)
160 | fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
161 |
162 | x = self.patchify_proj(x)
163 | timestep = timestep * 1000.0
164 |
165 | if attention_mask is not None and not torch.is_floating_point(attention_mask):
166 | attention_mask = (attention_mask - 1).to(x.dtype).reshape(
167 | (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
168 | ) * torch.finfo(x.dtype).max
169 |
170 | pe = precompute_freqs_cis(
171 | fractional_coords, dim=self.inner_dim, out_dtype=x.dtype
172 | )
173 |
174 | batch_size = x.shape[0]
175 | timestep, embedded_timestep = self.adaln_single(
176 | timestep.flatten(),
177 | {"resolution": None, "aspect_ratio": None},
178 | batch_size=batch_size,
179 | hidden_dtype=x.dtype,
180 | )
181 | # Second dimension is 1 or number of tokens (if timestep_per_token)
182 | timestep = timestep.view(batch_size, -1, timestep.shape[-1])
183 | embedded_timestep = embedded_timestep.view(
184 | batch_size, -1, embedded_timestep.shape[-1]
185 | )
186 |
187 | # 2. Blocks
188 | if self.caption_projection is not None:
189 | batch_size = x.shape[0]
190 | context = self.caption_projection(context)
191 | context = context.view(batch_size, -1, x.shape[-1])
192 |
193 | blocks_replace = patches_replace.get("dit", {})
194 | for i, block in enumerate(self.transformer_blocks):
195 | if ("double_block", i) in blocks_replace:
196 |
197 | def block_wrap(args):
198 | out = {}
199 | out["img"] = block(
200 | args["img"],
201 | context=args["txt"],
202 | attention_mask=args["attention_mask"],
203 | timestep=args["vec"],
204 | pe=args["pe"],
205 | )
206 | return out
207 |
208 | out = blocks_replace[("double_block", i)](
209 | {
210 | "img": x,
211 | "txt": context,
212 | "attention_mask": attention_mask,
213 | "vec": timestep,
214 | "pe": pe,
215 | },
216 | {"original_block": block_wrap},
217 | )
218 | x = out["img"]
219 | else:
220 | x = block(
221 | x,
222 | context=context,
223 | attention_mask=attention_mask,
224 | timestep=timestep,
225 | pe=pe,
226 | transformer_options=transformer_options,
227 | )
228 |
229 | # 3. Output
230 | scale_shift_values = (
231 | self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
232 | + embedded_timestep[:, :, None]
233 | )
234 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
235 | x = self.norm_out(x)
236 | # Modulation
237 | x = x * (1 + scale) + shift
238 | x = self.proj_out(x)
239 |
240 | x = self.patchifier.unpatchify(
241 | latents=x,
242 | output_height=orig_shape[3],
243 | output_width=orig_shape[4],
244 | output_num_frames=orig_shape[2],
245 | out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
246 | )
247 |
248 | return x
249 |
250 |
251 | def inject_model(diffusion_model):
252 | diffusion_model.__class__ = LTXVModelModified
253 | for idx, transformer_block in enumerate(diffusion_model.transformer_blocks):
254 | transformer_block.__class__ = LTXModifiedBasicTransformerBlock
255 | transformer_block.idx = idx
256 | transformer_block.attn1.__class__ = LTXModifiedCrossAttention
257 | transformer_block.attn1.idx = idx
258 | return diffusion_model
259 |
--------------------------------------------------------------------------------
/tricks/nodes/attn_bank_nodes.py:
--------------------------------------------------------------------------------
1 | from ..utils.attn_bank import AttentionBank
2 |
3 |
4 | class LTXAttentionBankNode:
5 | @classmethod
6 | def INPUT_TYPES(s):
7 | return {
8 | "required": {
9 | "save_steps": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
10 | "blocks": ("STRING", {"multiline": True}),
11 | }
12 | }
13 |
14 | RETURN_TYPES = ("ATTN_BANK",)
15 | FUNCTION = "build"
16 |
17 | CATEGORY = "ltxtricks"
18 |
19 | def build(self, save_steps, blocks=""):
20 | block_map = {}
21 | block_list = blocks.split(",")
22 | for block in block_list:
23 | block_idx = int(block)
24 | block_map[block_idx] = {}
25 |
26 | bank = AttentionBank(save_steps, block_map)
27 | return (bank,)
28 |
29 |
30 | class LTXPrepareAttnInjectionsNode:
31 | @classmethod
32 | def INPUT_TYPES(s):
33 | return {
34 | "required": {
35 | "latent": ("LATENT",),
36 | "attn_bank": ("ATTN_BANK",),
37 | "query": ("BOOLEAN", {"default": False}),
38 | "key": ("BOOLEAN", {"default": False}),
39 | "value": ("BOOLEAN", {"default": False}),
40 | "inject_steps": (
41 | "INT",
42 | {"default": 0, "min": 0, "max": 1000, "step": 1},
43 | ),
44 | },
45 | "optional": {"blocks": ("LTX_BLOCKS",)},
46 | }
47 |
48 | RETURN_TYPES = ("LATENT", "ATTN_INJ")
49 | FUNCTION = "prepare"
50 |
51 | CATEGORY = "fluxtapoz"
52 |
53 | def prepare(self, latent, attn_bank, query, key, value, inject_steps, blocks=None):
54 | if inject_steps > attn_bank["save_steps"]:
55 | raise ValueError("Can not inject more steps than were saved.")
56 | attn_bank = AttentionBank(
57 | attn_bank["save_steps"], attn_bank["block_map"], inject_steps
58 | )
59 | attn_bank["inject_settings"] = set([])
60 | if query:
61 | attn_bank["inject_settings"].add("q")
62 | if key:
63 | attn_bank["inject_settings"].add("k")
64 | if value:
65 | attn_bank["inject_settings"].add("v")
66 |
67 | if blocks is not None:
68 | attn_bank["block_map"] = {**attn_bank["block_map"]}
69 | for key in list(attn_bank["block_map"].keys()):
70 | if key not in blocks:
71 | del attn_bank["block_map"][key]
72 |
73 | # Hack to force order of operations in ComfyUI graph
74 | return (latent, attn_bank)
75 |
76 |
77 | class LTXAttentioOverrideNode:
78 | @classmethod
79 | def INPUT_TYPES(s):
80 | return {"required": {"blocks": ("STRING", {"multiline": True})}}
81 |
82 | RETURN_TYPES = ("LTX_BLOCKS",)
83 | FUNCTION = "build"
84 |
85 | CATEGORY = "ltxtricks"
86 |
87 | def build(self, blocks=""):
88 | block_set = set(list(int(block) for block in blocks.split(",")))
89 |
90 | return (block_set,)
91 |
--------------------------------------------------------------------------------
/tricks/nodes/attn_override_node.py:
--------------------------------------------------------------------------------
1 | def is_integer(string):
2 | try:
3 | int(string)
4 | return True
5 | except ValueError:
6 | return False
7 |
8 |
9 | class LTXAttnOverrideNode:
10 | @classmethod
11 | def INPUT_TYPES(s):
12 | return {
13 | "required": {
14 | "layers": ("STRING", {"multiline": True}),
15 | }
16 | }
17 |
18 | RETURN_TYPES = ("ATTN_OVERRIDE",)
19 | FUNCTION = "build"
20 |
21 | CATEGORY = "ltxtricks/attn"
22 |
23 | def build(self, layers):
24 | layers_map = set([])
25 | for block in layers.split(","):
26 | block = block.strip()
27 | if is_integer(block):
28 | layers_map.add(int(block))
29 |
30 | return ({"layers": layers_map},)
31 |
--------------------------------------------------------------------------------
/tricks/nodes/latent_guide_node.py:
--------------------------------------------------------------------------------
1 | import comfy_extras.nodes_lt as nodes_lt
2 |
3 |
4 | class AddLatentGuideNode(nodes_lt.LTXVAddGuide):
5 | @classmethod
6 | def INPUT_TYPES(s):
7 | return {
8 | "required": {
9 | "model": ("MODEL",),
10 | "latent": ("LATENT",),
11 | "image_latent": ("LATENT",),
12 | "index": ("INT", {"default": 0, "min": -1, "max": 9999, "step": 1}),
13 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
14 | }
15 | }
16 |
17 | RETURN_TYPES = ("MODEL", "LATENT")
18 | RETURN_NAMES = ("model", "latent")
19 |
20 | CATEGORY = "ltxtricks"
21 | FUNCTION = "generate"
22 |
23 | def generate(self, model, latent, image_latent, index, strength):
24 | noise_mask = nodes_lt.get_noise_mask(latent)
25 | latent = latent["samples"]
26 |
27 | image_latent = image_latent["samples"]
28 |
29 | latent, noise_mask = self.replace_latent_frames(
30 | latent,
31 | noise_mask,
32 | image_latent,
33 | index,
34 | strength,
35 | )
36 |
37 | return (
38 | model,
39 | {"samples": latent, "noise_mask": noise_mask},
40 | )
41 |
--------------------------------------------------------------------------------
/tricks/nodes/ltx_feta_enhance_node.py:
--------------------------------------------------------------------------------
1 | DEFAULT_ATTN = {
2 | "layers": [i for i in range(0, 100, 1)],
3 | }
4 |
5 |
6 | class LTXFetaEnhanceNode:
7 | @classmethod
8 | def INPUT_TYPES(s):
9 | return {
10 | "required": {
11 | "model": ("MODEL",),
12 | "feta_weight": (
13 | "FLOAT",
14 | {"default": 4, "min": -100.0, "max": 100.0, "step": 0.01},
15 | ),
16 | },
17 | "optional": {"attn_override": ("ATTN_OVERRIDE",)},
18 | }
19 |
20 | RETURN_TYPES = ("MODEL",)
21 |
22 | CATEGORY = "ltxtricks"
23 | FUNCTION = "apply"
24 |
25 | def apply(self, model, feta_weight, attn_override=DEFAULT_ATTN):
26 | model = model.clone()
27 |
28 | model_options = model.model_options.copy()
29 | transformer_options = model_options["transformer_options"].copy()
30 |
31 | transformer_options["feta_weight"] = feta_weight
32 | transformer_options["feta_layers"] = attn_override
33 | model_options["transformer_options"] = transformer_options
34 |
35 | model.model_options = model_options
36 | return (model,)
37 |
--------------------------------------------------------------------------------
/tricks/nodes/ltx_flowedit_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy.samplers import KSAMPLER, CFGGuider, sampling_function
3 | from tqdm import trange
4 |
5 |
6 | class FlowEditGuider(CFGGuider):
7 | def __init__(self, model_patcher):
8 | super().__init__(model_patcher)
9 | self.cfgs = {}
10 |
11 | def set_conds(self, **kwargs):
12 | self.inner_set_conds(kwargs)
13 |
14 | def set_cfgs(self, **kwargs):
15 | self.cfgs = {**kwargs}
16 |
17 | def predict_noise(self, x, timestep, model_options={}, seed=None):
18 | latent_type = model_options["transformer_options"]["latent_type"]
19 | positive = self.conds.get(f"{latent_type}_positive", None)
20 | negative = self.conds.get(f"{latent_type}_negative", None)
21 | cfg = self.cfgs.get(latent_type, self.cfg)
22 | return sampling_function(
23 | self.inner_model,
24 | x,
25 | timestep,
26 | negative,
27 | positive,
28 | cfg,
29 | model_options=model_options,
30 | seed=seed,
31 | )
32 |
33 |
34 | class LTXFlowEditCFGGuiderNode:
35 | @classmethod
36 | def INPUT_TYPES(s):
37 | return {
38 | "required": {
39 | "model": ("MODEL",),
40 | "source_pos": ("CONDITIONING",),
41 | "source_neg": ("CONDITIONING",),
42 | "target_pos": ("CONDITIONING",),
43 | "target_neg": ("CONDITIONING",),
44 | "source_cfg": (
45 | "FLOAT",
46 | {"default": 2, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "step": 0.01},
47 | ),
48 | "target_cfg": (
49 | "FLOAT",
50 | {"default": 4.5, "min": 0, "max": 0xFFFFFFFFFFFFFFFF, "step": 0.01},
51 | ),
52 | }
53 | }
54 |
55 | RETURN_TYPES = ("GUIDER",)
56 |
57 | FUNCTION = "get_guider"
58 | CATEGORY = "ltxtricks"
59 |
60 | def get_guider(
61 | self,
62 | model,
63 | source_pos,
64 | source_neg,
65 | target_pos,
66 | target_neg,
67 | source_cfg,
68 | target_cfg,
69 | ):
70 | guider = FlowEditGuider(model)
71 | guider.set_conds(
72 | source_positive=source_pos,
73 | source_negative=source_neg,
74 | target_positive=target_pos,
75 | target_negative=target_neg,
76 | )
77 | guider.set_cfgs(source=source_cfg, target=target_cfg)
78 | return (guider,)
79 |
80 |
81 | def get_flowedit_sample(skip_steps, refine_steps, seed):
82 | generator = torch.manual_seed(seed)
83 |
84 | @torch.no_grad()
85 | def flowedit_sample(
86 | model, x_init, sigmas, extra_args=None, callback=None, disable=None
87 | ):
88 | extra_args = {} if extra_args is None else extra_args
89 |
90 | model_options = extra_args.get("model_options", {})
91 | transformer_options = model_options.get("transformer_options", {})
92 | transformer_options = {**transformer_options}
93 | model_options["transformer_options"] = transformer_options
94 | extra_args["model_options"] = model_options
95 | denoise_mask = extra_args.get("denoise_mask", 1)
96 | if denoise_mask is None:
97 | denoise_mask = 1
98 | else:
99 | extra_args["denoise_mask"] = torch.ones_like(denoise_mask)
100 |
101 | source_extra_args = {
102 | **extra_args,
103 | "model_options": {
104 | "transformer_options": {**transformer_options, "latent_type": "source"}
105 | },
106 | }
107 |
108 | sigmas = sigmas[skip_steps:]
109 |
110 | x_tgt = x_init.clone()
111 | N = len(sigmas) - 1
112 | s_in = x_init.new_ones([x_init.shape[0]])
113 |
114 | for i in trange(N, disable=disable):
115 | sigma = sigmas[i]
116 | noise = torch.randn(x_init.shape, generator=generator).to(x_init.device)
117 |
118 | zt_src = (1 - sigma) * x_init + sigma * noise
119 |
120 | if i < N - refine_steps:
121 | zt_tgt = x_tgt + (zt_src - x_init) * denoise_mask
122 | transformer_options["latent_type"] = "source"
123 | source_extra_args["model_options"]["transformer_options"][
124 | "latent_type"
125 | ] = "source"
126 | vt_src = model(zt_src, sigma * s_in, **source_extra_args)
127 | else:
128 | if i == N - refine_steps:
129 | x_tgt = x_tgt + (zt_src - x_init) * denoise_mask
130 | zt_tgt = x_tgt
131 | vt_src = 0
132 |
133 | transformer_options["latent_type"] = "target"
134 | vt_tgt = model(zt_tgt, sigma * s_in, **extra_args)
135 |
136 | v_delta = (vt_tgt - vt_src) * denoise_mask
137 | x_tgt += (sigmas[i + 1] - sigmas[i]) * v_delta
138 |
139 | if callback is not None:
140 | callback(
141 | {
142 | "x": x_tgt,
143 | "denoised": x_tgt,
144 | "i": i + skip_steps,
145 | "sigma": sigmas[i],
146 | "sigma_hat": sigmas[i],
147 | }
148 | )
149 |
150 | return x_tgt
151 |
152 | return flowedit_sample
153 |
154 |
155 | class LTXFlowEditSamplerNode:
156 | @classmethod
157 | def INPUT_TYPES(s):
158 | return {
159 | "required": {
160 | "skip_steps": (
161 | "INT",
162 | {"default": 4, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
163 | ),
164 | "refine_steps": (
165 | "INT",
166 | {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
167 | ),
168 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
169 | },
170 | "optional": {},
171 | }
172 |
173 | RETURN_TYPES = ("SAMPLER",)
174 | FUNCTION = "build"
175 |
176 | CATEGORY = "ltxtricks"
177 |
178 | def build(self, skip_steps, refine_steps, seed):
179 | sampler = KSAMPLER(get_flowedit_sample(skip_steps, refine_steps, seed))
180 | return (sampler,)
181 |
--------------------------------------------------------------------------------
/tricks/nodes/ltx_inverse_model_pred_nodes.py:
--------------------------------------------------------------------------------
1 | import comfy.latent_formats
2 | import comfy.model_sampling
3 | import comfy.sd
4 |
5 |
6 | class InverseCONST:
7 | def calculate_input(self, sigma, noise):
8 | return noise
9 |
10 | def calculate_denoised(self, sigma, model_output, model_input):
11 | sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
12 | return model_output
13 |
14 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
15 | return latent_image
16 |
17 | def inverse_noise_scaling(self, sigma, latent):
18 | return latent
19 |
20 |
21 | class LTXForwardModelSamplingPredNode:
22 | @classmethod
23 | def INPUT_TYPES(s):
24 | return {
25 | "required": {
26 | "model": ("MODEL",),
27 | }
28 | }
29 |
30 | RETURN_TYPES = ("MODEL",)
31 | FUNCTION = "patch"
32 |
33 | CATEGORY = "ltxtricks"
34 |
35 | def patch(self, model):
36 | m = model.clone()
37 |
38 | sampling_base = comfy.model_sampling.ModelSamplingFlux
39 | sampling_type = InverseCONST
40 |
41 | class ModelSamplingAdvanced(sampling_base, sampling_type):
42 | pass
43 |
44 | model_sampling = ModelSamplingAdvanced(model.model.model_config)
45 | model_sampling.set_parameters(shift=1.15)
46 | m.add_object_patch("model_sampling", model_sampling)
47 | return (m,)
48 |
49 |
50 | class ReverseCONST:
51 | def calculate_input(self, sigma, noise):
52 | return noise
53 |
54 | def calculate_denoised(self, sigma, model_output, model_input):
55 | sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
56 | return model_output # model_input - model_output * sigma
57 |
58 | def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
59 | return latent_image
60 |
61 | def inverse_noise_scaling(self, sigma, latent):
62 | return latent / (1.0 - sigma)
63 |
64 |
65 | class LTXReverseModelSamplingPredNode:
66 | @classmethod
67 | def INPUT_TYPES(s):
68 | return {
69 | "required": {
70 | "model": ("MODEL",),
71 | }
72 | }
73 |
74 | RETURN_TYPES = ("MODEL",)
75 | FUNCTION = "patch"
76 |
77 | CATEGORY = "ltxtricks"
78 |
79 | def patch(self, model):
80 | m = model.clone()
81 |
82 | sampling_base = comfy.model_sampling.ModelSamplingFlux
83 | sampling_type = ReverseCONST
84 |
85 | class ModelSamplingAdvanced(sampling_base, sampling_type):
86 | pass
87 |
88 | model_sampling = ModelSamplingAdvanced(model.model.model_config)
89 | model_sampling.set_parameters(shift=1.15)
90 | m.add_object_patch("model_sampling", model_sampling)
91 | return (m,)
92 |
--------------------------------------------------------------------------------
/tricks/nodes/ltx_pag_node.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import comfy.model_patcher
4 | import comfy.samplers
5 | import torch
6 | import torch.nn.functional as F
7 | from comfy.ldm.modules.attention import optimized_attention
8 | from einops import rearrange
9 |
10 | DEFAULT_PAG_LTX = {"layers": set([14])}
11 |
12 |
13 | def gaussian_blur_2d(img, kernel_size, sigma):
14 | height = img.shape[-1]
15 | kernel_size = min(kernel_size, height - (height % 2 - 1))
16 | ksize_half = (kernel_size - 1) * 0.5
17 |
18 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
19 |
20 | pdf = torch.exp(-0.5 * (x / sigma).pow(2))
21 |
22 | x_kernel = pdf / pdf.sum()
23 | x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
24 |
25 | kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
26 | kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
27 |
28 | padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
29 |
30 | img = F.pad(img, padding, mode="reflect")
31 | img = F.conv2d(img, kernel2d, groups=img.shape[-3])
32 |
33 | return img
34 |
35 |
36 | class LTXPerturbedAttentionNode:
37 | @classmethod
38 | def INPUT_TYPES(s):
39 | return {
40 | "required": {
41 | "model": ("MODEL",),
42 | "scale": (
43 | "FLOAT",
44 | {
45 | "default": 2.0,
46 | "min": 0.0,
47 | "max": 100.0,
48 | "step": 0.01,
49 | "round": 0.01,
50 | },
51 | ),
52 | "rescale": (
53 | "FLOAT",
54 | {
55 | "default": 0.5,
56 | "min": 0.0,
57 | "max": 100.0,
58 | "step": 0.01,
59 | "round": 0.01,
60 | },
61 | ),
62 | "cfg": (
63 | "FLOAT",
64 | {
65 | "default": 3.0,
66 | "min": 0.0,
67 | "max": 100.0,
68 | "step": 0.01,
69 | "round": 0.01,
70 | },
71 | ),
72 | },
73 | "optional": {
74 | "attn_override": ("ATTN_OVERRIDE",),
75 | # "attn_type": (["PAG", "SEG"],),
76 | },
77 | }
78 |
79 | RETURN_TYPES = ("MODEL",)
80 | FUNCTION = "patch"
81 |
82 | CATEGORY = "ltxtricks/attn"
83 |
84 | def patch(
85 | self, model, scale, rescale, cfg, attn_override=DEFAULT_PAG_LTX, attn_type="PAG"
86 | ):
87 | m = model.clone()
88 |
89 | def pag_fn(q, k, v, heads, attn_precision=None, transformer_options=None):
90 | return v
91 |
92 | def seg_fn(q, k, v, heads, attn_precision=None, transformer_options=None):
93 | _, sequence_length, _ = q.shape
94 | b, c, f, h, w = transformer_options["original_shape"]
95 |
96 | q = rearrange(q, "b (f h w) d -> b (f d) w h", h=h, w=w)
97 | kernel_size = math.ceil(6 * scale) + 1 - math.ceil(6 * scale) % 2
98 | q = gaussian_blur_2d(q, kernel_size, scale)
99 | q = rearrange(q, "b (f d) w h -> b (f h w) d", f=f)
100 | return optimized_attention(q, k, v, heads, attn_precision=attn_precision)
101 |
102 | def post_cfg_function(args):
103 | model = args["model"]
104 |
105 | cond_pred = args["cond_denoised"]
106 | uncond_pred = args["uncond_denoised"]
107 |
108 | len_conds = 1 if args.get("uncond", None) is None else 2
109 |
110 | cond = args["cond"]
111 | sigma = args["sigma"]
112 | model_options = args["model_options"].copy()
113 | x = args["input"]
114 |
115 | if scale == 0:
116 | if len_conds == 1:
117 | return cond_pred
118 | return uncond_pred + (cond_pred - uncond_pred)
119 |
120 | attn_fn = pag_fn if attn_type == "PAG" else seg_fn
121 | for block_idx in attn_override["layers"]:
122 | model_options = comfy.model_patcher.set_model_options_patch_replace(
123 | model_options, attn_fn, "layer", "self_attn", int(block_idx)
124 | )
125 |
126 | (perturbed,) = comfy.samplers.calc_cond_batch(
127 | model, [cond], x, sigma, model_options
128 | )
129 |
130 | # if len_conds == 1:
131 | # output = cond_pred + scale * (cond_pred - pag)
132 | # else:
133 | # output = cond_pred + (scale-1.0) * (cond_pred - uncond_pred) + scale * (cond_pred - pag)
134 |
135 | output = (
136 | uncond_pred
137 | + cfg * (cond_pred - uncond_pred)
138 | + scale * (cond_pred - perturbed)
139 | )
140 | if rescale > 0:
141 | factor = cond_pred.std() / output.std()
142 | factor = rescale * factor + (1 - rescale)
143 | output = output * factor
144 |
145 | return output
146 |
147 | m.set_model_sampler_post_cfg_function(post_cfg_function)
148 |
149 | return (m,)
150 |
--------------------------------------------------------------------------------
/tricks/nodes/modify_ltx_model_node.py:
--------------------------------------------------------------------------------
1 | from ..modules.ltx_model import inject_model
2 |
3 |
4 | class ModifyLTXModelNode:
5 | @classmethod
6 | def INPUT_TYPES(s):
7 | return {
8 | "required": {
9 | "model": ("MODEL",),
10 | }
11 | }
12 |
13 | RETURN_TYPES = ("MODEL",)
14 |
15 | CATEGORY = "ltxtricks"
16 | FUNCTION = "modify"
17 |
18 | def modify(self, model):
19 | model.model.diffusion_model = inject_model(model.model.diffusion_model)
20 | return (model,)
21 |
--------------------------------------------------------------------------------
/tricks/nodes/rectified_sampler_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy.samplers import KSAMPLER
3 | from tqdm import trange
4 |
5 |
6 | def generate_trend_values(steps, start_time, end_time, eta, eta_trend):
7 | eta_values = [0] * steps
8 |
9 | if eta_trend == "constant":
10 | for i in range(start_time, end_time):
11 | eta_values[i] = eta
12 | elif eta_trend == "linear_increase":
13 | for i in range(start_time, end_time):
14 | progress = (i - start_time) / (end_time - start_time - 1)
15 | eta_values[i] = eta * progress
16 | elif eta_trend == "linear_decrease":
17 | for i in range(start_time, end_time):
18 | progress = 1 - (i - start_time) / (end_time - start_time - 1)
19 | eta_values[i] = eta * progress
20 |
21 | return eta_values
22 |
23 |
24 | def get_sample_forward(
25 | gamma, start_step, end_step, gamma_trend, seed, attn_bank=None, order="first"
26 | ):
27 | # Controlled Forward ODE (Algorithm 1)
28 | generator = torch.Generator()
29 | generator.manual_seed(seed)
30 |
31 | @torch.no_grad()
32 | def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None):
33 | if attn_bank is not None:
34 | for block_idx in attn_bank["block_map"]:
35 | attn_bank["block_map"][block_idx].clear()
36 |
37 | extra_args = {} if extra_args is None else extra_args
38 | model_options = extra_args.get("model_options", {})
39 | model_options = {**model_options}
40 | transformer_options = model_options.get("transformer_options", {})
41 | transformer_options = {
42 | **transformer_options,
43 | "total_steps": len(sigmas) - 1,
44 | "sample_mode": "forward",
45 | "attn_bank": attn_bank,
46 | }
47 | model_options["transformer_options"] = transformer_options
48 | extra_args["model_options"] = model_options
49 |
50 | Y = y0.clone()
51 | y1 = torch.randn(Y.shape, generator=generator).to(y0.device)
52 | N = len(sigmas) - 1
53 | s_in = y0.new_ones([y0.shape[0]])
54 | gamma_values = generate_trend_values(
55 | N, start_step, end_step, gamma, gamma_trend
56 | )
57 | for i in trange(N, disable=disable):
58 | transformer_options["step"] = i
59 | sigma = sigmas[i]
60 | sigma_next = sigmas[i + 1]
61 | t_i = model.inner_model.inner_model.model_sampling.timestep(sigmas[i])
62 |
63 | conditional_vector_field = (y1 - Y) / (1 - t_i)
64 |
65 | transformer_options["pred_order"] = "first"
66 | pred = model(
67 | Y, s_in * sigmas[i], **extra_args
68 | ) # this implementation takes sigma instead of timestep
69 |
70 | if order == "second":
71 | transformer_options["pred_order"] = "second"
72 | img_mid = Y + (sigma_next - sigma) / 2 * pred
73 | sigma_mid = sigma + (sigma_next - sigma) / 2
74 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
75 |
76 | first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2)
77 | pred = pred + gamma_values[i] * (conditional_vector_field - pred)
78 | # first_order = first_order + gamma_values[i] * (conditional_vector_field - first_order)
79 | Y = (
80 | Y
81 | + (sigma_next - sigma) * pred
82 | + 0.5 * (sigma_next - sigma) ** 2 * first_order
83 | )
84 | else:
85 | pred = pred + gamma_values[i] * (conditional_vector_field - pred)
86 | Y = Y + pred * (sigma_next - sigma)
87 |
88 | if callback is not None:
89 | callback(
90 | {"x": Y, "denoised": Y, "i": i, "sigma": sigma, "sigma_hat": sigma}
91 | )
92 |
93 | return Y
94 |
95 | return sample_forward
96 |
97 |
98 | def get_sample_reverse(
99 | latent_image, eta, start_time, end_time, eta_trend, attn_bank=None, order="first"
100 | ):
101 | # Controlled Reverse ODE (Algorithm 2)
102 | @torch.no_grad()
103 | def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None):
104 | extra_args = {} if extra_args is None else extra_args
105 | model_options = extra_args.get("model_options", {})
106 | model_options = {**model_options}
107 | transformer_options = model_options.get("transformer_options", {})
108 | transformer_options = {
109 | **transformer_options,
110 | "total_steps": len(sigmas) - 1,
111 | "sample_mode": "reverse",
112 | "attn_bank": attn_bank,
113 | }
114 | model_options["transformer_options"] = transformer_options
115 | extra_args["model_options"] = model_options
116 |
117 | X = y1.clone()
118 | N = len(sigmas) - 1
119 | y0 = latent_image.clone().to(y1.device)
120 | s_in = y0.new_ones([y0.shape[0]])
121 | eta_values = generate_trend_values(N, start_time, end_time, eta, eta_trend)
122 | for i in trange(N, disable=disable):
123 | transformer_options["step"] = i
124 | t_i = 1 - model.inner_model.inner_model.model_sampling.timestep(sigmas[i])
125 | sigma = sigmas[i]
126 | sigma_prev = sigmas[i + 1]
127 |
128 | conditional_vector_field = (y0 - X) / (1 - t_i)
129 |
130 | transformer_options["pred_order"] = "first"
131 | pred = model(
132 | X, sigma * s_in, **extra_args
133 | ) # this implementation takes sigma instead of timestep
134 |
135 | if order == "second":
136 | transformer_options["pred_order"] = "second"
137 | img_mid = X + (sigma_prev - sigma) / 2 * pred
138 | sigma_mid = sigma + (sigma_prev - sigma) / 2
139 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
140 |
141 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
142 | pred = -pred + eta_values[i] * (conditional_vector_field + pred)
143 |
144 | first_order = -first_order + eta_values[i] * (
145 | conditional_vector_field + first_order
146 | )
147 | X = (
148 | X
149 | + (sigma - sigma_prev) * pred
150 | + 0.5 * (sigma - sigma_prev) ** 2 * first_order
151 | )
152 | else:
153 | controlled_vector_field = -pred + eta_values[i] * (
154 | conditional_vector_field + pred
155 | )
156 | X = X + controlled_vector_field * (sigma - sigma_prev)
157 |
158 | if callback is not None:
159 | callback(
160 | {
161 | "x": X,
162 | "denoised": X,
163 | "i": i,
164 | "sigma": sigmas[i],
165 | "sigma_hat": sigmas[i],
166 | }
167 | )
168 |
169 | return X
170 |
171 | return sample_reverse
172 |
173 |
174 | class LTXRFForwardODESamplerNode:
175 | @classmethod
176 | def INPUT_TYPES(s):
177 | return {
178 | "required": {
179 | "gamma": (
180 | "FLOAT",
181 | {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01},
182 | ),
183 | "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
184 | "end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}),
185 | "gamma_trend": (["linear_decrease", "linear_increase", "constant"],),
186 | },
187 | "optional": {
188 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF}),
189 | "attn_bank": ("ATTN_BANK",),
190 | "order": (["first", "second"],),
191 | },
192 | }
193 |
194 | RETURN_TYPES = ("SAMPLER",)
195 | FUNCTION = "build"
196 |
197 | CATEGORY = "ltxtricks"
198 |
199 | def build(
200 | self,
201 | gamma,
202 | start_step,
203 | end_step,
204 | gamma_trend,
205 | seed=0,
206 | attn_bank=None,
207 | order="first",
208 | ):
209 | sampler = KSAMPLER(
210 | get_sample_forward(
211 | gamma,
212 | start_step,
213 | end_step,
214 | gamma_trend,
215 | seed,
216 | attn_bank=attn_bank,
217 | order=order,
218 | )
219 | )
220 |
221 | return (sampler,)
222 |
223 |
224 | class LTXRFReverseODESamplerNode:
225 | @classmethod
226 | def INPUT_TYPES(s):
227 | return {
228 | "required": {
229 | "model": ("MODEL",),
230 | "latent_image": ("LATENT",),
231 | "eta": (
232 | "FLOAT",
233 | {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01},
234 | ),
235 | "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
236 | "end_step": ("INT", {"default": 15, "min": 0, "max": 1000, "step": 1}),
237 | },
238 | "optional": {
239 | "eta_trend": (["linear_decrease", "linear_increase", "constant"],),
240 | "attn_inj": ("ATTN_INJ",),
241 | "order": (["first", "second"],),
242 | },
243 | }
244 |
245 | RETURN_TYPES = ("SAMPLER",)
246 | FUNCTION = "build"
247 |
248 | CATEGORY = "ltxtricks"
249 |
250 | def build(
251 | self,
252 | model,
253 | latent_image,
254 | eta,
255 | start_step,
256 | end_step,
257 | eta_trend="constant",
258 | attn_inj=None,
259 | order="first",
260 | ):
261 | process_latent_in = model.get_model_object("process_latent_in")
262 | latent_image = process_latent_in(latent_image["samples"])
263 | sampler = KSAMPLER(
264 | get_sample_reverse(
265 | latent_image,
266 | eta,
267 | start_step,
268 | end_step,
269 | eta_trend,
270 | attn_bank=attn_inj,
271 | order=order,
272 | )
273 | )
274 |
275 | return (sampler,)
276 |
--------------------------------------------------------------------------------
/tricks/nodes/rf_edit_sampler_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from comfy.samplers import KSAMPLER
3 | from tqdm import trange
4 |
5 |
6 | def get_sample_forward(attn_bank, save_steps, single_layers, double_layers):
7 | @torch.no_grad()
8 | def sample_forward(model, x, sigmas, extra_args=None, callback=None, disable=None):
9 | attn_bank.clear()
10 | attn_bank["save_steps"] = save_steps
11 |
12 | extra_args = {} if extra_args is None else extra_args
13 |
14 | model_options = extra_args.get("model_options", {})
15 | model_options = {**model_options}
16 | transformer_options = model_options.get("transformer_options", {})
17 | transformer_options = {**transformer_options}
18 | model_options["transformer_options"] = transformer_options
19 | extra_args["model_options"] = model_options
20 |
21 | N = len(sigmas) - 1
22 | s_in = x.new_ones([x.shape[0]])
23 | for i in trange(N, disable=disable):
24 | sigma = sigmas[i]
25 | sigma_next = sigmas[i + 1]
26 |
27 | if N - i - 1 < save_steps:
28 | attn_bank[N - i - 1] = {"first": {}, "mid": {}}
29 |
30 | transformer_options["rfedit"] = {
31 | "step": N - i - 1,
32 | "process": "forward" if N - i - 1 < save_steps else None,
33 | "pred": "first",
34 | "bank": attn_bank,
35 | "single_layers": single_layers,
36 | "double_layers": double_layers,
37 | }
38 |
39 | pred = model(x, s_in * sigma, **extra_args)
40 |
41 | transformer_options["rfedit"] = {
42 | "step": N - i - 1,
43 | "process": "forward" if N - i - 1 < save_steps else None,
44 | "pred": "mid",
45 | "bank": attn_bank,
46 | "single_layers": single_layers,
47 | "double_layers": double_layers,
48 | }
49 |
50 | img_mid = x + (sigma_next - sigma) / 2 * pred
51 | sigma_mid = sigma + (sigma_next - sigma) / 2
52 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
53 |
54 | first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2)
55 | x = (
56 | x
57 | + (sigma_next - sigma) * pred
58 | + 0.5 * (sigma_next - sigma) ** 2 * first_order
59 | )
60 |
61 | if callback is not None:
62 | callback(
63 | {
64 | "x": x,
65 | "denoised": x,
66 | "i": i,
67 | "sigma": sigmas[i],
68 | "sigma_hat": sigmas[i],
69 | }
70 | )
71 |
72 | return x
73 |
74 | return sample_forward
75 |
76 |
77 | def get_sample_reverse(attn_bank, inject_steps, single_layers, double_layers):
78 | @torch.no_grad()
79 | def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None):
80 | if inject_steps > attn_bank["save_steps"]:
81 | raise ValueError(
82 | f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}'
83 | )
84 |
85 | extra_args = {} if extra_args is None else extra_args
86 |
87 | model_options = extra_args.get("model_options", {})
88 | model_options = {**model_options}
89 | transformer_options = model_options.get("transformer_options", {})
90 | transformer_options = {**transformer_options}
91 | model_options["transformer_options"] = transformer_options
92 | extra_args["model_options"] = model_options
93 |
94 | N = len(sigmas) - 1
95 | s_in = x.new_ones([x.shape[0]])
96 | for i in trange(N, disable=disable):
97 | sigma = sigmas[i]
98 | sigma_prev = sigmas[i + 1]
99 |
100 | transformer_options["rfedit"] = {
101 | "step": i,
102 | "process": "reverse" if i < inject_steps else None,
103 | "pred": "first",
104 | "bank": attn_bank,
105 | "single_layers": single_layers,
106 | "double_layers": double_layers,
107 | }
108 |
109 | pred = model(x, s_in * sigma, **extra_args)
110 |
111 | transformer_options["rfedit"] = {
112 | "step": i,
113 | "process": "reverse" if i < inject_steps else None,
114 | "pred": "mid",
115 | "bank": attn_bank,
116 | "single_layers": single_layers,
117 | "double_layers": double_layers,
118 | }
119 |
120 | img_mid = x + (sigma_prev - sigma) / 2 * pred
121 | sigma_mid = sigma + (sigma_prev - sigma) / 2
122 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
123 |
124 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
125 | x = (
126 | x
127 | + (sigma_prev - sigma) * pred
128 | + 0.5 * (sigma_prev - sigma) ** 2 * first_order
129 | )
130 |
131 | if callback is not None:
132 | callback(
133 | {
134 | "x": x,
135 | "denoised": x,
136 | "i": i,
137 | "sigma": sigmas[i],
138 | "sigma_hat": sigmas[i],
139 | }
140 | )
141 |
142 | return x
143 |
144 | return sample_reverse
145 |
146 |
147 | DEFAULT_SINGLE_LAYERS = {}
148 | for i in range(38):
149 | DEFAULT_SINGLE_LAYERS[f"{i}"] = i > 19
150 |
151 | DEFAULT_DOUBLE_LAYERS = {}
152 | for i in range(19):
153 | DEFAULT_DOUBLE_LAYERS[f"{i}"] = False
154 |
155 |
156 | class FlowEditForwardSamplerNode:
157 | @classmethod
158 | def INPUT_TYPES(s):
159 | return {
160 | "required": {
161 | "save_steps": (
162 | "INT",
163 | {"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
164 | ),
165 | },
166 | "optional": {
167 | "single_layers": ("SINGLE_LAYERS",),
168 | "double_layers": ("DOUBLE_LAYERS",),
169 | },
170 | }
171 |
172 | RETURN_TYPES = ("SAMPLER", "ATTN_INJ")
173 | FUNCTION = "build"
174 |
175 | CATEGORY = "fluxtapoz"
176 |
177 | def build(
178 | self,
179 | save_steps,
180 | single_layers=DEFAULT_SINGLE_LAYERS,
181 | double_layers=DEFAULT_DOUBLE_LAYERS,
182 | ):
183 | attn_bank = {}
184 | sampler = KSAMPLER(
185 | get_sample_forward(attn_bank, save_steps, single_layers, double_layers)
186 | )
187 |
188 | return (sampler, attn_bank)
189 |
190 |
191 | # class FlowEditReverseSamplerNode:
192 | # @classmethod
193 | # def INPUT_TYPES(s):
194 | # return {
195 | # "required": {
196 | # "attn_inj": ("ATTN_INJ",),
197 | # "latent_image": ("LATENT",),
198 | # "eta": (
199 | # "FLOAT",
200 | # {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01},
201 | # ),
202 | # "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
203 | # "end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}),
204 | # },
205 | # "optional": {},
206 | # }
207 | #
208 | # RETURN_TYPES = ("SAMPLER",)
209 | # FUNCTION = "build"
210 | #
211 | # CATEGORY = "fluxtapoz"
212 | #
213 | # def build(self, latent_image, eta, start_step, end_step):
214 | # sampler = KSAMPLER(
215 | # get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers)
216 | # )
217 | # return (sampler,)
218 |
219 |
220 | def get_sample_reverse2(attn_bank, inject_steps, single_layers, double_layers):
221 | @torch.no_grad()
222 | def sample_reverse(model, x, sigmas, extra_args=None, callback=None, disable=None):
223 | if inject_steps > attn_bank["save_steps"]:
224 | raise ValueError(
225 | f'You must save at least as many steps as you want to inject. save_steps: {attn_bank["save_steps"]}, inject_steps: {inject_steps}'
226 | )
227 |
228 | extra_args = {} if extra_args is None else extra_args
229 |
230 | model_options = extra_args.get("model_options", {})
231 | model_options = {**model_options}
232 | transformer_options = model_options.get("transformer_options", {})
233 | transformer_options = {**transformer_options}
234 | model_options["transformer_options"] = transformer_options
235 | extra_args["model_options"] = model_options
236 |
237 | N = len(sigmas) - 1
238 | s_in = x.new_ones([x.shape[0]])
239 | for i in trange(N, disable=disable):
240 | sigma = sigmas[i]
241 | sigma_prev = sigmas[i + 1]
242 |
243 | transformer_options["rfedit"] = {
244 | "step": i,
245 | "process": "reverse" if i < inject_steps else None,
246 | "pred": "first",
247 | "bank": attn_bank,
248 | "single_layers": single_layers,
249 | "double_layers": double_layers,
250 | }
251 |
252 | pred = model(x, s_in * sigma, **extra_args)
253 |
254 | transformer_options["rfedit"] = {
255 | "step": i,
256 | "process": "reverse" if i < inject_steps else None,
257 | "pred": "mid",
258 | "bank": attn_bank,
259 | "single_layers": single_layers,
260 | "double_layers": double_layers,
261 | }
262 |
263 | img_mid = x + (sigma_prev - sigma) / 2 * pred
264 | sigma_mid = sigma + (sigma_prev - sigma) / 2
265 | pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)
266 |
267 | first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
268 | x = (
269 | x
270 | + (sigma_prev - sigma) * pred
271 | + 0.5 * (sigma_prev - sigma) ** 2 * first_order
272 | )
273 |
274 | if callback is not None:
275 | callback(
276 | {
277 | "x": x,
278 | "denoised": x,
279 | "i": i,
280 | "sigma": sigmas[i],
281 | "sigma_hat": sigmas[i],
282 | }
283 | )
284 |
285 | return x
286 |
287 | return sample_reverse
288 |
289 |
290 | class FlowEdit2ReverseSamplerNode:
291 | @classmethod
292 | def INPUT_TYPES(s):
293 | return {
294 | "required": {
295 | "attn_inj": ("ATTN_INJ",),
296 | "inject_steps": (
297 | "INT",
298 | {"default": 0, "min": 0, "max": 1000, "step": 1},
299 | ),
300 | },
301 | "optional": {
302 | "single_layers": ("SINGLE_LAYERS",),
303 | "double_layers": ("DOUBLE_LAYERS",),
304 | },
305 | }
306 |
307 | RETURN_TYPES = ("SAMPLER",)
308 | FUNCTION = "build"
309 |
310 | CATEGORY = "ltxtricks"
311 |
312 | def build(
313 | self,
314 | attn_inj,
315 | inject_steps,
316 | single_layers=DEFAULT_SINGLE_LAYERS,
317 | double_layers=DEFAULT_DOUBLE_LAYERS,
318 | ):
319 | sampler = KSAMPLER(
320 | get_sample_reverse(attn_inj, inject_steps, single_layers, double_layers)
321 | )
322 | return (sampler,)
323 |
324 |
325 | class PrepareAttnBankNode:
326 | @classmethod
327 | def INPUT_TYPES(s):
328 | return {
329 | "required": {
330 | "latent": ("LATENT",),
331 | "attn_inj": ("ATTN_INJ",),
332 | }
333 | }
334 |
335 | RETURN_TYPES = ("LATENT", "ATTN_INJ")
336 | FUNCTION = "prepare"
337 |
338 | CATEGORY = "ltxtricks"
339 |
340 | def prepare(self, latent, attn_inj):
341 | # Hack to force order of operations in ComfyUI graph
342 | return (latent, attn_inj)
343 |
344 |
345 | class RFSingleBlocksOverrideNode:
346 | @classmethod
347 | def INPUT_TYPES(s):
348 | layers = {}
349 | for i in range(38):
350 | layers[f"{i}"] = ("BOOLEAN", {"default": i > 19})
351 | return {"required": layers}
352 |
353 | RETURN_TYPES = ("SINGLE_LAYERS",)
354 | FUNCTION = "build"
355 |
356 | CATEGORY = "ltxtricks"
357 |
358 | def build(self, *args, **kwargs):
359 | return (kwargs,)
360 |
361 |
362 | class RFDoubleBlocksOverrideNode:
363 | @classmethod
364 | def INPUT_TYPES(s):
365 | layers = {}
366 | for i in range(19):
367 | layers[f"{i}"] = ("BOOLEAN", {"default": False})
368 | return {"required": layers}
369 |
370 | RETURN_TYPES = ("DOUBLE_LAYERS",)
371 | FUNCTION = "build"
372 |
373 | CATEGORY = "ltxtricks"
374 |
375 | def build(self, *args, **kwargs):
376 | return (kwargs,)
377 |
--------------------------------------------------------------------------------
/tricks/utils/attn_bank.py:
--------------------------------------------------------------------------------
1 | class AttentionBank:
2 | def __init__(self, save_steps, block_map, inject_steps=None):
3 | self._data = {
4 | "save_steps": save_steps,
5 | "block_map": block_map,
6 | "inject_steps": inject_steps,
7 | }
8 |
9 | def __getitem__(self, key):
10 | return self._data[key]
11 |
12 | def __setitem__(self, key, value):
13 | self._data[key] = value
14 |
15 | def get(self, key, default=None):
16 | return self._data.get(key, default)
17 |
--------------------------------------------------------------------------------
/tricks/utils/feta_enhance_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 |
4 |
5 | def _feta_score(query_image, key_image, head_dim, num_frames, enhance_weight):
6 | scale = head_dim**-0.5
7 | query_image = query_image * scale
8 | attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
9 | attn_temp = attn_temp.to(torch.float32)
10 | attn_temp = attn_temp.softmax(dim=-1)
11 |
12 | # Reshape to [batch_size * num_tokens, num_frames, num_frames]
13 | attn_temp = attn_temp.reshape(-1, num_frames, num_frames)
14 |
15 | # Create a mask for diagonal elements
16 | diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
17 | diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)
18 |
19 | # Zero out diagonal elements
20 | attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)
21 |
22 | # Calculate mean for each token's attention matrix
23 | # Number of off-diagonal elements per matrix is n*n - n
24 | num_off_diag = num_frames * num_frames - num_frames
25 | mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag
26 |
27 | enhance_scores = mean_scores.mean() * (num_frames + enhance_weight)
28 | enhance_scores = enhance_scores.clamp(min=1)
29 | return enhance_scores
30 |
31 |
32 | def get_feta_scores(img_q, img_k, num_heads, transformer_options):
33 | num_frames = transformer_options["original_shape"][2]
34 | _, ST, dim = img_q.shape
35 | head_dim = dim // num_heads
36 | spatial_dim = ST // num_frames
37 |
38 | query_image = rearrange(
39 | img_q,
40 | "B (T S) (N C) -> (B S) N T C",
41 | T=num_frames,
42 | S=spatial_dim,
43 | N=num_heads,
44 | C=head_dim,
45 | )
46 | key_image = rearrange(
47 | img_k,
48 | "B (T S) (N C) -> (B S) N T C",
49 | T=num_frames,
50 | S=spatial_dim,
51 | N=num_heads,
52 | C=head_dim,
53 | )
54 | weight = transformer_options.get("feta_weight", 0)
55 | return _feta_score(query_image, key_image, head_dim, num_frames, weight)
56 |
--------------------------------------------------------------------------------
/tricks/utils/latent_guide.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LatentGuide(torch.nn.Module):
5 | def __init__(self, latent: torch.Tensor, index) -> None:
6 | super().__init__()
7 | self.index = index
8 | self.register_buffer("latent", latent)
9 |
--------------------------------------------------------------------------------
/tricks/utils/module_utils.py:
--------------------------------------------------------------------------------
1 | def isinstance_str(x: object, cls_name: str):
2 | for _cls in x.__class__.__mro__:
3 | if _cls.__name__ == cls_name:
4 | return True
5 |
6 | return False
7 |
--------------------------------------------------------------------------------
/tricks/utils/noise_utils.py:
--------------------------------------------------------------------------------
1 | def get_alphacumprod(sigma):
2 | return 1 / ((sigma * sigma) + 1)
3 |
4 |
5 | def add_noise(src_latent, noise, sigma):
6 | alphas_cumprod = get_alphacumprod(sigma)
7 |
8 | sqrt_alpha_prod = alphas_cumprod**0.5
9 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
10 | while len(sqrt_alpha_prod.shape) < len(src_latent.shape):
11 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
12 |
13 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
14 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
15 | while len(sqrt_one_minus_alpha_prod.shape) < len(src_latent.shape):
16 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
17 |
18 | noisy_samples = sqrt_alpha_prod * src_latent + sqrt_one_minus_alpha_prod * noise
19 | return noisy_samples
20 |
21 |
22 | def add_noise_flux(src_latent, noise, sigma):
23 | return sigma * noise + (1.0 - sigma) * src_latent
24 |
--------------------------------------------------------------------------------