├── .gitattributes
├── .github
└── workflows
│ └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── __init__.py
├── configs
├── ltxv-13b-0.9.7-dev.yaml
├── ltxv-13b-0.9.7-distilled.yaml
├── ltxv-2b-0.9.1.yaml
├── ltxv-2b-0.9.5.yaml
├── ltxv-2b-0.9.6-dev.yaml
├── ltxv-2b-0.9.6-distilled.yaml
└── ltxv-2b-0.9.yaml
├── docs
└── _static
│ ├── ltx-video_example_00001.gif
│ ├── ltx-video_example_00005.gif
│ ├── ltx-video_example_00006.gif
│ ├── ltx-video_example_00007.gif
│ ├── ltx-video_example_00010.gif
│ ├── ltx-video_example_00011.gif
│ ├── ltx-video_example_00013.gif
│ ├── ltx-video_example_00014.gif
│ ├── ltx-video_example_00015.gif
│ ├── ltx-video_i2v_example_00001.gif
│ ├── ltx-video_i2v_example_00002.gif
│ ├── ltx-video_i2v_example_00003.gif
│ ├── ltx-video_i2v_example_00004.gif
│ ├── ltx-video_i2v_example_00005.gif
│ ├── ltx-video_i2v_example_00006.gif
│ ├── ltx-video_i2v_example_00007.gif
│ ├── ltx-video_i2v_example_00008.gif
│ └── ltx-video_i2v_example_00009.gif
├── inference.py
├── ltx_video
├── __init__.py
├── models
│ ├── __init__.py
│ ├── autoencoders
│ │ ├── __init__.py
│ │ ├── causal_conv3d.py
│ │ ├── causal_video_autoencoder.py
│ │ ├── conv_nd_factory.py
│ │ ├── dual_conv3d.py
│ │ ├── latent_upsampler.py
│ │ ├── pixel_norm.py
│ │ ├── pixel_shuffle.py
│ │ ├── vae.py
│ │ ├── vae_encode.py
│ │ └── video_autoencoder.py
│ └── transformers
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── embeddings.py
│ │ ├── symmetric_patchifier.py
│ │ └── transformer3d.py
├── pipelines
│ ├── __init__.py
│ ├── crf_compressor.py
│ └── pipeline_ltx_video.py
├── schedulers
│ ├── __init__.py
│ └── rf.py
└── utils
│ ├── __init__.py
│ ├── diffusers_config_mapping.py
│ ├── prompt_enhance_utils.py
│ ├── skip_layer_strategy.py
│ └── torch_utils.py
├── pyproject.toml
└── tests
├── conftest.py
├── test_inference.py
├── test_scheduler.py
├── test_vae.py
└── utils
├── .gitattributes
├── woman.jpeg
└── woman.mp4
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.jpg filter=lfs diff=lfs merge=lfs -text
2 | *.jpeg filter=lfs diff=lfs merge=lfs -text
3 | *.png filter=lfs diff=lfs merge=lfs -text
4 | *.gif filter=lfs diff=lfs merge=lfs -text
5 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ["3.10"]
11 | steps:
12 | - name: Checkout repository and submodules
13 | uses: actions/checkout@v3
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v3
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install ruff==0.2.2 black==24.2.0
22 | - name: Analyzing the code with ruff
23 | run: |
24 | ruff $(git ls-files '*.py')
25 | - name: Verify that no Black changes are required
26 | run: |
27 | black --check $(git ls-files '*.py')
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | .idea/
163 |
164 | # From inference.py
165 | outputs/
166 | video_output_*.mp4
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.2.2
5 | hooks:
6 | # Run the linter.
7 | - id: ruff
8 | args: [--fix] # Automatically fix issues if possible.
9 | types: [python] # Ensure it only runs on .py files.
10 |
11 | - repo: https://github.com/psf/black
12 | rev: 24.2.0 # Specify the version of Black you want
13 | hooks:
14 | - id: black
15 | name: Black code formatter
16 | language_version: python3 # Use the Python version you're targeting (e.g., 3.10)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # LTX-Video
4 |
5 | This is the official repository for LTX-Video.
6 |
7 | [Website](https://www.lightricks.com/ltxv) |
8 | [Model](https://huggingface.co/Lightricks/LTX-Video) |
9 | [Demo](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) |
10 | [Paper](https://arxiv.org/abs/2501.00103) |
11 | [Trainer](https://github.com/Lightricks/LTX-Video-Trainer) |
12 | [Discord](https://discord.gg/Mn8BRgUKKy)
13 |
14 |
15 |
16 | ## Table of Contents
17 |
18 | - [Introduction](#introduction)
19 | - [What's new](#news)
20 | - [Models & Workflows](#models--workflows)
21 | - [Quick Start Guide](#quick-start-guide)
22 | - [Use online](#online-inference)
23 | - [Run locally](#run-locally)
24 | - [Installation](#installation)
25 | - [Inference](#inference)
26 | - [ComfyUI Integration](#comfyui-integration)
27 | - [Diffusers Integration](#diffusers-integration)
28 | - [Model User Guide](#model-user-guide)
29 | - [Community Contribution](#community-contribution)
30 | - [Training](#⚡️-training)
31 | - [Join Us!](#🚀-join-us)
32 | - [Acknowledgement](#acknowledgement)
33 |
34 | # Introduction
35 |
36 | LTX-Video is the first DiT-based video generation model that can generate high-quality videos in *real-time*.
37 | It can generate 30 FPS videos at 1216×704 resolution, faster than it takes to watch them.
38 | The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos
39 | with realistic and diverse content.
40 |
41 | The model supports text-to-image, image-to-video, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features.
42 |
43 | ### Image to video examples
44 | | | | |
45 | |:---:|:---:|:---:|
46 | |  |  |  |
47 | |  |  |  |
48 | |  |  |  |
49 |
50 |
51 | ### Text to video examples
52 | | | | |
53 | |:---:|:---:|:---:|
54 | | 
A woman with long brown hair and light skin smiles at another woman...
A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage. | 
A clear, turquoise river flows through a rocky canyon...
A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility. | 
Two police officers in dark blue uniforms and matching hats...
Two police officers in dark blue uniforms and matching hats enter a dimly lit room through a doorway on the left side of the frame. The first officer, with short brown hair and a mustache, steps inside first, followed by his partner, who has a shaved head and a goatee. Both officers have serious expressions and maintain a steady pace as they move deeper into the room. The camera remains stationary, capturing them from a slightly low angle as they enter. The room has exposed brick walls and a corrugated metal ceiling, with a barred window visible in the background. The lighting is low-key, casting shadows on the officers' faces and emphasizing the grim atmosphere. The scene appears to be from a film or television show. |
55 | | 
A woman with light skin, wearing a blue jacket and a black hat...
A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage. | 
A man in a dimly lit room talks on a vintage telephone...
A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. | 
A prison guard unlocks and opens a cell door...
A prison guard unlocks and opens a cell door to reveal a young man sitting at a table with a woman. The guard, wearing a dark blue uniform with a badge on his left chest, unlocks the cell door with a key held in his right hand and pulls it open; he has short brown hair, light skin, and a neutral expression. The young man, wearing a black and white striped shirt, sits at a table covered with a white tablecloth, facing the woman; he has short brown hair, light skin, and a neutral expression. The woman, wearing a dark blue shirt, sits opposite the young man, her face turned towards him; she has short blonde hair and light skin. The camera remains stationary, capturing the scene from a medium distance, positioned slightly to the right of the guard. The room is dimly lit, with a single light fixture illuminating the table and the two figures. The walls are made of large, grey concrete blocks, and a metal door is visible in the background. The scene is captured in real-life footage. |
56 | | 
A man walks towards a window, looks out, and then turns around...
A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage. | 
The camera pans across a cityscape of tall buildings...
The camera pans across a cityscape of tall buildings with a circular building in the center. The camera moves from left to right, showing the tops of the buildings and the circular building in the center. The buildings are various shades of gray and white, and the circular building has a green roof. The camera angle is high, looking down at the city. The lighting is bright, with the sun shining from the upper left, casting shadows from the buildings. The scene is computer-generated imagery. | 
A man in a suit enters a room and speaks to two women...
A man in a suit enters a room and speaks to two women sitting on a couch. The man, wearing a dark suit with a gold tie, enters the room from the left and walks towards the center of the frame. He has short gray hair, light skin, and a serious expression. He places his right hand on the back of a chair as he approaches the couch. Two women are seated on a light-colored couch in the background. The woman on the left wears a light blue sweater and has short blonde hair. The woman on the right wears a white sweater and has short blonde hair. The camera remains stationary, focusing on the man as he enters the room. The room is brightly lit, with warm tones reflecting off the walls and furniture. The scene appears to be from a film or television show. |
57 |
58 | # News
59 |
60 | ## May, 14th, 2025: New distilled model 13B v0.9.7:
61 | - Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors)
62 | * Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)!
63 | * Does not require classifier-free guidance and spatio-temporal guidance.
64 | * Supports sampling with 8 (recommended), or less diffusion steps.
65 | * Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors)
66 | * Requires only 1GB of VRAM
67 | * Can be used with the full 13B model for fast inference
68 | - Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM (Supported in the [official ComfyUI workflow](https://github.com/Lightricks/ComfyUI-LTXVideo/))
69 |
70 | ## May, 5th, 2025: New model 13B v0.9.7:
71 | - Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors)
72 | - Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRAM (Supported in the [official ComfyUI workflow](https://github.com/Lightricks/ComfyUI-LTXVideo/))
73 | - Release a new upscalers
74 | * [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors)
75 | * [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors)
76 | - Breakthrough prompt adherence and physical understanding.
77 | - New Pipeline for multi-scale video rendering for fast and high quality results
78 |
79 |
80 | ## April, 15th, 2025: New checkpoints v0.9.6:
81 | - Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality
82 | - Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors)
83 | * 15x faster inference than non-distilled model.
84 | * Does not require classifier-free guidance and spatio-temporal guidance.
85 | * Supports sampling with 8 (recommended), or less diffusion steps.
86 | - Improved prompt adherence, motion quality and fine details.
87 | - New default resolution and FPS: 1216 × 704 pixels at 30 FPS
88 | * Still real time on H100 with the distilled model.
89 | * Other resolutions and FPS are still supported.
90 | - Support stochastic inference (can improve visual quality when using the distilled model)
91 |
92 | ## March, 5th, 2025: New checkpoint v0.9.5
93 | - New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt))
94 | - Release a new checkpoint v0.9.5 with improved quality
95 | - Support keyframes and video extension
96 | - Support higher resolutions
97 | - Improved prompt understanding
98 | - Improved VAE
99 | - New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video)
100 | - Automatic prompt enhancement
101 |
102 | ## February, 20th, 2025: More inference options
103 | - Improve STG (Spatiotemporal Guidance) for LTX-Video
104 | - Support MPS on macOS with PyTorch 2.3.0
105 | - Add support for 8-bit model, LTX-VideoQ8
106 | - Add TeaCache for LTX-Video
107 | - Add [ComfyUI-LTXTricks](#comfyui-integration)
108 | - Add Diffusion-Pipe
109 |
110 | ## December 31st, 2024: Research paper
111 | - Release the [research paper](https://arxiv.org/abs/2501.00103)
112 |
113 | ## December 20th, 2024: New checkpoint v0.9.1
114 | - Release a new checkpoint v0.9.1 with improved quality
115 | - Support for STG / PAG
116 | - Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly)
117 | - Support offloading unused parts to CPU
118 | - Support the new timestep-conditioned VAE decoder
119 | - Reference contributions from the community in the readme file
120 | - Relax transformers dependency
121 |
122 | ## November 21th, 2024: Initial release v0.9.0
123 | - Initial release of LTX-Video
124 | - Support text-to-video and image-to-video generation
125 |
126 |
127 | # Models & Workflows
128 |
129 | | Name | Notes | inference.py config | ComfyUI workflow (Recommended) |
130 | |-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------|
131 | | ltxv-13b-0.9.7-dev | Highest quality, requires more VRAM | [ltxv-13b-0.9.7-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.7-dev.yaml) | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json) |
132 | | [ltxv-13b-0.9.7-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json) |
133 | [ltxv-13b-0.9.7-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv) | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.7-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.7-dev.yaml) | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) |
134 | | [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors) | LoRA to make ltxv-13b-dev behave like the distilled model | N/A | N/A |
135 | | ltxv-13b-0.9.7-fp8 | Quantized version of ltxv-13b | Coming soon | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) |
136 | | ltxv-13b-0.9.7-distilled-fp8 | Quantized version of ltxv-13b-distilled | Coming soon | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) |
137 | | ltxv-2b-0.9.6 | Good quality, lower VRAM requirement than ltxv-13b | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml) | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json) |
138 | | ltxv-2b-0.9.6-distilled | 15× faster, real-time capable, fewer steps needed, no STG/CFG required | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml) | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json) |
139 |
140 |
141 | # Quick Start Guide
142 |
143 | ## Online inference
144 | The model is accessible right away via the following links:
145 | - [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b)
146 | - [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv)
147 | - [Fal.ai text-to-video](https://fal.ai/models/fal-ai/ltx-video)
148 | - [Fal.ai image-to-video](https://fal.ai/models/fal-ai/ltx-video/image-to-video)
149 | - [Replicate text-to-video and image-to-video](https://replicate.com/lightricks/ltx-video)
150 |
151 | ## Run locally
152 |
153 | ### Installation
154 | The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2.
155 | On macos, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6.
156 |
157 | ```bash
158 | git clone https://github.com/Lightricks/LTX-Video.git
159 | cd LTX-Video
160 |
161 | # create env
162 | python -m venv env
163 | source env/bin/activate
164 | python -m pip install -e .\[inference-script\]
165 | ```
166 |
167 | ### Inference
168 |
169 | 📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We’re working on updating the inference.py script to match the high quality and output fidelity of ComfyUI.
170 |
171 | To use our model, please follow the inference code in [inference.py](./inference.py):
172 |
173 | #### For text-to-video generation:
174 |
175 | ```bash
176 | python inference.py --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml
177 | ```
178 |
179 | #### For image-to-video generation:
180 |
181 | ```bash
182 | python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml
183 | ```
184 |
185 | #### Extending a video:
186 |
187 | 📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8.
188 |
189 |
190 | ```bash
191 | python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml
192 | ```
193 |
194 | #### For video generation with multiple conditions:
195 |
196 | You can now generate a video conditioned on a set of images and/or short video segments.
197 | Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0).
198 |
199 | ```bash
200 | python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml
201 | ```
202 |
203 | ## ComfyUI Integration
204 | To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/).
205 |
206 | ## Diffusers Integration
207 | To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
208 |
209 | Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8)
210 |
211 | # Model User Guide
212 |
213 | ## 📝 Prompt Engineering
214 |
215 | When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure:
216 |
217 | * Start with main action in a single sentence
218 | * Add specific details about movements and gestures
219 | * Describe character/object appearances precisely
220 | * Include background and environment details
221 | * Specify camera angles and movements
222 | * Describe lighting and colors
223 | * Note any changes or sudden events
224 | * See [examples](#introduction) for more inspiration.
225 |
226 | ### Automatic Prompt Enhancement
227 | When using `inference.py`, shorts prompts (below `prompt_enhancement_words_threshold` words) are automatically enhanced by a language model. This is supported with text-to-video and image-to-video (first-frame conditioning).
228 |
229 | When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`.
230 |
231 | ## 🎮 Parameter Guide
232 |
233 | * Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257
234 | * Seed: Save seed values to recreate specific styles or compositions you like
235 | * Guidance Scale: 3-3.5 are the recommended values
236 | * Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
237 |
238 | 📝 For advanced parameters usage, please see `python inference.py --help`
239 |
240 | ## Community Contribution
241 |
242 | ### ComfyUI-LTXTricks 🛠️
243 |
244 | A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.
245 |
246 | - **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
247 | - **Features:**
248 | - 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
249 | - ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
250 | - 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
251 | - 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
252 | - ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
253 | - 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).
254 |
255 |
256 | ### LTX-VideoQ8 🎱
257 |
258 | **LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.
259 |
260 | - **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
261 | - **Features:**
262 | - 🚀 Up to 3X speed-up with no accuracy loss
263 | - 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
264 | - 🛠️ Fine-tune 2B transformer models with precalculated latents
265 | - **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)
266 | - **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video)
267 |
268 |
269 | ### TeaCache for LTX-Video 🍵
270 |
271 | **TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation.
272 |
273 | - **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video)
274 | - **Features:**
275 | - 🚀 Speeds up LTX-Video inference.
276 | - 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters.
277 | - 🛠️ No retraining required: Works directly with existing models.
278 |
279 | ### Your Contribution
280 |
281 | ...is welcome! If you have a project or tool that integrates with LTX-Video,
282 | please let us know by opening an issue or pull request.
283 |
284 | # ⚡️ Training
285 |
286 | We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer).
287 | This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training.
288 |
289 | Explore the repository to customize the model for your specific use cases!
290 | More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md).
291 |
292 |
293 | # 🚀 Join Us
294 |
295 | Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?
296 |
297 | At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created.
298 |
299 | If you are passionate about AI, computer vision, and video generation, we would love to hear from you!
300 |
301 | Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.
302 |
303 | # Acknowledgement
304 |
305 | We are grateful for the following awesome projects when implementing LTX-Video:
306 | * [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation.
307 |
308 |
309 | ## Citation
310 |
311 | 📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper.
312 |
313 | ```
314 | @article{HaCohen2024LTXVideo,
315 | title={LTX-Video: Realtime Video Latent Diffusion},
316 | author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir},
317 | journal={arXiv preprint arXiv:2501.00103},
318 | year={2024}
319 | }
320 | ```
321 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/__init__.py
--------------------------------------------------------------------------------
/configs/ltxv-13b-0.9.7-dev.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: multi-scale
2 | checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors"
3 | downscale_factor: 0.6666666
4 | spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
5 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6 | decode_timestep: 0.05
7 | decode_noise_scale: 0.025
8 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9 | precision: "bfloat16"
10 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11 | prompt_enhancement_words_threshold: 120
12 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14 | stochastic_sampling: false
15 |
16 | first_pass:
17 | guidance_scale: [1, 1, 6, 8, 6, 1, 1]
18 | stg_scale: [0, 0, 4, 4, 4, 2, 1]
19 | rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
20 | guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
21 | skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
22 | num_inference_steps: 30
23 | skip_final_inference_steps: 3
24 | cfg_star_rescale: true
25 |
26 | second_pass:
27 | guidance_scale: [1]
28 | stg_scale: [1]
29 | rescaling_scale: [1]
30 | guidance_timesteps: [1.0]
31 | skip_block_list: [27]
32 | num_inference_steps: 30
33 | skip_initial_inference_steps: 17
34 | cfg_star_rescale: true
--------------------------------------------------------------------------------
/configs/ltxv-13b-0.9.7-distilled.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: multi-scale
2 | checkpoint_path: "ltxv-13b-0.9.7-distilled.safetensors"
3 | downscale_factor: 0.6666666
4 | spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
5 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
6 | decode_timestep: 0.05
7 | decode_noise_scale: 0.025
8 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
9 | precision: "bfloat16"
10 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
11 | prompt_enhancement_words_threshold: 120
12 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
13 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
14 | stochastic_sampling: false
15 |
16 | first_pass:
17 | timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
18 | guidance_scale: 1
19 | stg_scale: 0
20 | rescaling_scale: 1
21 | skip_block_list: [42]
22 |
23 | second_pass:
24 | timesteps: [0.9094, 0.7250, 0.4219]
25 | guidance_scale: 1
26 | stg_scale: 0
27 | rescaling_scale: 1
28 | skip_block_list: [42]
29 |
--------------------------------------------------------------------------------
/configs/ltxv-2b-0.9.1.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: base
2 | checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
3 | guidance_scale: 3
4 | stg_scale: 1
5 | rescaling_scale: 0.7
6 | skip_block_list: [19]
7 | num_inference_steps: 40
8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9 | decode_timestep: 0.05
10 | decode_noise_scale: 0.025
11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12 | precision: "bfloat16"
13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14 | prompt_enhancement_words_threshold: 120
15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17 | stochastic_sampling: false
--------------------------------------------------------------------------------
/configs/ltxv-2b-0.9.5.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: base
2 | checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
3 | guidance_scale: 3
4 | stg_scale: 1
5 | rescaling_scale: 0.7
6 | skip_block_list: [19]
7 | num_inference_steps: 40
8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9 | decode_timestep: 0.05
10 | decode_noise_scale: 0.025
11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12 | precision: "bfloat16"
13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14 | prompt_enhancement_words_threshold: 120
15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17 | stochastic_sampling: false
--------------------------------------------------------------------------------
/configs/ltxv-2b-0.9.6-dev.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: base
2 | checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
3 | guidance_scale: 3
4 | stg_scale: 1
5 | rescaling_scale: 0.7
6 | skip_block_list: [19]
7 | num_inference_steps: 40
8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9 | decode_timestep: 0.05
10 | decode_noise_scale: 0.025
11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12 | precision: "bfloat16"
13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14 | prompt_enhancement_words_threshold: 120
15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17 | stochastic_sampling: false
--------------------------------------------------------------------------------
/configs/ltxv-2b-0.9.6-distilled.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: base
2 | checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
3 | guidance_scale: 1
4 | stg_scale: 0
5 | rescaling_scale: 1
6 | num_inference_steps: 8
7 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
8 | decode_timestep: 0.05
9 | decode_noise_scale: 0.025
10 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
11 | precision: "bfloat16"
12 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
13 | prompt_enhancement_words_threshold: 120
14 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
15 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
16 | stochastic_sampling: true
--------------------------------------------------------------------------------
/configs/ltxv-2b-0.9.yaml:
--------------------------------------------------------------------------------
1 | pipeline_type: base
2 | checkpoint_path: "ltx-video-2b-v0.9.safetensors"
3 | guidance_scale: 3
4 | stg_scale: 1
5 | rescaling_scale: 0.7
6 | skip_block_list: [19]
7 | num_inference_steps: 40
8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9 | decode_timestep: 0.05
10 | decode_noise_scale: 0.025
11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12 | precision: "bfloat16"
13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14 | prompt_enhancement_words_threshold: 120
15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17 | stochastic_sampling: false
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00001.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b679f14a09d2321b7e34b3ecd23bc01c2cfa75c8d4214a1e59af09826003e2ec
3 | size 7963919
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00005.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:87fdb9556c1218db4b929994e9b807d1d63f4676defef5b418a4edb1ddaa8422
3 | size 5732587
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00006.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:f56f3dcc84a871ab4ef1510120f7a4586c7044c5609a897d8177ae8d52eb3eae
3 | size 4239543
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00007.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a08a06681334856db516e969a9ae4290acfd7550f7b970331e87d0223e282bcc
3 | size 7829259
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00010.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:bcf1e084e936a75eaae73a29f60935c469b1fc34eb3f5ad89483e88b3a2eaffe
3 | size 6193172
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00011.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:3e3d04f5763ecb416b3b80c3488e48c49991d80661c94e8f08dddd7b890b1b75
3 | size 5345673
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00013.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:aa7eb790b43f8a55c01d1fbed4c7a7f657fb2ca78a9685833cf9cb558d2002c1
3 | size 9024843
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00014.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:4f7afc4b498a927dcc4e1492548db5c32fa76d117e0410d11e1e0b1929153e54
3 | size 7434241
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_example_00015.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d897c9656e0cba89512ab9d2cbe2d2c0f2ddf907dcab5f7eadab4b96b1cb1efe
3 | size 6556457
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00001.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:51e7f2ef92a9a2296e7ce9b40a118033b26372838befdd3d2281819c438ee928
3 | size 9285683
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00002.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:0bd2c1314587efc0b8b7656d6634fbcc3a2045801441bf139bab7726275fa353
3 | size 20393804
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00003.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:bb62a79340882b490d11e05de0427e4efad3dca19a55e003ef92889613b67825
3 | size 9825156
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00004.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:151cb985a1622512b656288b1a1cba7906a34678a2fd9ae6e25611e330a0f9bb
3 | size 15691608
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00005.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:6280efa40b66a50a8a32fddc0e81a55081e8c8efef56a54229ea4e7f2ae4309d
3 | size 10329925
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00006.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:da23e79524a6427959b20809e392555ab4336881ade20001e5ae276d662ed291
3 | size 11936674
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00007.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:9bcfe336e823925ace248f3ad517df791986c00a1a3c375bac2cb433154ca133
3 | size 11755718
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00008.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:706368111eda5331dca089e9231bb272a2e61a6b231bd13876ac442cb9d78019
3 | size 14716658
4 |
--------------------------------------------------------------------------------
/docs/_static/ltx-video_i2v_example_00009.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a9c13259e01485c16148dc6681922cc3a748c2f7b52ff7e208c3ffc5ab71397d
3 | size 16848341
4 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | from datetime import datetime
5 | from pathlib import Path
6 | from diffusers.utils import logging
7 | from typing import Optional, List, Union
8 | import yaml
9 |
10 | import imageio
11 | import json
12 | import numpy as np
13 | import torch
14 | import cv2
15 | from safetensors import safe_open
16 | from PIL import Image
17 | from transformers import (
18 | T5EncoderModel,
19 | T5Tokenizer,
20 | AutoModelForCausalLM,
21 | AutoProcessor,
22 | AutoTokenizer,
23 | )
24 | from huggingface_hub import hf_hub_download
25 |
26 | from ltx_video.models.autoencoders.causal_video_autoencoder import (
27 | CausalVideoAutoencoder,
28 | )
29 | from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
30 | from ltx_video.models.transformers.transformer3d import Transformer3DModel
31 | from ltx_video.pipelines.pipeline_ltx_video import (
32 | ConditioningItem,
33 | LTXVideoPipeline,
34 | LTXMultiScalePipeline,
35 | )
36 | from ltx_video.schedulers.rf import RectifiedFlowScheduler
37 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38 | from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39 | import ltx_video.pipelines.crf_compressor as crf_compressor
40 |
41 | MAX_HEIGHT = 720
42 | MAX_WIDTH = 1280
43 | MAX_NUM_FRAMES = 257
44 |
45 | logger = logging.get_logger("LTX-Video")
46 |
47 |
48 | def get_total_gpu_memory():
49 | if torch.cuda.is_available():
50 | total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
51 | return total_memory
52 | return 0
53 |
54 |
55 | def get_device():
56 | if torch.cuda.is_available():
57 | return "cuda"
58 | elif torch.backends.mps.is_available():
59 | return "mps"
60 | return "cpu"
61 |
62 |
63 | def load_image_to_tensor_with_resize_and_crop(
64 | image_input: Union[str, Image.Image],
65 | target_height: int = 512,
66 | target_width: int = 768,
67 | just_crop: bool = False,
68 | ) -> torch.Tensor:
69 | """Load and process an image into a tensor.
70 |
71 | Args:
72 | image_input: Either a file path (str) or a PIL Image object
73 | target_height: Desired height of output tensor
74 | target_width: Desired width of output tensor
75 | just_crop: If True, only crop the image to the target size without resizing
76 | """
77 | if isinstance(image_input, str):
78 | image = Image.open(image_input).convert("RGB")
79 | elif isinstance(image_input, Image.Image):
80 | image = image_input
81 | else:
82 | raise ValueError("image_input must be either a file path or a PIL Image object")
83 |
84 | input_width, input_height = image.size
85 | aspect_ratio_target = target_width / target_height
86 | aspect_ratio_frame = input_width / input_height
87 | if aspect_ratio_frame > aspect_ratio_target:
88 | new_width = int(input_height * aspect_ratio_target)
89 | new_height = input_height
90 | x_start = (input_width - new_width) // 2
91 | y_start = 0
92 | else:
93 | new_width = input_width
94 | new_height = int(input_width / aspect_ratio_target)
95 | x_start = 0
96 | y_start = (input_height - new_height) // 2
97 |
98 | image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
99 | if not just_crop:
100 | image = image.resize((target_width, target_height))
101 |
102 | image = np.array(image)
103 | image = cv2.GaussianBlur(image, (3, 3), 0)
104 | frame_tensor = torch.from_numpy(image).float()
105 | frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
106 | frame_tensor = frame_tensor.permute(2, 0, 1)
107 | frame_tensor = (frame_tensor / 127.5) - 1.0
108 | # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
109 | return frame_tensor.unsqueeze(0).unsqueeze(2)
110 |
111 |
112 | def calculate_padding(
113 | source_height: int, source_width: int, target_height: int, target_width: int
114 | ) -> tuple[int, int, int, int]:
115 |
116 | # Calculate total padding needed
117 | pad_height = target_height - source_height
118 | pad_width = target_width - source_width
119 |
120 | # Calculate padding for each side
121 | pad_top = pad_height // 2
122 | pad_bottom = pad_height - pad_top # Handles odd padding
123 | pad_left = pad_width // 2
124 | pad_right = pad_width - pad_left # Handles odd padding
125 |
126 | # Return padded tensor
127 | # Padding format is (left, right, top, bottom)
128 | padding = (pad_left, pad_right, pad_top, pad_bottom)
129 | return padding
130 |
131 |
132 | def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
133 | # Remove non-letters and convert to lowercase
134 | clean_text = "".join(
135 | char.lower() for char in text if char.isalpha() or char.isspace()
136 | )
137 |
138 | # Split into words
139 | words = clean_text.split()
140 |
141 | # Build result string keeping track of length
142 | result = []
143 | current_length = 0
144 |
145 | for word in words:
146 | # Add word length plus 1 for underscore (except for first word)
147 | new_length = current_length + len(word)
148 |
149 | if new_length <= max_len:
150 | result.append(word)
151 | current_length += len(word)
152 | else:
153 | break
154 |
155 | return "-".join(result)
156 |
157 |
158 | # Generate output video name
159 | def get_unique_filename(
160 | base: str,
161 | ext: str,
162 | prompt: str,
163 | seed: int,
164 | resolution: tuple[int, int, int],
165 | dir: Path,
166 | endswith=None,
167 | index_range=1000,
168 | ) -> Path:
169 | base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
170 | for i in range(index_range):
171 | filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
172 | if not os.path.exists(filename):
173 | return filename
174 | raise FileExistsError(
175 | f"Could not find a unique filename after {index_range} attempts."
176 | )
177 |
178 |
179 | def seed_everething(seed: int):
180 | random.seed(seed)
181 | np.random.seed(seed)
182 | torch.manual_seed(seed)
183 | if torch.cuda.is_available():
184 | torch.cuda.manual_seed(seed)
185 | if torch.backends.mps.is_available():
186 | torch.mps.manual_seed(seed)
187 |
188 |
189 | def main():
190 | parser = argparse.ArgumentParser(
191 | description="Load models from separate directories and run the pipeline."
192 | )
193 |
194 | # Directories
195 | parser.add_argument(
196 | "--output_path",
197 | type=str,
198 | default=None,
199 | help="Path to the folder to save output video, if None will save in outputs/ directory.",
200 | )
201 | parser.add_argument("--seed", type=int, default="171198")
202 |
203 | # Pipeline parameters
204 | parser.add_argument(
205 | "--num_images_per_prompt",
206 | type=int,
207 | default=1,
208 | help="Number of images per prompt",
209 | )
210 | parser.add_argument(
211 | "--image_cond_noise_scale",
212 | type=float,
213 | default=0.15,
214 | help="Amount of noise to add to the conditioned image",
215 | )
216 | parser.add_argument(
217 | "--height",
218 | type=int,
219 | default=704,
220 | help="Height of the output video frames. Optional if an input image provided.",
221 | )
222 | parser.add_argument(
223 | "--width",
224 | type=int,
225 | default=1216,
226 | help="Width of the output video frames. If None will infer from input image.",
227 | )
228 | parser.add_argument(
229 | "--num_frames",
230 | type=int,
231 | default=121,
232 | help="Number of frames to generate in the output video",
233 | )
234 | parser.add_argument(
235 | "--frame_rate", type=int, default=30, help="Frame rate for the output video"
236 | )
237 | parser.add_argument(
238 | "--device",
239 | default=None,
240 | help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
241 | )
242 | parser.add_argument(
243 | "--pipeline_config",
244 | type=str,
245 | default="configs/ltxv-13b-0.9.7-dev.yaml",
246 | help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
247 | )
248 |
249 | # Prompts
250 | parser.add_argument(
251 | "--prompt",
252 | type=str,
253 | help="Text prompt to guide generation",
254 | )
255 | parser.add_argument(
256 | "--negative_prompt",
257 | type=str,
258 | default="worst quality, inconsistent motion, blurry, jittery, distorted",
259 | help="Negative prompt for undesired features",
260 | )
261 |
262 | parser.add_argument(
263 | "--offload_to_cpu",
264 | action="store_true",
265 | help="Offloading unnecessary computations to CPU.",
266 | )
267 |
268 | # video-to-video arguments:
269 | parser.add_argument(
270 | "--input_media_path",
271 | type=str,
272 | default=None,
273 | help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
274 | )
275 |
276 | # Conditioning arguments
277 | parser.add_argument(
278 | "--conditioning_media_paths",
279 | type=str,
280 | nargs="*",
281 | help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
282 | )
283 | parser.add_argument(
284 | "--conditioning_strengths",
285 | type=float,
286 | nargs="*",
287 | help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
288 | )
289 | parser.add_argument(
290 | "--conditioning_start_frames",
291 | type=int,
292 | nargs="*",
293 | help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
294 | )
295 |
296 | args = parser.parse_args()
297 | logger.warning(f"Running generation with arguments: {args}")
298 | infer(**vars(args))
299 |
300 |
301 | def create_ltx_video_pipeline(
302 | ckpt_path: str,
303 | precision: str,
304 | text_encoder_model_name_or_path: str,
305 | sampler: Optional[str] = None,
306 | device: Optional[str] = None,
307 | enhance_prompt: bool = False,
308 | prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
309 | prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
310 | ) -> LTXVideoPipeline:
311 | ckpt_path = Path(ckpt_path)
312 | assert os.path.exists(
313 | ckpt_path
314 | ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
315 |
316 | with safe_open(ckpt_path, framework="pt") as f:
317 | metadata = f.metadata()
318 | config_str = metadata.get("config")
319 | configs = json.loads(config_str)
320 | allowed_inference_steps = configs.get("allowed_inference_steps", None)
321 |
322 | vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
323 | transformer = Transformer3DModel.from_pretrained(ckpt_path)
324 |
325 | # Use constructor if sampler is specified, otherwise use from_pretrained
326 | if sampler == "from_checkpoint" or not sampler:
327 | scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
328 | else:
329 | scheduler = RectifiedFlowScheduler(
330 | sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
331 | )
332 |
333 | text_encoder = T5EncoderModel.from_pretrained(
334 | text_encoder_model_name_or_path, subfolder="text_encoder"
335 | )
336 | patchifier = SymmetricPatchifier(patch_size=1)
337 | tokenizer = T5Tokenizer.from_pretrained(
338 | text_encoder_model_name_or_path, subfolder="tokenizer"
339 | )
340 |
341 | transformer = transformer.to(device)
342 | vae = vae.to(device)
343 | text_encoder = text_encoder.to(device)
344 |
345 | if enhance_prompt:
346 | prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
347 | prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
348 | )
349 | prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
350 | prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
351 | )
352 | prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
353 | prompt_enhancer_llm_model_name_or_path,
354 | torch_dtype="bfloat16",
355 | )
356 | prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
357 | prompt_enhancer_llm_model_name_or_path,
358 | )
359 | else:
360 | prompt_enhancer_image_caption_model = None
361 | prompt_enhancer_image_caption_processor = None
362 | prompt_enhancer_llm_model = None
363 | prompt_enhancer_llm_tokenizer = None
364 |
365 | vae = vae.to(torch.bfloat16)
366 | if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
367 | transformer = transformer.to(torch.bfloat16)
368 | text_encoder = text_encoder.to(torch.bfloat16)
369 |
370 | # Use submodels for the pipeline
371 | submodel_dict = {
372 | "transformer": transformer,
373 | "patchifier": patchifier,
374 | "text_encoder": text_encoder,
375 | "tokenizer": tokenizer,
376 | "scheduler": scheduler,
377 | "vae": vae,
378 | "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
379 | "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
380 | "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
381 | "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
382 | "allowed_inference_steps": allowed_inference_steps,
383 | }
384 |
385 | pipeline = LTXVideoPipeline(**submodel_dict)
386 | pipeline = pipeline.to(device)
387 | return pipeline
388 |
389 |
390 | def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
391 | latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
392 | latent_upsampler.to(device)
393 | latent_upsampler.eval()
394 | return latent_upsampler
395 |
396 |
397 | def infer(
398 | output_path: Optional[str],
399 | seed: int,
400 | pipeline_config: str,
401 | image_cond_noise_scale: float,
402 | height: Optional[int],
403 | width: Optional[int],
404 | num_frames: int,
405 | frame_rate: int,
406 | prompt: str,
407 | negative_prompt: str,
408 | offload_to_cpu: bool,
409 | input_media_path: Optional[str] = None,
410 | conditioning_media_paths: Optional[List[str]] = None,
411 | conditioning_strengths: Optional[List[float]] = None,
412 | conditioning_start_frames: Optional[List[int]] = None,
413 | device: Optional[str] = None,
414 | **kwargs,
415 | ):
416 | # check if pipeline_config is a file
417 | if not os.path.isfile(pipeline_config):
418 | raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
419 | with open(pipeline_config, "r") as f:
420 | pipeline_config = yaml.safe_load(f)
421 |
422 | models_dir = "MODEL_DIR"
423 |
424 | ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
425 | if not os.path.isfile(ltxv_model_name_or_path):
426 | ltxv_model_path = hf_hub_download(
427 | repo_id="Lightricks/LTX-Video",
428 | filename=ltxv_model_name_or_path,
429 | local_dir=models_dir,
430 | repo_type="model",
431 | )
432 | else:
433 | ltxv_model_path = ltxv_model_name_or_path
434 |
435 | spatial_upscaler_model_name_or_path = pipeline_config.get(
436 | "spatial_upscaler_model_path"
437 | )
438 | if spatial_upscaler_model_name_or_path and not os.path.isfile(
439 | spatial_upscaler_model_name_or_path
440 | ):
441 | spatial_upscaler_model_path = hf_hub_download(
442 | repo_id="Lightricks/LTX-Video",
443 | filename=spatial_upscaler_model_name_or_path,
444 | local_dir=models_dir,
445 | repo_type="model",
446 | )
447 | else:
448 | spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
449 |
450 | if kwargs.get("input_image_path", None):
451 | logger.warning(
452 | "Please use conditioning_media_paths instead of input_image_path."
453 | )
454 | assert not conditioning_media_paths and not conditioning_start_frames
455 | conditioning_media_paths = [kwargs["input_image_path"]]
456 | conditioning_start_frames = [0]
457 |
458 | # Validate conditioning arguments
459 | if conditioning_media_paths:
460 | # Use default strengths of 1.0
461 | if not conditioning_strengths:
462 | conditioning_strengths = [1.0] * len(conditioning_media_paths)
463 | if not conditioning_start_frames:
464 | raise ValueError(
465 | "If `conditioning_media_paths` is provided, "
466 | "`conditioning_start_frames` must also be provided"
467 | )
468 | if len(conditioning_media_paths) != len(conditioning_strengths) or len(
469 | conditioning_media_paths
470 | ) != len(conditioning_start_frames):
471 | raise ValueError(
472 | "`conditioning_media_paths`, `conditioning_strengths`, "
473 | "and `conditioning_start_frames` must have the same length"
474 | )
475 | if any(s < 0 or s > 1 for s in conditioning_strengths):
476 | raise ValueError("All conditioning strengths must be between 0 and 1")
477 | if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
478 | raise ValueError(
479 | f"All conditioning start frames must be between 0 and {num_frames-1}"
480 | )
481 |
482 | seed_everething(seed)
483 | if offload_to_cpu and not torch.cuda.is_available():
484 | logger.warning(
485 | "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
486 | )
487 | offload_to_cpu = False
488 | else:
489 | offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
490 |
491 | output_dir = (
492 | Path(output_path)
493 | if output_path
494 | else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
495 | )
496 | output_dir.mkdir(parents=True, exist_ok=True)
497 |
498 | # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
499 | height_padded = ((height - 1) // 32 + 1) * 32
500 | width_padded = ((width - 1) // 32 + 1) * 32
501 | num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
502 |
503 | padding = calculate_padding(height, width, height_padded, width_padded)
504 |
505 | logger.warning(
506 | f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
507 | )
508 |
509 | prompt_enhancement_words_threshold = pipeline_config[
510 | "prompt_enhancement_words_threshold"
511 | ]
512 |
513 | prompt_word_count = len(prompt.split())
514 | enhance_prompt = (
515 | prompt_enhancement_words_threshold > 0
516 | and prompt_word_count < prompt_enhancement_words_threshold
517 | )
518 |
519 | if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
520 | logger.info(
521 | f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
522 | )
523 |
524 | precision = pipeline_config["precision"]
525 | text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
526 | sampler = pipeline_config["sampler"]
527 | prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
528 | "prompt_enhancer_image_caption_model_name_or_path"
529 | ]
530 | prompt_enhancer_llm_model_name_or_path = pipeline_config[
531 | "prompt_enhancer_llm_model_name_or_path"
532 | ]
533 |
534 | pipeline = create_ltx_video_pipeline(
535 | ckpt_path=ltxv_model_path,
536 | precision=precision,
537 | text_encoder_model_name_or_path=text_encoder_model_name_or_path,
538 | sampler=sampler,
539 | device=kwargs.get("device", get_device()),
540 | enhance_prompt=enhance_prompt,
541 | prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
542 | prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
543 | )
544 |
545 | if pipeline_config.get("pipeline_type", None) == "multi-scale":
546 | if not spatial_upscaler_model_path:
547 | raise ValueError(
548 | "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
549 | )
550 | latent_upsampler = create_latent_upsampler(
551 | spatial_upscaler_model_path, pipeline.device
552 | )
553 | pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
554 |
555 | media_item = None
556 | if input_media_path:
557 | media_item = load_media_file(
558 | media_path=input_media_path,
559 | height=height,
560 | width=width,
561 | max_frames=num_frames_padded,
562 | padding=padding,
563 | )
564 |
565 | conditioning_items = (
566 | prepare_conditioning(
567 | conditioning_media_paths=conditioning_media_paths,
568 | conditioning_strengths=conditioning_strengths,
569 | conditioning_start_frames=conditioning_start_frames,
570 | height=height,
571 | width=width,
572 | num_frames=num_frames,
573 | padding=padding,
574 | pipeline=pipeline,
575 | )
576 | if conditioning_media_paths
577 | else None
578 | )
579 |
580 | stg_mode = pipeline_config.get("stg_mode", "attention_values")
581 | del pipeline_config["stg_mode"]
582 | if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
583 | skip_layer_strategy = SkipLayerStrategy.AttentionValues
584 | elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
585 | skip_layer_strategy = SkipLayerStrategy.AttentionSkip
586 | elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
587 | skip_layer_strategy = SkipLayerStrategy.Residual
588 | elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
589 | skip_layer_strategy = SkipLayerStrategy.TransformerBlock
590 | else:
591 | raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
592 |
593 | # Prepare input for the pipeline
594 | sample = {
595 | "prompt": prompt,
596 | "prompt_attention_mask": None,
597 | "negative_prompt": negative_prompt,
598 | "negative_prompt_attention_mask": None,
599 | }
600 |
601 | device = device or get_device()
602 | generator = torch.Generator(device=device).manual_seed(seed)
603 |
604 | images = pipeline(
605 | **pipeline_config,
606 | skip_layer_strategy=skip_layer_strategy,
607 | generator=generator,
608 | output_type="pt",
609 | callback_on_step_end=None,
610 | height=height_padded,
611 | width=width_padded,
612 | num_frames=num_frames_padded,
613 | frame_rate=frame_rate,
614 | **sample,
615 | media_items=media_item,
616 | conditioning_items=conditioning_items,
617 | is_video=True,
618 | vae_per_channel_normalize=True,
619 | image_cond_noise_scale=image_cond_noise_scale,
620 | mixed_precision=(precision == "mixed_precision"),
621 | offload_to_cpu=offload_to_cpu,
622 | device=device,
623 | enhance_prompt=enhance_prompt,
624 | ).images
625 |
626 | # Crop the padded images to the desired resolution and number of frames
627 | (pad_left, pad_right, pad_top, pad_bottom) = padding
628 | pad_bottom = -pad_bottom
629 | pad_right = -pad_right
630 | if pad_bottom == 0:
631 | pad_bottom = images.shape[3]
632 | if pad_right == 0:
633 | pad_right = images.shape[4]
634 | images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
635 |
636 | for i in range(images.shape[0]):
637 | # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
638 | video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
639 | # Unnormalizing images to [0, 255] range
640 | video_np = (video_np * 255).astype(np.uint8)
641 | fps = frame_rate
642 | height, width = video_np.shape[1:3]
643 | # In case a single image is generated
644 | if video_np.shape[0] == 1:
645 | output_filename = get_unique_filename(
646 | f"image_output_{i}",
647 | ".png",
648 | prompt=prompt,
649 | seed=seed,
650 | resolution=(height, width, num_frames),
651 | dir=output_dir,
652 | )
653 | imageio.imwrite(output_filename, video_np[0])
654 | else:
655 | output_filename = get_unique_filename(
656 | f"video_output_{i}",
657 | ".mp4",
658 | prompt=prompt,
659 | seed=seed,
660 | resolution=(height, width, num_frames),
661 | dir=output_dir,
662 | )
663 |
664 | # Write video
665 | with imageio.get_writer(output_filename, fps=fps) as video:
666 | for frame in video_np:
667 | video.append_data(frame)
668 |
669 | logger.warning(f"Output saved to {output_filename}")
670 |
671 |
672 | def prepare_conditioning(
673 | conditioning_media_paths: List[str],
674 | conditioning_strengths: List[float],
675 | conditioning_start_frames: List[int],
676 | height: int,
677 | width: int,
678 | num_frames: int,
679 | padding: tuple[int, int, int, int],
680 | pipeline: LTXVideoPipeline,
681 | ) -> Optional[List[ConditioningItem]]:
682 | """Prepare conditioning items based on input media paths and their parameters.
683 |
684 | Args:
685 | conditioning_media_paths: List of paths to conditioning media (images or videos)
686 | conditioning_strengths: List of conditioning strengths for each media item
687 | conditioning_start_frames: List of frame indices where each item should be applied
688 | height: Height of the output frames
689 | width: Width of the output frames
690 | num_frames: Number of frames in the output video
691 | padding: Padding to apply to the frames
692 | pipeline: LTXVideoPipeline object used for condition video trimming
693 |
694 | Returns:
695 | A list of ConditioningItem objects.
696 | """
697 | conditioning_items = []
698 | for path, strength, start_frame in zip(
699 | conditioning_media_paths, conditioning_strengths, conditioning_start_frames
700 | ):
701 | num_input_frames = orig_num_input_frames = get_media_num_frames(path)
702 | if hasattr(pipeline, "trim_conditioning_sequence") and callable(
703 | getattr(pipeline, "trim_conditioning_sequence")
704 | ):
705 | num_input_frames = pipeline.trim_conditioning_sequence(
706 | start_frame, orig_num_input_frames, num_frames
707 | )
708 | if num_input_frames < orig_num_input_frames:
709 | logger.warning(
710 | f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
711 | )
712 |
713 | media_tensor = load_media_file(
714 | media_path=path,
715 | height=height,
716 | width=width,
717 | max_frames=num_input_frames,
718 | padding=padding,
719 | just_crop=True,
720 | )
721 | conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
722 | return conditioning_items
723 |
724 |
725 | def get_media_num_frames(media_path: str) -> int:
726 | is_video = any(
727 | media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
728 | )
729 | num_frames = 1
730 | if is_video:
731 | reader = imageio.get_reader(media_path)
732 | num_frames = reader.count_frames()
733 | reader.close()
734 | return num_frames
735 |
736 |
737 | def load_media_file(
738 | media_path: str,
739 | height: int,
740 | width: int,
741 | max_frames: int,
742 | padding: tuple[int, int, int, int],
743 | just_crop: bool = False,
744 | ) -> torch.Tensor:
745 | is_video = any(
746 | media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
747 | )
748 | if is_video:
749 | reader = imageio.get_reader(media_path)
750 | num_input_frames = min(reader.count_frames(), max_frames)
751 |
752 | # Read and preprocess the relevant frames from the video file.
753 | frames = []
754 | for i in range(num_input_frames):
755 | frame = Image.fromarray(reader.get_data(i))
756 | frame_tensor = load_image_to_tensor_with_resize_and_crop(
757 | frame, height, width, just_crop=just_crop
758 | )
759 | frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
760 | frames.append(frame_tensor)
761 | reader.close()
762 |
763 | # Stack frames along the temporal dimension
764 | media_tensor = torch.cat(frames, dim=2)
765 | else: # Input image
766 | media_tensor = load_image_to_tensor_with_resize_and_crop(
767 | media_path, height, width, just_crop=just_crop
768 | )
769 | media_tensor = torch.nn.functional.pad(media_tensor, padding)
770 | return media_tensor
771 |
772 |
773 | if __name__ == "__main__":
774 | main()
775 |
--------------------------------------------------------------------------------
/ltx_video/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/__init__.py
--------------------------------------------------------------------------------
/ltx_video/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/__init__.py
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/autoencoders/__init__.py
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/causal_conv3d.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class CausalConv3d(nn.Module):
8 | def __init__(
9 | self,
10 | in_channels,
11 | out_channels,
12 | kernel_size: int = 3,
13 | stride: Union[int, Tuple[int]] = 1,
14 | dilation: int = 1,
15 | groups: int = 1,
16 | spatial_padding_mode: str = "zeros",
17 | **kwargs,
18 | ):
19 | super().__init__()
20 |
21 | self.in_channels = in_channels
22 | self.out_channels = out_channels
23 |
24 | kernel_size = (kernel_size, kernel_size, kernel_size)
25 | self.time_kernel_size = kernel_size[0]
26 |
27 | dilation = (dilation, 1, 1)
28 |
29 | height_pad = kernel_size[1] // 2
30 | width_pad = kernel_size[2] // 2
31 | padding = (0, height_pad, width_pad)
32 |
33 | self.conv = nn.Conv3d(
34 | in_channels,
35 | out_channels,
36 | kernel_size,
37 | stride=stride,
38 | dilation=dilation,
39 | padding=padding,
40 | padding_mode=spatial_padding_mode,
41 | groups=groups,
42 | )
43 |
44 | def forward(self, x, causal: bool = True):
45 | if causal:
46 | first_frame_pad = x[:, :, :1, :, :].repeat(
47 | (1, 1, self.time_kernel_size - 1, 1, 1)
48 | )
49 | x = torch.concatenate((first_frame_pad, x), dim=2)
50 | else:
51 | first_frame_pad = x[:, :, :1, :, :].repeat(
52 | (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
53 | )
54 | last_frame_pad = x[:, :, -1:, :, :].repeat(
55 | (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
56 | )
57 | x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
58 | x = self.conv(x)
59 | return x
60 |
61 | @property
62 | def weight(self):
63 | return self.conv.weight
64 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/conv_nd_factory.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import torch
4 |
5 | from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
6 | from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
7 |
8 |
9 | def make_conv_nd(
10 | dims: Union[int, Tuple[int, int]],
11 | in_channels: int,
12 | out_channels: int,
13 | kernel_size: int,
14 | stride=1,
15 | padding=0,
16 | dilation=1,
17 | groups=1,
18 | bias=True,
19 | causal=False,
20 | spatial_padding_mode="zeros",
21 | temporal_padding_mode="zeros",
22 | ):
23 | if not (spatial_padding_mode == temporal_padding_mode or causal):
24 | raise NotImplementedError("spatial and temporal padding modes must be equal")
25 | if dims == 2:
26 | return torch.nn.Conv2d(
27 | in_channels=in_channels,
28 | out_channels=out_channels,
29 | kernel_size=kernel_size,
30 | stride=stride,
31 | padding=padding,
32 | dilation=dilation,
33 | groups=groups,
34 | bias=bias,
35 | padding_mode=spatial_padding_mode,
36 | )
37 | elif dims == 3:
38 | if causal:
39 | return CausalConv3d(
40 | in_channels=in_channels,
41 | out_channels=out_channels,
42 | kernel_size=kernel_size,
43 | stride=stride,
44 | padding=padding,
45 | dilation=dilation,
46 | groups=groups,
47 | bias=bias,
48 | spatial_padding_mode=spatial_padding_mode,
49 | )
50 | return torch.nn.Conv3d(
51 | in_channels=in_channels,
52 | out_channels=out_channels,
53 | kernel_size=kernel_size,
54 | stride=stride,
55 | padding=padding,
56 | dilation=dilation,
57 | groups=groups,
58 | bias=bias,
59 | padding_mode=spatial_padding_mode,
60 | )
61 | elif dims == (2, 1):
62 | return DualConv3d(
63 | in_channels=in_channels,
64 | out_channels=out_channels,
65 | kernel_size=kernel_size,
66 | stride=stride,
67 | padding=padding,
68 | bias=bias,
69 | padding_mode=spatial_padding_mode,
70 | )
71 | else:
72 | raise ValueError(f"unsupported dimensions: {dims}")
73 |
74 |
75 | def make_linear_nd(
76 | dims: int,
77 | in_channels: int,
78 | out_channels: int,
79 | bias=True,
80 | ):
81 | if dims == 2:
82 | return torch.nn.Conv2d(
83 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
84 | )
85 | elif dims == 3 or dims == (2, 1):
86 | return torch.nn.Conv3d(
87 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
88 | )
89 | else:
90 | raise ValueError(f"unsupported dimensions: {dims}")
91 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/dual_conv3d.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 |
9 |
10 | class DualConv3d(nn.Module):
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size,
16 | stride: Union[int, Tuple[int, int, int]] = 1,
17 | padding: Union[int, Tuple[int, int, int]] = 0,
18 | dilation: Union[int, Tuple[int, int, int]] = 1,
19 | groups=1,
20 | bias=True,
21 | padding_mode="zeros",
22 | ):
23 | super(DualConv3d, self).__init__()
24 |
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 | self.padding_mode = padding_mode
28 | # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
29 | if isinstance(kernel_size, int):
30 | kernel_size = (kernel_size, kernel_size, kernel_size)
31 | if kernel_size == (1, 1, 1):
32 | raise ValueError(
33 | "kernel_size must be greater than 1. Use make_linear_nd instead."
34 | )
35 | if isinstance(stride, int):
36 | stride = (stride, stride, stride)
37 | if isinstance(padding, int):
38 | padding = (padding, padding, padding)
39 | if isinstance(dilation, int):
40 | dilation = (dilation, dilation, dilation)
41 |
42 | # Set parameters for convolutions
43 | self.groups = groups
44 | self.bias = bias
45 |
46 | # Define the size of the channels after the first convolution
47 | intermediate_channels = (
48 | out_channels if in_channels < out_channels else in_channels
49 | )
50 |
51 | # Define parameters for the first convolution
52 | self.weight1 = nn.Parameter(
53 | torch.Tensor(
54 | intermediate_channels,
55 | in_channels // groups,
56 | 1,
57 | kernel_size[1],
58 | kernel_size[2],
59 | )
60 | )
61 | self.stride1 = (1, stride[1], stride[2])
62 | self.padding1 = (0, padding[1], padding[2])
63 | self.dilation1 = (1, dilation[1], dilation[2])
64 | if bias:
65 | self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
66 | else:
67 | self.register_parameter("bias1", None)
68 |
69 | # Define parameters for the second convolution
70 | self.weight2 = nn.Parameter(
71 | torch.Tensor(
72 | out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
73 | )
74 | )
75 | self.stride2 = (stride[0], 1, 1)
76 | self.padding2 = (padding[0], 0, 0)
77 | self.dilation2 = (dilation[0], 1, 1)
78 | if bias:
79 | self.bias2 = nn.Parameter(torch.Tensor(out_channels))
80 | else:
81 | self.register_parameter("bias2", None)
82 |
83 | # Initialize weights and biases
84 | self.reset_parameters()
85 |
86 | def reset_parameters(self):
87 | nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
88 | nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
89 | if self.bias:
90 | fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
91 | bound1 = 1 / math.sqrt(fan_in1)
92 | nn.init.uniform_(self.bias1, -bound1, bound1)
93 | fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
94 | bound2 = 1 / math.sqrt(fan_in2)
95 | nn.init.uniform_(self.bias2, -bound2, bound2)
96 |
97 | def forward(self, x, use_conv3d=False, skip_time_conv=False):
98 | if use_conv3d:
99 | return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
100 | else:
101 | return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
102 |
103 | def forward_with_3d(self, x, skip_time_conv):
104 | # First convolution
105 | x = F.conv3d(
106 | x,
107 | self.weight1,
108 | self.bias1,
109 | self.stride1,
110 | self.padding1,
111 | self.dilation1,
112 | self.groups,
113 | padding_mode=self.padding_mode,
114 | )
115 |
116 | if skip_time_conv:
117 | return x
118 |
119 | # Second convolution
120 | x = F.conv3d(
121 | x,
122 | self.weight2,
123 | self.bias2,
124 | self.stride2,
125 | self.padding2,
126 | self.dilation2,
127 | self.groups,
128 | padding_mode=self.padding_mode,
129 | )
130 |
131 | return x
132 |
133 | def forward_with_2d(self, x, skip_time_conv):
134 | b, c, d, h, w = x.shape
135 |
136 | # First 2D convolution
137 | x = rearrange(x, "b c d h w -> (b d) c h w")
138 | # Squeeze the depth dimension out of weight1 since it's 1
139 | weight1 = self.weight1.squeeze(2)
140 | # Select stride, padding, and dilation for the 2D convolution
141 | stride1 = (self.stride1[1], self.stride1[2])
142 | padding1 = (self.padding1[1], self.padding1[2])
143 | dilation1 = (self.dilation1[1], self.dilation1[2])
144 | x = F.conv2d(
145 | x,
146 | weight1,
147 | self.bias1,
148 | stride1,
149 | padding1,
150 | dilation1,
151 | self.groups,
152 | padding_mode=self.padding_mode,
153 | )
154 |
155 | _, _, h, w = x.shape
156 |
157 | if skip_time_conv:
158 | x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
159 | return x
160 |
161 | # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
162 | x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
163 |
164 | # Reshape weight2 to match the expected dimensions for conv1d
165 | weight2 = self.weight2.squeeze(-1).squeeze(-1)
166 | # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
167 | stride2 = self.stride2[0]
168 | padding2 = self.padding2[0]
169 | dilation2 = self.dilation2[0]
170 | x = F.conv1d(
171 | x,
172 | weight2,
173 | self.bias2,
174 | stride2,
175 | padding2,
176 | dilation2,
177 | self.groups,
178 | padding_mode=self.padding_mode,
179 | )
180 | x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
181 |
182 | return x
183 |
184 | @property
185 | def weight(self):
186 | return self.weight2
187 |
188 |
189 | def test_dual_conv3d_consistency():
190 | # Initialize parameters
191 | in_channels = 3
192 | out_channels = 5
193 | kernel_size = (3, 3, 3)
194 | stride = (2, 2, 2)
195 | padding = (1, 1, 1)
196 |
197 | # Create an instance of the DualConv3d class
198 | dual_conv3d = DualConv3d(
199 | in_channels=in_channels,
200 | out_channels=out_channels,
201 | kernel_size=kernel_size,
202 | stride=stride,
203 | padding=padding,
204 | bias=True,
205 | )
206 |
207 | # Example input tensor
208 | test_input = torch.randn(1, 3, 10, 10, 10)
209 |
210 | # Perform forward passes with both 3D and 2D settings
211 | output_conv3d = dual_conv3d(test_input, use_conv3d=True)
212 | output_2d = dual_conv3d(test_input, use_conv3d=False)
213 |
214 | # Assert that the outputs from both methods are sufficiently close
215 | assert torch.allclose(
216 | output_conv3d, output_2d, atol=1e-6
217 | ), "Outputs are not consistent between 3D and 2D convolutions."
218 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/latent_upsampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | from pathlib import Path
3 | import os
4 | import json
5 |
6 | import torch
7 | import torch.nn as nn
8 | from einops import rearrange
9 | from diffusers import ConfigMixin, ModelMixin
10 | from safetensors.torch import safe_open
11 |
12 | from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
13 |
14 |
15 | class ResBlock(nn.Module):
16 | def __init__(
17 | self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
18 | ):
19 | super().__init__()
20 | if mid_channels is None:
21 | mid_channels = channels
22 |
23 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d
24 |
25 | self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
26 | self.norm1 = nn.GroupNorm(32, mid_channels)
27 | self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
28 | self.norm2 = nn.GroupNorm(32, channels)
29 | self.activation = nn.SiLU()
30 |
31 | def forward(self, x: torch.Tensor) -> torch.Tensor:
32 | residual = x
33 | x = self.conv1(x)
34 | x = self.norm1(x)
35 | x = self.activation(x)
36 | x = self.conv2(x)
37 | x = self.norm2(x)
38 | x = self.activation(x + residual)
39 | return x
40 |
41 |
42 | class LatentUpsampler(ModelMixin, ConfigMixin):
43 | """
44 | Model to spatially upsample VAE latents.
45 |
46 | Args:
47 | in_channels (`int`): Number of channels in the input latent
48 | mid_channels (`int`): Number of channels in the middle layers
49 | num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
50 | dims (`int`): Number of dimensions for convolutions (2 or 3)
51 | spatial_upsample (`bool`): Whether to spatially upsample the latent
52 | temporal_upsample (`bool`): Whether to temporally upsample the latent
53 | """
54 |
55 | def __init__(
56 | self,
57 | in_channels: int = 128,
58 | mid_channels: int = 512,
59 | num_blocks_per_stage: int = 4,
60 | dims: int = 3,
61 | spatial_upsample: bool = True,
62 | temporal_upsample: bool = False,
63 | ):
64 | super().__init__()
65 |
66 | self.in_channels = in_channels
67 | self.mid_channels = mid_channels
68 | self.num_blocks_per_stage = num_blocks_per_stage
69 | self.dims = dims
70 | self.spatial_upsample = spatial_upsample
71 | self.temporal_upsample = temporal_upsample
72 |
73 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d
74 |
75 | self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
76 | self.initial_norm = nn.GroupNorm(32, mid_channels)
77 | self.initial_activation = nn.SiLU()
78 |
79 | self.res_blocks = nn.ModuleList(
80 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
81 | )
82 |
83 | if spatial_upsample and temporal_upsample:
84 | self.upsampler = nn.Sequential(
85 | nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
86 | PixelShuffleND(3),
87 | )
88 | elif spatial_upsample:
89 | self.upsampler = nn.Sequential(
90 | nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
91 | PixelShuffleND(2),
92 | )
93 | elif temporal_upsample:
94 | self.upsampler = nn.Sequential(
95 | nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
96 | PixelShuffleND(1),
97 | )
98 | else:
99 | raise ValueError(
100 | "Either spatial_upsample or temporal_upsample must be True"
101 | )
102 |
103 | self.post_upsample_res_blocks = nn.ModuleList(
104 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
105 | )
106 |
107 | self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
108 |
109 | def forward(self, latent: torch.Tensor) -> torch.Tensor:
110 | b, c, f, h, w = latent.shape
111 |
112 | if self.dims == 2:
113 | x = rearrange(latent, "b c f h w -> (b f) c h w")
114 | x = self.initial_conv(x)
115 | x = self.initial_norm(x)
116 | x = self.initial_activation(x)
117 |
118 | for block in self.res_blocks:
119 | x = block(x)
120 |
121 | x = self.upsampler(x)
122 |
123 | for block in self.post_upsample_res_blocks:
124 | x = block(x)
125 |
126 | x = self.final_conv(x)
127 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
128 | else:
129 | x = self.initial_conv(latent)
130 | x = self.initial_norm(x)
131 | x = self.initial_activation(x)
132 |
133 | for block in self.res_blocks:
134 | x = block(x)
135 |
136 | if self.temporal_upsample:
137 | x = self.upsampler(x)
138 | x = x[:, :, 1:, :, :]
139 | else:
140 | x = rearrange(x, "b c f h w -> (b f) c h w")
141 | x = self.upsampler(x)
142 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
143 |
144 | for block in self.post_upsample_res_blocks:
145 | x = block(x)
146 |
147 | x = self.final_conv(x)
148 |
149 | return x
150 |
151 | @classmethod
152 | def from_config(cls, config):
153 | return cls(
154 | in_channels=config.get("in_channels", 4),
155 | mid_channels=config.get("mid_channels", 128),
156 | num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
157 | dims=config.get("dims", 2),
158 | spatial_upsample=config.get("spatial_upsample", True),
159 | temporal_upsample=config.get("temporal_upsample", False),
160 | )
161 |
162 | def config(self):
163 | return {
164 | "_class_name": "LatentUpsampler",
165 | "in_channels": self.in_channels,
166 | "mid_channels": self.mid_channels,
167 | "num_blocks_per_stage": self.num_blocks_per_stage,
168 | "dims": self.dims,
169 | "spatial_upsample": self.spatial_upsample,
170 | "temporal_upsample": self.temporal_upsample,
171 | }
172 |
173 | @classmethod
174 | def from_pretrained(
175 | cls,
176 | pretrained_model_path: Optional[Union[str, os.PathLike]],
177 | *args,
178 | **kwargs,
179 | ):
180 | pretrained_model_path = Path(pretrained_model_path)
181 | if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
182 | ".safetensors"
183 | ):
184 | state_dict = {}
185 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
186 | metadata = f.metadata()
187 | for k in f.keys():
188 | state_dict[k] = f.get_tensor(k)
189 | config = json.loads(metadata["config"])
190 | with torch.device("meta"):
191 | latent_upsampler = LatentUpsampler.from_config(config)
192 | latent_upsampler.load_state_dict(state_dict, assign=True)
193 | return latent_upsampler
194 |
195 |
196 | if __name__ == "__main__":
197 | latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
198 | print(latent_upsampler)
199 | total_params = sum(p.numel() for p in latent_upsampler.parameters())
200 | print(f"Total number of parameters: {total_params:,}")
201 | latent = torch.randn(1, 128, 9, 16, 16)
202 | upsampled_latent = latent_upsampler(latent)
203 | print(f"Upsampled latent shape: {upsampled_latent.shape}")
204 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/pixel_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class PixelNorm(nn.Module):
6 | def __init__(self, dim=1, eps=1e-8):
7 | super(PixelNorm, self).__init__()
8 | self.dim = dim
9 | self.eps = eps
10 |
11 | def forward(self, x):
12 | return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
13 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/pixel_shuffle.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from einops import rearrange
3 |
4 |
5 | class PixelShuffleND(nn.Module):
6 | def __init__(self, dims, upscale_factors=(2, 2, 2)):
7 | super().__init__()
8 | assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
9 | self.dims = dims
10 | self.upscale_factors = upscale_factors
11 |
12 | def forward(self, x):
13 | if self.dims == 3:
14 | return rearrange(
15 | x,
16 | "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
17 | p1=self.upscale_factors[0],
18 | p2=self.upscale_factors[1],
19 | p3=self.upscale_factors[2],
20 | )
21 | elif self.dims == 2:
22 | return rearrange(
23 | x,
24 | "b (c p1 p2) h w -> b c (h p1) (w p2)",
25 | p1=self.upscale_factors[0],
26 | p2=self.upscale_factors[1],
27 | )
28 | elif self.dims == 1:
29 | return rearrange(
30 | x,
31 | "b (c p1) f h w -> b c (f p1) h w",
32 | p1=self.upscale_factors[0],
33 | )
34 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/vae.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | import inspect
5 | import math
6 | import torch.nn as nn
7 | from diffusers import ConfigMixin, ModelMixin
8 | from diffusers.models.autoencoders.vae import (
9 | DecoderOutput,
10 | DiagonalGaussianDistribution,
11 | )
12 | from diffusers.models.modeling_outputs import AutoencoderKLOutput
13 | from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
14 |
15 |
16 | class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
17 | """Variational Autoencoder (VAE) model with KL loss.
18 |
19 | VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
20 | This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
21 |
22 | Args:
23 | encoder (`nn.Module`):
24 | Encoder module.
25 | decoder (`nn.Module`):
26 | Decoder module.
27 | latent_channels (`int`, *optional*, defaults to 4):
28 | Number of latent channels.
29 | """
30 |
31 | def __init__(
32 | self,
33 | encoder: nn.Module,
34 | decoder: nn.Module,
35 | latent_channels: int = 4,
36 | dims: int = 2,
37 | sample_size=512,
38 | use_quant_conv: bool = True,
39 | normalize_latent_channels: bool = False,
40 | ):
41 | super().__init__()
42 |
43 | # pass init params to Encoder
44 | self.encoder = encoder
45 | self.use_quant_conv = use_quant_conv
46 | self.normalize_latent_channels = normalize_latent_channels
47 |
48 | # pass init params to Decoder
49 | quant_dims = 2 if dims == 2 else 3
50 | self.decoder = decoder
51 | if use_quant_conv:
52 | self.quant_conv = make_conv_nd(
53 | quant_dims, 2 * latent_channels, 2 * latent_channels, 1
54 | )
55 | self.post_quant_conv = make_conv_nd(
56 | quant_dims, latent_channels, latent_channels, 1
57 | )
58 | else:
59 | self.quant_conv = nn.Identity()
60 | self.post_quant_conv = nn.Identity()
61 |
62 | if normalize_latent_channels:
63 | if dims == 2:
64 | self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
65 | else:
66 | self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
67 | else:
68 | self.latent_norm_out = nn.Identity()
69 | self.use_z_tiling = False
70 | self.use_hw_tiling = False
71 | self.dims = dims
72 | self.z_sample_size = 1
73 |
74 | self.decoder_params = inspect.signature(self.decoder.forward).parameters
75 |
76 | # only relevant if vae tiling is enabled
77 | self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
78 |
79 | def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
80 | self.tile_sample_min_size = sample_size
81 | num_blocks = len(self.encoder.down_blocks)
82 | self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
83 | self.tile_overlap_factor = overlap_factor
84 |
85 | def enable_z_tiling(self, z_sample_size: int = 8):
86 | r"""
87 | Enable tiling during VAE decoding.
88 |
89 | When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
90 | steps. This is useful to save some memory and allow larger batch sizes.
91 | """
92 | self.use_z_tiling = z_sample_size > 1
93 | self.z_sample_size = z_sample_size
94 | assert (
95 | z_sample_size % 8 == 0 or z_sample_size == 1
96 | ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
97 |
98 | def disable_z_tiling(self):
99 | r"""
100 | Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
101 | decoding in one step.
102 | """
103 | self.use_z_tiling = False
104 |
105 | def enable_hw_tiling(self):
106 | r"""
107 | Enable tiling during VAE decoding along the height and width dimension.
108 | """
109 | self.use_hw_tiling = True
110 |
111 | def disable_hw_tiling(self):
112 | r"""
113 | Disable tiling during VAE decoding along the height and width dimension.
114 | """
115 | self.use_hw_tiling = False
116 |
117 | def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
118 | overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
119 | blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
120 | row_limit = self.tile_latent_min_size - blend_extent
121 |
122 | # Split the image into 512x512 tiles and encode them separately.
123 | rows = []
124 | for i in range(0, x.shape[3], overlap_size):
125 | row = []
126 | for j in range(0, x.shape[4], overlap_size):
127 | tile = x[
128 | :,
129 | :,
130 | :,
131 | i : i + self.tile_sample_min_size,
132 | j : j + self.tile_sample_min_size,
133 | ]
134 | tile = self.encoder(tile)
135 | tile = self.quant_conv(tile)
136 | row.append(tile)
137 | rows.append(row)
138 | result_rows = []
139 | for i, row in enumerate(rows):
140 | result_row = []
141 | for j, tile in enumerate(row):
142 | # blend the above tile and the left tile
143 | # to the current tile and add the current tile to the result row
144 | if i > 0:
145 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
146 | if j > 0:
147 | tile = self.blend_h(row[j - 1], tile, blend_extent)
148 | result_row.append(tile[:, :, :, :row_limit, :row_limit])
149 | result_rows.append(torch.cat(result_row, dim=4))
150 |
151 | moments = torch.cat(result_rows, dim=3)
152 | return moments
153 |
154 | def blend_z(
155 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
156 | ) -> torch.Tensor:
157 | blend_extent = min(a.shape[2], b.shape[2], blend_extent)
158 | for z in range(blend_extent):
159 | b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
160 | 1 - z / blend_extent
161 | ) + b[:, :, z, :, :] * (z / blend_extent)
162 | return b
163 |
164 | def blend_v(
165 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
166 | ) -> torch.Tensor:
167 | blend_extent = min(a.shape[3], b.shape[3], blend_extent)
168 | for y in range(blend_extent):
169 | b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
170 | 1 - y / blend_extent
171 | ) + b[:, :, :, y, :] * (y / blend_extent)
172 | return b
173 |
174 | def blend_h(
175 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
176 | ) -> torch.Tensor:
177 | blend_extent = min(a.shape[4], b.shape[4], blend_extent)
178 | for x in range(blend_extent):
179 | b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
180 | 1 - x / blend_extent
181 | ) + b[:, :, :, :, x] * (x / blend_extent)
182 | return b
183 |
184 | def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
185 | overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
186 | blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
187 | row_limit = self.tile_sample_min_size - blend_extent
188 | tile_target_shape = (
189 | *target_shape[:3],
190 | self.tile_sample_min_size,
191 | self.tile_sample_min_size,
192 | )
193 | # Split z into overlapping 64x64 tiles and decode them separately.
194 | # The tiles have an overlap to avoid seams between tiles.
195 | rows = []
196 | for i in range(0, z.shape[3], overlap_size):
197 | row = []
198 | for j in range(0, z.shape[4], overlap_size):
199 | tile = z[
200 | :,
201 | :,
202 | :,
203 | i : i + self.tile_latent_min_size,
204 | j : j + self.tile_latent_min_size,
205 | ]
206 | tile = self.post_quant_conv(tile)
207 | decoded = self.decoder(tile, target_shape=tile_target_shape)
208 | row.append(decoded)
209 | rows.append(row)
210 | result_rows = []
211 | for i, row in enumerate(rows):
212 | result_row = []
213 | for j, tile in enumerate(row):
214 | # blend the above tile and the left tile
215 | # to the current tile and add the current tile to the result row
216 | if i > 0:
217 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
218 | if j > 0:
219 | tile = self.blend_h(row[j - 1], tile, blend_extent)
220 | result_row.append(tile[:, :, :, :row_limit, :row_limit])
221 | result_rows.append(torch.cat(result_row, dim=4))
222 |
223 | dec = torch.cat(result_rows, dim=3)
224 | return dec
225 |
226 | def encode(
227 | self, z: torch.FloatTensor, return_dict: bool = True
228 | ) -> Union[DecoderOutput, torch.FloatTensor]:
229 | if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
230 | num_splits = z.shape[2] // self.z_sample_size
231 | sizes = [self.z_sample_size] * num_splits
232 | sizes = (
233 | sizes + [z.shape[2] - sum(sizes)]
234 | if z.shape[2] - sum(sizes) > 0
235 | else sizes
236 | )
237 | tiles = z.split(sizes, dim=2)
238 | moments_tiles = [
239 | (
240 | self._hw_tiled_encode(z_tile, return_dict)
241 | if self.use_hw_tiling
242 | else self._encode(z_tile)
243 | )
244 | for z_tile in tiles
245 | ]
246 | moments = torch.cat(moments_tiles, dim=2)
247 |
248 | else:
249 | moments = (
250 | self._hw_tiled_encode(z, return_dict)
251 | if self.use_hw_tiling
252 | else self._encode(z)
253 | )
254 |
255 | posterior = DiagonalGaussianDistribution(moments)
256 | if not return_dict:
257 | return (posterior,)
258 |
259 | return AutoencoderKLOutput(latent_dist=posterior)
260 |
261 | def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
262 | if isinstance(self.latent_norm_out, nn.BatchNorm3d):
263 | _, c, _, _, _ = z.shape
264 | z = torch.cat(
265 | [
266 | self.latent_norm_out(z[:, : c // 2, :, :, :]),
267 | z[:, c // 2 :, :, :, :],
268 | ],
269 | dim=1,
270 | )
271 | elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
272 | raise NotImplementedError("BatchNorm2d not supported")
273 | return z
274 |
275 | def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
276 | if isinstance(self.latent_norm_out, nn.BatchNorm3d):
277 | running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
278 | running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
279 | eps = self.latent_norm_out.eps
280 |
281 | z = z * torch.sqrt(running_var + eps) + running_mean
282 | elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
283 | raise NotImplementedError("BatchNorm2d not supported")
284 | return z
285 |
286 | def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
287 | h = self.encoder(x)
288 | moments = self.quant_conv(h)
289 | moments = self._normalize_latent_channels(moments)
290 | return moments
291 |
292 | def _decode(
293 | self,
294 | z: torch.FloatTensor,
295 | target_shape=None,
296 | timestep: Optional[torch.Tensor] = None,
297 | ) -> Union[DecoderOutput, torch.FloatTensor]:
298 | z = self._unnormalize_latent_channels(z)
299 | z = self.post_quant_conv(z)
300 | if "timestep" in self.decoder_params:
301 | dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
302 | else:
303 | dec = self.decoder(z, target_shape=target_shape)
304 | return dec
305 |
306 | def decode(
307 | self,
308 | z: torch.FloatTensor,
309 | return_dict: bool = True,
310 | target_shape=None,
311 | timestep: Optional[torch.Tensor] = None,
312 | ) -> Union[DecoderOutput, torch.FloatTensor]:
313 | assert target_shape is not None, "target_shape must be provided for decoding"
314 | if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
315 | reduction_factor = int(
316 | self.encoder.patch_size_t
317 | * 2
318 | ** (
319 | len(self.encoder.down_blocks)
320 | - 1
321 | - math.sqrt(self.encoder.patch_size)
322 | )
323 | )
324 | split_size = self.z_sample_size // reduction_factor
325 | num_splits = z.shape[2] // split_size
326 |
327 | # copy target shape, and divide frame dimension (=2) by the context size
328 | target_shape_split = list(target_shape)
329 | target_shape_split[2] = target_shape[2] // num_splits
330 |
331 | decoded_tiles = [
332 | (
333 | self._hw_tiled_decode(z_tile, target_shape_split)
334 | if self.use_hw_tiling
335 | else self._decode(z_tile, target_shape=target_shape_split)
336 | )
337 | for z_tile in torch.tensor_split(z, num_splits, dim=2)
338 | ]
339 | decoded = torch.cat(decoded_tiles, dim=2)
340 | else:
341 | decoded = (
342 | self._hw_tiled_decode(z, target_shape)
343 | if self.use_hw_tiling
344 | else self._decode(z, target_shape=target_shape, timestep=timestep)
345 | )
346 |
347 | if not return_dict:
348 | return (decoded,)
349 |
350 | return DecoderOutput(sample=decoded)
351 |
352 | def forward(
353 | self,
354 | sample: torch.FloatTensor,
355 | sample_posterior: bool = False,
356 | return_dict: bool = True,
357 | generator: Optional[torch.Generator] = None,
358 | ) -> Union[DecoderOutput, torch.FloatTensor]:
359 | r"""
360 | Args:
361 | sample (`torch.FloatTensor`): Input sample.
362 | sample_posterior (`bool`, *optional*, defaults to `False`):
363 | Whether to sample from the posterior.
364 | return_dict (`bool`, *optional*, defaults to `True`):
365 | Whether to return a [`DecoderOutput`] instead of a plain tuple.
366 | generator (`torch.Generator`, *optional*):
367 | Generator used to sample from the posterior.
368 | """
369 | x = sample
370 | posterior = self.encode(x).latent_dist
371 | if sample_posterior:
372 | z = posterior.sample(generator=generator)
373 | else:
374 | z = posterior.mode()
375 | dec = self.decode(z, target_shape=sample.shape).sample
376 |
377 | if not return_dict:
378 | return (dec,)
379 |
380 | return DecoderOutput(sample=dec)
381 |
--------------------------------------------------------------------------------
/ltx_video/models/autoencoders/vae_encode.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from diffusers import AutoencoderKL
4 | from einops import rearrange
5 | from torch import Tensor
6 |
7 |
8 | from ltx_video.models.autoencoders.causal_video_autoencoder import (
9 | CausalVideoAutoencoder,
10 | )
11 | from ltx_video.models.autoencoders.video_autoencoder import (
12 | Downsample3D,
13 | VideoAutoencoder,
14 | )
15 |
16 | try:
17 | import torch_xla.core.xla_model as xm
18 | except ImportError:
19 | xm = None
20 |
21 |
22 | def vae_encode(
23 | media_items: Tensor,
24 | vae: AutoencoderKL,
25 | split_size: int = 1,
26 | vae_per_channel_normalize=False,
27 | ) -> Tensor:
28 | """
29 | Encodes media items (images or videos) into latent representations using a specified VAE model.
30 | The function supports processing batches of images or video frames and can handle the processing
31 | in smaller sub-batches if needed.
32 |
33 | Args:
34 | media_items (Tensor): A torch Tensor containing the media items to encode. The expected
35 | shape is (batch_size, channels, height, width) for images or (batch_size, channels,
36 | frames, height, width) for videos.
37 | vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
38 | pre-configured and loaded with the appropriate model weights.
39 | split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
40 | If set to more than 1, the input media items are processed in smaller batches according to
41 | this value. Defaults to 1, which processes all items in a single batch.
42 |
43 | Returns:
44 | Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
45 | to match the input shape, scaled by the model's configuration.
46 |
47 | Examples:
48 | >>> import torch
49 | >>> from diffusers import AutoencoderKL
50 | >>> vae = AutoencoderKL.from_pretrained('your-model-name')
51 | >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
52 | >>> latents = vae_encode(images, vae)
53 | >>> print(latents.shape) # Output shape will depend on the model's latent configuration.
54 |
55 | Note:
56 | In case of a video, the function encodes the media item frame-by frame.
57 | """
58 | is_video_shaped = media_items.dim() == 5
59 | batch_size, channels = media_items.shape[0:2]
60 |
61 | if channels != 3:
62 | raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
63 |
64 | if is_video_shaped and not isinstance(
65 | vae, (VideoAutoencoder, CausalVideoAutoencoder)
66 | ):
67 | media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
68 | if split_size > 1:
69 | if len(media_items) % split_size != 0:
70 | raise ValueError(
71 | "Error: The batch size must be divisible by 'train.vae_bs_split"
72 | )
73 | encode_bs = len(media_items) // split_size
74 | # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
75 | latents = []
76 | if media_items.device.type == "xla":
77 | xm.mark_step()
78 | for image_batch in media_items.split(encode_bs):
79 | latents.append(vae.encode(image_batch).latent_dist.sample())
80 | if media_items.device.type == "xla":
81 | xm.mark_step()
82 | latents = torch.cat(latents, dim=0)
83 | else:
84 | latents = vae.encode(media_items).latent_dist.sample()
85 |
86 | latents = normalize_latents(latents, vae, vae_per_channel_normalize)
87 | if is_video_shaped and not isinstance(
88 | vae, (VideoAutoencoder, CausalVideoAutoencoder)
89 | ):
90 | latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
91 | return latents
92 |
93 |
94 | def vae_decode(
95 | latents: Tensor,
96 | vae: AutoencoderKL,
97 | is_video: bool = True,
98 | split_size: int = 1,
99 | vae_per_channel_normalize=False,
100 | timestep=None,
101 | ) -> Tensor:
102 | is_video_shaped = latents.dim() == 5
103 | batch_size = latents.shape[0]
104 |
105 | if is_video_shaped and not isinstance(
106 | vae, (VideoAutoencoder, CausalVideoAutoencoder)
107 | ):
108 | latents = rearrange(latents, "b c n h w -> (b n) c h w")
109 | if split_size > 1:
110 | if len(latents) % split_size != 0:
111 | raise ValueError(
112 | "Error: The batch size must be divisible by 'train.vae_bs_split"
113 | )
114 | encode_bs = len(latents) // split_size
115 | image_batch = [
116 | _run_decoder(
117 | latent_batch, vae, is_video, vae_per_channel_normalize, timestep
118 | )
119 | for latent_batch in latents.split(encode_bs)
120 | ]
121 | images = torch.cat(image_batch, dim=0)
122 | else:
123 | images = _run_decoder(
124 | latents, vae, is_video, vae_per_channel_normalize, timestep
125 | )
126 |
127 | if is_video_shaped and not isinstance(
128 | vae, (VideoAutoencoder, CausalVideoAutoencoder)
129 | ):
130 | images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
131 | return images
132 |
133 |
134 | def _run_decoder(
135 | latents: Tensor,
136 | vae: AutoencoderKL,
137 | is_video: bool,
138 | vae_per_channel_normalize=False,
139 | timestep=None,
140 | ) -> Tensor:
141 | if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
142 | *_, fl, hl, wl = latents.shape
143 | temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
144 | latents = latents.to(vae.dtype)
145 | vae_decode_kwargs = {}
146 | if timestep is not None:
147 | vae_decode_kwargs["timestep"] = timestep
148 | image = vae.decode(
149 | un_normalize_latents(latents, vae, vae_per_channel_normalize),
150 | return_dict=False,
151 | target_shape=(
152 | 1,
153 | 3,
154 | fl * temporal_scale if is_video else 1,
155 | hl * spatial_scale,
156 | wl * spatial_scale,
157 | ),
158 | **vae_decode_kwargs,
159 | )[0]
160 | else:
161 | image = vae.decode(
162 | un_normalize_latents(latents, vae, vae_per_channel_normalize),
163 | return_dict=False,
164 | )[0]
165 | return image
166 |
167 |
168 | def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
169 | if isinstance(vae, CausalVideoAutoencoder):
170 | spatial = vae.spatial_downscale_factor
171 | temporal = vae.temporal_downscale_factor
172 | else:
173 | down_blocks = len(
174 | [
175 | block
176 | for block in vae.encoder.down_blocks
177 | if isinstance(block.downsample, Downsample3D)
178 | ]
179 | )
180 | spatial = vae.config.patch_size * 2**down_blocks
181 | temporal = (
182 | vae.config.patch_size_t * 2**down_blocks
183 | if isinstance(vae, VideoAutoencoder)
184 | else 1
185 | )
186 |
187 | return (temporal, spatial, spatial)
188 |
189 |
190 | def latent_to_pixel_coords(
191 | latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
192 | ) -> Tensor:
193 | """
194 | Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
195 | configuration.
196 |
197 | Args:
198 | latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
199 | containing the latent corner coordinates of each token.
200 | vae (AutoencoderKL): The VAE model
201 | causal_fix (bool): Whether to take into account the different temporal scale
202 | of the first frame. Default = False for backwards compatibility.
203 | Returns:
204 | Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
205 | """
206 |
207 | scale_factors = get_vae_size_scale_factor(vae)
208 | causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
209 | pixel_coords = latent_to_pixel_coords_from_factors(
210 | latent_coords, scale_factors, causal_fix
211 | )
212 | return pixel_coords
213 |
214 |
215 | def latent_to_pixel_coords_from_factors(
216 | latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
217 | ) -> Tensor:
218 | pixel_coords = (
219 | latent_coords
220 | * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
221 | )
222 | if causal_fix:
223 | # Fix temporal scale for first frame to 1 due to causality
224 | pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
225 | return pixel_coords
226 |
227 |
228 | def normalize_latents(
229 | latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
230 | ) -> Tensor:
231 | return (
232 | (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
233 | / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
234 | if vae_per_channel_normalize
235 | else latents * vae.config.scaling_factor
236 | )
237 |
238 |
239 | def un_normalize_latents(
240 | latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
241 | ) -> Tensor:
242 | return (
243 | latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
244 | + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
245 | if vae_per_channel_normalize
246 | else latents / vae.config.scaling_factor
247 | )
248 |
--------------------------------------------------------------------------------
/ltx_video/models/transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/transformers/__init__.py
--------------------------------------------------------------------------------
/ltx_video/models/transformers/embeddings.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | from einops import rearrange
7 | from torch import nn
8 |
9 |
10 | def get_timestep_embedding(
11 | timesteps: torch.Tensor,
12 | embedding_dim: int,
13 | flip_sin_to_cos: bool = False,
14 | downscale_freq_shift: float = 1,
15 | scale: float = 1,
16 | max_period: int = 10000,
17 | ):
18 | """
19 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20 |
21 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
22 | These may be fractional.
23 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
25 | """
26 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27 |
28 | half_dim = embedding_dim // 2
29 | exponent = -math.log(max_period) * torch.arange(
30 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31 | )
32 | exponent = exponent / (half_dim - downscale_freq_shift)
33 |
34 | emb = torch.exp(exponent)
35 | emb = timesteps[:, None].float() * emb[None, :]
36 |
37 | # scale embeddings
38 | emb = scale * emb
39 |
40 | # concat sine and cosine embeddings
41 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
42 |
43 | # flip sine and cosine embeddings
44 | if flip_sin_to_cos:
45 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
46 |
47 | # zero pad
48 | if embedding_dim % 2 == 1:
49 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50 | return emb
51 |
52 |
53 | def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
54 | """
55 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
56 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
57 | """
58 | grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
59 | grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
60 | grid = grid.reshape([3, 1, w, h, f])
61 | pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
62 | pos_embed = pos_embed.transpose(1, 0, 2, 3)
63 | return rearrange(pos_embed, "h w f c -> (f h w) c")
64 |
65 |
66 | def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
67 | if embed_dim % 3 != 0:
68 | raise ValueError("embed_dim must be divisible by 3")
69 |
70 | # use half of dimensions to encode grid_h
71 | emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
72 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
73 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
74 |
75 | emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
76 | return emb
77 |
78 |
79 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
80 | """
81 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
82 | """
83 | if embed_dim % 2 != 0:
84 | raise ValueError("embed_dim must be divisible by 2")
85 |
86 | omega = np.arange(embed_dim // 2, dtype=np.float64)
87 | omega /= embed_dim / 2.0
88 | omega = 1.0 / 10000**omega # (D/2,)
89 |
90 | pos_shape = pos.shape
91 |
92 | pos = pos.reshape(-1)
93 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
94 | out = out.reshape([*pos_shape, -1])[0]
95 |
96 | emb_sin = np.sin(out) # (M, D/2)
97 | emb_cos = np.cos(out) # (M, D/2)
98 |
99 | emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
100 | return emb
101 |
102 |
103 | class SinusoidalPositionalEmbedding(nn.Module):
104 | """Apply positional information to a sequence of embeddings.
105 |
106 | Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
107 | them
108 |
109 | Args:
110 | embed_dim: (int): Dimension of the positional embedding.
111 | max_seq_length: Maximum sequence length to apply positional embeddings
112 |
113 | """
114 |
115 | def __init__(self, embed_dim: int, max_seq_length: int = 32):
116 | super().__init__()
117 | position = torch.arange(max_seq_length).unsqueeze(1)
118 | div_term = torch.exp(
119 | torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120 | )
121 | pe = torch.zeros(1, max_seq_length, embed_dim)
122 | pe[0, :, 0::2] = torch.sin(position * div_term)
123 | pe[0, :, 1::2] = torch.cos(position * div_term)
124 | self.register_buffer("pe", pe)
125 |
126 | def forward(self, x):
127 | _, seq_length, _ = x.shape
128 | x = x + self.pe[:, :seq_length]
129 | return x
130 |
--------------------------------------------------------------------------------
/ltx_video/models/transformers/symmetric_patchifier.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple
3 |
4 | import torch
5 | from diffusers.configuration_utils import ConfigMixin
6 | from einops import rearrange
7 | from torch import Tensor
8 |
9 |
10 | class Patchifier(ConfigMixin, ABC):
11 | def __init__(self, patch_size: int):
12 | super().__init__()
13 | self._patch_size = (1, patch_size, patch_size)
14 |
15 | @abstractmethod
16 | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
17 | raise NotImplementedError("Patchify method not implemented")
18 |
19 | @abstractmethod
20 | def unpatchify(
21 | self,
22 | latents: Tensor,
23 | output_height: int,
24 | output_width: int,
25 | out_channels: int,
26 | ) -> Tuple[Tensor, Tensor]:
27 | pass
28 |
29 | @property
30 | def patch_size(self):
31 | return self._patch_size
32 |
33 | def get_latent_coords(
34 | self, latent_num_frames, latent_height, latent_width, batch_size, device
35 | ):
36 | """
37 | Return a tensor of shape [batch_size, 3, num_patches] containing the
38 | top-left corner latent coordinates of each latent patch.
39 | The tensor is repeated for each batch element.
40 | """
41 | latent_sample_coords = torch.meshgrid(
42 | torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
43 | torch.arange(0, latent_height, self._patch_size[1], device=device),
44 | torch.arange(0, latent_width, self._patch_size[2], device=device),
45 | )
46 | latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
47 | latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
48 | latent_coords = rearrange(
49 | latent_coords, "b c f h w -> b c (f h w)", b=batch_size
50 | )
51 | return latent_coords
52 |
53 |
54 | class SymmetricPatchifier(Patchifier):
55 | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
56 | b, _, f, h, w = latents.shape
57 | latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
58 | latents = rearrange(
59 | latents,
60 | "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
61 | p1=self._patch_size[0],
62 | p2=self._patch_size[1],
63 | p3=self._patch_size[2],
64 | )
65 | return latents, latent_coords
66 |
67 | def unpatchify(
68 | self,
69 | latents: Tensor,
70 | output_height: int,
71 | output_width: int,
72 | out_channels: int,
73 | ) -> Tuple[Tensor, Tensor]:
74 | output_height = output_height // self._patch_size[1]
75 | output_width = output_width // self._patch_size[2]
76 | latents = rearrange(
77 | latents,
78 | "b (f h w) (c p q) -> b c f (h p) (w q)",
79 | h=output_height,
80 | w=output_width,
81 | p=self._patch_size[1],
82 | q=self._patch_size[2],
83 | )
84 | return latents
85 |
--------------------------------------------------------------------------------
/ltx_video/models/transformers/transformer3d.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2 | import math
3 | from dataclasses import dataclass
4 | from typing import Any, Dict, List, Optional, Union
5 | import os
6 | import json
7 | import glob
8 | from pathlib import Path
9 |
10 | import torch
11 | from diffusers.configuration_utils import ConfigMixin, register_to_config
12 | from diffusers.models.embeddings import PixArtAlphaTextProjection
13 | from diffusers.models.modeling_utils import ModelMixin
14 | from diffusers.models.normalization import AdaLayerNormSingle
15 | from diffusers.utils import BaseOutput, is_torch_version
16 | from diffusers.utils import logging
17 | from torch import nn
18 | from safetensors import safe_open
19 |
20 |
21 | from ltx_video.models.transformers.attention import BasicTransformerBlock
22 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
23 |
24 | from ltx_video.utils.diffusers_config_mapping import (
25 | diffusers_and_ours_config_mapping,
26 | make_hashable_key,
27 | TRANSFORMER_KEYS_RENAME_DICT,
28 | )
29 |
30 |
31 | logger = logging.get_logger(__name__)
32 |
33 |
34 | @dataclass
35 | class Transformer3DModelOutput(BaseOutput):
36 | """
37 | The output of [`Transformer2DModel`].
38 |
39 | Args:
40 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
41 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
42 | distributions for the unnoised latent pixels.
43 | """
44 |
45 | sample: torch.FloatTensor
46 |
47 |
48 | class Transformer3DModel(ModelMixin, ConfigMixin):
49 | _supports_gradient_checkpointing = True
50 |
51 | @register_to_config
52 | def __init__(
53 | self,
54 | num_attention_heads: int = 16,
55 | attention_head_dim: int = 88,
56 | in_channels: Optional[int] = None,
57 | out_channels: Optional[int] = None,
58 | num_layers: int = 1,
59 | dropout: float = 0.0,
60 | norm_num_groups: int = 32,
61 | cross_attention_dim: Optional[int] = None,
62 | attention_bias: bool = False,
63 | num_vector_embeds: Optional[int] = None,
64 | activation_fn: str = "geglu",
65 | num_embeds_ada_norm: Optional[int] = None,
66 | use_linear_projection: bool = False,
67 | only_cross_attention: bool = False,
68 | double_self_attention: bool = False,
69 | upcast_attention: bool = False,
70 | adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
71 | standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
72 | norm_elementwise_affine: bool = True,
73 | norm_eps: float = 1e-5,
74 | attention_type: str = "default",
75 | caption_channels: int = None,
76 | use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
77 | qk_norm: Optional[str] = None,
78 | positional_embedding_type: str = "rope",
79 | positional_embedding_theta: Optional[float] = None,
80 | positional_embedding_max_pos: Optional[List[int]] = None,
81 | timestep_scale_multiplier: Optional[float] = None,
82 | causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
83 | ):
84 | super().__init__()
85 | self.use_tpu_flash_attention = (
86 | use_tpu_flash_attention # FIXME: push config down to the attention modules
87 | )
88 | self.use_linear_projection = use_linear_projection
89 | self.num_attention_heads = num_attention_heads
90 | self.attention_head_dim = attention_head_dim
91 | inner_dim = num_attention_heads * attention_head_dim
92 | self.inner_dim = inner_dim
93 | self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
94 | self.positional_embedding_type = positional_embedding_type
95 | self.positional_embedding_theta = positional_embedding_theta
96 | self.positional_embedding_max_pos = positional_embedding_max_pos
97 | self.use_rope = self.positional_embedding_type == "rope"
98 | self.timestep_scale_multiplier = timestep_scale_multiplier
99 |
100 | if self.positional_embedding_type == "absolute":
101 | raise ValueError("Absolute positional embedding is no longer supported")
102 | elif self.positional_embedding_type == "rope":
103 | if positional_embedding_theta is None:
104 | raise ValueError(
105 | "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
106 | )
107 | if positional_embedding_max_pos is None:
108 | raise ValueError(
109 | "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
110 | )
111 |
112 | # 3. Define transformers blocks
113 | self.transformer_blocks = nn.ModuleList(
114 | [
115 | BasicTransformerBlock(
116 | inner_dim,
117 | num_attention_heads,
118 | attention_head_dim,
119 | dropout=dropout,
120 | cross_attention_dim=cross_attention_dim,
121 | activation_fn=activation_fn,
122 | num_embeds_ada_norm=num_embeds_ada_norm,
123 | attention_bias=attention_bias,
124 | only_cross_attention=only_cross_attention,
125 | double_self_attention=double_self_attention,
126 | upcast_attention=upcast_attention,
127 | adaptive_norm=adaptive_norm,
128 | standardization_norm=standardization_norm,
129 | norm_elementwise_affine=norm_elementwise_affine,
130 | norm_eps=norm_eps,
131 | attention_type=attention_type,
132 | use_tpu_flash_attention=use_tpu_flash_attention,
133 | qk_norm=qk_norm,
134 | use_rope=self.use_rope,
135 | )
136 | for d in range(num_layers)
137 | ]
138 | )
139 |
140 | # 4. Define output layers
141 | self.out_channels = in_channels if out_channels is None else out_channels
142 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
143 | self.scale_shift_table = nn.Parameter(
144 | torch.randn(2, inner_dim) / inner_dim**0.5
145 | )
146 | self.proj_out = nn.Linear(inner_dim, self.out_channels)
147 |
148 | self.adaln_single = AdaLayerNormSingle(
149 | inner_dim, use_additional_conditions=False
150 | )
151 | if adaptive_norm == "single_scale":
152 | self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
153 |
154 | self.caption_projection = None
155 | if caption_channels is not None:
156 | self.caption_projection = PixArtAlphaTextProjection(
157 | in_features=caption_channels, hidden_size=inner_dim
158 | )
159 |
160 | self.gradient_checkpointing = False
161 |
162 | def set_use_tpu_flash_attention(self):
163 | r"""
164 | Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
165 | attention kernel.
166 | """
167 | logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
168 | self.use_tpu_flash_attention = True
169 | # push config down to the attention modules
170 | for block in self.transformer_blocks:
171 | block.set_use_tpu_flash_attention()
172 |
173 | def create_skip_layer_mask(
174 | self,
175 | batch_size: int,
176 | num_conds: int,
177 | ptb_index: int,
178 | skip_block_list: Optional[List[int]] = None,
179 | ):
180 | if skip_block_list is None or len(skip_block_list) == 0:
181 | return None
182 | num_layers = len(self.transformer_blocks)
183 | mask = torch.ones(
184 | (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
185 | )
186 | for block_idx in skip_block_list:
187 | mask[block_idx, ptb_index::num_conds] = 0
188 | return mask
189 |
190 | def _set_gradient_checkpointing(self, module, value=False):
191 | if hasattr(module, "gradient_checkpointing"):
192 | module.gradient_checkpointing = value
193 |
194 | def get_fractional_positions(self, indices_grid):
195 | fractional_positions = torch.stack(
196 | [
197 | indices_grid[:, i] / self.positional_embedding_max_pos[i]
198 | for i in range(3)
199 | ],
200 | dim=-1,
201 | )
202 | return fractional_positions
203 |
204 | def precompute_freqs_cis(self, indices_grid, spacing="exp"):
205 | dtype = torch.float32 # We need full precision in the freqs_cis computation.
206 | dim = self.inner_dim
207 | theta = self.positional_embedding_theta
208 |
209 | fractional_positions = self.get_fractional_positions(indices_grid)
210 |
211 | start = 1
212 | end = theta
213 | device = fractional_positions.device
214 | if spacing == "exp":
215 | indices = theta ** (
216 | torch.linspace(
217 | math.log(start, theta),
218 | math.log(end, theta),
219 | dim // 6,
220 | device=device,
221 | dtype=dtype,
222 | )
223 | )
224 | indices = indices.to(dtype=dtype)
225 | elif spacing == "exp_2":
226 | indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
227 | indices = indices.to(dtype=dtype)
228 | elif spacing == "linear":
229 | indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
230 | elif spacing == "sqrt":
231 | indices = torch.linspace(
232 | start**2, end**2, dim // 6, device=device, dtype=dtype
233 | ).sqrt()
234 |
235 | indices = indices * math.pi / 2
236 |
237 | if spacing == "exp_2":
238 | freqs = (
239 | (indices * fractional_positions.unsqueeze(-1))
240 | .transpose(-1, -2)
241 | .flatten(2)
242 | )
243 | else:
244 | freqs = (
245 | (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
246 | .transpose(-1, -2)
247 | .flatten(2)
248 | )
249 |
250 | cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
251 | sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
252 | if dim % 6 != 0:
253 | cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
254 | sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
255 | cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
256 | sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
257 | return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
258 |
259 | def load_state_dict(
260 | self,
261 | state_dict: Dict,
262 | *args,
263 | **kwargs,
264 | ):
265 | if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
266 | state_dict = {
267 | key.replace("model.diffusion_model.", ""): value
268 | for key, value in state_dict.items()
269 | if key.startswith("model.diffusion_model.")
270 | }
271 | super().load_state_dict(state_dict, **kwargs)
272 |
273 | @classmethod
274 | def from_pretrained(
275 | cls,
276 | pretrained_model_path: Optional[Union[str, os.PathLike]],
277 | *args,
278 | **kwargs,
279 | ):
280 | pretrained_model_path = Path(pretrained_model_path)
281 | if pretrained_model_path.is_dir():
282 | config_path = pretrained_model_path / "transformer" / "config.json"
283 | with open(config_path, "r") as f:
284 | config = make_hashable_key(json.load(f))
285 |
286 | assert config in diffusers_and_ours_config_mapping, (
287 | "Provided diffusers checkpoint config for transformer is not suppported. "
288 | "We only support diffusers configs found in Lightricks/LTX-Video."
289 | )
290 |
291 | config = diffusers_and_ours_config_mapping[config]
292 | state_dict = {}
293 | ckpt_paths = (
294 | pretrained_model_path
295 | / "transformer"
296 | / "diffusion_pytorch_model*.safetensors"
297 | )
298 | dict_list = glob.glob(str(ckpt_paths))
299 | for dict_path in dict_list:
300 | part_dict = {}
301 | with safe_open(dict_path, framework="pt", device="cpu") as f:
302 | for k in f.keys():
303 | part_dict[k] = f.get_tensor(k)
304 | state_dict.update(part_dict)
305 |
306 | for key in list(state_dict.keys()):
307 | new_key = key
308 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
309 | new_key = new_key.replace(replace_key, rename_key)
310 | state_dict[new_key] = state_dict.pop(key)
311 |
312 | with torch.device("meta"):
313 | transformer = cls.from_config(config)
314 | transformer.load_state_dict(state_dict, assign=True, strict=True)
315 | elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
316 | ".safetensors"
317 | ):
318 | comfy_single_file_state_dict = {}
319 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
320 | metadata = f.metadata()
321 | for k in f.keys():
322 | comfy_single_file_state_dict[k] = f.get_tensor(k)
323 | configs = json.loads(metadata["config"])
324 | transformer_config = configs["transformer"]
325 | with torch.device("meta"):
326 | transformer = Transformer3DModel.from_config(transformer_config)
327 | transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
328 | return transformer
329 |
330 | def forward(
331 | self,
332 | hidden_states: torch.Tensor,
333 | indices_grid: torch.Tensor,
334 | encoder_hidden_states: Optional[torch.Tensor] = None,
335 | timestep: Optional[torch.LongTensor] = None,
336 | class_labels: Optional[torch.LongTensor] = None,
337 | cross_attention_kwargs: Dict[str, Any] = None,
338 | attention_mask: Optional[torch.Tensor] = None,
339 | encoder_attention_mask: Optional[torch.Tensor] = None,
340 | skip_layer_mask: Optional[torch.Tensor] = None,
341 | skip_layer_strategy: Optional[SkipLayerStrategy] = None,
342 | return_dict: bool = True,
343 | ):
344 | """
345 | The [`Transformer2DModel`] forward method.
346 |
347 | Args:
348 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349 | Input `hidden_states`.
350 | indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
351 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
352 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
353 | self-attention.
354 | timestep ( `torch.LongTensor`, *optional*):
355 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
356 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
357 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
358 | `AdaLayerZeroNorm`.
359 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
360 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
361 | `self.processor` in
362 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
363 | attention_mask ( `torch.Tensor`, *optional*):
364 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
365 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
366 | negative values to the attention scores corresponding to "discard" tokens.
367 | encoder_attention_mask ( `torch.Tensor`, *optional*):
368 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
369 |
370 | * Mask `(batch, sequence_length)` True = keep, False = discard.
371 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
372 |
373 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
374 | above. This bias will be added to the cross-attention scores.
375 | skip_layer_mask ( `torch.Tensor`, *optional*):
376 | A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
377 | `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
378 | skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
379 | Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
380 | return_dict (`bool`, *optional*, defaults to `True`):
381 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
382 | tuple.
383 |
384 | Returns:
385 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
386 | `tuple` where the first element is the sample tensor.
387 | """
388 | # for tpu attention offload 2d token masks are used. No need to transform.
389 | if not self.use_tpu_flash_attention:
390 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
391 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
392 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
393 | # expects mask of shape:
394 | # [batch, key_tokens]
395 | # adds singleton query_tokens dimension:
396 | # [batch, 1, key_tokens]
397 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
398 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
399 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
400 | if attention_mask is not None and attention_mask.ndim == 2:
401 | # assume that mask is expressed as:
402 | # (1 = keep, 0 = discard)
403 | # convert mask into a bias that can be added to attention scores:
404 | # (keep = +0, discard = -10000.0)
405 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
406 | attention_mask = attention_mask.unsqueeze(1)
407 |
408 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
409 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
410 | encoder_attention_mask = (
411 | 1 - encoder_attention_mask.to(hidden_states.dtype)
412 | ) * -10000.0
413 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
414 |
415 | # 1. Input
416 | hidden_states = self.patchify_proj(hidden_states)
417 |
418 | if self.timestep_scale_multiplier:
419 | timestep = self.timestep_scale_multiplier * timestep
420 |
421 | freqs_cis = self.precompute_freqs_cis(indices_grid)
422 |
423 | batch_size = hidden_states.shape[0]
424 | timestep, embedded_timestep = self.adaln_single(
425 | timestep.flatten(),
426 | {"resolution": None, "aspect_ratio": None},
427 | batch_size=batch_size,
428 | hidden_dtype=hidden_states.dtype,
429 | )
430 | # Second dimension is 1 or number of tokens (if timestep_per_token)
431 | timestep = timestep.view(batch_size, -1, timestep.shape[-1])
432 | embedded_timestep = embedded_timestep.view(
433 | batch_size, -1, embedded_timestep.shape[-1]
434 | )
435 |
436 | # 2. Blocks
437 | if self.caption_projection is not None:
438 | batch_size = hidden_states.shape[0]
439 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
440 | encoder_hidden_states = encoder_hidden_states.view(
441 | batch_size, -1, hidden_states.shape[-1]
442 | )
443 |
444 | for block_idx, block in enumerate(self.transformer_blocks):
445 | if self.training and self.gradient_checkpointing:
446 |
447 | def create_custom_forward(module, return_dict=None):
448 | def custom_forward(*inputs):
449 | if return_dict is not None:
450 | return module(*inputs, return_dict=return_dict)
451 | else:
452 | return module(*inputs)
453 |
454 | return custom_forward
455 |
456 | ckpt_kwargs: Dict[str, Any] = (
457 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
458 | )
459 | hidden_states = torch.utils.checkpoint.checkpoint(
460 | create_custom_forward(block),
461 | hidden_states,
462 | freqs_cis,
463 | attention_mask,
464 | encoder_hidden_states,
465 | encoder_attention_mask,
466 | timestep,
467 | cross_attention_kwargs,
468 | class_labels,
469 | (
470 | skip_layer_mask[block_idx]
471 | if skip_layer_mask is not None
472 | else None
473 | ),
474 | skip_layer_strategy,
475 | **ckpt_kwargs,
476 | )
477 | else:
478 | hidden_states = block(
479 | hidden_states,
480 | freqs_cis=freqs_cis,
481 | attention_mask=attention_mask,
482 | encoder_hidden_states=encoder_hidden_states,
483 | encoder_attention_mask=encoder_attention_mask,
484 | timestep=timestep,
485 | cross_attention_kwargs=cross_attention_kwargs,
486 | class_labels=class_labels,
487 | skip_layer_mask=(
488 | skip_layer_mask[block_idx]
489 | if skip_layer_mask is not None
490 | else None
491 | ),
492 | skip_layer_strategy=skip_layer_strategy,
493 | )
494 |
495 | # 3. Output
496 | scale_shift_values = (
497 | self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
498 | )
499 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
500 | hidden_states = self.norm_out(hidden_states)
501 | # Modulation
502 | hidden_states = hidden_states * (1 + scale) + shift
503 | hidden_states = self.proj_out(hidden_states)
504 | if not return_dict:
505 | return (hidden_states,)
506 |
507 | return Transformer3DModelOutput(sample=hidden_states)
508 |
--------------------------------------------------------------------------------
/ltx_video/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/pipelines/__init__.py
--------------------------------------------------------------------------------
/ltx_video/pipelines/crf_compressor.py:
--------------------------------------------------------------------------------
1 | import av
2 | import torch
3 | import io
4 | import numpy as np
5 |
6 |
7 | def _encode_single_frame(output_file, image_array: np.ndarray, crf):
8 | container = av.open(output_file, "w", format="mp4")
9 | try:
10 | stream = container.add_stream(
11 | "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
12 | )
13 | stream.height = image_array.shape[0]
14 | stream.width = image_array.shape[1]
15 | av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
16 | format="yuv420p"
17 | )
18 | container.mux(stream.encode(av_frame))
19 | container.mux(stream.encode())
20 | finally:
21 | container.close()
22 |
23 |
24 | def _decode_single_frame(video_file):
25 | container = av.open(video_file)
26 | try:
27 | stream = next(s for s in container.streams if s.type == "video")
28 | frame = next(container.decode(stream))
29 | finally:
30 | container.close()
31 | return frame.to_ndarray(format="rgb24")
32 |
33 |
34 | def compress(image: torch.Tensor, crf=29):
35 | if crf == 0:
36 | return image
37 |
38 | image_array = (
39 | (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
40 | .byte()
41 | .cpu()
42 | .numpy()
43 | )
44 | with io.BytesIO() as output_file:
45 | _encode_single_frame(output_file, image_array, crf)
46 | video_bytes = output_file.getvalue()
47 | with io.BytesIO(video_bytes) as video_file:
48 | image_array = _decode_single_frame(video_file)
49 | tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
50 | return tensor
51 |
--------------------------------------------------------------------------------
/ltx_video/schedulers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/schedulers/__init__.py
--------------------------------------------------------------------------------
/ltx_video/schedulers/rf.py:
--------------------------------------------------------------------------------
1 | import math
2 | from abc import ABC, abstractmethod
3 | from dataclasses import dataclass
4 | from typing import Callable, Optional, Tuple, Union
5 | import json
6 | import os
7 | from pathlib import Path
8 |
9 | import torch
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config
11 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
12 | from diffusers.utils import BaseOutput
13 | from torch import Tensor
14 | from safetensors import safe_open
15 |
16 |
17 | from ltx_video.utils.torch_utils import append_dims
18 |
19 | from ltx_video.utils.diffusers_config_mapping import (
20 | diffusers_and_ours_config_mapping,
21 | make_hashable_key,
22 | )
23 |
24 |
25 | def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
26 | if num_steps == 1:
27 | return torch.tensor([1.0])
28 | if linear_steps is None:
29 | linear_steps = num_steps // 2
30 | linear_sigma_schedule = [
31 | i * threshold_noise / linear_steps for i in range(linear_steps)
32 | ]
33 | threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
34 | quadratic_steps = num_steps - linear_steps
35 | quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
36 | linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (
37 | quadratic_steps**2
38 | )
39 | const = quadratic_coef * (linear_steps**2)
40 | quadratic_sigma_schedule = [
41 | quadratic_coef * (i**2) + linear_coef * i + const
42 | for i in range(linear_steps, num_steps)
43 | ]
44 | sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
45 | sigma_schedule = [1.0 - x for x in sigma_schedule]
46 | return torch.tensor(sigma_schedule[:-1])
47 |
48 |
49 | def simple_diffusion_resolution_dependent_timestep_shift(
50 | samples_shape: torch.Size,
51 | timesteps: Tensor,
52 | n: int = 32 * 32,
53 | ) -> Tensor:
54 | if len(samples_shape) == 3:
55 | _, m, _ = samples_shape
56 | elif len(samples_shape) in [4, 5]:
57 | m = math.prod(samples_shape[2:])
58 | else:
59 | raise ValueError(
60 | "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
61 | )
62 | snr = (timesteps / (1 - timesteps)) ** 2
63 | shift_snr = torch.log(snr) + 2 * math.log(m / n)
64 | shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
65 |
66 | return shifted_timesteps
67 |
68 |
69 | def time_shift(mu: float, sigma: float, t: Tensor):
70 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
71 |
72 |
73 | def get_normal_shift(
74 | n_tokens: int,
75 | min_tokens: int = 1024,
76 | max_tokens: int = 4096,
77 | min_shift: float = 0.95,
78 | max_shift: float = 2.05,
79 | ) -> Callable[[float], float]:
80 | m = (max_shift - min_shift) / (max_tokens - min_tokens)
81 | b = min_shift - m * min_tokens
82 | return m * n_tokens + b
83 |
84 |
85 | def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1):
86 | """
87 | Stretch a function (given as sampled shifts) so that its final value matches the given terminal value
88 | using the provided formula.
89 |
90 | Parameters:
91 | - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor).
92 | - terminal (float): The desired terminal value (value at the last sample).
93 |
94 | Returns:
95 | - Tensor: The stretched shifts such that the final value equals `terminal`.
96 | """
97 | if shifts.numel() == 0:
98 | raise ValueError("The 'shifts' tensor must not be empty.")
99 |
100 | # Ensure terminal value is valid
101 | if terminal <= 0 or terminal >= 1:
102 | raise ValueError("The terminal value must be between 0 and 1 (exclusive).")
103 |
104 | # Transform the shifts using the given formula
105 | one_minus_z = 1 - shifts
106 | scale_factor = one_minus_z[-1] / (1 - terminal)
107 | stretched_shifts = 1 - (one_minus_z / scale_factor)
108 |
109 | return stretched_shifts
110 |
111 |
112 | def sd3_resolution_dependent_timestep_shift(
113 | samples_shape: torch.Size,
114 | timesteps: Tensor,
115 | target_shift_terminal: Optional[float] = None,
116 | ) -> Tensor:
117 | """
118 | Shifts the timestep schedule as a function of the generated resolution.
119 |
120 | In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images.
121 | For more details: https://arxiv.org/pdf/2403.03206
122 |
123 | In Flux they later propose a more dynamic resolution dependent timestep shift, see:
124 | https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66
125 |
126 |
127 | Args:
128 | samples_shape (torch.Size): The samples batch shape (batch_size, channels, height, width) or
129 | (batch_size, channels, frame, height, width).
130 | timesteps (Tensor): A batch of timesteps with shape (batch_size,).
131 | target_shift_terminal (float): The target terminal value for the shifted timesteps.
132 |
133 | Returns:
134 | Tensor: The shifted timesteps.
135 | """
136 | if len(samples_shape) == 3:
137 | _, m, _ = samples_shape
138 | elif len(samples_shape) in [4, 5]:
139 | m = math.prod(samples_shape[2:])
140 | else:
141 | raise ValueError(
142 | "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
143 | )
144 |
145 | shift = get_normal_shift(m)
146 | time_shifts = time_shift(shift, 1, timesteps)
147 | if target_shift_terminal is not None: # Stretch the shifts to the target terminal
148 | time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal)
149 | return time_shifts
150 |
151 |
152 | class TimestepShifter(ABC):
153 | @abstractmethod
154 | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
155 | pass
156 |
157 |
158 | @dataclass
159 | class RectifiedFlowSchedulerOutput(BaseOutput):
160 | """
161 | Output class for the scheduler's step function output.
162 |
163 | Args:
164 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
165 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
166 | denoising loop.
167 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
168 | The predicted denoised sample (x_{0}) based on the model output from the current timestep.
169 | `pred_original_sample` can be used to preview progress or for guidance.
170 | """
171 |
172 | prev_sample: torch.FloatTensor
173 | pred_original_sample: Optional[torch.FloatTensor] = None
174 |
175 |
176 | class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
177 | order = 1
178 |
179 | @register_to_config
180 | def __init__(
181 | self,
182 | num_train_timesteps=1000,
183 | shifting: Optional[str] = None,
184 | base_resolution: int = 32**2,
185 | target_shift_terminal: Optional[float] = None,
186 | sampler: Optional[str] = "Uniform",
187 | shift: Optional[float] = None,
188 | ):
189 | super().__init__()
190 | self.init_noise_sigma = 1.0
191 | self.num_inference_steps = None
192 | self.sampler = sampler
193 | self.shifting = shifting
194 | self.base_resolution = base_resolution
195 | self.target_shift_terminal = target_shift_terminal
196 | self.timesteps = self.sigmas = self.get_initial_timesteps(
197 | num_train_timesteps, shift=shift
198 | )
199 | self.shift = shift
200 |
201 | def get_initial_timesteps(
202 | self, num_timesteps: int, shift: Optional[float] = None
203 | ) -> Tensor:
204 | if self.sampler == "Uniform":
205 | return torch.linspace(1, 1 / num_timesteps, num_timesteps)
206 | elif self.sampler == "LinearQuadratic":
207 | return linear_quadratic_schedule(num_timesteps)
208 | elif self.sampler == "Constant":
209 | assert (
210 | shift is not None
211 | ), "Shift must be provided for constant time shift sampler."
212 | return time_shift(
213 | shift, 1, torch.linspace(1, 1 / num_timesteps, num_timesteps)
214 | )
215 |
216 | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
217 | if self.shifting == "SD3":
218 | return sd3_resolution_dependent_timestep_shift(
219 | samples_shape, timesteps, self.target_shift_terminal
220 | )
221 | elif self.shifting == "SimpleDiffusion":
222 | return simple_diffusion_resolution_dependent_timestep_shift(
223 | samples_shape, timesteps, self.base_resolution
224 | )
225 | return timesteps
226 |
227 | def set_timesteps(
228 | self,
229 | num_inference_steps: Optional[int] = None,
230 | samples_shape: Optional[torch.Size] = None,
231 | timesteps: Optional[Tensor] = None,
232 | device: Union[str, torch.device] = None,
233 | ):
234 | """
235 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
236 | If `timesteps` are provided, they will be used instead of the scheduled timesteps.
237 |
238 | Args:
239 | num_inference_steps (`int` *optional*): The number of diffusion steps used when generating samples.
240 | samples_shape (`torch.Size` *optional*): The samples batch shape, used for shifting.
241 | timesteps ('torch.Tensor' *optional*): Specific timesteps to use instead of scheduled timesteps.
242 | device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
243 | """
244 | if timesteps is not None and num_inference_steps is not None:
245 | raise ValueError(
246 | "You cannot provide both `timesteps` and `num_inference_steps`."
247 | )
248 | if timesteps is None:
249 | num_inference_steps = min(
250 | self.config.num_train_timesteps, num_inference_steps
251 | )
252 | timesteps = self.get_initial_timesteps(
253 | num_inference_steps, shift=self.shift
254 | ).to(device)
255 | timesteps = self.shift_timesteps(samples_shape, timesteps)
256 | else:
257 | timesteps = torch.Tensor(timesteps).to(device)
258 | num_inference_steps = len(timesteps)
259 | self.timesteps = timesteps
260 | self.num_inference_steps = num_inference_steps
261 | self.sigmas = self.timesteps
262 |
263 | @staticmethod
264 | def from_pretrained(pretrained_model_path: Union[str, os.PathLike]):
265 | pretrained_model_path = Path(pretrained_model_path)
266 | if pretrained_model_path.is_file():
267 | comfy_single_file_state_dict = {}
268 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
269 | metadata = f.metadata()
270 | for k in f.keys():
271 | comfy_single_file_state_dict[k] = f.get_tensor(k)
272 | configs = json.loads(metadata["config"])
273 | config = configs["scheduler"]
274 | del comfy_single_file_state_dict
275 |
276 | elif pretrained_model_path.is_dir():
277 | diffusers_noise_scheduler_config_path = (
278 | pretrained_model_path / "scheduler" / "scheduler_config.json"
279 | )
280 |
281 | with open(diffusers_noise_scheduler_config_path, "r") as f:
282 | scheduler_config = json.load(f)
283 | hashable_config = make_hashable_key(scheduler_config)
284 | if hashable_config in diffusers_and_ours_config_mapping:
285 | config = diffusers_and_ours_config_mapping[hashable_config]
286 | return RectifiedFlowScheduler.from_config(config)
287 |
288 | def scale_model_input(
289 | self, sample: torch.FloatTensor, timestep: Optional[int] = None
290 | ) -> torch.FloatTensor:
291 | # pylint: disable=unused-argument
292 | """
293 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
294 | current timestep.
295 |
296 | Args:
297 | sample (`torch.FloatTensor`): input sample
298 | timestep (`int`, optional): current timestep
299 |
300 | Returns:
301 | `torch.FloatTensor`: scaled input sample
302 | """
303 | return sample
304 |
305 | def step(
306 | self,
307 | model_output: torch.FloatTensor,
308 | timestep: torch.FloatTensor,
309 | sample: torch.FloatTensor,
310 | return_dict: bool = True,
311 | stochastic_sampling: Optional[bool] = False,
312 | **kwargs,
313 | ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
314 | """
315 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
316 | process from the learned model outputs (most often the predicted noise).
317 | z_{t_1} = z_t - \Delta_t * v
318 | The method finds the next timestep that is lower than the input timestep(s) and denoises the latents
319 | to that level. The input timestep(s) are not required to be one of the predefined timesteps.
320 |
321 | Args:
322 | model_output (`torch.FloatTensor`):
323 | The direct output from learned diffusion model - the velocity,
324 | timestep (`float`):
325 | The current discrete timestep in the diffusion chain (global or per-token).
326 | sample (`torch.FloatTensor`):
327 | A current latent tokens to be de-noised.
328 | return_dict (`bool`, *optional*, defaults to `True`):
329 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
330 | stochastic_sampling (`bool`, *optional*, defaults to `False`):
331 | Whether to use stochastic sampling for the sampling process.
332 |
333 | Returns:
334 | [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`:
335 | If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned,
336 | otherwise a tuple is returned where the first element is the sample tensor.
337 | """
338 | if self.num_inference_steps is None:
339 | raise ValueError(
340 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
341 | )
342 | t_eps = 1e-6 # Small epsilon to avoid numerical issues in timestep values
343 |
344 | timesteps_padded = torch.cat(
345 | [self.timesteps, torch.zeros(1, device=self.timesteps.device)]
346 | )
347 |
348 | # Find the next lower timestep(s) and compute the dt from the current timestep(s)
349 | if timestep.ndim == 0:
350 | # Global timestep case
351 | lower_mask = timesteps_padded < timestep - t_eps
352 | lower_timestep = timesteps_padded[lower_mask][0] # Closest lower timestep
353 | dt = timestep - lower_timestep
354 |
355 | else:
356 | # Per-token case
357 | assert timestep.ndim == 2
358 | lower_mask = timesteps_padded[:, None, None] < timestep[None] - t_eps
359 | lower_timestep = lower_mask * timesteps_padded[:, None, None]
360 | lower_timestep, _ = lower_timestep.max(dim=0)
361 | dt = (timestep - lower_timestep)[..., None]
362 |
363 | # Compute previous sample
364 | if stochastic_sampling:
365 | x0 = sample - timestep[..., None] * model_output
366 | next_timestep = timestep[..., None] - dt
367 | prev_sample = self.add_noise(x0, torch.randn_like(sample), next_timestep)
368 | else:
369 | prev_sample = sample - dt * model_output
370 |
371 | if not return_dict:
372 | return (prev_sample,)
373 |
374 | return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
375 |
376 | def add_noise(
377 | self,
378 | original_samples: torch.FloatTensor,
379 | noise: torch.FloatTensor,
380 | timesteps: torch.FloatTensor,
381 | ) -> torch.FloatTensor:
382 | sigmas = timesteps
383 | sigmas = append_dims(sigmas, original_samples.ndim)
384 | alphas = 1 - sigmas
385 | noisy_samples = alphas * original_samples + sigmas * noise
386 | return noisy_samples
387 |
--------------------------------------------------------------------------------
/ltx_video/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/utils/__init__.py
--------------------------------------------------------------------------------
/ltx_video/utils/diffusers_config_mapping.py:
--------------------------------------------------------------------------------
1 | def make_hashable_key(dict_key):
2 | def convert_value(value):
3 | if isinstance(value, list):
4 | return tuple(value)
5 | elif isinstance(value, dict):
6 | return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
7 | else:
8 | return value
9 |
10 | return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))
11 |
12 |
13 | DIFFUSERS_SCHEDULER_CONFIG = {
14 | "_class_name": "FlowMatchEulerDiscreteScheduler",
15 | "_diffusers_version": "0.32.0.dev0",
16 | "base_image_seq_len": 1024,
17 | "base_shift": 0.95,
18 | "invert_sigmas": False,
19 | "max_image_seq_len": 4096,
20 | "max_shift": 2.05,
21 | "num_train_timesteps": 1000,
22 | "shift": 1.0,
23 | "shift_terminal": 0.1,
24 | "use_beta_sigmas": False,
25 | "use_dynamic_shifting": True,
26 | "use_exponential_sigmas": False,
27 | "use_karras_sigmas": False,
28 | }
29 | DIFFUSERS_TRANSFORMER_CONFIG = {
30 | "_class_name": "LTXVideoTransformer3DModel",
31 | "_diffusers_version": "0.32.0.dev0",
32 | "activation_fn": "gelu-approximate",
33 | "attention_bias": True,
34 | "attention_head_dim": 64,
35 | "attention_out_bias": True,
36 | "caption_channels": 4096,
37 | "cross_attention_dim": 2048,
38 | "in_channels": 128,
39 | "norm_elementwise_affine": False,
40 | "norm_eps": 1e-06,
41 | "num_attention_heads": 32,
42 | "num_layers": 28,
43 | "out_channels": 128,
44 | "patch_size": 1,
45 | "patch_size_t": 1,
46 | "qk_norm": "rms_norm_across_heads",
47 | }
48 | DIFFUSERS_VAE_CONFIG = {
49 | "_class_name": "AutoencoderKLLTXVideo",
50 | "_diffusers_version": "0.32.0.dev0",
51 | "block_out_channels": [128, 256, 512, 512],
52 | "decoder_causal": False,
53 | "encoder_causal": True,
54 | "in_channels": 3,
55 | "latent_channels": 128,
56 | "layers_per_block": [4, 3, 3, 3, 4],
57 | "out_channels": 3,
58 | "patch_size": 4,
59 | "patch_size_t": 1,
60 | "resnet_norm_eps": 1e-06,
61 | "scaling_factor": 1.0,
62 | "spatio_temporal_scaling": [True, True, True, False],
63 | }
64 |
65 | OURS_SCHEDULER_CONFIG = {
66 | "_class_name": "RectifiedFlowScheduler",
67 | "_diffusers_version": "0.25.1",
68 | "num_train_timesteps": 1000,
69 | "shifting": "SD3",
70 | "base_resolution": None,
71 | "target_shift_terminal": 0.1,
72 | }
73 |
74 | OURS_TRANSFORMER_CONFIG = {
75 | "_class_name": "Transformer3DModel",
76 | "_diffusers_version": "0.25.1",
77 | "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
78 | "activation_fn": "gelu-approximate",
79 | "attention_bias": True,
80 | "attention_head_dim": 64,
81 | "attention_type": "default",
82 | "caption_channels": 4096,
83 | "cross_attention_dim": 2048,
84 | "double_self_attention": False,
85 | "dropout": 0.0,
86 | "in_channels": 128,
87 | "norm_elementwise_affine": False,
88 | "norm_eps": 1e-06,
89 | "norm_num_groups": 32,
90 | "num_attention_heads": 32,
91 | "num_embeds_ada_norm": 1000,
92 | "num_layers": 28,
93 | "num_vector_embeds": None,
94 | "only_cross_attention": False,
95 | "out_channels": 128,
96 | "project_to_2d_pos": True,
97 | "upcast_attention": False,
98 | "use_linear_projection": False,
99 | "qk_norm": "rms_norm",
100 | "standardization_norm": "rms_norm",
101 | "positional_embedding_type": "rope",
102 | "positional_embedding_theta": 10000.0,
103 | "positional_embedding_max_pos": [20, 2048, 2048],
104 | "timestep_scale_multiplier": 1000,
105 | }
106 | OURS_VAE_CONFIG = {
107 | "_class_name": "CausalVideoAutoencoder",
108 | "dims": 3,
109 | "in_channels": 3,
110 | "out_channels": 3,
111 | "latent_channels": 128,
112 | "blocks": [
113 | ["res_x", 4],
114 | ["compress_all", 1],
115 | ["res_x_y", 1],
116 | ["res_x", 3],
117 | ["compress_all", 1],
118 | ["res_x_y", 1],
119 | ["res_x", 3],
120 | ["compress_all", 1],
121 | ["res_x", 3],
122 | ["res_x", 4],
123 | ],
124 | "scaling_factor": 1.0,
125 | "norm_layer": "pixel_norm",
126 | "patch_size": 4,
127 | "latent_log_var": "uniform",
128 | "use_quant_conv": False,
129 | "causal_decoder": False,
130 | }
131 |
132 |
133 | diffusers_and_ours_config_mapping = {
134 | make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
135 | make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
136 | make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
137 | }
138 |
139 |
140 | TRANSFORMER_KEYS_RENAME_DICT = {
141 | "proj_in": "patchify_proj",
142 | "time_embed": "adaln_single",
143 | "norm_q": "q_norm",
144 | "norm_k": "k_norm",
145 | }
146 |
147 |
148 | VAE_KEYS_RENAME_DICT = {
149 | "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
150 | "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
151 | "decoder.up_blocks.3": "decoder.up_blocks.9",
152 | "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
153 | "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
154 | "decoder.up_blocks.2": "decoder.up_blocks.6",
155 | "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
156 | "decoder.up_blocks.1": "decoder.up_blocks.3",
157 | "decoder.up_blocks.0": "decoder.up_blocks.1",
158 | "decoder.mid_block": "decoder.up_blocks.0",
159 | "encoder.down_blocks.3": "encoder.down_blocks.8",
160 | "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
161 | "encoder.down_blocks.2": "encoder.down_blocks.6",
162 | "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
163 | "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
164 | "encoder.down_blocks.1": "encoder.down_blocks.3",
165 | "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
166 | "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
167 | "encoder.down_blocks.0": "encoder.down_blocks.0",
168 | "encoder.mid_block": "encoder.down_blocks.9",
169 | "conv_shortcut.conv": "conv_shortcut",
170 | "resnets": "res_blocks",
171 | "norm3": "norm3.norm",
172 | "latents_mean": "per_channel_statistics.mean-of-means",
173 | "latents_std": "per_channel_statistics.std-of-means",
174 | }
175 |
--------------------------------------------------------------------------------
/ltx_video/utils/prompt_enhance_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Union, List, Optional
3 |
4 | import torch
5 | from PIL import Image
6 |
7 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
8 |
9 | T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
10 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
11 | Start directly with the action, and keep descriptions literal and precise.
12 | Think like a cinematographer describing a shot list.
13 | Do not change the user input intent, just enhance it.
14 | Keep within 150 words.
15 | For best results, build your prompts using this structure:
16 | Start with main action in a single sentence
17 | Add specific details about movements and gestures
18 | Describe character/object appearances precisely
19 | Include background and environment details
20 | Specify camera angles and movements
21 | Describe lighting and colors
22 | Note any changes or sudden events
23 | Do not exceed the 150 word limit!
24 | Output the enhanced prompt only.
25 | """
26 |
27 | I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
28 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
29 | Start directly with the action, and keep descriptions literal and precise.
30 | Think like a cinematographer describing a shot list.
31 | Keep within 150 words.
32 | For best results, build your prompts using this structure:
33 | Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
34 | Start with main action in a single sentence
35 | Add specific details about movements and gestures
36 | Describe character/object appearances precisely
37 | Include background and environment details
38 | Specify camera angles and movements
39 | Describe lighting and colors
40 | Note any changes or sudden events
41 | Align to the image caption if it contradicts the user text input.
42 | Do not exceed the 150 word limit!
43 | Output the enhanced prompt only.
44 | """
45 |
46 |
47 | def tensor_to_pil(tensor):
48 | # Ensure tensor is in range [-1, 1]
49 | assert tensor.min() >= -1 and tensor.max() <= 1
50 |
51 | # Convert from [-1, 1] to [0, 1]
52 | tensor = (tensor + 1) / 2
53 |
54 | # Rearrange from [C, H, W] to [H, W, C]
55 | tensor = tensor.permute(1, 2, 0)
56 |
57 | # Convert to numpy array and then to uint8 range [0, 255]
58 | numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
59 |
60 | # Convert to PIL Image
61 | return Image.fromarray(numpy_image)
62 |
63 |
64 | def generate_cinematic_prompt(
65 | image_caption_model,
66 | image_caption_processor,
67 | prompt_enhancer_model,
68 | prompt_enhancer_tokenizer,
69 | prompt: Union[str, List[str]],
70 | conditioning_items: Optional[List] = None,
71 | max_new_tokens: int = 256,
72 | ) -> List[str]:
73 | prompts = [prompt] if isinstance(prompt, str) else prompt
74 |
75 | if conditioning_items is None:
76 | prompts = _generate_t2v_prompt(
77 | prompt_enhancer_model,
78 | prompt_enhancer_tokenizer,
79 | prompts,
80 | max_new_tokens,
81 | T2V_CINEMATIC_PROMPT,
82 | )
83 | else:
84 | if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0:
85 | logger.warning(
86 | "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts"
87 | )
88 | return prompts
89 |
90 | first_frame_conditioning_item = conditioning_items[0]
91 | first_frames = _get_first_frames_from_conditioning_item(
92 | first_frame_conditioning_item
93 | )
94 |
95 | assert len(first_frames) == len(
96 | prompts
97 | ), "Number of conditioning frames must match number of prompts"
98 |
99 | prompts = _generate_i2v_prompt(
100 | image_caption_model,
101 | image_caption_processor,
102 | prompt_enhancer_model,
103 | prompt_enhancer_tokenizer,
104 | prompts,
105 | first_frames,
106 | max_new_tokens,
107 | I2V_CINEMATIC_PROMPT,
108 | )
109 |
110 | return prompts
111 |
112 |
113 | def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]:
114 | frames_tensor = conditioning_item.media_item
115 | return [
116 | tensor_to_pil(frames_tensor[i, :, 0, :, :])
117 | for i in range(frames_tensor.shape[0])
118 | ]
119 |
120 |
121 | def _generate_t2v_prompt(
122 | prompt_enhancer_model,
123 | prompt_enhancer_tokenizer,
124 | prompts: List[str],
125 | max_new_tokens: int,
126 | system_prompt: str,
127 | ) -> List[str]:
128 | messages = [
129 | [
130 | {"role": "system", "content": system_prompt},
131 | {"role": "user", "content": f"user_prompt: {p}"},
132 | ]
133 | for p in prompts
134 | ]
135 |
136 | texts = [
137 | prompt_enhancer_tokenizer.apply_chat_template(
138 | m, tokenize=False, add_generation_prompt=True
139 | )
140 | for m in messages
141 | ]
142 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
143 | prompt_enhancer_model.device
144 | )
145 |
146 | return _generate_and_decode_prompts(
147 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
148 | )
149 |
150 |
151 | def _generate_i2v_prompt(
152 | image_caption_model,
153 | image_caption_processor,
154 | prompt_enhancer_model,
155 | prompt_enhancer_tokenizer,
156 | prompts: List[str],
157 | first_frames: List[Image.Image],
158 | max_new_tokens: int,
159 | system_prompt: str,
160 | ) -> List[str]:
161 | image_captions = _generate_image_captions(
162 | image_caption_model, image_caption_processor, first_frames
163 | )
164 |
165 | messages = [
166 | [
167 | {"role": "system", "content": system_prompt},
168 | {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
169 | ]
170 | for p, c in zip(prompts, image_captions)
171 | ]
172 |
173 | texts = [
174 | prompt_enhancer_tokenizer.apply_chat_template(
175 | m, tokenize=False, add_generation_prompt=True
176 | )
177 | for m in messages
178 | ]
179 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
180 | prompt_enhancer_model.device
181 | )
182 |
183 | return _generate_and_decode_prompts(
184 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
185 | )
186 |
187 |
188 | def _generate_image_captions(
189 | image_caption_model,
190 | image_caption_processor,
191 | images: List[Image.Image],
192 | system_prompt: str = "",
193 | ) -> List[str]:
194 | image_caption_prompts = [system_prompt] * len(images)
195 | inputs = image_caption_processor(
196 | image_caption_prompts, images, return_tensors="pt"
197 | ).to(image_caption_model.device)
198 |
199 | with torch.inference_mode():
200 | generated_ids = image_caption_model.generate(
201 | input_ids=inputs["input_ids"],
202 | pixel_values=inputs["pixel_values"],
203 | max_new_tokens=1024,
204 | do_sample=False,
205 | num_beams=3,
206 | )
207 |
208 | return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
209 |
210 |
211 | def _generate_and_decode_prompts(
212 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
213 | ) -> List[str]:
214 | with torch.inference_mode():
215 | outputs = prompt_enhancer_model.generate(
216 | **model_inputs, max_new_tokens=max_new_tokens
217 | )
218 | generated_ids = [
219 | output_ids[len(input_ids) :]
220 | for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
221 | ]
222 | decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
223 | generated_ids, skip_special_tokens=True
224 | )
225 |
226 | return decoded_prompts
227 |
--------------------------------------------------------------------------------
/ltx_video/utils/skip_layer_strategy.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 |
3 |
4 | class SkipLayerStrategy(Enum):
5 | AttentionSkip = auto()
6 | AttentionValues = auto()
7 | Residual = auto()
8 | TransformerBlock = auto()
9 |
--------------------------------------------------------------------------------
/ltx_video/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
7 | dims_to_append = target_dims - x.ndim
8 | if dims_to_append < 0:
9 | raise ValueError(
10 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
11 | )
12 | elif dims_to_append == 0:
13 | return x
14 | return x[(...,) + (None,) * dims_to_append]
15 |
16 |
17 | class Identity(nn.Module):
18 | """A placeholder identity operator that is argument-insensitive."""
19 |
20 | def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
21 | super().__init__()
22 |
23 | # pylint: disable=unused-argument
24 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
25 | return x
26 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "ltx-video"
7 | version = "0.1.2"
8 | description = "A package for LTX-Video model"
9 | authors = [
10 | { name = "Sapir Weissbuch", email = "sapir@lightricks.com" }
11 | ]
12 | requires-python = ">=3.10"
13 | readme = "README.md"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "Operating System :: OS Independent"
17 | ]
18 | dependencies = [
19 | "torch>=2.1.0",
20 | "diffusers>=0.28.2",
21 | "transformers>=4.47.2",
22 | "sentencepiece>=0.1.96",
23 | "huggingface-hub~=0.30",
24 | "einops",
25 | "timm"
26 | ]
27 |
28 | [project.optional-dependencies]
29 | # Instead of thinking of them as optional, think of them as specific modes
30 | inference-script = [
31 | "accelerate",
32 | "matplotlib",
33 | "imageio[ffmpeg]",
34 | "av",
35 | "opencv-python"
36 | ]
37 | test = [
38 | "pytest",
39 | ]
40 |
41 | [tool.setuptools.packages.find]
42 | include = ["ltx_video*"]
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import shutil
3 | import os
4 |
5 |
6 | @pytest.fixture(scope="session")
7 | def test_paths(request, pytestconfig):
8 | try:
9 | output_path = "output"
10 | ckpt_path = request.param # This will get the current parameterized item
11 | text_encoder_model_name_or_path = pytestconfig.getoption(
12 | "text_encoder_model_name_or_path"
13 | )
14 | input_image_path = pytestconfig.getoption("input_image_path")
15 | input_video_path = pytestconfig.getoption("input_video_path")
16 | prompt_enhancer_image_caption_model_name_or_path = pytestconfig.getoption(
17 | "prompt_enhancer_image_caption_model_name_or_path"
18 | )
19 | prompt_enhancer_llm_model_name_or_path = pytestconfig.getoption(
20 | "prompt_enhancer_llm_model_name_or_path"
21 | )
22 | prompt_enhancement_words_threshold = pytestconfig.getoption(
23 | "prompt_enhancement_words_threshold"
24 | )
25 |
26 | config = {
27 | "ckpt_path": ckpt_path,
28 | "input_image_path": input_image_path,
29 | "input_video_path": input_video_path,
30 | "output_path": output_path,
31 | "text_encoder_model_name_or_path": text_encoder_model_name_or_path,
32 | "prompt_enhancer_image_caption_model_name_or_path": prompt_enhancer_image_caption_model_name_or_path,
33 | "prompt_enhancer_llm_model_name_or_path": prompt_enhancer_llm_model_name_or_path,
34 | "prompt_enhancement_words_threshold": prompt_enhancement_words_threshold,
35 | }
36 |
37 | yield config
38 |
39 | finally:
40 | if os.path.exists(output_path):
41 | shutil.rmtree(output_path)
42 |
43 |
44 | def pytest_generate_tests(metafunc):
45 | if "test_paths" in metafunc.fixturenames:
46 | ckpt_paths = metafunc.config.getoption("ckpt_path")
47 | metafunc.parametrize("test_paths", ckpt_paths, indirect=True)
48 |
49 |
50 | def pytest_addoption(parser):
51 | parser.addoption(
52 | "--ckpt_path",
53 | action="append",
54 | default=[],
55 | help="Path to checkpoint files (can specify multiple)",
56 | )
57 | parser.addoption(
58 | "--text_encoder_model_name_or_path",
59 | action="store",
60 | default="PixArt-alpha/PixArt-XL-2-1024-MS",
61 | help="Path to the checkpoint file",
62 | )
63 | parser.addoption(
64 | "--input_image_path",
65 | action="store",
66 | default="tests/utils/woman.jpeg",
67 | help="Path to input image file.",
68 | )
69 | parser.addoption(
70 | "--input_video_path",
71 | action="store",
72 | default="tests/utils/woman.mp4",
73 | help="Path to input video file.",
74 | )
75 | parser.addoption(
76 | "--prompt_enhancer_image_caption_model_name_or_path",
77 | action="store",
78 | default="MiaoshouAI/Florence-2-large-PromptGen-v2.0",
79 | help="Path to prompt_enhancer_image_caption_model.",
80 | )
81 | parser.addoption(
82 | "--prompt_enhancer_llm_model_name_or_path",
83 | action="store",
84 | default="unsloth/Llama-3.2-3B-Instruct",
85 | help="Path to LLM model for prompt enhancement.",
86 | )
87 | parser.addoption(
88 | "--prompt_enhancement_words_threshold",
89 | type=int,
90 | default=50,
91 | help="Enable prompt enhancement only if input prompt has fewer words than this threshold. Set to 0 to disable enhancement completely.",
92 | )
93 |
--------------------------------------------------------------------------------
/tests/test_inference.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import yaml
4 | from inference import infer, create_ltx_video_pipeline
5 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
6 |
7 |
8 | def pytest_make_parametrize_id(config, val, argname):
9 | if isinstance(val, str):
10 | return f"{argname}-{val}"
11 | return f"{argname}-{repr(val)}"
12 |
13 |
14 | @pytest.mark.parametrize(
15 | "conditioning_test_mode",
16 | ["unconditional", "first-frame", "first-sequence", "sequence-and-frame"],
17 | ids=lambda x: f"conditioning_test_mode={x}",
18 | )
19 | def test_infer_runs_on_real_path(tmp_path, test_paths, conditioning_test_mode):
20 | conditioning_params = {}
21 | if conditioning_test_mode == "unconditional":
22 | pass
23 | elif conditioning_test_mode == "first-frame":
24 | conditioning_params["conditioning_media_paths"] = [
25 | test_paths["input_image_path"]
26 | ]
27 | conditioning_params["conditioning_start_frames"] = [0]
28 | elif conditioning_test_mode == "first-sequence":
29 | conditioning_params["conditioning_media_paths"] = [
30 | test_paths["input_video_path"]
31 | ]
32 | conditioning_params["conditioning_start_frames"] = [0]
33 | elif conditioning_test_mode == "sequence-and-frame":
34 | conditioning_params["conditioning_media_paths"] = [
35 | test_paths["input_video_path"],
36 | test_paths["input_image_path"],
37 | ]
38 | conditioning_params["conditioning_start_frames"] = [16, 67]
39 | else:
40 | raise ValueError(f"Unknown conditioning mode: {conditioning_test_mode}")
41 | test_paths = {
42 | k: v
43 | for k, v in test_paths.items()
44 | if k not in ["input_image_path", "input_video_path"]
45 | }
46 |
47 | params = {
48 | "seed": 42,
49 | "num_inference_steps": 1,
50 | "height": 512,
51 | "width": 768,
52 | "num_frames": 121,
53 | "frame_rate": 25,
54 | "prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.",
55 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
56 | "offload_to_cpu": False,
57 | "output_path": tmp_path,
58 | "image_cond_noise_scale": 0.15,
59 | }
60 |
61 | config = {
62 | "pipeline_type": "base",
63 | "num_images_per_prompt": 1,
64 | "guidance_scale": 2.5,
65 | "stg_scale": 1,
66 | "stg_rescale": 0.7,
67 | "stg_mode": "attention_values",
68 | "stg_skip_layers": "1,2,3",
69 | "precision": "bfloat16",
70 | "decode_timestep": 0.05,
71 | "decode_noise_scale": 0.025,
72 | "checkpoint_path": test_paths["ckpt_path"],
73 | "text_encoder_model_name_or_path": test_paths[
74 | "text_encoder_model_name_or_path"
75 | ],
76 | "prompt_enhancer_image_caption_model_name_or_path": test_paths[
77 | "prompt_enhancer_image_caption_model_name_or_path"
78 | ],
79 | "prompt_enhancer_llm_model_name_or_path": test_paths[
80 | "prompt_enhancer_llm_model_name_or_path"
81 | ],
82 | "prompt_enhancement_words_threshold": 120,
83 | "stochastic_sampling": False,
84 | "sampler": "from_checkpoint",
85 | }
86 |
87 | temp_config_path = tmp_path / "config.yaml"
88 | with open(temp_config_path, "w") as f:
89 | yaml.dump(config, f)
90 |
91 | infer(**{**conditioning_params, **params, "pipeline_config": temp_config_path})
92 |
93 |
94 | def test_vid2vid(tmp_path, test_paths):
95 | params = {
96 | "seed": 42,
97 | "image_cond_noise_scale": 0.15,
98 | "height": 512,
99 | "width": 768,
100 | "num_frames": 25,
101 | "frame_rate": 25,
102 | "prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.",
103 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
104 | "offload_to_cpu": False,
105 | "input_media_path": test_paths["input_video_path"],
106 | }
107 |
108 | config = {
109 | "num_inference_steps": 3,
110 | "skip_initial_inference_steps": 1,
111 | "guidance_scale": 2.5,
112 | "stg_scale": 1,
113 | "stg_rescale": 0.7,
114 | "stg_mode": "attention_values",
115 | "stg_skip_layers": "1,2,3",
116 | "precision": "bfloat16",
117 | "decode_timestep": 0.05,
118 | "decode_noise_scale": 0.025,
119 | "sampler": "from_checkpoint",
120 | "checkpoint_path": test_paths["ckpt_path"],
121 | "text_encoder_model_name_or_path": test_paths[
122 | "text_encoder_model_name_or_path"
123 | ],
124 | "prompt_enhancer_image_caption_model_name_or_path": test_paths[
125 | "prompt_enhancer_image_caption_model_name_or_path"
126 | ],
127 | "prompt_enhancer_llm_model_name_or_path": test_paths[
128 | "prompt_enhancer_llm_model_name_or_path"
129 | ],
130 | "prompt_enhancement_words_threshold": 120,
131 | }
132 | test_paths = {
133 | k: v
134 | for k, v in test_paths.items()
135 | if k not in ["input_image_path", "input_video_path"]
136 | }
137 | temp_config_path = tmp_path / "config.yaml"
138 | with open(temp_config_path, "w") as f:
139 | yaml.dump(config, f)
140 |
141 | infer(**{**test_paths, **params, "pipeline_config": temp_config_path})
142 |
143 |
144 | def get_device():
145 | if torch.cuda.is_available():
146 | return "cuda"
147 | elif torch.backends.mps.is_available():
148 | return "mps"
149 | return "cpu"
150 |
151 |
152 | def test_pipeline_on_batch(tmp_path, test_paths):
153 | device = get_device()
154 | pipeline = create_ltx_video_pipeline(
155 | ckpt_path=test_paths["ckpt_path"],
156 | device=device,
157 | precision="bfloat16",
158 | text_encoder_model_name_or_path=test_paths["text_encoder_model_name_or_path"],
159 | enhance_prompt=False,
160 | prompt_enhancer_image_caption_model_name_or_path=test_paths[
161 | "prompt_enhancer_image_caption_model_name_or_path"
162 | ],
163 | prompt_enhancer_llm_model_name_or_path=test_paths[
164 | "prompt_enhancer_llm_model_name_or_path"
165 | ],
166 | )
167 |
168 | params = {
169 | "seed": 42,
170 | "image_cond_noise_scale": 0.15,
171 | "height": 512,
172 | "width": 768,
173 | "num_frames": 1,
174 | "frame_rate": 25,
175 | "offload_to_cpu": False,
176 | "output_type": "pt",
177 | "is_video": False,
178 | "vae_per_channel_normalize": True,
179 | "mixed_precision": False,
180 | }
181 |
182 | config = {
183 | "num_inference_steps": 2,
184 | "guidance_scale": 2.5,
185 | "stg_scale": 1,
186 | "rescaling_scale": 0.7,
187 | "skip_block_list": [1, 2],
188 | "decode_timestep": 0.05,
189 | "decode_noise_scale": 0.025,
190 | }
191 |
192 | temp_config_path = tmp_path / "config.yaml"
193 | with open(temp_config_path, "w") as f:
194 | yaml.dump(config, f)
195 |
196 | first_prompt = "A vintage yellow car drives along a wet mountain road, its rear wheels kicking up a light spray as it moves. The camera follows close behind, capturing the curvature of the road as it winds through rocky cliffs and lush green hills. The sunlight pierces through scattered clouds, reflecting off the car's rain-speckled surface, creating a dynamic, cinematic moment. The scene conveys a sense of freedom and exploration as the car disappears into the distance."
197 | second_prompt = "A woman with blonde hair styled up, wearing a black dress with sequins and pearl earrings, looks down with a sad expression on her face. The camera remains stationary, focused on the woman's face. The lighting is dim, casting soft shadows on her face. The scene appears to be from a movie or TV show."
198 |
199 | sample = {
200 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
201 | "prompt_attention_mask": None,
202 | "negative_prompt_attention_mask": None,
203 | "media_items": None,
204 | }
205 |
206 | def get_images(prompts):
207 | generators = [
208 | torch.Generator(device=device).manual_seed(params["seed"]) for _ in range(2)
209 | ]
210 | torch.manual_seed(params["seed"])
211 |
212 | images = pipeline(
213 | prompt=prompts,
214 | generator=generators,
215 | **sample,
216 | **params,
217 | pipeline_config=temp_config_path,
218 | ).images
219 | return images
220 |
221 | batch_diff_images = get_images([first_prompt, second_prompt])
222 | batch_same_images = get_images([second_prompt, second_prompt])
223 |
224 | # Take the second image from both runs
225 | image2_not_same = batch_diff_images[1, :, 0, :, :]
226 | image2_same = batch_same_images[1, :, 0, :, :]
227 |
228 | # Compute mean absolute difference, should be 0
229 | mad = torch.mean(torch.abs(image2_not_same - image2_same)).item()
230 | print(f"Mean absolute difference: {mad}")
231 |
232 | assert torch.allclose(image2_not_same, image2_same)
233 |
234 |
235 | def test_prompt_enhancement(tmp_path, test_paths, monkeypatch):
236 | # Create pipeline with prompt enhancement enabled
237 | device = get_device()
238 | pipeline = create_ltx_video_pipeline(
239 | ckpt_path=test_paths["ckpt_path"],
240 | device=device,
241 | precision="bfloat16",
242 | text_encoder_model_name_or_path=test_paths["text_encoder_model_name_or_path"],
243 | enhance_prompt=True,
244 | prompt_enhancer_image_caption_model_name_or_path=test_paths[
245 | "prompt_enhancer_image_caption_model_name_or_path"
246 | ],
247 | prompt_enhancer_llm_model_name_or_path=test_paths[
248 | "prompt_enhancer_llm_model_name_or_path"
249 | ],
250 | )
251 |
252 | original_prompt = "A cat sitting on a windowsill"
253 |
254 | # Mock the pipeline's _encode_prompt method to verify the prompt being used
255 | original_encode_prompt = pipeline.encode_prompt
256 |
257 | prompts_used = []
258 |
259 | def mock_encode_prompt(prompt, *args, **kwargs):
260 | prompts_used.append(prompt[0] if isinstance(prompt, list) else prompt)
261 | return original_encode_prompt(prompt, *args, **kwargs)
262 |
263 | pipeline.encode_prompt = mock_encode_prompt
264 |
265 | # Set up minimal parameters for a quick test
266 | params = {
267 | "seed": 42,
268 | "image_cond_noise_scale": 0.15,
269 | "height": 512,
270 | "width": 768,
271 | "skip_layer_strategy": SkipLayerStrategy.AttentionValues,
272 | "num_frames": 1,
273 | "frame_rate": 25,
274 | "offload_to_cpu": False,
275 | "output_type": "pt",
276 | "is_video": False,
277 | "vae_per_channel_normalize": True,
278 | "mixed_precision": False,
279 | }
280 |
281 | config = {
282 | "pipeline_type": "base",
283 | "num_inference_steps": 1,
284 | "guidance_scale": 2.5,
285 | "stg_scale": 1,
286 | "rescaling_scale": 0.7,
287 | "skip_block_list": [1, 2],
288 | "decode_timestep": 0.05,
289 | "decode_noise_scale": 0.025,
290 | }
291 |
292 | temp_config_path = tmp_path / "config.yaml"
293 | with open(temp_config_path, "w") as f:
294 | yaml.dump(config, f)
295 |
296 | # Run pipeline with prompt enhancement enabled
297 | _ = pipeline(
298 | prompt=original_prompt,
299 | negative_prompt="worst quality",
300 | enhance_prompt=True,
301 | **params,
302 | pipeline_config=temp_config_path,
303 | )
304 |
305 | # Verify that the enhanced prompt was used
306 | assert len(prompts_used) > 0
307 | assert (
308 | prompts_used[0] != original_prompt
309 | ), f"Expected enhanced prompt to be different from original prompt, but got: {original_prompt}"
310 |
311 | # Run pipeline with prompt enhancement disabled
312 | prompts_used.clear()
313 | _ = pipeline(
314 | prompt=original_prompt,
315 | negative_prompt="worst quality",
316 | enhance_prompt=False,
317 | **params,
318 | pipeline_config=temp_config_path,
319 | )
320 |
321 | # Verify that the original prompt was used
322 | assert len(prompts_used) > 0
323 | assert (
324 | prompts_used[0] == original_prompt
325 | ), f"Expected original prompt to be used, but got: {prompts_used[0]}"
326 |
--------------------------------------------------------------------------------
/tests/test_scheduler.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from ltx_video.schedulers.rf import RectifiedFlowScheduler
4 |
5 |
6 | def init_latents_and_scheduler(sampler):
7 | batch_size, n_tokens, n_channels = 2, 4096, 128
8 | num_steps = 20
9 | scheduler = RectifiedFlowScheduler(
10 | sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
11 | )
12 | latents = torch.randn(size=(batch_size, n_tokens, n_channels))
13 | scheduler.set_timesteps(num_inference_steps=num_steps, samples_shape=latents.shape)
14 | return scheduler, latents
15 |
16 |
17 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
18 | def test_scheduler_default_behavior(sampler):
19 | """
20 | Test the case of a single timestep from the list of timesteps.
21 | """
22 | scheduler, latents = init_latents_and_scheduler(sampler)
23 |
24 | for i, t in enumerate(scheduler.timesteps):
25 | noise_pred = torch.randn_like(latents)
26 | denoised_latents = scheduler.step(
27 | noise_pred,
28 | t,
29 | latents,
30 | return_dict=False,
31 | )[0]
32 |
33 | # Verify the denoising
34 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
35 | dt = t - next_t
36 | expected_denoised_latents = latents - dt * noise_pred
37 | assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)
38 |
39 |
40 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
41 | def test_scheduler_per_token(sampler):
42 | """
43 | Test the case of a timestep per token (from the list of timesteps).
44 | Some tokens are set with timestep of 0.
45 | """
46 | scheduler, latents = init_latents_and_scheduler(sampler)
47 | batch_size, n_tokens = latents.shape[:2]
48 | for i, t in enumerate(scheduler.timesteps):
49 | timesteps = torch.full((batch_size, n_tokens), t)
50 | timesteps[:, 0] = 0.0
51 | noise_pred = torch.randn_like(latents)
52 | denoised_latents = scheduler.step(
53 | noise_pred,
54 | timesteps,
55 | latents,
56 | return_dict=False,
57 | )[0]
58 |
59 | # Verify the denoising
60 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
61 | next_timesteps = torch.full((batch_size, n_tokens), next_t)
62 | dt = timesteps - next_timesteps
63 | expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
64 | assert torch.allclose(
65 | denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06
66 | )
67 | assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06)
68 |
69 |
70 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
71 | def test_scheduler_t_not_in_list(sampler):
72 | """
73 | Test the case of a timestep per token NOT from the list of timesteps.
74 | """
75 | scheduler, latents = init_latents_and_scheduler(sampler)
76 | batch_size, n_tokens = latents.shape[:2]
77 | for i in range(len(scheduler.timesteps)):
78 | if i < len(scheduler.timesteps) - 1:
79 | t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2
80 | else:
81 | t = scheduler.timesteps[i] / 2
82 | timesteps = torch.full((batch_size, n_tokens), t)
83 | noise_pred = torch.randn_like(latents)
84 | denoised_latents = scheduler.step(
85 | noise_pred,
86 | timesteps,
87 | latents,
88 | return_dict=False,
89 | )[0]
90 |
91 | # Verify the denoising
92 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
93 | next_timesteps = torch.full((batch_size, n_tokens), next_t)
94 | dt = timesteps - next_timesteps
95 | expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
96 | assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)
97 |
--------------------------------------------------------------------------------
/tests/test_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ltx_video.models.autoencoders.causal_video_autoencoder import (
4 | CausalVideoAutoencoder,
5 | create_video_autoencoder_demo_config,
6 | )
7 |
8 |
9 | def test_vae():
10 | # create vae and run a forward and backward pass
11 | config = create_video_autoencoder_demo_config(latent_channels=16)
12 | video_autoencoder = CausalVideoAutoencoder.from_config(config)
13 | video_autoencoder.eval()
14 | input_videos = torch.randn(2, 3, 17, 64, 64)
15 | latent = video_autoencoder.encode(input_videos).latent_dist.mode()
16 | assert latent.shape == (2, 16, 3, 2, 2)
17 | timestep = torch.ones(input_videos.shape[0]) * 0.1
18 | reconstructed_videos = video_autoencoder.decode(
19 | latent, target_shape=input_videos.shape, timestep=timestep
20 | ).sample
21 | assert input_videos.shape == reconstructed_videos.shape
22 | loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
23 | loss.backward()
24 |
25 | # validate temporal causality in encoder
26 | input_image = input_videos[:, :, :1, :, :]
27 | image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
28 | assert torch.allclose(image_latent, latent[:, :, :1, :, :], atol=1e-6)
29 |
30 | input_sequence = input_videos[:, :, :9, :, :]
31 | sequence_latent = video_autoencoder.encode(input_sequence).latent_dist.mode()
32 | assert torch.allclose(sequence_latent, latent[:, :, :2, :, :], atol=1e-6)
33 |
--------------------------------------------------------------------------------
/tests/utils/.gitattributes:
--------------------------------------------------------------------------------
1 | *.mp4 filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/tests/utils/woman.jpeg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:acf0a6de9f7fb5551702fa8065ab423811e2011c36338f49636f84b033715df8
3 | size 33249
4 |
--------------------------------------------------------------------------------
/tests/utils/woman.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:833c2a0afe03ad60cb0c6361433eb11acf2ed58890f6b411cd306c01069e199f
3 | size 73228
4 |
--------------------------------------------------------------------------------