├── .gitattributes ├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── ltxv-13b-0.9.7-dev.yaml ├── ltxv-13b-0.9.7-distilled.yaml ├── ltxv-2b-0.9.1.yaml ├── ltxv-2b-0.9.5.yaml ├── ltxv-2b-0.9.6-dev.yaml ├── ltxv-2b-0.9.6-distilled.yaml └── ltxv-2b-0.9.yaml ├── docs └── _static │ ├── ltx-video_example_00001.gif │ ├── ltx-video_example_00005.gif │ ├── ltx-video_example_00006.gif │ ├── ltx-video_example_00007.gif │ ├── ltx-video_example_00010.gif │ ├── ltx-video_example_00011.gif │ ├── ltx-video_example_00013.gif │ ├── ltx-video_example_00014.gif │ ├── ltx-video_example_00015.gif │ ├── ltx-video_i2v_example_00001.gif │ ├── ltx-video_i2v_example_00002.gif │ ├── ltx-video_i2v_example_00003.gif │ ├── ltx-video_i2v_example_00004.gif │ ├── ltx-video_i2v_example_00005.gif │ ├── ltx-video_i2v_example_00006.gif │ ├── ltx-video_i2v_example_00007.gif │ ├── ltx-video_i2v_example_00008.gif │ └── ltx-video_i2v_example_00009.gif ├── inference.py ├── ltx_video ├── __init__.py ├── models │ ├── __init__.py │ ├── autoencoders │ │ ├── __init__.py │ │ ├── causal_conv3d.py │ │ ├── causal_video_autoencoder.py │ │ ├── conv_nd_factory.py │ │ ├── dual_conv3d.py │ │ ├── latent_upsampler.py │ │ ├── pixel_norm.py │ │ ├── pixel_shuffle.py │ │ ├── vae.py │ │ ├── vae_encode.py │ │ └── video_autoencoder.py │ └── transformers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── embeddings.py │ │ ├── symmetric_patchifier.py │ │ └── transformer3d.py ├── pipelines │ ├── __init__.py │ ├── crf_compressor.py │ └── pipeline_ltx_video.py ├── schedulers │ ├── __init__.py │ └── rf.py └── utils │ ├── __init__.py │ ├── diffusers_config_mapping.py │ ├── prompt_enhance_utils.py │ ├── skip_layer_strategy.py │ └── torch_utils.py ├── pyproject.toml └── tests ├── conftest.py ├── test_inference.py ├── test_scheduler.py ├── test_vae.py └── utils ├── .gitattributes ├── woman.jpeg └── woman.mp4 /.gitattributes: -------------------------------------------------------------------------------- 1 | *.jpg filter=lfs diff=lfs merge=lfs -text 2 | *.jpeg filter=lfs diff=lfs merge=lfs -text 3 | *.png filter=lfs diff=lfs merge=lfs -text 4 | *.gif filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10"] 11 | steps: 12 | - name: Checkout repository and submodules 13 | uses: actions/checkout@v3 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v3 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install ruff==0.2.2 black==24.2.0 22 | - name: Analyzing the code with ruff 23 | run: | 24 | ruff $(git ls-files '*.py') 25 | - name: Verify that no Black changes are required 26 | run: | 27 | black --check $(git ls-files '*.py') 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | 164 | # From inference.py 165 | outputs/ 166 | video_output_*.mp4 -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.2.2 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | args: [--fix] # Automatically fix issues if possible. 9 | types: [python] # Ensure it only runs on .py files. 10 | 11 | - repo: https://github.com/psf/black 12 | rev: 24.2.0 # Specify the version of Black you want 13 | hooks: 14 | - id: black 15 | name: Black code formatter 16 | language_version: python3 # Use the Python version you're targeting (e.g., 3.10) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # LTX-Video 4 | 5 | This is the official repository for LTX-Video. 6 | 7 | [Website](https://www.lightricks.com/ltxv) | 8 | [Model](https://huggingface.co/Lightricks/LTX-Video) | 9 | [Demo](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) | 10 | [Paper](https://arxiv.org/abs/2501.00103) | 11 | [Trainer](https://github.com/Lightricks/LTX-Video-Trainer) | 12 | [Discord](https://discord.gg/Mn8BRgUKKy) 13 | 14 |
15 | 16 | ## Table of Contents 17 | 18 | - [Introduction](#introduction) 19 | - [What's new](#news) 20 | - [Models & Workflows](#models--workflows) 21 | - [Quick Start Guide](#quick-start-guide) 22 | - [Use online](#online-inference) 23 | - [Run locally](#run-locally) 24 | - [Installation](#installation) 25 | - [Inference](#inference) 26 | - [ComfyUI Integration](#comfyui-integration) 27 | - [Diffusers Integration](#diffusers-integration) 28 | - [Model User Guide](#model-user-guide) 29 | - [Community Contribution](#community-contribution) 30 | - [Training](#⚡️-training) 31 | - [Join Us!](#🚀-join-us) 32 | - [Acknowledgement](#acknowledgement) 33 | 34 | # Introduction 35 | 36 | LTX-Video is the first DiT-based video generation model that can generate high-quality videos in *real-time*. 37 | It can generate 30 FPS videos at 1216×704 resolution, faster than it takes to watch them. 38 | The model is trained on a large-scale dataset of diverse videos and can generate high-resolution videos 39 | with realistic and diverse content. 40 | 41 | The model supports text-to-image, image-to-video, keyframe-based animation, video extension (both forward and backward), video-to-video transformations, and any combination of these features. 42 | 43 | ### Image to video examples 44 | | | | | 45 | |:---:|:---:|:---:| 46 | | ![example1](./docs/_static/ltx-video_i2v_example_00001.gif) | ![example2](./docs/_static/ltx-video_i2v_example_00002.gif) | ![example3](./docs/_static/ltx-video_i2v_example_00003.gif) | 47 | | ![example4](./docs/_static/ltx-video_i2v_example_00004.gif) | ![example5](./docs/_static/ltx-video_i2v_example_00005.gif) | ![example6](./docs/_static/ltx-video_i2v_example_00006.gif) | 48 | | ![example7](./docs/_static/ltx-video_i2v_example_00007.gif) | ![example8](./docs/_static/ltx-video_i2v_example_00008.gif) | ![example9](./docs/_static/ltx-video_i2v_example_00009.gif) | 49 | 50 | 51 | ### Text to video examples 52 | | | | | 53 | |:---:|:---:|:---:| 54 | | ![example1](./docs/_static/ltx-video_example_00001.gif)
A woman with long brown hair and light skin smiles at another woman...A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.
| ![example10](./docs/_static/ltx-video_example_00010.gif)
A clear, turquoise river flows through a rocky canyon...A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility.
| ![example3](./docs/_static/ltx-video_example_00015.gif)
Two police officers in dark blue uniforms and matching hats...Two police officers in dark blue uniforms and matching hats enter a dimly lit room through a doorway on the left side of the frame. The first officer, with short brown hair and a mustache, steps inside first, followed by his partner, who has a shaved head and a goatee. Both officers have serious expressions and maintain a steady pace as they move deeper into the room. The camera remains stationary, capturing them from a slightly low angle as they enter. The room has exposed brick walls and a corrugated metal ceiling, with a barred window visible in the background. The lighting is low-key, casting shadows on the officers' faces and emphasizing the grim atmosphere. The scene appears to be from a film or television show.
| 55 | | ![example5](./docs/_static/ltx-video_example_00005.gif)
A woman with light skin, wearing a blue jacket and a black hat...A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage.
| ![example6](./docs/_static/ltx-video_example_00006.gif)
A man in a dimly lit room talks on a vintage telephone...A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie.
| ![example7](./docs/_static/ltx-video_example_00007.gif)
A prison guard unlocks and opens a cell door...A prison guard unlocks and opens a cell door to reveal a young man sitting at a table with a woman. The guard, wearing a dark blue uniform with a badge on his left chest, unlocks the cell door with a key held in his right hand and pulls it open; he has short brown hair, light skin, and a neutral expression. The young man, wearing a black and white striped shirt, sits at a table covered with a white tablecloth, facing the woman; he has short brown hair, light skin, and a neutral expression. The woman, wearing a dark blue shirt, sits opposite the young man, her face turned towards him; she has short blonde hair and light skin. The camera remains stationary, capturing the scene from a medium distance, positioned slightly to the right of the guard. The room is dimly lit, with a single light fixture illuminating the table and the two figures. The walls are made of large, grey concrete blocks, and a metal door is visible in the background. The scene is captured in real-life footage.
| 56 | | ![example2](./docs/_static/ltx-video_example_00014.gif)
A man walks towards a window, looks out, and then turns around...A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage.
| ![example13](./docs/_static/ltx-video_example_00013.gif)
The camera pans across a cityscape of tall buildings...The camera pans across a cityscape of tall buildings with a circular building in the center. The camera moves from left to right, showing the tops of the buildings and the circular building in the center. The buildings are various shades of gray and white, and the circular building has a green roof. The camera angle is high, looking down at the city. The lighting is bright, with the sun shining from the upper left, casting shadows from the buildings. The scene is computer-generated imagery.
| ![example11](./docs/_static/ltx-video_example_00011.gif)
A man in a suit enters a room and speaks to two women...A man in a suit enters a room and speaks to two women sitting on a couch. The man, wearing a dark suit with a gold tie, enters the room from the left and walks towards the center of the frame. He has short gray hair, light skin, and a serious expression. He places his right hand on the back of a chair as he approaches the couch. Two women are seated on a light-colored couch in the background. The woman on the left wears a light blue sweater and has short blonde hair. The woman on the right wears a white sweater and has short blonde hair. The camera remains stationary, focusing on the man as he enters the room. The room is brightly lit, with warm tones reflecting off the walls and furniture. The scene appears to be from a film or television show.
| 57 | 58 | # News 59 | 60 | ## May, 14th, 2025: New distilled model 13B v0.9.7: 61 | - Release a new 13B distilled model [ltxv-13b-0.9.7-distilled](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) 62 | * Amazing for iterative work - generates HD videos in 10 seconds, with low-res preview after just 3 seconds (on H100)! 63 | * Does not require classifier-free guidance and spatio-temporal guidance. 64 | * Supports sampling with 8 (recommended), or less diffusion steps. 65 | * Also released a LoRA version of the distilled model, [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors) 66 | * Requires only 1GB of VRAM 67 | * Can be used with the full 13B model for fast inference 68 | - Release a new quantized distilled model [ltxv-13b-0.9.7-distilled-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-fp8.safetensors) for *real-time* generation (on H100) with even less VRAM (Supported in the [official ComfyUI workflow](https://github.com/Lightricks/ComfyUI-LTXVideo/)) 69 | 70 | ## May, 5th, 2025: New model 13B v0.9.7: 71 | - Release a new 13B model [ltxv-13b-0.9.7-dev](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) 72 | - Release a new quantized model [ltxv-13b-0.9.7-dev-fp8](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev-fp8.safetensors) for faster inference with less VRAM (Supported in the [official ComfyUI workflow](https://github.com/Lightricks/ComfyUI-LTXVideo/)) 73 | - Release a new upscalers 74 | * [ltxv-temporal-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-temporal-upscaler-0.9.7.safetensors) 75 | * [ltxv-spatial-upscaler-0.9.7](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) 76 | - Breakthrough prompt adherence and physical understanding. 77 | - New Pipeline for multi-scale video rendering for fast and high quality results 78 | 79 | 80 | ## April, 15th, 2025: New checkpoints v0.9.6: 81 | - Release a new checkpoint [ltxv-2b-0.9.6-dev-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-dev-04-25.safetensors) with improved quality 82 | - Release a new distilled model [ltxv-2b-0.9.6-distilled-04-25](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors) 83 | * 15x faster inference than non-distilled model. 84 | * Does not require classifier-free guidance and spatio-temporal guidance. 85 | * Supports sampling with 8 (recommended), or less diffusion steps. 86 | - Improved prompt adherence, motion quality and fine details. 87 | - New default resolution and FPS: 1216 × 704 pixels at 30 FPS 88 | * Still real time on H100 with the distilled model. 89 | * Other resolutions and FPS are still supported. 90 | - Support stochastic inference (can improve visual quality when using the distilled model) 91 | 92 | ## March, 5th, 2025: New checkpoint v0.9.5 93 | - New license for commercial use ([OpenRail-M](https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.5.license.txt)) 94 | - Release a new checkpoint v0.9.5 with improved quality 95 | - Support keyframes and video extension 96 | - Support higher resolutions 97 | - Improved prompt understanding 98 | - Improved VAE 99 | - New online web app in [LTX-Studio](https://app.ltx.studio/ltx-video) 100 | - Automatic prompt enhancement 101 | 102 | ## February, 20th, 2025: More inference options 103 | - Improve STG (Spatiotemporal Guidance) for LTX-Video 104 | - Support MPS on macOS with PyTorch 2.3.0 105 | - Add support for 8-bit model, LTX-VideoQ8 106 | - Add TeaCache for LTX-Video 107 | - Add [ComfyUI-LTXTricks](#comfyui-integration) 108 | - Add Diffusion-Pipe 109 | 110 | ## December 31st, 2024: Research paper 111 | - Release the [research paper](https://arxiv.org/abs/2501.00103) 112 | 113 | ## December 20th, 2024: New checkpoint v0.9.1 114 | - Release a new checkpoint v0.9.1 with improved quality 115 | - Support for STG / PAG 116 | - Support loading checkpoints of LTX-Video in Diffusers format (conversion is done on-the-fly) 117 | - Support offloading unused parts to CPU 118 | - Support the new timestep-conditioned VAE decoder 119 | - Reference contributions from the community in the readme file 120 | - Relax transformers dependency 121 | 122 | ## November 21th, 2024: Initial release v0.9.0 123 | - Initial release of LTX-Video 124 | - Support text-to-video and image-to-video generation 125 | 126 | 127 | # Models & Workflows 128 | 129 | | Name | Notes | inference.py config | ComfyUI workflow (Recommended) | 130 | |-------------------------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------| 131 | | ltxv-13b-0.9.7-dev | Highest quality, requires more VRAM | [ltxv-13b-0.9.7-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.7-dev.yaml) | [ltxv-13b-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base.json) | 132 | | [ltxv-13b-0.9.7-mix](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) | Mix ltxv-13b-dev and ltxv-13b-distilled in the same multi-scale rendering workflow for balanced speed-quality | N/A | [ltxv-13b-i2v-mixed-multiscale.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-mixed-multiscale.json) | 133 | [ltxv-13b-0.9.7-distilled](https://app.ltx.studio/motion-workspace?videoModel=ltxv) | Faster, less VRAM usage, slight quality reduction compared to 13b. Ideal for rapid iterations | [ltxv-13b-0.9.7-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-13b-0.9.7-dev.yaml) | [ltxv-13b-dist-i2v-base.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base.json) | 134 | | [ltxv-13b-0.9.7-distilled-lora128](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled-lora128.safetensors) | LoRA to make ltxv-13b-dev behave like the distilled model | N/A | N/A | 135 | | ltxv-13b-0.9.7-fp8 | Quantized version of ltxv-13b | Coming soon | [ltxv-13b-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/ltxv-13b-i2v-base-fp8.json) | 136 | | ltxv-13b-0.9.7-distilled-fp8 | Quantized version of ltxv-13b-distilled | Coming soon | [ltxv-13b-dist-i2v-base-fp8.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/13b-distilled/ltxv-13b-dist-i2v-base-fp8.json) | 137 | | ltxv-2b-0.9.6 | Good quality, lower VRAM requirement than ltxv-13b | [ltxv-2b-0.9.6-dev.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-dev.yaml) | [ltxvideo-i2v.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v.json) | 138 | | ltxv-2b-0.9.6-distilled | 15× faster, real-time capable, fewer steps needed, no STG/CFG required | [ltxv-2b-0.9.6-distilled.yaml](https://github.com/Lightricks/LTX-Video/blob/main/configs/ltxv-2b-0.9.6-distilled.yaml) | [ltxvideo-i2v-distilled.json](https://github.com/Lightricks/ComfyUI-LTXVideo/blob/master/example_workflows/low_level/ltxvideo-i2v-distilled.json) | 139 | 140 | 141 | # Quick Start Guide 142 | 143 | ## Online inference 144 | The model is accessible right away via the following links: 145 | - [LTX-Studio image-to-video (13B-mix)](https://app.ltx.studio/motion-workspace?videoModel=ltxv-13b) 146 | - [LTX-Studio image-to-video (13B distilled)](https://app.ltx.studio/motion-workspace?videoModel=ltxv) 147 | - [Fal.ai text-to-video](https://fal.ai/models/fal-ai/ltx-video) 148 | - [Fal.ai image-to-video](https://fal.ai/models/fal-ai/ltx-video/image-to-video) 149 | - [Replicate text-to-video and image-to-video](https://replicate.com/lightricks/ltx-video) 150 | 151 | ## Run locally 152 | 153 | ### Installation 154 | The codebase was tested with Python 3.10.5, CUDA version 12.2, and supports PyTorch >= 2.1.2. 155 | On macos, MPS was tested with PyTorch 2.3.0, and should support PyTorch == 2.3 or >= 2.6. 156 | 157 | ```bash 158 | git clone https://github.com/Lightricks/LTX-Video.git 159 | cd LTX-Video 160 | 161 | # create env 162 | python -m venv env 163 | source env/bin/activate 164 | python -m pip install -e .\[inference-script\] 165 | ``` 166 | 167 | ### Inference 168 | 169 | 📝 **Note:** For best results, we recommend using our [ComfyUI](#comfyui-integration) workflow. We’re working on updating the inference.py script to match the high quality and output fidelity of ComfyUI. 170 | 171 | To use our model, please follow the inference code in [inference.py](./inference.py): 172 | 173 | #### For text-to-video generation: 174 | 175 | ```bash 176 | python inference.py --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml 177 | ``` 178 | 179 | #### For image-to-video generation: 180 | 181 | ```bash 182 | python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_PATH --conditioning_start_frames 0 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml 183 | ``` 184 | 185 | #### Extending a video: 186 | 187 | 📝 **Note:** Input video segments must contain a multiple of 8 frames plus 1 (e.g., 9, 17, 25, etc.), and the target frame number should be a multiple of 8. 188 | 189 | 190 | ```bash 191 | python inference.py --prompt "PROMPT" --conditioning_media_paths VIDEO_PATH --conditioning_start_frames START_FRAME --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml 192 | ``` 193 | 194 | #### For video generation with multiple conditions: 195 | 196 | You can now generate a video conditioned on a set of images and/or short video segments. 197 | Simply provide a list of paths to the images or video segments you want to condition on, along with their target frame numbers in the generated video. You can also specify the conditioning strength for each item (default: 1.0). 198 | 199 | ```bash 200 | python inference.py --prompt "PROMPT" --conditioning_media_paths IMAGE_OR_VIDEO_PATH_1 IMAGE_OR_VIDEO_PATH_2 --conditioning_start_frames TARGET_FRAME_1 TARGET_FRAME_2 --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED --pipeline_config configs/ltxv-13b-0.9.7-distilled.yaml 201 | ``` 202 | 203 | ## ComfyUI Integration 204 | To use our model with ComfyUI, please follow the instructions at [https://github.com/Lightricks/ComfyUI-LTXVideo/](https://github.com/Lightricks/ComfyUI-LTXVideo/). 205 | 206 | ## Diffusers Integration 207 | To use our model with the Diffusers Python library, check out the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video). 208 | 209 | Diffusers also support an 8-bit version of LTX-Video, [see details below](#ltx-videoq8) 210 | 211 | # Model User Guide 212 | 213 | ## 📝 Prompt Engineering 214 | 215 | When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words. For best results, build your prompts using this structure: 216 | 217 | * Start with main action in a single sentence 218 | * Add specific details about movements and gestures 219 | * Describe character/object appearances precisely 220 | * Include background and environment details 221 | * Specify camera angles and movements 222 | * Describe lighting and colors 223 | * Note any changes or sudden events 224 | * See [examples](#introduction) for more inspiration. 225 | 226 | ### Automatic Prompt Enhancement 227 | When using `inference.py`, shorts prompts (below `prompt_enhancement_words_threshold` words) are automatically enhanced by a language model. This is supported with text-to-video and image-to-video (first-frame conditioning). 228 | 229 | When using `LTXVideoPipeline` directly, you can enable prompt enhancement by setting `enhance_prompt=True`. 230 | 231 | ## 🎮 Parameter Guide 232 | 233 | * Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes. The model works on resolutions that are divisible by 32 and number of frames that are divisible by 8 + 1 (e.g. 257). In case the resolution or number of frames are not divisible by 32 or 8 + 1, the input will be padded with -1 and then cropped to the desired resolution and number of frames. The model works best on resolutions under 720 x 1280 and number of frames below 257 234 | * Seed: Save seed values to recreate specific styles or compositions you like 235 | * Guidance Scale: 3-3.5 are the recommended values 236 | * Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed 237 | 238 | 📝 For advanced parameters usage, please see `python inference.py --help` 239 | 240 | ## Community Contribution 241 | 242 | ### ComfyUI-LTXTricks 🛠️ 243 | 244 | A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings. 245 | 246 | - **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks) 247 | - **Features:** 248 | - 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json). 249 | - ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json). 250 | - 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json). 251 | - 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json). 252 | - ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json). 253 | - 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json). 254 | 255 | 256 | ### LTX-VideoQ8 🎱 257 | 258 | **LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs. 259 | 260 | - **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video) 261 | - **Features:** 262 | - 🚀 Up to 3X speed-up with no accuracy loss 263 | - 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM) 264 | - 🛠️ Fine-tune 2B transformer models with precalculated latents 265 | - **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/) 266 | - **Diffusers integration:** A diffusers integration for the 8-bit model is already out! [Details here](https://github.com/sayakpaul/q8-ltx-video) 267 | 268 | 269 | ### TeaCache for LTX-Video 🍵 270 | 271 | **TeaCache** is a training-free caching approach that leverages timestep differences across model outputs to accelerate LTX-Video inference by up to 2x without significant visual quality degradation. 272 | 273 | - **Repository:** [TeaCache4LTX-Video](https://github.com/ali-vilab/TeaCache/tree/main/TeaCache4LTX-Video) 274 | - **Features:** 275 | - 🚀 Speeds up LTX-Video inference. 276 | - 📊 Adjustable trade-offs between speed (up to 2x) and visual quality using configurable parameters. 277 | - 🛠️ No retraining required: Works directly with existing models. 278 | 279 | ### Your Contribution 280 | 281 | ...is welcome! If you have a project or tool that integrates with LTX-Video, 282 | please let us know by opening an issue or pull request. 283 | 284 | # ⚡️ Training 285 | 286 | We provide an open-source repository for fine-tuning the LTX-Video model: [LTX-Video-Trainer](https://github.com/Lightricks/LTX-Video-Trainer). 287 | This repository supports both the 2B and 13B model variants, enabling full fine-tuning as well as LoRA (Low-Rank Adaptation) fine-tuning for more efficient training. 288 | 289 | Explore the repository to customize the model for your specific use cases! 290 | More information and training instructions can be found in the [README](https://github.com/Lightricks/LTX-Video-Trainer/blob/main/README.md). 291 | 292 | 293 | # 🚀 Join Us 294 | 295 | Want to work on cutting-edge AI research and make a real impact on millions of users worldwide? 296 | 297 | At **Lightricks**, an AI-first company, we're revolutionizing how visual content is created. 298 | 299 | If you are passionate about AI, computer vision, and video generation, we would love to hear from you! 300 | 301 | Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information. 302 | 303 | # Acknowledgement 304 | 305 | We are grateful for the following awesome projects when implementing LTX-Video: 306 | * [DiT](https://github.com/facebookresearch/DiT) and [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): vision transformers for image generation. 307 | 308 | 309 | ## Citation 310 | 311 | 📄 Our tech report is out! If you find our work helpful, please ⭐️ star the repository and cite our paper. 312 | 313 | ``` 314 | @article{HaCohen2024LTXVideo, 315 | title={LTX-Video: Realtime Video Latent Diffusion}, 316 | author={HaCohen, Yoav and Chiprut, Nisan and Brazowski, Benny and Shalem, Daniel and Moshe, Dudu and Richardson, Eitan and Levin, Eran and Shiran, Guy and Zabari, Nir and Gordon, Ori and Panet, Poriya and Weissbuch, Sapir and Kulikov, Victor and Bitterman, Yaki and Melumian, Zeev and Bibi, Ofir}, 317 | journal={arXiv preprint arXiv:2501.00103}, 318 | year={2024} 319 | } 320 | ``` 321 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/__init__.py -------------------------------------------------------------------------------- /configs/ltxv-13b-0.9.7-dev.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: multi-scale 2 | checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors" 3 | downscale_factor: 0.6666666 4 | spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" 5 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 6 | decode_timestep: 0.05 7 | decode_noise_scale: 0.025 8 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 9 | precision: "bfloat16" 10 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 11 | prompt_enhancement_words_threshold: 120 12 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 13 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 14 | stochastic_sampling: false 15 | 16 | first_pass: 17 | guidance_scale: [1, 1, 6, 8, 6, 1, 1] 18 | stg_scale: [0, 0, 4, 4, 4, 2, 1] 19 | rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] 20 | guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] 21 | skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] 22 | num_inference_steps: 30 23 | skip_final_inference_steps: 3 24 | cfg_star_rescale: true 25 | 26 | second_pass: 27 | guidance_scale: [1] 28 | stg_scale: [1] 29 | rescaling_scale: [1] 30 | guidance_timesteps: [1.0] 31 | skip_block_list: [27] 32 | num_inference_steps: 30 33 | skip_initial_inference_steps: 17 34 | cfg_star_rescale: true -------------------------------------------------------------------------------- /configs/ltxv-13b-0.9.7-distilled.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: multi-scale 2 | checkpoint_path: "ltxv-13b-0.9.7-distilled.safetensors" 3 | downscale_factor: 0.6666666 4 | spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" 5 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 6 | decode_timestep: 0.05 7 | decode_noise_scale: 0.025 8 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 9 | precision: "bfloat16" 10 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 11 | prompt_enhancement_words_threshold: 120 12 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 13 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 14 | stochastic_sampling: false 15 | 16 | first_pass: 17 | timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] 18 | guidance_scale: 1 19 | stg_scale: 0 20 | rescaling_scale: 1 21 | skip_block_list: [42] 22 | 23 | second_pass: 24 | timesteps: [0.9094, 0.7250, 0.4219] 25 | guidance_scale: 1 26 | stg_scale: 0 27 | rescaling_scale: 1 28 | skip_block_list: [42] 29 | -------------------------------------------------------------------------------- /configs/ltxv-2b-0.9.1.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: base 2 | checkpoint_path: "ltx-video-2b-v0.9.1.safetensors" 3 | guidance_scale: 3 4 | stg_scale: 1 5 | rescaling_scale: 0.7 6 | skip_block_list: [19] 7 | num_inference_steps: 40 8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 9 | decode_timestep: 0.05 10 | decode_noise_scale: 0.025 11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 12 | precision: "bfloat16" 13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 14 | prompt_enhancement_words_threshold: 120 15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 17 | stochastic_sampling: false -------------------------------------------------------------------------------- /configs/ltxv-2b-0.9.5.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: base 2 | checkpoint_path: "ltx-video-2b-v0.9.5.safetensors" 3 | guidance_scale: 3 4 | stg_scale: 1 5 | rescaling_scale: 0.7 6 | skip_block_list: [19] 7 | num_inference_steps: 40 8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 9 | decode_timestep: 0.05 10 | decode_noise_scale: 0.025 11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 12 | precision: "bfloat16" 13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 14 | prompt_enhancement_words_threshold: 120 15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 17 | stochastic_sampling: false -------------------------------------------------------------------------------- /configs/ltxv-2b-0.9.6-dev.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: base 2 | checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors" 3 | guidance_scale: 3 4 | stg_scale: 1 5 | rescaling_scale: 0.7 6 | skip_block_list: [19] 7 | num_inference_steps: 40 8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 9 | decode_timestep: 0.05 10 | decode_noise_scale: 0.025 11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 12 | precision: "bfloat16" 13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 14 | prompt_enhancement_words_threshold: 120 15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 17 | stochastic_sampling: false -------------------------------------------------------------------------------- /configs/ltxv-2b-0.9.6-distilled.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: base 2 | checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors" 3 | guidance_scale: 1 4 | stg_scale: 0 5 | rescaling_scale: 1 6 | num_inference_steps: 8 7 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 8 | decode_timestep: 0.05 9 | decode_noise_scale: 0.025 10 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 11 | precision: "bfloat16" 12 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 13 | prompt_enhancement_words_threshold: 120 14 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 15 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 16 | stochastic_sampling: true -------------------------------------------------------------------------------- /configs/ltxv-2b-0.9.yaml: -------------------------------------------------------------------------------- 1 | pipeline_type: base 2 | checkpoint_path: "ltx-video-2b-v0.9.safetensors" 3 | guidance_scale: 3 4 | stg_scale: 1 5 | rescaling_scale: 0.7 6 | skip_block_list: [19] 7 | num_inference_steps: 40 8 | stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block" 9 | decode_timestep: 0.05 10 | decode_noise_scale: 0.025 11 | text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS" 12 | precision: "bfloat16" 13 | sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint" 14 | prompt_enhancement_words_threshold: 120 15 | prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" 16 | prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" 17 | stochastic_sampling: false -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00001.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b679f14a09d2321b7e34b3ecd23bc01c2cfa75c8d4214a1e59af09826003e2ec 3 | size 7963919 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00005.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:87fdb9556c1218db4b929994e9b807d1d63f4676defef5b418a4edb1ddaa8422 3 | size 5732587 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00006.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f56f3dcc84a871ab4ef1510120f7a4586c7044c5609a897d8177ae8d52eb3eae 3 | size 4239543 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00007.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a08a06681334856db516e969a9ae4290acfd7550f7b970331e87d0223e282bcc 3 | size 7829259 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00010.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bcf1e084e936a75eaae73a29f60935c469b1fc34eb3f5ad89483e88b3a2eaffe 3 | size 6193172 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00011.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3e3d04f5763ecb416b3b80c3488e48c49991d80661c94e8f08dddd7b890b1b75 3 | size 5345673 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00013.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aa7eb790b43f8a55c01d1fbed4c7a7f657fb2ca78a9685833cf9cb558d2002c1 3 | size 9024843 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00014.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4f7afc4b498a927dcc4e1492548db5c32fa76d117e0410d11e1e0b1929153e54 3 | size 7434241 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_example_00015.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d897c9656e0cba89512ab9d2cbe2d2c0f2ddf907dcab5f7eadab4b96b1cb1efe 3 | size 6556457 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00001.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:51e7f2ef92a9a2296e7ce9b40a118033b26372838befdd3d2281819c438ee928 3 | size 9285683 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00002.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0bd2c1314587efc0b8b7656d6634fbcc3a2045801441bf139bab7726275fa353 3 | size 20393804 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00003.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bb62a79340882b490d11e05de0427e4efad3dca19a55e003ef92889613b67825 3 | size 9825156 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00004.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:151cb985a1622512b656288b1a1cba7906a34678a2fd9ae6e25611e330a0f9bb 3 | size 15691608 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00005.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6280efa40b66a50a8a32fddc0e81a55081e8c8efef56a54229ea4e7f2ae4309d 3 | size 10329925 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00006.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:da23e79524a6427959b20809e392555ab4336881ade20001e5ae276d662ed291 3 | size 11936674 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00007.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9bcfe336e823925ace248f3ad517df791986c00a1a3c375bac2cb433154ca133 3 | size 11755718 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00008.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:706368111eda5331dca089e9231bb272a2e61a6b231bd13876ac442cb9d78019 3 | size 14716658 4 | -------------------------------------------------------------------------------- /docs/_static/ltx-video_i2v_example_00009.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a9c13259e01485c16148dc6681922cc3a748c2f7b52ff7e208c3ffc5ab71397d 3 | size 16848341 4 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from datetime import datetime 5 | from pathlib import Path 6 | from diffusers.utils import logging 7 | from typing import Optional, List, Union 8 | import yaml 9 | 10 | import imageio 11 | import json 12 | import numpy as np 13 | import torch 14 | import cv2 15 | from safetensors import safe_open 16 | from PIL import Image 17 | from transformers import ( 18 | T5EncoderModel, 19 | T5Tokenizer, 20 | AutoModelForCausalLM, 21 | AutoProcessor, 22 | AutoTokenizer, 23 | ) 24 | from huggingface_hub import hf_hub_download 25 | 26 | from ltx_video.models.autoencoders.causal_video_autoencoder import ( 27 | CausalVideoAutoencoder, 28 | ) 29 | from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier 30 | from ltx_video.models.transformers.transformer3d import Transformer3DModel 31 | from ltx_video.pipelines.pipeline_ltx_video import ( 32 | ConditioningItem, 33 | LTXVideoPipeline, 34 | LTXMultiScalePipeline, 35 | ) 36 | from ltx_video.schedulers.rf import RectifiedFlowScheduler 37 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy 38 | from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler 39 | import ltx_video.pipelines.crf_compressor as crf_compressor 40 | 41 | MAX_HEIGHT = 720 42 | MAX_WIDTH = 1280 43 | MAX_NUM_FRAMES = 257 44 | 45 | logger = logging.get_logger("LTX-Video") 46 | 47 | 48 | def get_total_gpu_memory(): 49 | if torch.cuda.is_available(): 50 | total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) 51 | return total_memory 52 | return 0 53 | 54 | 55 | def get_device(): 56 | if torch.cuda.is_available(): 57 | return "cuda" 58 | elif torch.backends.mps.is_available(): 59 | return "mps" 60 | return "cpu" 61 | 62 | 63 | def load_image_to_tensor_with_resize_and_crop( 64 | image_input: Union[str, Image.Image], 65 | target_height: int = 512, 66 | target_width: int = 768, 67 | just_crop: bool = False, 68 | ) -> torch.Tensor: 69 | """Load and process an image into a tensor. 70 | 71 | Args: 72 | image_input: Either a file path (str) or a PIL Image object 73 | target_height: Desired height of output tensor 74 | target_width: Desired width of output tensor 75 | just_crop: If True, only crop the image to the target size without resizing 76 | """ 77 | if isinstance(image_input, str): 78 | image = Image.open(image_input).convert("RGB") 79 | elif isinstance(image_input, Image.Image): 80 | image = image_input 81 | else: 82 | raise ValueError("image_input must be either a file path or a PIL Image object") 83 | 84 | input_width, input_height = image.size 85 | aspect_ratio_target = target_width / target_height 86 | aspect_ratio_frame = input_width / input_height 87 | if aspect_ratio_frame > aspect_ratio_target: 88 | new_width = int(input_height * aspect_ratio_target) 89 | new_height = input_height 90 | x_start = (input_width - new_width) // 2 91 | y_start = 0 92 | else: 93 | new_width = input_width 94 | new_height = int(input_width / aspect_ratio_target) 95 | x_start = 0 96 | y_start = (input_height - new_height) // 2 97 | 98 | image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) 99 | if not just_crop: 100 | image = image.resize((target_width, target_height)) 101 | 102 | image = np.array(image) 103 | image = cv2.GaussianBlur(image, (3, 3), 0) 104 | frame_tensor = torch.from_numpy(image).float() 105 | frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0 106 | frame_tensor = frame_tensor.permute(2, 0, 1) 107 | frame_tensor = (frame_tensor / 127.5) - 1.0 108 | # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) 109 | return frame_tensor.unsqueeze(0).unsqueeze(2) 110 | 111 | 112 | def calculate_padding( 113 | source_height: int, source_width: int, target_height: int, target_width: int 114 | ) -> tuple[int, int, int, int]: 115 | 116 | # Calculate total padding needed 117 | pad_height = target_height - source_height 118 | pad_width = target_width - source_width 119 | 120 | # Calculate padding for each side 121 | pad_top = pad_height // 2 122 | pad_bottom = pad_height - pad_top # Handles odd padding 123 | pad_left = pad_width // 2 124 | pad_right = pad_width - pad_left # Handles odd padding 125 | 126 | # Return padded tensor 127 | # Padding format is (left, right, top, bottom) 128 | padding = (pad_left, pad_right, pad_top, pad_bottom) 129 | return padding 130 | 131 | 132 | def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: 133 | # Remove non-letters and convert to lowercase 134 | clean_text = "".join( 135 | char.lower() for char in text if char.isalpha() or char.isspace() 136 | ) 137 | 138 | # Split into words 139 | words = clean_text.split() 140 | 141 | # Build result string keeping track of length 142 | result = [] 143 | current_length = 0 144 | 145 | for word in words: 146 | # Add word length plus 1 for underscore (except for first word) 147 | new_length = current_length + len(word) 148 | 149 | if new_length <= max_len: 150 | result.append(word) 151 | current_length += len(word) 152 | else: 153 | break 154 | 155 | return "-".join(result) 156 | 157 | 158 | # Generate output video name 159 | def get_unique_filename( 160 | base: str, 161 | ext: str, 162 | prompt: str, 163 | seed: int, 164 | resolution: tuple[int, int, int], 165 | dir: Path, 166 | endswith=None, 167 | index_range=1000, 168 | ) -> Path: 169 | base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" 170 | for i in range(index_range): 171 | filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" 172 | if not os.path.exists(filename): 173 | return filename 174 | raise FileExistsError( 175 | f"Could not find a unique filename after {index_range} attempts." 176 | ) 177 | 178 | 179 | def seed_everething(seed: int): 180 | random.seed(seed) 181 | np.random.seed(seed) 182 | torch.manual_seed(seed) 183 | if torch.cuda.is_available(): 184 | torch.cuda.manual_seed(seed) 185 | if torch.backends.mps.is_available(): 186 | torch.mps.manual_seed(seed) 187 | 188 | 189 | def main(): 190 | parser = argparse.ArgumentParser( 191 | description="Load models from separate directories and run the pipeline." 192 | ) 193 | 194 | # Directories 195 | parser.add_argument( 196 | "--output_path", 197 | type=str, 198 | default=None, 199 | help="Path to the folder to save output video, if None will save in outputs/ directory.", 200 | ) 201 | parser.add_argument("--seed", type=int, default="171198") 202 | 203 | # Pipeline parameters 204 | parser.add_argument( 205 | "--num_images_per_prompt", 206 | type=int, 207 | default=1, 208 | help="Number of images per prompt", 209 | ) 210 | parser.add_argument( 211 | "--image_cond_noise_scale", 212 | type=float, 213 | default=0.15, 214 | help="Amount of noise to add to the conditioned image", 215 | ) 216 | parser.add_argument( 217 | "--height", 218 | type=int, 219 | default=704, 220 | help="Height of the output video frames. Optional if an input image provided.", 221 | ) 222 | parser.add_argument( 223 | "--width", 224 | type=int, 225 | default=1216, 226 | help="Width of the output video frames. If None will infer from input image.", 227 | ) 228 | parser.add_argument( 229 | "--num_frames", 230 | type=int, 231 | default=121, 232 | help="Number of frames to generate in the output video", 233 | ) 234 | parser.add_argument( 235 | "--frame_rate", type=int, default=30, help="Frame rate for the output video" 236 | ) 237 | parser.add_argument( 238 | "--device", 239 | default=None, 240 | help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.", 241 | ) 242 | parser.add_argument( 243 | "--pipeline_config", 244 | type=str, 245 | default="configs/ltxv-13b-0.9.7-dev.yaml", 246 | help="The path to the config file for the pipeline, which contains the parameters for the pipeline", 247 | ) 248 | 249 | # Prompts 250 | parser.add_argument( 251 | "--prompt", 252 | type=str, 253 | help="Text prompt to guide generation", 254 | ) 255 | parser.add_argument( 256 | "--negative_prompt", 257 | type=str, 258 | default="worst quality, inconsistent motion, blurry, jittery, distorted", 259 | help="Negative prompt for undesired features", 260 | ) 261 | 262 | parser.add_argument( 263 | "--offload_to_cpu", 264 | action="store_true", 265 | help="Offloading unnecessary computations to CPU.", 266 | ) 267 | 268 | # video-to-video arguments: 269 | parser.add_argument( 270 | "--input_media_path", 271 | type=str, 272 | default=None, 273 | help="Path to the input video (or imaage) to be modified using the video-to-video pipeline", 274 | ) 275 | 276 | # Conditioning arguments 277 | parser.add_argument( 278 | "--conditioning_media_paths", 279 | type=str, 280 | nargs="*", 281 | help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.", 282 | ) 283 | parser.add_argument( 284 | "--conditioning_strengths", 285 | type=float, 286 | nargs="*", 287 | help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.", 288 | ) 289 | parser.add_argument( 290 | "--conditioning_start_frames", 291 | type=int, 292 | nargs="*", 293 | help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.", 294 | ) 295 | 296 | args = parser.parse_args() 297 | logger.warning(f"Running generation with arguments: {args}") 298 | infer(**vars(args)) 299 | 300 | 301 | def create_ltx_video_pipeline( 302 | ckpt_path: str, 303 | precision: str, 304 | text_encoder_model_name_or_path: str, 305 | sampler: Optional[str] = None, 306 | device: Optional[str] = None, 307 | enhance_prompt: bool = False, 308 | prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None, 309 | prompt_enhancer_llm_model_name_or_path: Optional[str] = None, 310 | ) -> LTXVideoPipeline: 311 | ckpt_path = Path(ckpt_path) 312 | assert os.path.exists( 313 | ckpt_path 314 | ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist" 315 | 316 | with safe_open(ckpt_path, framework="pt") as f: 317 | metadata = f.metadata() 318 | config_str = metadata.get("config") 319 | configs = json.loads(config_str) 320 | allowed_inference_steps = configs.get("allowed_inference_steps", None) 321 | 322 | vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) 323 | transformer = Transformer3DModel.from_pretrained(ckpt_path) 324 | 325 | # Use constructor if sampler is specified, otherwise use from_pretrained 326 | if sampler == "from_checkpoint" or not sampler: 327 | scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) 328 | else: 329 | scheduler = RectifiedFlowScheduler( 330 | sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic") 331 | ) 332 | 333 | text_encoder = T5EncoderModel.from_pretrained( 334 | text_encoder_model_name_or_path, subfolder="text_encoder" 335 | ) 336 | patchifier = SymmetricPatchifier(patch_size=1) 337 | tokenizer = T5Tokenizer.from_pretrained( 338 | text_encoder_model_name_or_path, subfolder="tokenizer" 339 | ) 340 | 341 | transformer = transformer.to(device) 342 | vae = vae.to(device) 343 | text_encoder = text_encoder.to(device) 344 | 345 | if enhance_prompt: 346 | prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( 347 | prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True 348 | ) 349 | prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( 350 | prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True 351 | ) 352 | prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( 353 | prompt_enhancer_llm_model_name_or_path, 354 | torch_dtype="bfloat16", 355 | ) 356 | prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( 357 | prompt_enhancer_llm_model_name_or_path, 358 | ) 359 | else: 360 | prompt_enhancer_image_caption_model = None 361 | prompt_enhancer_image_caption_processor = None 362 | prompt_enhancer_llm_model = None 363 | prompt_enhancer_llm_tokenizer = None 364 | 365 | vae = vae.to(torch.bfloat16) 366 | if precision == "bfloat16" and transformer.dtype != torch.bfloat16: 367 | transformer = transformer.to(torch.bfloat16) 368 | text_encoder = text_encoder.to(torch.bfloat16) 369 | 370 | # Use submodels for the pipeline 371 | submodel_dict = { 372 | "transformer": transformer, 373 | "patchifier": patchifier, 374 | "text_encoder": text_encoder, 375 | "tokenizer": tokenizer, 376 | "scheduler": scheduler, 377 | "vae": vae, 378 | "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, 379 | "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, 380 | "prompt_enhancer_llm_model": prompt_enhancer_llm_model, 381 | "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer, 382 | "allowed_inference_steps": allowed_inference_steps, 383 | } 384 | 385 | pipeline = LTXVideoPipeline(**submodel_dict) 386 | pipeline = pipeline.to(device) 387 | return pipeline 388 | 389 | 390 | def create_latent_upsampler(latent_upsampler_model_path: str, device: str): 391 | latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) 392 | latent_upsampler.to(device) 393 | latent_upsampler.eval() 394 | return latent_upsampler 395 | 396 | 397 | def infer( 398 | output_path: Optional[str], 399 | seed: int, 400 | pipeline_config: str, 401 | image_cond_noise_scale: float, 402 | height: Optional[int], 403 | width: Optional[int], 404 | num_frames: int, 405 | frame_rate: int, 406 | prompt: str, 407 | negative_prompt: str, 408 | offload_to_cpu: bool, 409 | input_media_path: Optional[str] = None, 410 | conditioning_media_paths: Optional[List[str]] = None, 411 | conditioning_strengths: Optional[List[float]] = None, 412 | conditioning_start_frames: Optional[List[int]] = None, 413 | device: Optional[str] = None, 414 | **kwargs, 415 | ): 416 | # check if pipeline_config is a file 417 | if not os.path.isfile(pipeline_config): 418 | raise ValueError(f"Pipeline config file {pipeline_config} does not exist") 419 | with open(pipeline_config, "r") as f: 420 | pipeline_config = yaml.safe_load(f) 421 | 422 | models_dir = "MODEL_DIR" 423 | 424 | ltxv_model_name_or_path = pipeline_config["checkpoint_path"] 425 | if not os.path.isfile(ltxv_model_name_or_path): 426 | ltxv_model_path = hf_hub_download( 427 | repo_id="Lightricks/LTX-Video", 428 | filename=ltxv_model_name_or_path, 429 | local_dir=models_dir, 430 | repo_type="model", 431 | ) 432 | else: 433 | ltxv_model_path = ltxv_model_name_or_path 434 | 435 | spatial_upscaler_model_name_or_path = pipeline_config.get( 436 | "spatial_upscaler_model_path" 437 | ) 438 | if spatial_upscaler_model_name_or_path and not os.path.isfile( 439 | spatial_upscaler_model_name_or_path 440 | ): 441 | spatial_upscaler_model_path = hf_hub_download( 442 | repo_id="Lightricks/LTX-Video", 443 | filename=spatial_upscaler_model_name_or_path, 444 | local_dir=models_dir, 445 | repo_type="model", 446 | ) 447 | else: 448 | spatial_upscaler_model_path = spatial_upscaler_model_name_or_path 449 | 450 | if kwargs.get("input_image_path", None): 451 | logger.warning( 452 | "Please use conditioning_media_paths instead of input_image_path." 453 | ) 454 | assert not conditioning_media_paths and not conditioning_start_frames 455 | conditioning_media_paths = [kwargs["input_image_path"]] 456 | conditioning_start_frames = [0] 457 | 458 | # Validate conditioning arguments 459 | if conditioning_media_paths: 460 | # Use default strengths of 1.0 461 | if not conditioning_strengths: 462 | conditioning_strengths = [1.0] * len(conditioning_media_paths) 463 | if not conditioning_start_frames: 464 | raise ValueError( 465 | "If `conditioning_media_paths` is provided, " 466 | "`conditioning_start_frames` must also be provided" 467 | ) 468 | if len(conditioning_media_paths) != len(conditioning_strengths) or len( 469 | conditioning_media_paths 470 | ) != len(conditioning_start_frames): 471 | raise ValueError( 472 | "`conditioning_media_paths`, `conditioning_strengths`, " 473 | "and `conditioning_start_frames` must have the same length" 474 | ) 475 | if any(s < 0 or s > 1 for s in conditioning_strengths): 476 | raise ValueError("All conditioning strengths must be between 0 and 1") 477 | if any(f < 0 or f >= num_frames for f in conditioning_start_frames): 478 | raise ValueError( 479 | f"All conditioning start frames must be between 0 and {num_frames-1}" 480 | ) 481 | 482 | seed_everething(seed) 483 | if offload_to_cpu and not torch.cuda.is_available(): 484 | logger.warning( 485 | "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU." 486 | ) 487 | offload_to_cpu = False 488 | else: 489 | offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30 490 | 491 | output_dir = ( 492 | Path(output_path) 493 | if output_path 494 | else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") 495 | ) 496 | output_dir.mkdir(parents=True, exist_ok=True) 497 | 498 | # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) 499 | height_padded = ((height - 1) // 32 + 1) * 32 500 | width_padded = ((width - 1) // 32 + 1) * 32 501 | num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1 502 | 503 | padding = calculate_padding(height, width, height_padded, width_padded) 504 | 505 | logger.warning( 506 | f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" 507 | ) 508 | 509 | prompt_enhancement_words_threshold = pipeline_config[ 510 | "prompt_enhancement_words_threshold" 511 | ] 512 | 513 | prompt_word_count = len(prompt.split()) 514 | enhance_prompt = ( 515 | prompt_enhancement_words_threshold > 0 516 | and prompt_word_count < prompt_enhancement_words_threshold 517 | ) 518 | 519 | if prompt_enhancement_words_threshold > 0 and not enhance_prompt: 520 | logger.info( 521 | f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled." 522 | ) 523 | 524 | precision = pipeline_config["precision"] 525 | text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"] 526 | sampler = pipeline_config["sampler"] 527 | prompt_enhancer_image_caption_model_name_or_path = pipeline_config[ 528 | "prompt_enhancer_image_caption_model_name_or_path" 529 | ] 530 | prompt_enhancer_llm_model_name_or_path = pipeline_config[ 531 | "prompt_enhancer_llm_model_name_or_path" 532 | ] 533 | 534 | pipeline = create_ltx_video_pipeline( 535 | ckpt_path=ltxv_model_path, 536 | precision=precision, 537 | text_encoder_model_name_or_path=text_encoder_model_name_or_path, 538 | sampler=sampler, 539 | device=kwargs.get("device", get_device()), 540 | enhance_prompt=enhance_prompt, 541 | prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path, 542 | prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path, 543 | ) 544 | 545 | if pipeline_config.get("pipeline_type", None) == "multi-scale": 546 | if not spatial_upscaler_model_path: 547 | raise ValueError( 548 | "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" 549 | ) 550 | latent_upsampler = create_latent_upsampler( 551 | spatial_upscaler_model_path, pipeline.device 552 | ) 553 | pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) 554 | 555 | media_item = None 556 | if input_media_path: 557 | media_item = load_media_file( 558 | media_path=input_media_path, 559 | height=height, 560 | width=width, 561 | max_frames=num_frames_padded, 562 | padding=padding, 563 | ) 564 | 565 | conditioning_items = ( 566 | prepare_conditioning( 567 | conditioning_media_paths=conditioning_media_paths, 568 | conditioning_strengths=conditioning_strengths, 569 | conditioning_start_frames=conditioning_start_frames, 570 | height=height, 571 | width=width, 572 | num_frames=num_frames, 573 | padding=padding, 574 | pipeline=pipeline, 575 | ) 576 | if conditioning_media_paths 577 | else None 578 | ) 579 | 580 | stg_mode = pipeline_config.get("stg_mode", "attention_values") 581 | del pipeline_config["stg_mode"] 582 | if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": 583 | skip_layer_strategy = SkipLayerStrategy.AttentionValues 584 | elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": 585 | skip_layer_strategy = SkipLayerStrategy.AttentionSkip 586 | elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": 587 | skip_layer_strategy = SkipLayerStrategy.Residual 588 | elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": 589 | skip_layer_strategy = SkipLayerStrategy.TransformerBlock 590 | else: 591 | raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") 592 | 593 | # Prepare input for the pipeline 594 | sample = { 595 | "prompt": prompt, 596 | "prompt_attention_mask": None, 597 | "negative_prompt": negative_prompt, 598 | "negative_prompt_attention_mask": None, 599 | } 600 | 601 | device = device or get_device() 602 | generator = torch.Generator(device=device).manual_seed(seed) 603 | 604 | images = pipeline( 605 | **pipeline_config, 606 | skip_layer_strategy=skip_layer_strategy, 607 | generator=generator, 608 | output_type="pt", 609 | callback_on_step_end=None, 610 | height=height_padded, 611 | width=width_padded, 612 | num_frames=num_frames_padded, 613 | frame_rate=frame_rate, 614 | **sample, 615 | media_items=media_item, 616 | conditioning_items=conditioning_items, 617 | is_video=True, 618 | vae_per_channel_normalize=True, 619 | image_cond_noise_scale=image_cond_noise_scale, 620 | mixed_precision=(precision == "mixed_precision"), 621 | offload_to_cpu=offload_to_cpu, 622 | device=device, 623 | enhance_prompt=enhance_prompt, 624 | ).images 625 | 626 | # Crop the padded images to the desired resolution and number of frames 627 | (pad_left, pad_right, pad_top, pad_bottom) = padding 628 | pad_bottom = -pad_bottom 629 | pad_right = -pad_right 630 | if pad_bottom == 0: 631 | pad_bottom = images.shape[3] 632 | if pad_right == 0: 633 | pad_right = images.shape[4] 634 | images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right] 635 | 636 | for i in range(images.shape[0]): 637 | # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C 638 | video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() 639 | # Unnormalizing images to [0, 255] range 640 | video_np = (video_np * 255).astype(np.uint8) 641 | fps = frame_rate 642 | height, width = video_np.shape[1:3] 643 | # In case a single image is generated 644 | if video_np.shape[0] == 1: 645 | output_filename = get_unique_filename( 646 | f"image_output_{i}", 647 | ".png", 648 | prompt=prompt, 649 | seed=seed, 650 | resolution=(height, width, num_frames), 651 | dir=output_dir, 652 | ) 653 | imageio.imwrite(output_filename, video_np[0]) 654 | else: 655 | output_filename = get_unique_filename( 656 | f"video_output_{i}", 657 | ".mp4", 658 | prompt=prompt, 659 | seed=seed, 660 | resolution=(height, width, num_frames), 661 | dir=output_dir, 662 | ) 663 | 664 | # Write video 665 | with imageio.get_writer(output_filename, fps=fps) as video: 666 | for frame in video_np: 667 | video.append_data(frame) 668 | 669 | logger.warning(f"Output saved to {output_filename}") 670 | 671 | 672 | def prepare_conditioning( 673 | conditioning_media_paths: List[str], 674 | conditioning_strengths: List[float], 675 | conditioning_start_frames: List[int], 676 | height: int, 677 | width: int, 678 | num_frames: int, 679 | padding: tuple[int, int, int, int], 680 | pipeline: LTXVideoPipeline, 681 | ) -> Optional[List[ConditioningItem]]: 682 | """Prepare conditioning items based on input media paths and their parameters. 683 | 684 | Args: 685 | conditioning_media_paths: List of paths to conditioning media (images or videos) 686 | conditioning_strengths: List of conditioning strengths for each media item 687 | conditioning_start_frames: List of frame indices where each item should be applied 688 | height: Height of the output frames 689 | width: Width of the output frames 690 | num_frames: Number of frames in the output video 691 | padding: Padding to apply to the frames 692 | pipeline: LTXVideoPipeline object used for condition video trimming 693 | 694 | Returns: 695 | A list of ConditioningItem objects. 696 | """ 697 | conditioning_items = [] 698 | for path, strength, start_frame in zip( 699 | conditioning_media_paths, conditioning_strengths, conditioning_start_frames 700 | ): 701 | num_input_frames = orig_num_input_frames = get_media_num_frames(path) 702 | if hasattr(pipeline, "trim_conditioning_sequence") and callable( 703 | getattr(pipeline, "trim_conditioning_sequence") 704 | ): 705 | num_input_frames = pipeline.trim_conditioning_sequence( 706 | start_frame, orig_num_input_frames, num_frames 707 | ) 708 | if num_input_frames < orig_num_input_frames: 709 | logger.warning( 710 | f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames." 711 | ) 712 | 713 | media_tensor = load_media_file( 714 | media_path=path, 715 | height=height, 716 | width=width, 717 | max_frames=num_input_frames, 718 | padding=padding, 719 | just_crop=True, 720 | ) 721 | conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) 722 | return conditioning_items 723 | 724 | 725 | def get_media_num_frames(media_path: str) -> int: 726 | is_video = any( 727 | media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] 728 | ) 729 | num_frames = 1 730 | if is_video: 731 | reader = imageio.get_reader(media_path) 732 | num_frames = reader.count_frames() 733 | reader.close() 734 | return num_frames 735 | 736 | 737 | def load_media_file( 738 | media_path: str, 739 | height: int, 740 | width: int, 741 | max_frames: int, 742 | padding: tuple[int, int, int, int], 743 | just_crop: bool = False, 744 | ) -> torch.Tensor: 745 | is_video = any( 746 | media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] 747 | ) 748 | if is_video: 749 | reader = imageio.get_reader(media_path) 750 | num_input_frames = min(reader.count_frames(), max_frames) 751 | 752 | # Read and preprocess the relevant frames from the video file. 753 | frames = [] 754 | for i in range(num_input_frames): 755 | frame = Image.fromarray(reader.get_data(i)) 756 | frame_tensor = load_image_to_tensor_with_resize_and_crop( 757 | frame, height, width, just_crop=just_crop 758 | ) 759 | frame_tensor = torch.nn.functional.pad(frame_tensor, padding) 760 | frames.append(frame_tensor) 761 | reader.close() 762 | 763 | # Stack frames along the temporal dimension 764 | media_tensor = torch.cat(frames, dim=2) 765 | else: # Input image 766 | media_tensor = load_image_to_tensor_with_resize_and_crop( 767 | media_path, height, width, just_crop=just_crop 768 | ) 769 | media_tensor = torch.nn.functional.pad(media_tensor, padding) 770 | return media_tensor 771 | 772 | 773 | if __name__ == "__main__": 774 | main() 775 | -------------------------------------------------------------------------------- /ltx_video/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/__init__.py -------------------------------------------------------------------------------- /ltx_video/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/__init__.py -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/autoencoders/__init__.py -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/causal_conv3d.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CausalConv3d(nn.Module): 8 | def __init__( 9 | self, 10 | in_channels, 11 | out_channels, 12 | kernel_size: int = 3, 13 | stride: Union[int, Tuple[int]] = 1, 14 | dilation: int = 1, 15 | groups: int = 1, 16 | spatial_padding_mode: str = "zeros", 17 | **kwargs, 18 | ): 19 | super().__init__() 20 | 21 | self.in_channels = in_channels 22 | self.out_channels = out_channels 23 | 24 | kernel_size = (kernel_size, kernel_size, kernel_size) 25 | self.time_kernel_size = kernel_size[0] 26 | 27 | dilation = (dilation, 1, 1) 28 | 29 | height_pad = kernel_size[1] // 2 30 | width_pad = kernel_size[2] // 2 31 | padding = (0, height_pad, width_pad) 32 | 33 | self.conv = nn.Conv3d( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=stride, 38 | dilation=dilation, 39 | padding=padding, 40 | padding_mode=spatial_padding_mode, 41 | groups=groups, 42 | ) 43 | 44 | def forward(self, x, causal: bool = True): 45 | if causal: 46 | first_frame_pad = x[:, :, :1, :, :].repeat( 47 | (1, 1, self.time_kernel_size - 1, 1, 1) 48 | ) 49 | x = torch.concatenate((first_frame_pad, x), dim=2) 50 | else: 51 | first_frame_pad = x[:, :, :1, :, :].repeat( 52 | (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) 53 | ) 54 | last_frame_pad = x[:, :, -1:, :, :].repeat( 55 | (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) 56 | ) 57 | x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) 58 | x = self.conv(x) 59 | return x 60 | 61 | @property 62 | def weight(self): 63 | return self.conv.weight 64 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/conv_nd_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | 5 | from ltx_video.models.autoencoders.dual_conv3d import DualConv3d 6 | from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d 7 | 8 | 9 | def make_conv_nd( 10 | dims: Union[int, Tuple[int, int]], 11 | in_channels: int, 12 | out_channels: int, 13 | kernel_size: int, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | causal=False, 20 | spatial_padding_mode="zeros", 21 | temporal_padding_mode="zeros", 22 | ): 23 | if not (spatial_padding_mode == temporal_padding_mode or causal): 24 | raise NotImplementedError("spatial and temporal padding modes must be equal") 25 | if dims == 2: 26 | return torch.nn.Conv2d( 27 | in_channels=in_channels, 28 | out_channels=out_channels, 29 | kernel_size=kernel_size, 30 | stride=stride, 31 | padding=padding, 32 | dilation=dilation, 33 | groups=groups, 34 | bias=bias, 35 | padding_mode=spatial_padding_mode, 36 | ) 37 | elif dims == 3: 38 | if causal: 39 | return CausalConv3d( 40 | in_channels=in_channels, 41 | out_channels=out_channels, 42 | kernel_size=kernel_size, 43 | stride=stride, 44 | padding=padding, 45 | dilation=dilation, 46 | groups=groups, 47 | bias=bias, 48 | spatial_padding_mode=spatial_padding_mode, 49 | ) 50 | return torch.nn.Conv3d( 51 | in_channels=in_channels, 52 | out_channels=out_channels, 53 | kernel_size=kernel_size, 54 | stride=stride, 55 | padding=padding, 56 | dilation=dilation, 57 | groups=groups, 58 | bias=bias, 59 | padding_mode=spatial_padding_mode, 60 | ) 61 | elif dims == (2, 1): 62 | return DualConv3d( 63 | in_channels=in_channels, 64 | out_channels=out_channels, 65 | kernel_size=kernel_size, 66 | stride=stride, 67 | padding=padding, 68 | bias=bias, 69 | padding_mode=spatial_padding_mode, 70 | ) 71 | else: 72 | raise ValueError(f"unsupported dimensions: {dims}") 73 | 74 | 75 | def make_linear_nd( 76 | dims: int, 77 | in_channels: int, 78 | out_channels: int, 79 | bias=True, 80 | ): 81 | if dims == 2: 82 | return torch.nn.Conv2d( 83 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias 84 | ) 85 | elif dims == 3 or dims == (2, 1): 86 | return torch.nn.Conv3d( 87 | in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias 88 | ) 89 | else: 90 | raise ValueError(f"unsupported dimensions: {dims}") 91 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/dual_conv3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | 9 | 10 | class DualConv3d(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride: Union[int, Tuple[int, int, int]] = 1, 17 | padding: Union[int, Tuple[int, int, int]] = 0, 18 | dilation: Union[int, Tuple[int, int, int]] = 1, 19 | groups=1, 20 | bias=True, 21 | padding_mode="zeros", 22 | ): 23 | super(DualConv3d, self).__init__() 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.padding_mode = padding_mode 28 | # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 29 | if isinstance(kernel_size, int): 30 | kernel_size = (kernel_size, kernel_size, kernel_size) 31 | if kernel_size == (1, 1, 1): 32 | raise ValueError( 33 | "kernel_size must be greater than 1. Use make_linear_nd instead." 34 | ) 35 | if isinstance(stride, int): 36 | stride = (stride, stride, stride) 37 | if isinstance(padding, int): 38 | padding = (padding, padding, padding) 39 | if isinstance(dilation, int): 40 | dilation = (dilation, dilation, dilation) 41 | 42 | # Set parameters for convolutions 43 | self.groups = groups 44 | self.bias = bias 45 | 46 | # Define the size of the channels after the first convolution 47 | intermediate_channels = ( 48 | out_channels if in_channels < out_channels else in_channels 49 | ) 50 | 51 | # Define parameters for the first convolution 52 | self.weight1 = nn.Parameter( 53 | torch.Tensor( 54 | intermediate_channels, 55 | in_channels // groups, 56 | 1, 57 | kernel_size[1], 58 | kernel_size[2], 59 | ) 60 | ) 61 | self.stride1 = (1, stride[1], stride[2]) 62 | self.padding1 = (0, padding[1], padding[2]) 63 | self.dilation1 = (1, dilation[1], dilation[2]) 64 | if bias: 65 | self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) 66 | else: 67 | self.register_parameter("bias1", None) 68 | 69 | # Define parameters for the second convolution 70 | self.weight2 = nn.Parameter( 71 | torch.Tensor( 72 | out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 73 | ) 74 | ) 75 | self.stride2 = (stride[0], 1, 1) 76 | self.padding2 = (padding[0], 0, 0) 77 | self.dilation2 = (dilation[0], 1, 1) 78 | if bias: 79 | self.bias2 = nn.Parameter(torch.Tensor(out_channels)) 80 | else: 81 | self.register_parameter("bias2", None) 82 | 83 | # Initialize weights and biases 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) 88 | nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) 89 | if self.bias: 90 | fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) 91 | bound1 = 1 / math.sqrt(fan_in1) 92 | nn.init.uniform_(self.bias1, -bound1, bound1) 93 | fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) 94 | bound2 = 1 / math.sqrt(fan_in2) 95 | nn.init.uniform_(self.bias2, -bound2, bound2) 96 | 97 | def forward(self, x, use_conv3d=False, skip_time_conv=False): 98 | if use_conv3d: 99 | return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) 100 | else: 101 | return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) 102 | 103 | def forward_with_3d(self, x, skip_time_conv): 104 | # First convolution 105 | x = F.conv3d( 106 | x, 107 | self.weight1, 108 | self.bias1, 109 | self.stride1, 110 | self.padding1, 111 | self.dilation1, 112 | self.groups, 113 | padding_mode=self.padding_mode, 114 | ) 115 | 116 | if skip_time_conv: 117 | return x 118 | 119 | # Second convolution 120 | x = F.conv3d( 121 | x, 122 | self.weight2, 123 | self.bias2, 124 | self.stride2, 125 | self.padding2, 126 | self.dilation2, 127 | self.groups, 128 | padding_mode=self.padding_mode, 129 | ) 130 | 131 | return x 132 | 133 | def forward_with_2d(self, x, skip_time_conv): 134 | b, c, d, h, w = x.shape 135 | 136 | # First 2D convolution 137 | x = rearrange(x, "b c d h w -> (b d) c h w") 138 | # Squeeze the depth dimension out of weight1 since it's 1 139 | weight1 = self.weight1.squeeze(2) 140 | # Select stride, padding, and dilation for the 2D convolution 141 | stride1 = (self.stride1[1], self.stride1[2]) 142 | padding1 = (self.padding1[1], self.padding1[2]) 143 | dilation1 = (self.dilation1[1], self.dilation1[2]) 144 | x = F.conv2d( 145 | x, 146 | weight1, 147 | self.bias1, 148 | stride1, 149 | padding1, 150 | dilation1, 151 | self.groups, 152 | padding_mode=self.padding_mode, 153 | ) 154 | 155 | _, _, h, w = x.shape 156 | 157 | if skip_time_conv: 158 | x = rearrange(x, "(b d) c h w -> b c d h w", b=b) 159 | return x 160 | 161 | # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension 162 | x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) 163 | 164 | # Reshape weight2 to match the expected dimensions for conv1d 165 | weight2 = self.weight2.squeeze(-1).squeeze(-1) 166 | # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution 167 | stride2 = self.stride2[0] 168 | padding2 = self.padding2[0] 169 | dilation2 = self.dilation2[0] 170 | x = F.conv1d( 171 | x, 172 | weight2, 173 | self.bias2, 174 | stride2, 175 | padding2, 176 | dilation2, 177 | self.groups, 178 | padding_mode=self.padding_mode, 179 | ) 180 | x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) 181 | 182 | return x 183 | 184 | @property 185 | def weight(self): 186 | return self.weight2 187 | 188 | 189 | def test_dual_conv3d_consistency(): 190 | # Initialize parameters 191 | in_channels = 3 192 | out_channels = 5 193 | kernel_size = (3, 3, 3) 194 | stride = (2, 2, 2) 195 | padding = (1, 1, 1) 196 | 197 | # Create an instance of the DualConv3d class 198 | dual_conv3d = DualConv3d( 199 | in_channels=in_channels, 200 | out_channels=out_channels, 201 | kernel_size=kernel_size, 202 | stride=stride, 203 | padding=padding, 204 | bias=True, 205 | ) 206 | 207 | # Example input tensor 208 | test_input = torch.randn(1, 3, 10, 10, 10) 209 | 210 | # Perform forward passes with both 3D and 2D settings 211 | output_conv3d = dual_conv3d(test_input, use_conv3d=True) 212 | output_2d = dual_conv3d(test_input, use_conv3d=False) 213 | 214 | # Assert that the outputs from both methods are sufficiently close 215 | assert torch.allclose( 216 | output_conv3d, output_2d, atol=1e-6 217 | ), "Outputs are not consistent between 3D and 2D convolutions." 218 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/latent_upsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pathlib import Path 3 | import os 4 | import json 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from diffusers import ConfigMixin, ModelMixin 10 | from safetensors.torch import safe_open 11 | 12 | from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND 13 | 14 | 15 | class ResBlock(nn.Module): 16 | def __init__( 17 | self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 18 | ): 19 | super().__init__() 20 | if mid_channels is None: 21 | mid_channels = channels 22 | 23 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d 24 | 25 | self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) 26 | self.norm1 = nn.GroupNorm(32, mid_channels) 27 | self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) 28 | self.norm2 = nn.GroupNorm(32, channels) 29 | self.activation = nn.SiLU() 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | residual = x 33 | x = self.conv1(x) 34 | x = self.norm1(x) 35 | x = self.activation(x) 36 | x = self.conv2(x) 37 | x = self.norm2(x) 38 | x = self.activation(x + residual) 39 | return x 40 | 41 | 42 | class LatentUpsampler(ModelMixin, ConfigMixin): 43 | """ 44 | Model to spatially upsample VAE latents. 45 | 46 | Args: 47 | in_channels (`int`): Number of channels in the input latent 48 | mid_channels (`int`): Number of channels in the middle layers 49 | num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) 50 | dims (`int`): Number of dimensions for convolutions (2 or 3) 51 | spatial_upsample (`bool`): Whether to spatially upsample the latent 52 | temporal_upsample (`bool`): Whether to temporally upsample the latent 53 | """ 54 | 55 | def __init__( 56 | self, 57 | in_channels: int = 128, 58 | mid_channels: int = 512, 59 | num_blocks_per_stage: int = 4, 60 | dims: int = 3, 61 | spatial_upsample: bool = True, 62 | temporal_upsample: bool = False, 63 | ): 64 | super().__init__() 65 | 66 | self.in_channels = in_channels 67 | self.mid_channels = mid_channels 68 | self.num_blocks_per_stage = num_blocks_per_stage 69 | self.dims = dims 70 | self.spatial_upsample = spatial_upsample 71 | self.temporal_upsample = temporal_upsample 72 | 73 | Conv = nn.Conv2d if dims == 2 else nn.Conv3d 74 | 75 | self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) 76 | self.initial_norm = nn.GroupNorm(32, mid_channels) 77 | self.initial_activation = nn.SiLU() 78 | 79 | self.res_blocks = nn.ModuleList( 80 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] 81 | ) 82 | 83 | if spatial_upsample and temporal_upsample: 84 | self.upsampler = nn.Sequential( 85 | nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), 86 | PixelShuffleND(3), 87 | ) 88 | elif spatial_upsample: 89 | self.upsampler = nn.Sequential( 90 | nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), 91 | PixelShuffleND(2), 92 | ) 93 | elif temporal_upsample: 94 | self.upsampler = nn.Sequential( 95 | nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), 96 | PixelShuffleND(1), 97 | ) 98 | else: 99 | raise ValueError( 100 | "Either spatial_upsample or temporal_upsample must be True" 101 | ) 102 | 103 | self.post_upsample_res_blocks = nn.ModuleList( 104 | [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] 105 | ) 106 | 107 | self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) 108 | 109 | def forward(self, latent: torch.Tensor) -> torch.Tensor: 110 | b, c, f, h, w = latent.shape 111 | 112 | if self.dims == 2: 113 | x = rearrange(latent, "b c f h w -> (b f) c h w") 114 | x = self.initial_conv(x) 115 | x = self.initial_norm(x) 116 | x = self.initial_activation(x) 117 | 118 | for block in self.res_blocks: 119 | x = block(x) 120 | 121 | x = self.upsampler(x) 122 | 123 | for block in self.post_upsample_res_blocks: 124 | x = block(x) 125 | 126 | x = self.final_conv(x) 127 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) 128 | else: 129 | x = self.initial_conv(latent) 130 | x = self.initial_norm(x) 131 | x = self.initial_activation(x) 132 | 133 | for block in self.res_blocks: 134 | x = block(x) 135 | 136 | if self.temporal_upsample: 137 | x = self.upsampler(x) 138 | x = x[:, :, 1:, :, :] 139 | else: 140 | x = rearrange(x, "b c f h w -> (b f) c h w") 141 | x = self.upsampler(x) 142 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) 143 | 144 | for block in self.post_upsample_res_blocks: 145 | x = block(x) 146 | 147 | x = self.final_conv(x) 148 | 149 | return x 150 | 151 | @classmethod 152 | def from_config(cls, config): 153 | return cls( 154 | in_channels=config.get("in_channels", 4), 155 | mid_channels=config.get("mid_channels", 128), 156 | num_blocks_per_stage=config.get("num_blocks_per_stage", 4), 157 | dims=config.get("dims", 2), 158 | spatial_upsample=config.get("spatial_upsample", True), 159 | temporal_upsample=config.get("temporal_upsample", False), 160 | ) 161 | 162 | def config(self): 163 | return { 164 | "_class_name": "LatentUpsampler", 165 | "in_channels": self.in_channels, 166 | "mid_channels": self.mid_channels, 167 | "num_blocks_per_stage": self.num_blocks_per_stage, 168 | "dims": self.dims, 169 | "spatial_upsample": self.spatial_upsample, 170 | "temporal_upsample": self.temporal_upsample, 171 | } 172 | 173 | @classmethod 174 | def from_pretrained( 175 | cls, 176 | pretrained_model_path: Optional[Union[str, os.PathLike]], 177 | *args, 178 | **kwargs, 179 | ): 180 | pretrained_model_path = Path(pretrained_model_path) 181 | if pretrained_model_path.is_file() and str(pretrained_model_path).endswith( 182 | ".safetensors" 183 | ): 184 | state_dict = {} 185 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: 186 | metadata = f.metadata() 187 | for k in f.keys(): 188 | state_dict[k] = f.get_tensor(k) 189 | config = json.loads(metadata["config"]) 190 | with torch.device("meta"): 191 | latent_upsampler = LatentUpsampler.from_config(config) 192 | latent_upsampler.load_state_dict(state_dict, assign=True) 193 | return latent_upsampler 194 | 195 | 196 | if __name__ == "__main__": 197 | latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3) 198 | print(latent_upsampler) 199 | total_params = sum(p.numel() for p in latent_upsampler.parameters()) 200 | print(f"Total number of parameters: {total_params:,}") 201 | latent = torch.randn(1, 128, 9, 16, 16) 202 | upsampled_latent = latent_upsampler(latent) 203 | print(f"Upsampled latent shape: {upsampled_latent.shape}") 204 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/pixel_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PixelNorm(nn.Module): 6 | def __init__(self, dim=1, eps=1e-8): 7 | super(PixelNorm, self).__init__() 8 | self.dim = dim 9 | self.eps = eps 10 | 11 | def forward(self, x): 12 | return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) 13 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/pixel_shuffle.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | 4 | 5 | class PixelShuffleND(nn.Module): 6 | def __init__(self, dims, upscale_factors=(2, 2, 2)): 7 | super().__init__() 8 | assert dims in [1, 2, 3], "dims must be 1, 2, or 3" 9 | self.dims = dims 10 | self.upscale_factors = upscale_factors 11 | 12 | def forward(self, x): 13 | if self.dims == 3: 14 | return rearrange( 15 | x, 16 | "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", 17 | p1=self.upscale_factors[0], 18 | p2=self.upscale_factors[1], 19 | p3=self.upscale_factors[2], 20 | ) 21 | elif self.dims == 2: 22 | return rearrange( 23 | x, 24 | "b (c p1 p2) h w -> b c (h p1) (w p2)", 25 | p1=self.upscale_factors[0], 26 | p2=self.upscale_factors[1], 27 | ) 28 | elif self.dims == 1: 29 | return rearrange( 30 | x, 31 | "b (c p1) f h w -> b c (f p1) h w", 32 | p1=self.upscale_factors[0], 33 | ) 34 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/vae.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | import inspect 5 | import math 6 | import torch.nn as nn 7 | from diffusers import ConfigMixin, ModelMixin 8 | from diffusers.models.autoencoders.vae import ( 9 | DecoderOutput, 10 | DiagonalGaussianDistribution, 11 | ) 12 | from diffusers.models.modeling_outputs import AutoencoderKLOutput 13 | from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd 14 | 15 | 16 | class AutoencoderKLWrapper(ModelMixin, ConfigMixin): 17 | """Variational Autoencoder (VAE) model with KL loss. 18 | 19 | VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. 20 | This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. 21 | 22 | Args: 23 | encoder (`nn.Module`): 24 | Encoder module. 25 | decoder (`nn.Module`): 26 | Decoder module. 27 | latent_channels (`int`, *optional*, defaults to 4): 28 | Number of latent channels. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | encoder: nn.Module, 34 | decoder: nn.Module, 35 | latent_channels: int = 4, 36 | dims: int = 2, 37 | sample_size=512, 38 | use_quant_conv: bool = True, 39 | normalize_latent_channels: bool = False, 40 | ): 41 | super().__init__() 42 | 43 | # pass init params to Encoder 44 | self.encoder = encoder 45 | self.use_quant_conv = use_quant_conv 46 | self.normalize_latent_channels = normalize_latent_channels 47 | 48 | # pass init params to Decoder 49 | quant_dims = 2 if dims == 2 else 3 50 | self.decoder = decoder 51 | if use_quant_conv: 52 | self.quant_conv = make_conv_nd( 53 | quant_dims, 2 * latent_channels, 2 * latent_channels, 1 54 | ) 55 | self.post_quant_conv = make_conv_nd( 56 | quant_dims, latent_channels, latent_channels, 1 57 | ) 58 | else: 59 | self.quant_conv = nn.Identity() 60 | self.post_quant_conv = nn.Identity() 61 | 62 | if normalize_latent_channels: 63 | if dims == 2: 64 | self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False) 65 | else: 66 | self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False) 67 | else: 68 | self.latent_norm_out = nn.Identity() 69 | self.use_z_tiling = False 70 | self.use_hw_tiling = False 71 | self.dims = dims 72 | self.z_sample_size = 1 73 | 74 | self.decoder_params = inspect.signature(self.decoder.forward).parameters 75 | 76 | # only relevant if vae tiling is enabled 77 | self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) 78 | 79 | def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): 80 | self.tile_sample_min_size = sample_size 81 | num_blocks = len(self.encoder.down_blocks) 82 | self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) 83 | self.tile_overlap_factor = overlap_factor 84 | 85 | def enable_z_tiling(self, z_sample_size: int = 8): 86 | r""" 87 | Enable tiling during VAE decoding. 88 | 89 | When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several 90 | steps. This is useful to save some memory and allow larger batch sizes. 91 | """ 92 | self.use_z_tiling = z_sample_size > 1 93 | self.z_sample_size = z_sample_size 94 | assert ( 95 | z_sample_size % 8 == 0 or z_sample_size == 1 96 | ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." 97 | 98 | def disable_z_tiling(self): 99 | r""" 100 | Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing 101 | decoding in one step. 102 | """ 103 | self.use_z_tiling = False 104 | 105 | def enable_hw_tiling(self): 106 | r""" 107 | Enable tiling during VAE decoding along the height and width dimension. 108 | """ 109 | self.use_hw_tiling = True 110 | 111 | def disable_hw_tiling(self): 112 | r""" 113 | Disable tiling during VAE decoding along the height and width dimension. 114 | """ 115 | self.use_hw_tiling = False 116 | 117 | def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): 118 | overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) 119 | blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) 120 | row_limit = self.tile_latent_min_size - blend_extent 121 | 122 | # Split the image into 512x512 tiles and encode them separately. 123 | rows = [] 124 | for i in range(0, x.shape[3], overlap_size): 125 | row = [] 126 | for j in range(0, x.shape[4], overlap_size): 127 | tile = x[ 128 | :, 129 | :, 130 | :, 131 | i : i + self.tile_sample_min_size, 132 | j : j + self.tile_sample_min_size, 133 | ] 134 | tile = self.encoder(tile) 135 | tile = self.quant_conv(tile) 136 | row.append(tile) 137 | rows.append(row) 138 | result_rows = [] 139 | for i, row in enumerate(rows): 140 | result_row = [] 141 | for j, tile in enumerate(row): 142 | # blend the above tile and the left tile 143 | # to the current tile and add the current tile to the result row 144 | if i > 0: 145 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 146 | if j > 0: 147 | tile = self.blend_h(row[j - 1], tile, blend_extent) 148 | result_row.append(tile[:, :, :, :row_limit, :row_limit]) 149 | result_rows.append(torch.cat(result_row, dim=4)) 150 | 151 | moments = torch.cat(result_rows, dim=3) 152 | return moments 153 | 154 | def blend_z( 155 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int 156 | ) -> torch.Tensor: 157 | blend_extent = min(a.shape[2], b.shape[2], blend_extent) 158 | for z in range(blend_extent): 159 | b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( 160 | 1 - z / blend_extent 161 | ) + b[:, :, z, :, :] * (z / blend_extent) 162 | return b 163 | 164 | def blend_v( 165 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int 166 | ) -> torch.Tensor: 167 | blend_extent = min(a.shape[3], b.shape[3], blend_extent) 168 | for y in range(blend_extent): 169 | b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( 170 | 1 - y / blend_extent 171 | ) + b[:, :, :, y, :] * (y / blend_extent) 172 | return b 173 | 174 | def blend_h( 175 | self, a: torch.Tensor, b: torch.Tensor, blend_extent: int 176 | ) -> torch.Tensor: 177 | blend_extent = min(a.shape[4], b.shape[4], blend_extent) 178 | for x in range(blend_extent): 179 | b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( 180 | 1 - x / blend_extent 181 | ) + b[:, :, :, :, x] * (x / blend_extent) 182 | return b 183 | 184 | def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): 185 | overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) 186 | blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) 187 | row_limit = self.tile_sample_min_size - blend_extent 188 | tile_target_shape = ( 189 | *target_shape[:3], 190 | self.tile_sample_min_size, 191 | self.tile_sample_min_size, 192 | ) 193 | # Split z into overlapping 64x64 tiles and decode them separately. 194 | # The tiles have an overlap to avoid seams between tiles. 195 | rows = [] 196 | for i in range(0, z.shape[3], overlap_size): 197 | row = [] 198 | for j in range(0, z.shape[4], overlap_size): 199 | tile = z[ 200 | :, 201 | :, 202 | :, 203 | i : i + self.tile_latent_min_size, 204 | j : j + self.tile_latent_min_size, 205 | ] 206 | tile = self.post_quant_conv(tile) 207 | decoded = self.decoder(tile, target_shape=tile_target_shape) 208 | row.append(decoded) 209 | rows.append(row) 210 | result_rows = [] 211 | for i, row in enumerate(rows): 212 | result_row = [] 213 | for j, tile in enumerate(row): 214 | # blend the above tile and the left tile 215 | # to the current tile and add the current tile to the result row 216 | if i > 0: 217 | tile = self.blend_v(rows[i - 1][j], tile, blend_extent) 218 | if j > 0: 219 | tile = self.blend_h(row[j - 1], tile, blend_extent) 220 | result_row.append(tile[:, :, :, :row_limit, :row_limit]) 221 | result_rows.append(torch.cat(result_row, dim=4)) 222 | 223 | dec = torch.cat(result_rows, dim=3) 224 | return dec 225 | 226 | def encode( 227 | self, z: torch.FloatTensor, return_dict: bool = True 228 | ) -> Union[DecoderOutput, torch.FloatTensor]: 229 | if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: 230 | num_splits = z.shape[2] // self.z_sample_size 231 | sizes = [self.z_sample_size] * num_splits 232 | sizes = ( 233 | sizes + [z.shape[2] - sum(sizes)] 234 | if z.shape[2] - sum(sizes) > 0 235 | else sizes 236 | ) 237 | tiles = z.split(sizes, dim=2) 238 | moments_tiles = [ 239 | ( 240 | self._hw_tiled_encode(z_tile, return_dict) 241 | if self.use_hw_tiling 242 | else self._encode(z_tile) 243 | ) 244 | for z_tile in tiles 245 | ] 246 | moments = torch.cat(moments_tiles, dim=2) 247 | 248 | else: 249 | moments = ( 250 | self._hw_tiled_encode(z, return_dict) 251 | if self.use_hw_tiling 252 | else self._encode(z) 253 | ) 254 | 255 | posterior = DiagonalGaussianDistribution(moments) 256 | if not return_dict: 257 | return (posterior,) 258 | 259 | return AutoencoderKLOutput(latent_dist=posterior) 260 | 261 | def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: 262 | if isinstance(self.latent_norm_out, nn.BatchNorm3d): 263 | _, c, _, _, _ = z.shape 264 | z = torch.cat( 265 | [ 266 | self.latent_norm_out(z[:, : c // 2, :, :, :]), 267 | z[:, c // 2 :, :, :, :], 268 | ], 269 | dim=1, 270 | ) 271 | elif isinstance(self.latent_norm_out, nn.BatchNorm2d): 272 | raise NotImplementedError("BatchNorm2d not supported") 273 | return z 274 | 275 | def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor: 276 | if isinstance(self.latent_norm_out, nn.BatchNorm3d): 277 | running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1) 278 | running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1) 279 | eps = self.latent_norm_out.eps 280 | 281 | z = z * torch.sqrt(running_var + eps) + running_mean 282 | elif isinstance(self.latent_norm_out, nn.BatchNorm3d): 283 | raise NotImplementedError("BatchNorm2d not supported") 284 | return z 285 | 286 | def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: 287 | h = self.encoder(x) 288 | moments = self.quant_conv(h) 289 | moments = self._normalize_latent_channels(moments) 290 | return moments 291 | 292 | def _decode( 293 | self, 294 | z: torch.FloatTensor, 295 | target_shape=None, 296 | timestep: Optional[torch.Tensor] = None, 297 | ) -> Union[DecoderOutput, torch.FloatTensor]: 298 | z = self._unnormalize_latent_channels(z) 299 | z = self.post_quant_conv(z) 300 | if "timestep" in self.decoder_params: 301 | dec = self.decoder(z, target_shape=target_shape, timestep=timestep) 302 | else: 303 | dec = self.decoder(z, target_shape=target_shape) 304 | return dec 305 | 306 | def decode( 307 | self, 308 | z: torch.FloatTensor, 309 | return_dict: bool = True, 310 | target_shape=None, 311 | timestep: Optional[torch.Tensor] = None, 312 | ) -> Union[DecoderOutput, torch.FloatTensor]: 313 | assert target_shape is not None, "target_shape must be provided for decoding" 314 | if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: 315 | reduction_factor = int( 316 | self.encoder.patch_size_t 317 | * 2 318 | ** ( 319 | len(self.encoder.down_blocks) 320 | - 1 321 | - math.sqrt(self.encoder.patch_size) 322 | ) 323 | ) 324 | split_size = self.z_sample_size // reduction_factor 325 | num_splits = z.shape[2] // split_size 326 | 327 | # copy target shape, and divide frame dimension (=2) by the context size 328 | target_shape_split = list(target_shape) 329 | target_shape_split[2] = target_shape[2] // num_splits 330 | 331 | decoded_tiles = [ 332 | ( 333 | self._hw_tiled_decode(z_tile, target_shape_split) 334 | if self.use_hw_tiling 335 | else self._decode(z_tile, target_shape=target_shape_split) 336 | ) 337 | for z_tile in torch.tensor_split(z, num_splits, dim=2) 338 | ] 339 | decoded = torch.cat(decoded_tiles, dim=2) 340 | else: 341 | decoded = ( 342 | self._hw_tiled_decode(z, target_shape) 343 | if self.use_hw_tiling 344 | else self._decode(z, target_shape=target_shape, timestep=timestep) 345 | ) 346 | 347 | if not return_dict: 348 | return (decoded,) 349 | 350 | return DecoderOutput(sample=decoded) 351 | 352 | def forward( 353 | self, 354 | sample: torch.FloatTensor, 355 | sample_posterior: bool = False, 356 | return_dict: bool = True, 357 | generator: Optional[torch.Generator] = None, 358 | ) -> Union[DecoderOutput, torch.FloatTensor]: 359 | r""" 360 | Args: 361 | sample (`torch.FloatTensor`): Input sample. 362 | sample_posterior (`bool`, *optional*, defaults to `False`): 363 | Whether to sample from the posterior. 364 | return_dict (`bool`, *optional*, defaults to `True`): 365 | Whether to return a [`DecoderOutput`] instead of a plain tuple. 366 | generator (`torch.Generator`, *optional*): 367 | Generator used to sample from the posterior. 368 | """ 369 | x = sample 370 | posterior = self.encode(x).latent_dist 371 | if sample_posterior: 372 | z = posterior.sample(generator=generator) 373 | else: 374 | z = posterior.mode() 375 | dec = self.decode(z, target_shape=sample.shape).sample 376 | 377 | if not return_dict: 378 | return (dec,) 379 | 380 | return DecoderOutput(sample=dec) 381 | -------------------------------------------------------------------------------- /ltx_video/models/autoencoders/vae_encode.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from diffusers import AutoencoderKL 4 | from einops import rearrange 5 | from torch import Tensor 6 | 7 | 8 | from ltx_video.models.autoencoders.causal_video_autoencoder import ( 9 | CausalVideoAutoencoder, 10 | ) 11 | from ltx_video.models.autoencoders.video_autoencoder import ( 12 | Downsample3D, 13 | VideoAutoencoder, 14 | ) 15 | 16 | try: 17 | import torch_xla.core.xla_model as xm 18 | except ImportError: 19 | xm = None 20 | 21 | 22 | def vae_encode( 23 | media_items: Tensor, 24 | vae: AutoencoderKL, 25 | split_size: int = 1, 26 | vae_per_channel_normalize=False, 27 | ) -> Tensor: 28 | """ 29 | Encodes media items (images or videos) into latent representations using a specified VAE model. 30 | The function supports processing batches of images or video frames and can handle the processing 31 | in smaller sub-batches if needed. 32 | 33 | Args: 34 | media_items (Tensor): A torch Tensor containing the media items to encode. The expected 35 | shape is (batch_size, channels, height, width) for images or (batch_size, channels, 36 | frames, height, width) for videos. 37 | vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, 38 | pre-configured and loaded with the appropriate model weights. 39 | split_size (int, optional): The number of sub-batches to split the input batch into for encoding. 40 | If set to more than 1, the input media items are processed in smaller batches according to 41 | this value. Defaults to 1, which processes all items in a single batch. 42 | 43 | Returns: 44 | Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted 45 | to match the input shape, scaled by the model's configuration. 46 | 47 | Examples: 48 | >>> import torch 49 | >>> from diffusers import AutoencoderKL 50 | >>> vae = AutoencoderKL.from_pretrained('your-model-name') 51 | >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. 52 | >>> latents = vae_encode(images, vae) 53 | >>> print(latents.shape) # Output shape will depend on the model's latent configuration. 54 | 55 | Note: 56 | In case of a video, the function encodes the media item frame-by frame. 57 | """ 58 | is_video_shaped = media_items.dim() == 5 59 | batch_size, channels = media_items.shape[0:2] 60 | 61 | if channels != 3: 62 | raise ValueError(f"Expects tensors with 3 channels, got {channels}.") 63 | 64 | if is_video_shaped and not isinstance( 65 | vae, (VideoAutoencoder, CausalVideoAutoencoder) 66 | ): 67 | media_items = rearrange(media_items, "b c n h w -> (b n) c h w") 68 | if split_size > 1: 69 | if len(media_items) % split_size != 0: 70 | raise ValueError( 71 | "Error: The batch size must be divisible by 'train.vae_bs_split" 72 | ) 73 | encode_bs = len(media_items) // split_size 74 | # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] 75 | latents = [] 76 | if media_items.device.type == "xla": 77 | xm.mark_step() 78 | for image_batch in media_items.split(encode_bs): 79 | latents.append(vae.encode(image_batch).latent_dist.sample()) 80 | if media_items.device.type == "xla": 81 | xm.mark_step() 82 | latents = torch.cat(latents, dim=0) 83 | else: 84 | latents = vae.encode(media_items).latent_dist.sample() 85 | 86 | latents = normalize_latents(latents, vae, vae_per_channel_normalize) 87 | if is_video_shaped and not isinstance( 88 | vae, (VideoAutoencoder, CausalVideoAutoencoder) 89 | ): 90 | latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) 91 | return latents 92 | 93 | 94 | def vae_decode( 95 | latents: Tensor, 96 | vae: AutoencoderKL, 97 | is_video: bool = True, 98 | split_size: int = 1, 99 | vae_per_channel_normalize=False, 100 | timestep=None, 101 | ) -> Tensor: 102 | is_video_shaped = latents.dim() == 5 103 | batch_size = latents.shape[0] 104 | 105 | if is_video_shaped and not isinstance( 106 | vae, (VideoAutoencoder, CausalVideoAutoencoder) 107 | ): 108 | latents = rearrange(latents, "b c n h w -> (b n) c h w") 109 | if split_size > 1: 110 | if len(latents) % split_size != 0: 111 | raise ValueError( 112 | "Error: The batch size must be divisible by 'train.vae_bs_split" 113 | ) 114 | encode_bs = len(latents) // split_size 115 | image_batch = [ 116 | _run_decoder( 117 | latent_batch, vae, is_video, vae_per_channel_normalize, timestep 118 | ) 119 | for latent_batch in latents.split(encode_bs) 120 | ] 121 | images = torch.cat(image_batch, dim=0) 122 | else: 123 | images = _run_decoder( 124 | latents, vae, is_video, vae_per_channel_normalize, timestep 125 | ) 126 | 127 | if is_video_shaped and not isinstance( 128 | vae, (VideoAutoencoder, CausalVideoAutoencoder) 129 | ): 130 | images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) 131 | return images 132 | 133 | 134 | def _run_decoder( 135 | latents: Tensor, 136 | vae: AutoencoderKL, 137 | is_video: bool, 138 | vae_per_channel_normalize=False, 139 | timestep=None, 140 | ) -> Tensor: 141 | if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): 142 | *_, fl, hl, wl = latents.shape 143 | temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) 144 | latents = latents.to(vae.dtype) 145 | vae_decode_kwargs = {} 146 | if timestep is not None: 147 | vae_decode_kwargs["timestep"] = timestep 148 | image = vae.decode( 149 | un_normalize_latents(latents, vae, vae_per_channel_normalize), 150 | return_dict=False, 151 | target_shape=( 152 | 1, 153 | 3, 154 | fl * temporal_scale if is_video else 1, 155 | hl * spatial_scale, 156 | wl * spatial_scale, 157 | ), 158 | **vae_decode_kwargs, 159 | )[0] 160 | else: 161 | image = vae.decode( 162 | un_normalize_latents(latents, vae, vae_per_channel_normalize), 163 | return_dict=False, 164 | )[0] 165 | return image 166 | 167 | 168 | def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: 169 | if isinstance(vae, CausalVideoAutoencoder): 170 | spatial = vae.spatial_downscale_factor 171 | temporal = vae.temporal_downscale_factor 172 | else: 173 | down_blocks = len( 174 | [ 175 | block 176 | for block in vae.encoder.down_blocks 177 | if isinstance(block.downsample, Downsample3D) 178 | ] 179 | ) 180 | spatial = vae.config.patch_size * 2**down_blocks 181 | temporal = ( 182 | vae.config.patch_size_t * 2**down_blocks 183 | if isinstance(vae, VideoAutoencoder) 184 | else 1 185 | ) 186 | 187 | return (temporal, spatial, spatial) 188 | 189 | 190 | def latent_to_pixel_coords( 191 | latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False 192 | ) -> Tensor: 193 | """ 194 | Converts latent coordinates to pixel coordinates by scaling them according to the VAE's 195 | configuration. 196 | 197 | Args: 198 | latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents] 199 | containing the latent corner coordinates of each token. 200 | vae (AutoencoderKL): The VAE model 201 | causal_fix (bool): Whether to take into account the different temporal scale 202 | of the first frame. Default = False for backwards compatibility. 203 | Returns: 204 | Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. 205 | """ 206 | 207 | scale_factors = get_vae_size_scale_factor(vae) 208 | causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix 209 | pixel_coords = latent_to_pixel_coords_from_factors( 210 | latent_coords, scale_factors, causal_fix 211 | ) 212 | return pixel_coords 213 | 214 | 215 | def latent_to_pixel_coords_from_factors( 216 | latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False 217 | ) -> Tensor: 218 | pixel_coords = ( 219 | latent_coords 220 | * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] 221 | ) 222 | if causal_fix: 223 | # Fix temporal scale for first frame to 1 due to causality 224 | pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) 225 | return pixel_coords 226 | 227 | 228 | def normalize_latents( 229 | latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False 230 | ) -> Tensor: 231 | return ( 232 | (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) 233 | / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) 234 | if vae_per_channel_normalize 235 | else latents * vae.config.scaling_factor 236 | ) 237 | 238 | 239 | def un_normalize_latents( 240 | latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False 241 | ) -> Tensor: 242 | return ( 243 | latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) 244 | + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) 245 | if vae_per_channel_normalize 246 | else latents / vae.config.scaling_factor 247 | ) 248 | -------------------------------------------------------------------------------- /ltx_video/models/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/models/transformers/__init__.py -------------------------------------------------------------------------------- /ltx_video/models/transformers/embeddings.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from einops import rearrange 7 | from torch import nn 8 | 9 | 10 | def get_timestep_embedding( 11 | timesteps: torch.Tensor, 12 | embedding_dim: int, 13 | flip_sin_to_cos: bool = False, 14 | downscale_freq_shift: float = 1, 15 | scale: float = 1, 16 | max_period: int = 10000, 17 | ): 18 | """ 19 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 20 | 21 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 22 | These may be fractional. 23 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 24 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 25 | """ 26 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 27 | 28 | half_dim = embedding_dim // 2 29 | exponent = -math.log(max_period) * torch.arange( 30 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 31 | ) 32 | exponent = exponent / (half_dim - downscale_freq_shift) 33 | 34 | emb = torch.exp(exponent) 35 | emb = timesteps[:, None].float() * emb[None, :] 36 | 37 | # scale embeddings 38 | emb = scale * emb 39 | 40 | # concat sine and cosine embeddings 41 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 42 | 43 | # flip sine and cosine embeddings 44 | if flip_sin_to_cos: 45 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 46 | 47 | # zero pad 48 | if embedding_dim % 2 == 1: 49 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 50 | return emb 51 | 52 | 53 | def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): 54 | """ 55 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 56 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 57 | """ 58 | grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) 59 | grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) 60 | grid = grid.reshape([3, 1, w, h, f]) 61 | pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) 62 | pos_embed = pos_embed.transpose(1, 0, 2, 3) 63 | return rearrange(pos_embed, "h w f c -> (f h w) c") 64 | 65 | 66 | def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): 67 | if embed_dim % 3 != 0: 68 | raise ValueError("embed_dim must be divisible by 3") 69 | 70 | # use half of dimensions to encode grid_h 71 | emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) 72 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) 73 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) 74 | 75 | emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) 76 | return emb 77 | 78 | 79 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 80 | """ 81 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 82 | """ 83 | if embed_dim % 2 != 0: 84 | raise ValueError("embed_dim must be divisible by 2") 85 | 86 | omega = np.arange(embed_dim // 2, dtype=np.float64) 87 | omega /= embed_dim / 2.0 88 | omega = 1.0 / 10000**omega # (D/2,) 89 | 90 | pos_shape = pos.shape 91 | 92 | pos = pos.reshape(-1) 93 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 94 | out = out.reshape([*pos_shape, -1])[0] 95 | 96 | emb_sin = np.sin(out) # (M, D/2) 97 | emb_cos = np.cos(out) # (M, D/2) 98 | 99 | emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) 100 | return emb 101 | 102 | 103 | class SinusoidalPositionalEmbedding(nn.Module): 104 | """Apply positional information to a sequence of embeddings. 105 | 106 | Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to 107 | them 108 | 109 | Args: 110 | embed_dim: (int): Dimension of the positional embedding. 111 | max_seq_length: Maximum sequence length to apply positional embeddings 112 | 113 | """ 114 | 115 | def __init__(self, embed_dim: int, max_seq_length: int = 32): 116 | super().__init__() 117 | position = torch.arange(max_seq_length).unsqueeze(1) 118 | div_term = torch.exp( 119 | torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) 120 | ) 121 | pe = torch.zeros(1, max_seq_length, embed_dim) 122 | pe[0, :, 0::2] = torch.sin(position * div_term) 123 | pe[0, :, 1::2] = torch.cos(position * div_term) 124 | self.register_buffer("pe", pe) 125 | 126 | def forward(self, x): 127 | _, seq_length, _ = x.shape 128 | x = x + self.pe[:, :seq_length] 129 | return x 130 | -------------------------------------------------------------------------------- /ltx_video/models/transformers/symmetric_patchifier.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | from diffusers.configuration_utils import ConfigMixin 6 | from einops import rearrange 7 | from torch import Tensor 8 | 9 | 10 | class Patchifier(ConfigMixin, ABC): 11 | def __init__(self, patch_size: int): 12 | super().__init__() 13 | self._patch_size = (1, patch_size, patch_size) 14 | 15 | @abstractmethod 16 | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: 17 | raise NotImplementedError("Patchify method not implemented") 18 | 19 | @abstractmethod 20 | def unpatchify( 21 | self, 22 | latents: Tensor, 23 | output_height: int, 24 | output_width: int, 25 | out_channels: int, 26 | ) -> Tuple[Tensor, Tensor]: 27 | pass 28 | 29 | @property 30 | def patch_size(self): 31 | return self._patch_size 32 | 33 | def get_latent_coords( 34 | self, latent_num_frames, latent_height, latent_width, batch_size, device 35 | ): 36 | """ 37 | Return a tensor of shape [batch_size, 3, num_patches] containing the 38 | top-left corner latent coordinates of each latent patch. 39 | The tensor is repeated for each batch element. 40 | """ 41 | latent_sample_coords = torch.meshgrid( 42 | torch.arange(0, latent_num_frames, self._patch_size[0], device=device), 43 | torch.arange(0, latent_height, self._patch_size[1], device=device), 44 | torch.arange(0, latent_width, self._patch_size[2], device=device), 45 | ) 46 | latent_sample_coords = torch.stack(latent_sample_coords, dim=0) 47 | latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) 48 | latent_coords = rearrange( 49 | latent_coords, "b c f h w -> b c (f h w)", b=batch_size 50 | ) 51 | return latent_coords 52 | 53 | 54 | class SymmetricPatchifier(Patchifier): 55 | def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]: 56 | b, _, f, h, w = latents.shape 57 | latent_coords = self.get_latent_coords(f, h, w, b, latents.device) 58 | latents = rearrange( 59 | latents, 60 | "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", 61 | p1=self._patch_size[0], 62 | p2=self._patch_size[1], 63 | p3=self._patch_size[2], 64 | ) 65 | return latents, latent_coords 66 | 67 | def unpatchify( 68 | self, 69 | latents: Tensor, 70 | output_height: int, 71 | output_width: int, 72 | out_channels: int, 73 | ) -> Tuple[Tensor, Tensor]: 74 | output_height = output_height // self._patch_size[1] 75 | output_width = output_width // self._patch_size[2] 76 | latents = rearrange( 77 | latents, 78 | "b (f h w) (c p q) -> b c f (h p) (w q)", 79 | h=output_height, 80 | w=output_width, 81 | p=self._patch_size[1], 82 | q=self._patch_size[2], 83 | ) 84 | return latents 85 | -------------------------------------------------------------------------------- /ltx_video/models/transformers/transformer3d.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py 2 | import math 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List, Optional, Union 5 | import os 6 | import json 7 | import glob 8 | from pathlib import Path 9 | 10 | import torch 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models.embeddings import PixArtAlphaTextProjection 13 | from diffusers.models.modeling_utils import ModelMixin 14 | from diffusers.models.normalization import AdaLayerNormSingle 15 | from diffusers.utils import BaseOutput, is_torch_version 16 | from diffusers.utils import logging 17 | from torch import nn 18 | from safetensors import safe_open 19 | 20 | 21 | from ltx_video.models.transformers.attention import BasicTransformerBlock 22 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy 23 | 24 | from ltx_video.utils.diffusers_config_mapping import ( 25 | diffusers_and_ours_config_mapping, 26 | make_hashable_key, 27 | TRANSFORMER_KEYS_RENAME_DICT, 28 | ) 29 | 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | @dataclass 35 | class Transformer3DModelOutput(BaseOutput): 36 | """ 37 | The output of [`Transformer2DModel`]. 38 | 39 | Args: 40 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 41 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 42 | distributions for the unnoised latent pixels. 43 | """ 44 | 45 | sample: torch.FloatTensor 46 | 47 | 48 | class Transformer3DModel(ModelMixin, ConfigMixin): 49 | _supports_gradient_checkpointing = True 50 | 51 | @register_to_config 52 | def __init__( 53 | self, 54 | num_attention_heads: int = 16, 55 | attention_head_dim: int = 88, 56 | in_channels: Optional[int] = None, 57 | out_channels: Optional[int] = None, 58 | num_layers: int = 1, 59 | dropout: float = 0.0, 60 | norm_num_groups: int = 32, 61 | cross_attention_dim: Optional[int] = None, 62 | attention_bias: bool = False, 63 | num_vector_embeds: Optional[int] = None, 64 | activation_fn: str = "geglu", 65 | num_embeds_ada_norm: Optional[int] = None, 66 | use_linear_projection: bool = False, 67 | only_cross_attention: bool = False, 68 | double_self_attention: bool = False, 69 | upcast_attention: bool = False, 70 | adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' 71 | standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' 72 | norm_elementwise_affine: bool = True, 73 | norm_eps: float = 1e-5, 74 | attention_type: str = "default", 75 | caption_channels: int = None, 76 | use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') 77 | qk_norm: Optional[str] = None, 78 | positional_embedding_type: str = "rope", 79 | positional_embedding_theta: Optional[float] = None, 80 | positional_embedding_max_pos: Optional[List[int]] = None, 81 | timestep_scale_multiplier: Optional[float] = None, 82 | causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated 83 | ): 84 | super().__init__() 85 | self.use_tpu_flash_attention = ( 86 | use_tpu_flash_attention # FIXME: push config down to the attention modules 87 | ) 88 | self.use_linear_projection = use_linear_projection 89 | self.num_attention_heads = num_attention_heads 90 | self.attention_head_dim = attention_head_dim 91 | inner_dim = num_attention_heads * attention_head_dim 92 | self.inner_dim = inner_dim 93 | self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) 94 | self.positional_embedding_type = positional_embedding_type 95 | self.positional_embedding_theta = positional_embedding_theta 96 | self.positional_embedding_max_pos = positional_embedding_max_pos 97 | self.use_rope = self.positional_embedding_type == "rope" 98 | self.timestep_scale_multiplier = timestep_scale_multiplier 99 | 100 | if self.positional_embedding_type == "absolute": 101 | raise ValueError("Absolute positional embedding is no longer supported") 102 | elif self.positional_embedding_type == "rope": 103 | if positional_embedding_theta is None: 104 | raise ValueError( 105 | "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" 106 | ) 107 | if positional_embedding_max_pos is None: 108 | raise ValueError( 109 | "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" 110 | ) 111 | 112 | # 3. Define transformers blocks 113 | self.transformer_blocks = nn.ModuleList( 114 | [ 115 | BasicTransformerBlock( 116 | inner_dim, 117 | num_attention_heads, 118 | attention_head_dim, 119 | dropout=dropout, 120 | cross_attention_dim=cross_attention_dim, 121 | activation_fn=activation_fn, 122 | num_embeds_ada_norm=num_embeds_ada_norm, 123 | attention_bias=attention_bias, 124 | only_cross_attention=only_cross_attention, 125 | double_self_attention=double_self_attention, 126 | upcast_attention=upcast_attention, 127 | adaptive_norm=adaptive_norm, 128 | standardization_norm=standardization_norm, 129 | norm_elementwise_affine=norm_elementwise_affine, 130 | norm_eps=norm_eps, 131 | attention_type=attention_type, 132 | use_tpu_flash_attention=use_tpu_flash_attention, 133 | qk_norm=qk_norm, 134 | use_rope=self.use_rope, 135 | ) 136 | for d in range(num_layers) 137 | ] 138 | ) 139 | 140 | # 4. Define output layers 141 | self.out_channels = in_channels if out_channels is None else out_channels 142 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 143 | self.scale_shift_table = nn.Parameter( 144 | torch.randn(2, inner_dim) / inner_dim**0.5 145 | ) 146 | self.proj_out = nn.Linear(inner_dim, self.out_channels) 147 | 148 | self.adaln_single = AdaLayerNormSingle( 149 | inner_dim, use_additional_conditions=False 150 | ) 151 | if adaptive_norm == "single_scale": 152 | self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) 153 | 154 | self.caption_projection = None 155 | if caption_channels is not None: 156 | self.caption_projection = PixArtAlphaTextProjection( 157 | in_features=caption_channels, hidden_size=inner_dim 158 | ) 159 | 160 | self.gradient_checkpointing = False 161 | 162 | def set_use_tpu_flash_attention(self): 163 | r""" 164 | Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU 165 | attention kernel. 166 | """ 167 | logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") 168 | self.use_tpu_flash_attention = True 169 | # push config down to the attention modules 170 | for block in self.transformer_blocks: 171 | block.set_use_tpu_flash_attention() 172 | 173 | def create_skip_layer_mask( 174 | self, 175 | batch_size: int, 176 | num_conds: int, 177 | ptb_index: int, 178 | skip_block_list: Optional[List[int]] = None, 179 | ): 180 | if skip_block_list is None or len(skip_block_list) == 0: 181 | return None 182 | num_layers = len(self.transformer_blocks) 183 | mask = torch.ones( 184 | (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype 185 | ) 186 | for block_idx in skip_block_list: 187 | mask[block_idx, ptb_index::num_conds] = 0 188 | return mask 189 | 190 | def _set_gradient_checkpointing(self, module, value=False): 191 | if hasattr(module, "gradient_checkpointing"): 192 | module.gradient_checkpointing = value 193 | 194 | def get_fractional_positions(self, indices_grid): 195 | fractional_positions = torch.stack( 196 | [ 197 | indices_grid[:, i] / self.positional_embedding_max_pos[i] 198 | for i in range(3) 199 | ], 200 | dim=-1, 201 | ) 202 | return fractional_positions 203 | 204 | def precompute_freqs_cis(self, indices_grid, spacing="exp"): 205 | dtype = torch.float32 # We need full precision in the freqs_cis computation. 206 | dim = self.inner_dim 207 | theta = self.positional_embedding_theta 208 | 209 | fractional_positions = self.get_fractional_positions(indices_grid) 210 | 211 | start = 1 212 | end = theta 213 | device = fractional_positions.device 214 | if spacing == "exp": 215 | indices = theta ** ( 216 | torch.linspace( 217 | math.log(start, theta), 218 | math.log(end, theta), 219 | dim // 6, 220 | device=device, 221 | dtype=dtype, 222 | ) 223 | ) 224 | indices = indices.to(dtype=dtype) 225 | elif spacing == "exp_2": 226 | indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) 227 | indices = indices.to(dtype=dtype) 228 | elif spacing == "linear": 229 | indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) 230 | elif spacing == "sqrt": 231 | indices = torch.linspace( 232 | start**2, end**2, dim // 6, device=device, dtype=dtype 233 | ).sqrt() 234 | 235 | indices = indices * math.pi / 2 236 | 237 | if spacing == "exp_2": 238 | freqs = ( 239 | (indices * fractional_positions.unsqueeze(-1)) 240 | .transpose(-1, -2) 241 | .flatten(2) 242 | ) 243 | else: 244 | freqs = ( 245 | (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) 246 | .transpose(-1, -2) 247 | .flatten(2) 248 | ) 249 | 250 | cos_freq = freqs.cos().repeat_interleave(2, dim=-1) 251 | sin_freq = freqs.sin().repeat_interleave(2, dim=-1) 252 | if dim % 6 != 0: 253 | cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) 254 | sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) 255 | cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) 256 | sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) 257 | return cos_freq.to(self.dtype), sin_freq.to(self.dtype) 258 | 259 | def load_state_dict( 260 | self, 261 | state_dict: Dict, 262 | *args, 263 | **kwargs, 264 | ): 265 | if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): 266 | state_dict = { 267 | key.replace("model.diffusion_model.", ""): value 268 | for key, value in state_dict.items() 269 | if key.startswith("model.diffusion_model.") 270 | } 271 | super().load_state_dict(state_dict, **kwargs) 272 | 273 | @classmethod 274 | def from_pretrained( 275 | cls, 276 | pretrained_model_path: Optional[Union[str, os.PathLike]], 277 | *args, 278 | **kwargs, 279 | ): 280 | pretrained_model_path = Path(pretrained_model_path) 281 | if pretrained_model_path.is_dir(): 282 | config_path = pretrained_model_path / "transformer" / "config.json" 283 | with open(config_path, "r") as f: 284 | config = make_hashable_key(json.load(f)) 285 | 286 | assert config in diffusers_and_ours_config_mapping, ( 287 | "Provided diffusers checkpoint config for transformer is not suppported. " 288 | "We only support diffusers configs found in Lightricks/LTX-Video." 289 | ) 290 | 291 | config = diffusers_and_ours_config_mapping[config] 292 | state_dict = {} 293 | ckpt_paths = ( 294 | pretrained_model_path 295 | / "transformer" 296 | / "diffusion_pytorch_model*.safetensors" 297 | ) 298 | dict_list = glob.glob(str(ckpt_paths)) 299 | for dict_path in dict_list: 300 | part_dict = {} 301 | with safe_open(dict_path, framework="pt", device="cpu") as f: 302 | for k in f.keys(): 303 | part_dict[k] = f.get_tensor(k) 304 | state_dict.update(part_dict) 305 | 306 | for key in list(state_dict.keys()): 307 | new_key = key 308 | for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): 309 | new_key = new_key.replace(replace_key, rename_key) 310 | state_dict[new_key] = state_dict.pop(key) 311 | 312 | with torch.device("meta"): 313 | transformer = cls.from_config(config) 314 | transformer.load_state_dict(state_dict, assign=True, strict=True) 315 | elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( 316 | ".safetensors" 317 | ): 318 | comfy_single_file_state_dict = {} 319 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: 320 | metadata = f.metadata() 321 | for k in f.keys(): 322 | comfy_single_file_state_dict[k] = f.get_tensor(k) 323 | configs = json.loads(metadata["config"]) 324 | transformer_config = configs["transformer"] 325 | with torch.device("meta"): 326 | transformer = Transformer3DModel.from_config(transformer_config) 327 | transformer.load_state_dict(comfy_single_file_state_dict, assign=True) 328 | return transformer 329 | 330 | def forward( 331 | self, 332 | hidden_states: torch.Tensor, 333 | indices_grid: torch.Tensor, 334 | encoder_hidden_states: Optional[torch.Tensor] = None, 335 | timestep: Optional[torch.LongTensor] = None, 336 | class_labels: Optional[torch.LongTensor] = None, 337 | cross_attention_kwargs: Dict[str, Any] = None, 338 | attention_mask: Optional[torch.Tensor] = None, 339 | encoder_attention_mask: Optional[torch.Tensor] = None, 340 | skip_layer_mask: Optional[torch.Tensor] = None, 341 | skip_layer_strategy: Optional[SkipLayerStrategy] = None, 342 | return_dict: bool = True, 343 | ): 344 | """ 345 | The [`Transformer2DModel`] forward method. 346 | 347 | Args: 348 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 349 | Input `hidden_states`. 350 | indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): 351 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 352 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 353 | self-attention. 354 | timestep ( `torch.LongTensor`, *optional*): 355 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 356 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 357 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 358 | `AdaLayerZeroNorm`. 359 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 360 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 361 | `self.processor` in 362 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 363 | attention_mask ( `torch.Tensor`, *optional*): 364 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 365 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 366 | negative values to the attention scores corresponding to "discard" tokens. 367 | encoder_attention_mask ( `torch.Tensor`, *optional*): 368 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 369 | 370 | * Mask `(batch, sequence_length)` True = keep, False = discard. 371 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 372 | 373 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 374 | above. This bias will be added to the cross-attention scores. 375 | skip_layer_mask ( `torch.Tensor`, *optional*): 376 | A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position 377 | `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. 378 | skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): 379 | Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. 380 | return_dict (`bool`, *optional*, defaults to `True`): 381 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 382 | tuple. 383 | 384 | Returns: 385 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 386 | `tuple` where the first element is the sample tensor. 387 | """ 388 | # for tpu attention offload 2d token masks are used. No need to transform. 389 | if not self.use_tpu_flash_attention: 390 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 391 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 392 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 393 | # expects mask of shape: 394 | # [batch, key_tokens] 395 | # adds singleton query_tokens dimension: 396 | # [batch, 1, key_tokens] 397 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 398 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 399 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 400 | if attention_mask is not None and attention_mask.ndim == 2: 401 | # assume that mask is expressed as: 402 | # (1 = keep, 0 = discard) 403 | # convert mask into a bias that can be added to attention scores: 404 | # (keep = +0, discard = -10000.0) 405 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 406 | attention_mask = attention_mask.unsqueeze(1) 407 | 408 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 409 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 410 | encoder_attention_mask = ( 411 | 1 - encoder_attention_mask.to(hidden_states.dtype) 412 | ) * -10000.0 413 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 414 | 415 | # 1. Input 416 | hidden_states = self.patchify_proj(hidden_states) 417 | 418 | if self.timestep_scale_multiplier: 419 | timestep = self.timestep_scale_multiplier * timestep 420 | 421 | freqs_cis = self.precompute_freqs_cis(indices_grid) 422 | 423 | batch_size = hidden_states.shape[0] 424 | timestep, embedded_timestep = self.adaln_single( 425 | timestep.flatten(), 426 | {"resolution": None, "aspect_ratio": None}, 427 | batch_size=batch_size, 428 | hidden_dtype=hidden_states.dtype, 429 | ) 430 | # Second dimension is 1 or number of tokens (if timestep_per_token) 431 | timestep = timestep.view(batch_size, -1, timestep.shape[-1]) 432 | embedded_timestep = embedded_timestep.view( 433 | batch_size, -1, embedded_timestep.shape[-1] 434 | ) 435 | 436 | # 2. Blocks 437 | if self.caption_projection is not None: 438 | batch_size = hidden_states.shape[0] 439 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 440 | encoder_hidden_states = encoder_hidden_states.view( 441 | batch_size, -1, hidden_states.shape[-1] 442 | ) 443 | 444 | for block_idx, block in enumerate(self.transformer_blocks): 445 | if self.training and self.gradient_checkpointing: 446 | 447 | def create_custom_forward(module, return_dict=None): 448 | def custom_forward(*inputs): 449 | if return_dict is not None: 450 | return module(*inputs, return_dict=return_dict) 451 | else: 452 | return module(*inputs) 453 | 454 | return custom_forward 455 | 456 | ckpt_kwargs: Dict[str, Any] = ( 457 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 458 | ) 459 | hidden_states = torch.utils.checkpoint.checkpoint( 460 | create_custom_forward(block), 461 | hidden_states, 462 | freqs_cis, 463 | attention_mask, 464 | encoder_hidden_states, 465 | encoder_attention_mask, 466 | timestep, 467 | cross_attention_kwargs, 468 | class_labels, 469 | ( 470 | skip_layer_mask[block_idx] 471 | if skip_layer_mask is not None 472 | else None 473 | ), 474 | skip_layer_strategy, 475 | **ckpt_kwargs, 476 | ) 477 | else: 478 | hidden_states = block( 479 | hidden_states, 480 | freqs_cis=freqs_cis, 481 | attention_mask=attention_mask, 482 | encoder_hidden_states=encoder_hidden_states, 483 | encoder_attention_mask=encoder_attention_mask, 484 | timestep=timestep, 485 | cross_attention_kwargs=cross_attention_kwargs, 486 | class_labels=class_labels, 487 | skip_layer_mask=( 488 | skip_layer_mask[block_idx] 489 | if skip_layer_mask is not None 490 | else None 491 | ), 492 | skip_layer_strategy=skip_layer_strategy, 493 | ) 494 | 495 | # 3. Output 496 | scale_shift_values = ( 497 | self.scale_shift_table[None, None] + embedded_timestep[:, :, None] 498 | ) 499 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 500 | hidden_states = self.norm_out(hidden_states) 501 | # Modulation 502 | hidden_states = hidden_states * (1 + scale) + shift 503 | hidden_states = self.proj_out(hidden_states) 504 | if not return_dict: 505 | return (hidden_states,) 506 | 507 | return Transformer3DModelOutput(sample=hidden_states) 508 | -------------------------------------------------------------------------------- /ltx_video/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/pipelines/__init__.py -------------------------------------------------------------------------------- /ltx_video/pipelines/crf_compressor.py: -------------------------------------------------------------------------------- 1 | import av 2 | import torch 3 | import io 4 | import numpy as np 5 | 6 | 7 | def _encode_single_frame(output_file, image_array: np.ndarray, crf): 8 | container = av.open(output_file, "w", format="mp4") 9 | try: 10 | stream = container.add_stream( 11 | "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} 12 | ) 13 | stream.height = image_array.shape[0] 14 | stream.width = image_array.shape[1] 15 | av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( 16 | format="yuv420p" 17 | ) 18 | container.mux(stream.encode(av_frame)) 19 | container.mux(stream.encode()) 20 | finally: 21 | container.close() 22 | 23 | 24 | def _decode_single_frame(video_file): 25 | container = av.open(video_file) 26 | try: 27 | stream = next(s for s in container.streams if s.type == "video") 28 | frame = next(container.decode(stream)) 29 | finally: 30 | container.close() 31 | return frame.to_ndarray(format="rgb24") 32 | 33 | 34 | def compress(image: torch.Tensor, crf=29): 35 | if crf == 0: 36 | return image 37 | 38 | image_array = ( 39 | (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0) 40 | .byte() 41 | .cpu() 42 | .numpy() 43 | ) 44 | with io.BytesIO() as output_file: 45 | _encode_single_frame(output_file, image_array, crf) 46 | video_bytes = output_file.getvalue() 47 | with io.BytesIO(video_bytes) as video_file: 48 | image_array = _decode_single_frame(video_file) 49 | tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 50 | return tensor 51 | -------------------------------------------------------------------------------- /ltx_video/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/schedulers/__init__.py -------------------------------------------------------------------------------- /ltx_video/schedulers/rf.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import Callable, Optional, Tuple, Union 5 | import json 6 | import os 7 | from pathlib import Path 8 | 9 | import torch 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 12 | from diffusers.utils import BaseOutput 13 | from torch import Tensor 14 | from safetensors import safe_open 15 | 16 | 17 | from ltx_video.utils.torch_utils import append_dims 18 | 19 | from ltx_video.utils.diffusers_config_mapping import ( 20 | diffusers_and_ours_config_mapping, 21 | make_hashable_key, 22 | ) 23 | 24 | 25 | def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): 26 | if num_steps == 1: 27 | return torch.tensor([1.0]) 28 | if linear_steps is None: 29 | linear_steps = num_steps // 2 30 | linear_sigma_schedule = [ 31 | i * threshold_noise / linear_steps for i in range(linear_steps) 32 | ] 33 | threshold_noise_step_diff = linear_steps - threshold_noise * num_steps 34 | quadratic_steps = num_steps - linear_steps 35 | quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) 36 | linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / ( 37 | quadratic_steps**2 38 | ) 39 | const = quadratic_coef * (linear_steps**2) 40 | quadratic_sigma_schedule = [ 41 | quadratic_coef * (i**2) + linear_coef * i + const 42 | for i in range(linear_steps, num_steps) 43 | ] 44 | sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] 45 | sigma_schedule = [1.0 - x for x in sigma_schedule] 46 | return torch.tensor(sigma_schedule[:-1]) 47 | 48 | 49 | def simple_diffusion_resolution_dependent_timestep_shift( 50 | samples_shape: torch.Size, 51 | timesteps: Tensor, 52 | n: int = 32 * 32, 53 | ) -> Tensor: 54 | if len(samples_shape) == 3: 55 | _, m, _ = samples_shape 56 | elif len(samples_shape) in [4, 5]: 57 | m = math.prod(samples_shape[2:]) 58 | else: 59 | raise ValueError( 60 | "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" 61 | ) 62 | snr = (timesteps / (1 - timesteps)) ** 2 63 | shift_snr = torch.log(snr) + 2 * math.log(m / n) 64 | shifted_timesteps = torch.sigmoid(0.5 * shift_snr) 65 | 66 | return shifted_timesteps 67 | 68 | 69 | def time_shift(mu: float, sigma: float, t: Tensor): 70 | return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 71 | 72 | 73 | def get_normal_shift( 74 | n_tokens: int, 75 | min_tokens: int = 1024, 76 | max_tokens: int = 4096, 77 | min_shift: float = 0.95, 78 | max_shift: float = 2.05, 79 | ) -> Callable[[float], float]: 80 | m = (max_shift - min_shift) / (max_tokens - min_tokens) 81 | b = min_shift - m * min_tokens 82 | return m * n_tokens + b 83 | 84 | 85 | def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1): 86 | """ 87 | Stretch a function (given as sampled shifts) so that its final value matches the given terminal value 88 | using the provided formula. 89 | 90 | Parameters: 91 | - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor). 92 | - terminal (float): The desired terminal value (value at the last sample). 93 | 94 | Returns: 95 | - Tensor: The stretched shifts such that the final value equals `terminal`. 96 | """ 97 | if shifts.numel() == 0: 98 | raise ValueError("The 'shifts' tensor must not be empty.") 99 | 100 | # Ensure terminal value is valid 101 | if terminal <= 0 or terminal >= 1: 102 | raise ValueError("The terminal value must be between 0 and 1 (exclusive).") 103 | 104 | # Transform the shifts using the given formula 105 | one_minus_z = 1 - shifts 106 | scale_factor = one_minus_z[-1] / (1 - terminal) 107 | stretched_shifts = 1 - (one_minus_z / scale_factor) 108 | 109 | return stretched_shifts 110 | 111 | 112 | def sd3_resolution_dependent_timestep_shift( 113 | samples_shape: torch.Size, 114 | timesteps: Tensor, 115 | target_shift_terminal: Optional[float] = None, 116 | ) -> Tensor: 117 | """ 118 | Shifts the timestep schedule as a function of the generated resolution. 119 | 120 | In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images. 121 | For more details: https://arxiv.org/pdf/2403.03206 122 | 123 | In Flux they later propose a more dynamic resolution dependent timestep shift, see: 124 | https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66 125 | 126 | 127 | Args: 128 | samples_shape (torch.Size): The samples batch shape (batch_size, channels, height, width) or 129 | (batch_size, channels, frame, height, width). 130 | timesteps (Tensor): A batch of timesteps with shape (batch_size,). 131 | target_shift_terminal (float): The target terminal value for the shifted timesteps. 132 | 133 | Returns: 134 | Tensor: The shifted timesteps. 135 | """ 136 | if len(samples_shape) == 3: 137 | _, m, _ = samples_shape 138 | elif len(samples_shape) in [4, 5]: 139 | m = math.prod(samples_shape[2:]) 140 | else: 141 | raise ValueError( 142 | "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" 143 | ) 144 | 145 | shift = get_normal_shift(m) 146 | time_shifts = time_shift(shift, 1, timesteps) 147 | if target_shift_terminal is not None: # Stretch the shifts to the target terminal 148 | time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal) 149 | return time_shifts 150 | 151 | 152 | class TimestepShifter(ABC): 153 | @abstractmethod 154 | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor: 155 | pass 156 | 157 | 158 | @dataclass 159 | class RectifiedFlowSchedulerOutput(BaseOutput): 160 | """ 161 | Output class for the scheduler's step function output. 162 | 163 | Args: 164 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 165 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 166 | denoising loop. 167 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 168 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. 169 | `pred_original_sample` can be used to preview progress or for guidance. 170 | """ 171 | 172 | prev_sample: torch.FloatTensor 173 | pred_original_sample: Optional[torch.FloatTensor] = None 174 | 175 | 176 | class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): 177 | order = 1 178 | 179 | @register_to_config 180 | def __init__( 181 | self, 182 | num_train_timesteps=1000, 183 | shifting: Optional[str] = None, 184 | base_resolution: int = 32**2, 185 | target_shift_terminal: Optional[float] = None, 186 | sampler: Optional[str] = "Uniform", 187 | shift: Optional[float] = None, 188 | ): 189 | super().__init__() 190 | self.init_noise_sigma = 1.0 191 | self.num_inference_steps = None 192 | self.sampler = sampler 193 | self.shifting = shifting 194 | self.base_resolution = base_resolution 195 | self.target_shift_terminal = target_shift_terminal 196 | self.timesteps = self.sigmas = self.get_initial_timesteps( 197 | num_train_timesteps, shift=shift 198 | ) 199 | self.shift = shift 200 | 201 | def get_initial_timesteps( 202 | self, num_timesteps: int, shift: Optional[float] = None 203 | ) -> Tensor: 204 | if self.sampler == "Uniform": 205 | return torch.linspace(1, 1 / num_timesteps, num_timesteps) 206 | elif self.sampler == "LinearQuadratic": 207 | return linear_quadratic_schedule(num_timesteps) 208 | elif self.sampler == "Constant": 209 | assert ( 210 | shift is not None 211 | ), "Shift must be provided for constant time shift sampler." 212 | return time_shift( 213 | shift, 1, torch.linspace(1, 1 / num_timesteps, num_timesteps) 214 | ) 215 | 216 | def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor: 217 | if self.shifting == "SD3": 218 | return sd3_resolution_dependent_timestep_shift( 219 | samples_shape, timesteps, self.target_shift_terminal 220 | ) 221 | elif self.shifting == "SimpleDiffusion": 222 | return simple_diffusion_resolution_dependent_timestep_shift( 223 | samples_shape, timesteps, self.base_resolution 224 | ) 225 | return timesteps 226 | 227 | def set_timesteps( 228 | self, 229 | num_inference_steps: Optional[int] = None, 230 | samples_shape: Optional[torch.Size] = None, 231 | timesteps: Optional[Tensor] = None, 232 | device: Union[str, torch.device] = None, 233 | ): 234 | """ 235 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 236 | If `timesteps` are provided, they will be used instead of the scheduled timesteps. 237 | 238 | Args: 239 | num_inference_steps (`int` *optional*): The number of diffusion steps used when generating samples. 240 | samples_shape (`torch.Size` *optional*): The samples batch shape, used for shifting. 241 | timesteps ('torch.Tensor' *optional*): Specific timesteps to use instead of scheduled timesteps. 242 | device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved. 243 | """ 244 | if timesteps is not None and num_inference_steps is not None: 245 | raise ValueError( 246 | "You cannot provide both `timesteps` and `num_inference_steps`." 247 | ) 248 | if timesteps is None: 249 | num_inference_steps = min( 250 | self.config.num_train_timesteps, num_inference_steps 251 | ) 252 | timesteps = self.get_initial_timesteps( 253 | num_inference_steps, shift=self.shift 254 | ).to(device) 255 | timesteps = self.shift_timesteps(samples_shape, timesteps) 256 | else: 257 | timesteps = torch.Tensor(timesteps).to(device) 258 | num_inference_steps = len(timesteps) 259 | self.timesteps = timesteps 260 | self.num_inference_steps = num_inference_steps 261 | self.sigmas = self.timesteps 262 | 263 | @staticmethod 264 | def from_pretrained(pretrained_model_path: Union[str, os.PathLike]): 265 | pretrained_model_path = Path(pretrained_model_path) 266 | if pretrained_model_path.is_file(): 267 | comfy_single_file_state_dict = {} 268 | with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: 269 | metadata = f.metadata() 270 | for k in f.keys(): 271 | comfy_single_file_state_dict[k] = f.get_tensor(k) 272 | configs = json.loads(metadata["config"]) 273 | config = configs["scheduler"] 274 | del comfy_single_file_state_dict 275 | 276 | elif pretrained_model_path.is_dir(): 277 | diffusers_noise_scheduler_config_path = ( 278 | pretrained_model_path / "scheduler" / "scheduler_config.json" 279 | ) 280 | 281 | with open(diffusers_noise_scheduler_config_path, "r") as f: 282 | scheduler_config = json.load(f) 283 | hashable_config = make_hashable_key(scheduler_config) 284 | if hashable_config in diffusers_and_ours_config_mapping: 285 | config = diffusers_and_ours_config_mapping[hashable_config] 286 | return RectifiedFlowScheduler.from_config(config) 287 | 288 | def scale_model_input( 289 | self, sample: torch.FloatTensor, timestep: Optional[int] = None 290 | ) -> torch.FloatTensor: 291 | # pylint: disable=unused-argument 292 | """ 293 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 294 | current timestep. 295 | 296 | Args: 297 | sample (`torch.FloatTensor`): input sample 298 | timestep (`int`, optional): current timestep 299 | 300 | Returns: 301 | `torch.FloatTensor`: scaled input sample 302 | """ 303 | return sample 304 | 305 | def step( 306 | self, 307 | model_output: torch.FloatTensor, 308 | timestep: torch.FloatTensor, 309 | sample: torch.FloatTensor, 310 | return_dict: bool = True, 311 | stochastic_sampling: Optional[bool] = False, 312 | **kwargs, 313 | ) -> Union[RectifiedFlowSchedulerOutput, Tuple]: 314 | """ 315 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 316 | process from the learned model outputs (most often the predicted noise). 317 | z_{t_1} = z_t - \Delta_t * v 318 | The method finds the next timestep that is lower than the input timestep(s) and denoises the latents 319 | to that level. The input timestep(s) are not required to be one of the predefined timesteps. 320 | 321 | Args: 322 | model_output (`torch.FloatTensor`): 323 | The direct output from learned diffusion model - the velocity, 324 | timestep (`float`): 325 | The current discrete timestep in the diffusion chain (global or per-token). 326 | sample (`torch.FloatTensor`): 327 | A current latent tokens to be de-noised. 328 | return_dict (`bool`, *optional*, defaults to `True`): 329 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. 330 | stochastic_sampling (`bool`, *optional*, defaults to `False`): 331 | Whether to use stochastic sampling for the sampling process. 332 | 333 | Returns: 334 | [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`: 335 | If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned, 336 | otherwise a tuple is returned where the first element is the sample tensor. 337 | """ 338 | if self.num_inference_steps is None: 339 | raise ValueError( 340 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 341 | ) 342 | t_eps = 1e-6 # Small epsilon to avoid numerical issues in timestep values 343 | 344 | timesteps_padded = torch.cat( 345 | [self.timesteps, torch.zeros(1, device=self.timesteps.device)] 346 | ) 347 | 348 | # Find the next lower timestep(s) and compute the dt from the current timestep(s) 349 | if timestep.ndim == 0: 350 | # Global timestep case 351 | lower_mask = timesteps_padded < timestep - t_eps 352 | lower_timestep = timesteps_padded[lower_mask][0] # Closest lower timestep 353 | dt = timestep - lower_timestep 354 | 355 | else: 356 | # Per-token case 357 | assert timestep.ndim == 2 358 | lower_mask = timesteps_padded[:, None, None] < timestep[None] - t_eps 359 | lower_timestep = lower_mask * timesteps_padded[:, None, None] 360 | lower_timestep, _ = lower_timestep.max(dim=0) 361 | dt = (timestep - lower_timestep)[..., None] 362 | 363 | # Compute previous sample 364 | if stochastic_sampling: 365 | x0 = sample - timestep[..., None] * model_output 366 | next_timestep = timestep[..., None] - dt 367 | prev_sample = self.add_noise(x0, torch.randn_like(sample), next_timestep) 368 | else: 369 | prev_sample = sample - dt * model_output 370 | 371 | if not return_dict: 372 | return (prev_sample,) 373 | 374 | return RectifiedFlowSchedulerOutput(prev_sample=prev_sample) 375 | 376 | def add_noise( 377 | self, 378 | original_samples: torch.FloatTensor, 379 | noise: torch.FloatTensor, 380 | timesteps: torch.FloatTensor, 381 | ) -> torch.FloatTensor: 382 | sigmas = timesteps 383 | sigmas = append_dims(sigmas, original_samples.ndim) 384 | alphas = 1 - sigmas 385 | noisy_samples = alphas * original_samples + sigmas * noise 386 | return noisy_samples 387 | -------------------------------------------------------------------------------- /ltx_video/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lightricks/LTX-Video/34625c3a1fb4e9a7d2f091938f6989816343e270/ltx_video/utils/__init__.py -------------------------------------------------------------------------------- /ltx_video/utils/diffusers_config_mapping.py: -------------------------------------------------------------------------------- 1 | def make_hashable_key(dict_key): 2 | def convert_value(value): 3 | if isinstance(value, list): 4 | return tuple(value) 5 | elif isinstance(value, dict): 6 | return tuple(sorted((k, convert_value(v)) for k, v in value.items())) 7 | else: 8 | return value 9 | 10 | return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) 11 | 12 | 13 | DIFFUSERS_SCHEDULER_CONFIG = { 14 | "_class_name": "FlowMatchEulerDiscreteScheduler", 15 | "_diffusers_version": "0.32.0.dev0", 16 | "base_image_seq_len": 1024, 17 | "base_shift": 0.95, 18 | "invert_sigmas": False, 19 | "max_image_seq_len": 4096, 20 | "max_shift": 2.05, 21 | "num_train_timesteps": 1000, 22 | "shift": 1.0, 23 | "shift_terminal": 0.1, 24 | "use_beta_sigmas": False, 25 | "use_dynamic_shifting": True, 26 | "use_exponential_sigmas": False, 27 | "use_karras_sigmas": False, 28 | } 29 | DIFFUSERS_TRANSFORMER_CONFIG = { 30 | "_class_name": "LTXVideoTransformer3DModel", 31 | "_diffusers_version": "0.32.0.dev0", 32 | "activation_fn": "gelu-approximate", 33 | "attention_bias": True, 34 | "attention_head_dim": 64, 35 | "attention_out_bias": True, 36 | "caption_channels": 4096, 37 | "cross_attention_dim": 2048, 38 | "in_channels": 128, 39 | "norm_elementwise_affine": False, 40 | "norm_eps": 1e-06, 41 | "num_attention_heads": 32, 42 | "num_layers": 28, 43 | "out_channels": 128, 44 | "patch_size": 1, 45 | "patch_size_t": 1, 46 | "qk_norm": "rms_norm_across_heads", 47 | } 48 | DIFFUSERS_VAE_CONFIG = { 49 | "_class_name": "AutoencoderKLLTXVideo", 50 | "_diffusers_version": "0.32.0.dev0", 51 | "block_out_channels": [128, 256, 512, 512], 52 | "decoder_causal": False, 53 | "encoder_causal": True, 54 | "in_channels": 3, 55 | "latent_channels": 128, 56 | "layers_per_block": [4, 3, 3, 3, 4], 57 | "out_channels": 3, 58 | "patch_size": 4, 59 | "patch_size_t": 1, 60 | "resnet_norm_eps": 1e-06, 61 | "scaling_factor": 1.0, 62 | "spatio_temporal_scaling": [True, True, True, False], 63 | } 64 | 65 | OURS_SCHEDULER_CONFIG = { 66 | "_class_name": "RectifiedFlowScheduler", 67 | "_diffusers_version": "0.25.1", 68 | "num_train_timesteps": 1000, 69 | "shifting": "SD3", 70 | "base_resolution": None, 71 | "target_shift_terminal": 0.1, 72 | } 73 | 74 | OURS_TRANSFORMER_CONFIG = { 75 | "_class_name": "Transformer3DModel", 76 | "_diffusers_version": "0.25.1", 77 | "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", 78 | "activation_fn": "gelu-approximate", 79 | "attention_bias": True, 80 | "attention_head_dim": 64, 81 | "attention_type": "default", 82 | "caption_channels": 4096, 83 | "cross_attention_dim": 2048, 84 | "double_self_attention": False, 85 | "dropout": 0.0, 86 | "in_channels": 128, 87 | "norm_elementwise_affine": False, 88 | "norm_eps": 1e-06, 89 | "norm_num_groups": 32, 90 | "num_attention_heads": 32, 91 | "num_embeds_ada_norm": 1000, 92 | "num_layers": 28, 93 | "num_vector_embeds": None, 94 | "only_cross_attention": False, 95 | "out_channels": 128, 96 | "project_to_2d_pos": True, 97 | "upcast_attention": False, 98 | "use_linear_projection": False, 99 | "qk_norm": "rms_norm", 100 | "standardization_norm": "rms_norm", 101 | "positional_embedding_type": "rope", 102 | "positional_embedding_theta": 10000.0, 103 | "positional_embedding_max_pos": [20, 2048, 2048], 104 | "timestep_scale_multiplier": 1000, 105 | } 106 | OURS_VAE_CONFIG = { 107 | "_class_name": "CausalVideoAutoencoder", 108 | "dims": 3, 109 | "in_channels": 3, 110 | "out_channels": 3, 111 | "latent_channels": 128, 112 | "blocks": [ 113 | ["res_x", 4], 114 | ["compress_all", 1], 115 | ["res_x_y", 1], 116 | ["res_x", 3], 117 | ["compress_all", 1], 118 | ["res_x_y", 1], 119 | ["res_x", 3], 120 | ["compress_all", 1], 121 | ["res_x", 3], 122 | ["res_x", 4], 123 | ], 124 | "scaling_factor": 1.0, 125 | "norm_layer": "pixel_norm", 126 | "patch_size": 4, 127 | "latent_log_var": "uniform", 128 | "use_quant_conv": False, 129 | "causal_decoder": False, 130 | } 131 | 132 | 133 | diffusers_and_ours_config_mapping = { 134 | make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, 135 | make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, 136 | make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, 137 | } 138 | 139 | 140 | TRANSFORMER_KEYS_RENAME_DICT = { 141 | "proj_in": "patchify_proj", 142 | "time_embed": "adaln_single", 143 | "norm_q": "q_norm", 144 | "norm_k": "k_norm", 145 | } 146 | 147 | 148 | VAE_KEYS_RENAME_DICT = { 149 | "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", 150 | "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", 151 | "decoder.up_blocks.3": "decoder.up_blocks.9", 152 | "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", 153 | "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", 154 | "decoder.up_blocks.2": "decoder.up_blocks.6", 155 | "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", 156 | "decoder.up_blocks.1": "decoder.up_blocks.3", 157 | "decoder.up_blocks.0": "decoder.up_blocks.1", 158 | "decoder.mid_block": "decoder.up_blocks.0", 159 | "encoder.down_blocks.3": "encoder.down_blocks.8", 160 | "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", 161 | "encoder.down_blocks.2": "encoder.down_blocks.6", 162 | "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", 163 | "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", 164 | "encoder.down_blocks.1": "encoder.down_blocks.3", 165 | "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", 166 | "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", 167 | "encoder.down_blocks.0": "encoder.down_blocks.0", 168 | "encoder.mid_block": "encoder.down_blocks.9", 169 | "conv_shortcut.conv": "conv_shortcut", 170 | "resnets": "res_blocks", 171 | "norm3": "norm3.norm", 172 | "latents_mean": "per_channel_statistics.mean-of-means", 173 | "latents_std": "per_channel_statistics.std-of-means", 174 | } 175 | -------------------------------------------------------------------------------- /ltx_video/utils/prompt_enhance_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union, List, Optional 3 | 4 | import torch 5 | from PIL import Image 6 | 7 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 8 | 9 | T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. 10 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. 11 | Start directly with the action, and keep descriptions literal and precise. 12 | Think like a cinematographer describing a shot list. 13 | Do not change the user input intent, just enhance it. 14 | Keep within 150 words. 15 | For best results, build your prompts using this structure: 16 | Start with main action in a single sentence 17 | Add specific details about movements and gestures 18 | Describe character/object appearances precisely 19 | Include background and environment details 20 | Specify camera angles and movements 21 | Describe lighting and colors 22 | Note any changes or sudden events 23 | Do not exceed the 150 word limit! 24 | Output the enhanced prompt only. 25 | """ 26 | 27 | I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes. 28 | Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. 29 | Start directly with the action, and keep descriptions literal and precise. 30 | Think like a cinematographer describing a shot list. 31 | Keep within 150 words. 32 | For best results, build your prompts using this structure: 33 | Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input. 34 | Start with main action in a single sentence 35 | Add specific details about movements and gestures 36 | Describe character/object appearances precisely 37 | Include background and environment details 38 | Specify camera angles and movements 39 | Describe lighting and colors 40 | Note any changes or sudden events 41 | Align to the image caption if it contradicts the user text input. 42 | Do not exceed the 150 word limit! 43 | Output the enhanced prompt only. 44 | """ 45 | 46 | 47 | def tensor_to_pil(tensor): 48 | # Ensure tensor is in range [-1, 1] 49 | assert tensor.min() >= -1 and tensor.max() <= 1 50 | 51 | # Convert from [-1, 1] to [0, 1] 52 | tensor = (tensor + 1) / 2 53 | 54 | # Rearrange from [C, H, W] to [H, W, C] 55 | tensor = tensor.permute(1, 2, 0) 56 | 57 | # Convert to numpy array and then to uint8 range [0, 255] 58 | numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") 59 | 60 | # Convert to PIL Image 61 | return Image.fromarray(numpy_image) 62 | 63 | 64 | def generate_cinematic_prompt( 65 | image_caption_model, 66 | image_caption_processor, 67 | prompt_enhancer_model, 68 | prompt_enhancer_tokenizer, 69 | prompt: Union[str, List[str]], 70 | conditioning_items: Optional[List] = None, 71 | max_new_tokens: int = 256, 72 | ) -> List[str]: 73 | prompts = [prompt] if isinstance(prompt, str) else prompt 74 | 75 | if conditioning_items is None: 76 | prompts = _generate_t2v_prompt( 77 | prompt_enhancer_model, 78 | prompt_enhancer_tokenizer, 79 | prompts, 80 | max_new_tokens, 81 | T2V_CINEMATIC_PROMPT, 82 | ) 83 | else: 84 | if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0: 85 | logger.warning( 86 | "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts" 87 | ) 88 | return prompts 89 | 90 | first_frame_conditioning_item = conditioning_items[0] 91 | first_frames = _get_first_frames_from_conditioning_item( 92 | first_frame_conditioning_item 93 | ) 94 | 95 | assert len(first_frames) == len( 96 | prompts 97 | ), "Number of conditioning frames must match number of prompts" 98 | 99 | prompts = _generate_i2v_prompt( 100 | image_caption_model, 101 | image_caption_processor, 102 | prompt_enhancer_model, 103 | prompt_enhancer_tokenizer, 104 | prompts, 105 | first_frames, 106 | max_new_tokens, 107 | I2V_CINEMATIC_PROMPT, 108 | ) 109 | 110 | return prompts 111 | 112 | 113 | def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]: 114 | frames_tensor = conditioning_item.media_item 115 | return [ 116 | tensor_to_pil(frames_tensor[i, :, 0, :, :]) 117 | for i in range(frames_tensor.shape[0]) 118 | ] 119 | 120 | 121 | def _generate_t2v_prompt( 122 | prompt_enhancer_model, 123 | prompt_enhancer_tokenizer, 124 | prompts: List[str], 125 | max_new_tokens: int, 126 | system_prompt: str, 127 | ) -> List[str]: 128 | messages = [ 129 | [ 130 | {"role": "system", "content": system_prompt}, 131 | {"role": "user", "content": f"user_prompt: {p}"}, 132 | ] 133 | for p in prompts 134 | ] 135 | 136 | texts = [ 137 | prompt_enhancer_tokenizer.apply_chat_template( 138 | m, tokenize=False, add_generation_prompt=True 139 | ) 140 | for m in messages 141 | ] 142 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( 143 | prompt_enhancer_model.device 144 | ) 145 | 146 | return _generate_and_decode_prompts( 147 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens 148 | ) 149 | 150 | 151 | def _generate_i2v_prompt( 152 | image_caption_model, 153 | image_caption_processor, 154 | prompt_enhancer_model, 155 | prompt_enhancer_tokenizer, 156 | prompts: List[str], 157 | first_frames: List[Image.Image], 158 | max_new_tokens: int, 159 | system_prompt: str, 160 | ) -> List[str]: 161 | image_captions = _generate_image_captions( 162 | image_caption_model, image_caption_processor, first_frames 163 | ) 164 | 165 | messages = [ 166 | [ 167 | {"role": "system", "content": system_prompt}, 168 | {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"}, 169 | ] 170 | for p, c in zip(prompts, image_captions) 171 | ] 172 | 173 | texts = [ 174 | prompt_enhancer_tokenizer.apply_chat_template( 175 | m, tokenize=False, add_generation_prompt=True 176 | ) 177 | for m in messages 178 | ] 179 | model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to( 180 | prompt_enhancer_model.device 181 | ) 182 | 183 | return _generate_and_decode_prompts( 184 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens 185 | ) 186 | 187 | 188 | def _generate_image_captions( 189 | image_caption_model, 190 | image_caption_processor, 191 | images: List[Image.Image], 192 | system_prompt: str = "", 193 | ) -> List[str]: 194 | image_caption_prompts = [system_prompt] * len(images) 195 | inputs = image_caption_processor( 196 | image_caption_prompts, images, return_tensors="pt" 197 | ).to(image_caption_model.device) 198 | 199 | with torch.inference_mode(): 200 | generated_ids = image_caption_model.generate( 201 | input_ids=inputs["input_ids"], 202 | pixel_values=inputs["pixel_values"], 203 | max_new_tokens=1024, 204 | do_sample=False, 205 | num_beams=3, 206 | ) 207 | 208 | return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True) 209 | 210 | 211 | def _generate_and_decode_prompts( 212 | prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int 213 | ) -> List[str]: 214 | with torch.inference_mode(): 215 | outputs = prompt_enhancer_model.generate( 216 | **model_inputs, max_new_tokens=max_new_tokens 217 | ) 218 | generated_ids = [ 219 | output_ids[len(input_ids) :] 220 | for input_ids, output_ids in zip(model_inputs.input_ids, outputs) 221 | ] 222 | decoded_prompts = prompt_enhancer_tokenizer.batch_decode( 223 | generated_ids, skip_special_tokens=True 224 | ) 225 | 226 | return decoded_prompts 227 | -------------------------------------------------------------------------------- /ltx_video/utils/skip_layer_strategy.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class SkipLayerStrategy(Enum): 5 | AttentionSkip = auto() 6 | AttentionValues = auto() 7 | Residual = auto() 8 | TransformerBlock = auto() 9 | -------------------------------------------------------------------------------- /ltx_video/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 7 | dims_to_append = target_dims - x.ndim 8 | if dims_to_append < 0: 9 | raise ValueError( 10 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 11 | ) 12 | elif dims_to_append == 0: 13 | return x 14 | return x[(...,) + (None,) * dims_to_append] 15 | 16 | 17 | class Identity(nn.Module): 18 | """A placeholder identity operator that is argument-insensitive.""" 19 | 20 | def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument 21 | super().__init__() 22 | 23 | # pylint: disable=unused-argument 24 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 25 | return x 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ltx-video" 7 | version = "0.1.2" 8 | description = "A package for LTX-Video model" 9 | authors = [ 10 | { name = "Sapir Weissbuch", email = "sapir@lightricks.com" } 11 | ] 12 | requires-python = ">=3.10" 13 | readme = "README.md" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: OS Independent" 17 | ] 18 | dependencies = [ 19 | "torch>=2.1.0", 20 | "diffusers>=0.28.2", 21 | "transformers>=4.47.2", 22 | "sentencepiece>=0.1.96", 23 | "huggingface-hub~=0.30", 24 | "einops", 25 | "timm" 26 | ] 27 | 28 | [project.optional-dependencies] 29 | # Instead of thinking of them as optional, think of them as specific modes 30 | inference-script = [ 31 | "accelerate", 32 | "matplotlib", 33 | "imageio[ffmpeg]", 34 | "av", 35 | "opencv-python" 36 | ] 37 | test = [ 38 | "pytest", 39 | ] 40 | 41 | [tool.setuptools.packages.find] 42 | include = ["ltx_video*"] -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import shutil 3 | import os 4 | 5 | 6 | @pytest.fixture(scope="session") 7 | def test_paths(request, pytestconfig): 8 | try: 9 | output_path = "output" 10 | ckpt_path = request.param # This will get the current parameterized item 11 | text_encoder_model_name_or_path = pytestconfig.getoption( 12 | "text_encoder_model_name_or_path" 13 | ) 14 | input_image_path = pytestconfig.getoption("input_image_path") 15 | input_video_path = pytestconfig.getoption("input_video_path") 16 | prompt_enhancer_image_caption_model_name_or_path = pytestconfig.getoption( 17 | "prompt_enhancer_image_caption_model_name_or_path" 18 | ) 19 | prompt_enhancer_llm_model_name_or_path = pytestconfig.getoption( 20 | "prompt_enhancer_llm_model_name_or_path" 21 | ) 22 | prompt_enhancement_words_threshold = pytestconfig.getoption( 23 | "prompt_enhancement_words_threshold" 24 | ) 25 | 26 | config = { 27 | "ckpt_path": ckpt_path, 28 | "input_image_path": input_image_path, 29 | "input_video_path": input_video_path, 30 | "output_path": output_path, 31 | "text_encoder_model_name_or_path": text_encoder_model_name_or_path, 32 | "prompt_enhancer_image_caption_model_name_or_path": prompt_enhancer_image_caption_model_name_or_path, 33 | "prompt_enhancer_llm_model_name_or_path": prompt_enhancer_llm_model_name_or_path, 34 | "prompt_enhancement_words_threshold": prompt_enhancement_words_threshold, 35 | } 36 | 37 | yield config 38 | 39 | finally: 40 | if os.path.exists(output_path): 41 | shutil.rmtree(output_path) 42 | 43 | 44 | def pytest_generate_tests(metafunc): 45 | if "test_paths" in metafunc.fixturenames: 46 | ckpt_paths = metafunc.config.getoption("ckpt_path") 47 | metafunc.parametrize("test_paths", ckpt_paths, indirect=True) 48 | 49 | 50 | def pytest_addoption(parser): 51 | parser.addoption( 52 | "--ckpt_path", 53 | action="append", 54 | default=[], 55 | help="Path to checkpoint files (can specify multiple)", 56 | ) 57 | parser.addoption( 58 | "--text_encoder_model_name_or_path", 59 | action="store", 60 | default="PixArt-alpha/PixArt-XL-2-1024-MS", 61 | help="Path to the checkpoint file", 62 | ) 63 | parser.addoption( 64 | "--input_image_path", 65 | action="store", 66 | default="tests/utils/woman.jpeg", 67 | help="Path to input image file.", 68 | ) 69 | parser.addoption( 70 | "--input_video_path", 71 | action="store", 72 | default="tests/utils/woman.mp4", 73 | help="Path to input video file.", 74 | ) 75 | parser.addoption( 76 | "--prompt_enhancer_image_caption_model_name_or_path", 77 | action="store", 78 | default="MiaoshouAI/Florence-2-large-PromptGen-v2.0", 79 | help="Path to prompt_enhancer_image_caption_model.", 80 | ) 81 | parser.addoption( 82 | "--prompt_enhancer_llm_model_name_or_path", 83 | action="store", 84 | default="unsloth/Llama-3.2-3B-Instruct", 85 | help="Path to LLM model for prompt enhancement.", 86 | ) 87 | parser.addoption( 88 | "--prompt_enhancement_words_threshold", 89 | type=int, 90 | default=50, 91 | help="Enable prompt enhancement only if input prompt has fewer words than this threshold. Set to 0 to disable enhancement completely.", 92 | ) 93 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import yaml 4 | from inference import infer, create_ltx_video_pipeline 5 | from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy 6 | 7 | 8 | def pytest_make_parametrize_id(config, val, argname): 9 | if isinstance(val, str): 10 | return f"{argname}-{val}" 11 | return f"{argname}-{repr(val)}" 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "conditioning_test_mode", 16 | ["unconditional", "first-frame", "first-sequence", "sequence-and-frame"], 17 | ids=lambda x: f"conditioning_test_mode={x}", 18 | ) 19 | def test_infer_runs_on_real_path(tmp_path, test_paths, conditioning_test_mode): 20 | conditioning_params = {} 21 | if conditioning_test_mode == "unconditional": 22 | pass 23 | elif conditioning_test_mode == "first-frame": 24 | conditioning_params["conditioning_media_paths"] = [ 25 | test_paths["input_image_path"] 26 | ] 27 | conditioning_params["conditioning_start_frames"] = [0] 28 | elif conditioning_test_mode == "first-sequence": 29 | conditioning_params["conditioning_media_paths"] = [ 30 | test_paths["input_video_path"] 31 | ] 32 | conditioning_params["conditioning_start_frames"] = [0] 33 | elif conditioning_test_mode == "sequence-and-frame": 34 | conditioning_params["conditioning_media_paths"] = [ 35 | test_paths["input_video_path"], 36 | test_paths["input_image_path"], 37 | ] 38 | conditioning_params["conditioning_start_frames"] = [16, 67] 39 | else: 40 | raise ValueError(f"Unknown conditioning mode: {conditioning_test_mode}") 41 | test_paths = { 42 | k: v 43 | for k, v in test_paths.items() 44 | if k not in ["input_image_path", "input_video_path"] 45 | } 46 | 47 | params = { 48 | "seed": 42, 49 | "num_inference_steps": 1, 50 | "height": 512, 51 | "width": 768, 52 | "num_frames": 121, 53 | "frame_rate": 25, 54 | "prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.", 55 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", 56 | "offload_to_cpu": False, 57 | "output_path": tmp_path, 58 | "image_cond_noise_scale": 0.15, 59 | } 60 | 61 | config = { 62 | "pipeline_type": "base", 63 | "num_images_per_prompt": 1, 64 | "guidance_scale": 2.5, 65 | "stg_scale": 1, 66 | "stg_rescale": 0.7, 67 | "stg_mode": "attention_values", 68 | "stg_skip_layers": "1,2,3", 69 | "precision": "bfloat16", 70 | "decode_timestep": 0.05, 71 | "decode_noise_scale": 0.025, 72 | "checkpoint_path": test_paths["ckpt_path"], 73 | "text_encoder_model_name_or_path": test_paths[ 74 | "text_encoder_model_name_or_path" 75 | ], 76 | "prompt_enhancer_image_caption_model_name_or_path": test_paths[ 77 | "prompt_enhancer_image_caption_model_name_or_path" 78 | ], 79 | "prompt_enhancer_llm_model_name_or_path": test_paths[ 80 | "prompt_enhancer_llm_model_name_or_path" 81 | ], 82 | "prompt_enhancement_words_threshold": 120, 83 | "stochastic_sampling": False, 84 | "sampler": "from_checkpoint", 85 | } 86 | 87 | temp_config_path = tmp_path / "config.yaml" 88 | with open(temp_config_path, "w") as f: 89 | yaml.dump(config, f) 90 | 91 | infer(**{**conditioning_params, **params, "pipeline_config": temp_config_path}) 92 | 93 | 94 | def test_vid2vid(tmp_path, test_paths): 95 | params = { 96 | "seed": 42, 97 | "image_cond_noise_scale": 0.15, 98 | "height": 512, 99 | "width": 768, 100 | "num_frames": 25, 101 | "frame_rate": 25, 102 | "prompt": "A young woman with wavy, shoulder-length light brown hair stands outdoors on a foggy day. She wears a cozy pink turtleneck sweater, with a serene expression and piercing blue eyes. A wooden fence and a misty, grassy field fade into the background, evoking a calm and introspective mood.", 103 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", 104 | "offload_to_cpu": False, 105 | "input_media_path": test_paths["input_video_path"], 106 | } 107 | 108 | config = { 109 | "num_inference_steps": 3, 110 | "skip_initial_inference_steps": 1, 111 | "guidance_scale": 2.5, 112 | "stg_scale": 1, 113 | "stg_rescale": 0.7, 114 | "stg_mode": "attention_values", 115 | "stg_skip_layers": "1,2,3", 116 | "precision": "bfloat16", 117 | "decode_timestep": 0.05, 118 | "decode_noise_scale": 0.025, 119 | "sampler": "from_checkpoint", 120 | "checkpoint_path": test_paths["ckpt_path"], 121 | "text_encoder_model_name_or_path": test_paths[ 122 | "text_encoder_model_name_or_path" 123 | ], 124 | "prompt_enhancer_image_caption_model_name_or_path": test_paths[ 125 | "prompt_enhancer_image_caption_model_name_or_path" 126 | ], 127 | "prompt_enhancer_llm_model_name_or_path": test_paths[ 128 | "prompt_enhancer_llm_model_name_or_path" 129 | ], 130 | "prompt_enhancement_words_threshold": 120, 131 | } 132 | test_paths = { 133 | k: v 134 | for k, v in test_paths.items() 135 | if k not in ["input_image_path", "input_video_path"] 136 | } 137 | temp_config_path = tmp_path / "config.yaml" 138 | with open(temp_config_path, "w") as f: 139 | yaml.dump(config, f) 140 | 141 | infer(**{**test_paths, **params, "pipeline_config": temp_config_path}) 142 | 143 | 144 | def get_device(): 145 | if torch.cuda.is_available(): 146 | return "cuda" 147 | elif torch.backends.mps.is_available(): 148 | return "mps" 149 | return "cpu" 150 | 151 | 152 | def test_pipeline_on_batch(tmp_path, test_paths): 153 | device = get_device() 154 | pipeline = create_ltx_video_pipeline( 155 | ckpt_path=test_paths["ckpt_path"], 156 | device=device, 157 | precision="bfloat16", 158 | text_encoder_model_name_or_path=test_paths["text_encoder_model_name_or_path"], 159 | enhance_prompt=False, 160 | prompt_enhancer_image_caption_model_name_or_path=test_paths[ 161 | "prompt_enhancer_image_caption_model_name_or_path" 162 | ], 163 | prompt_enhancer_llm_model_name_or_path=test_paths[ 164 | "prompt_enhancer_llm_model_name_or_path" 165 | ], 166 | ) 167 | 168 | params = { 169 | "seed": 42, 170 | "image_cond_noise_scale": 0.15, 171 | "height": 512, 172 | "width": 768, 173 | "num_frames": 1, 174 | "frame_rate": 25, 175 | "offload_to_cpu": False, 176 | "output_type": "pt", 177 | "is_video": False, 178 | "vae_per_channel_normalize": True, 179 | "mixed_precision": False, 180 | } 181 | 182 | config = { 183 | "num_inference_steps": 2, 184 | "guidance_scale": 2.5, 185 | "stg_scale": 1, 186 | "rescaling_scale": 0.7, 187 | "skip_block_list": [1, 2], 188 | "decode_timestep": 0.05, 189 | "decode_noise_scale": 0.025, 190 | } 191 | 192 | temp_config_path = tmp_path / "config.yaml" 193 | with open(temp_config_path, "w") as f: 194 | yaml.dump(config, f) 195 | 196 | first_prompt = "A vintage yellow car drives along a wet mountain road, its rear wheels kicking up a light spray as it moves. The camera follows close behind, capturing the curvature of the road as it winds through rocky cliffs and lush green hills. The sunlight pierces through scattered clouds, reflecting off the car's rain-speckled surface, creating a dynamic, cinematic moment. The scene conveys a sense of freedom and exploration as the car disappears into the distance." 197 | second_prompt = "A woman with blonde hair styled up, wearing a black dress with sequins and pearl earrings, looks down with a sad expression on her face. The camera remains stationary, focused on the woman's face. The lighting is dim, casting soft shadows on her face. The scene appears to be from a movie or TV show." 198 | 199 | sample = { 200 | "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", 201 | "prompt_attention_mask": None, 202 | "negative_prompt_attention_mask": None, 203 | "media_items": None, 204 | } 205 | 206 | def get_images(prompts): 207 | generators = [ 208 | torch.Generator(device=device).manual_seed(params["seed"]) for _ in range(2) 209 | ] 210 | torch.manual_seed(params["seed"]) 211 | 212 | images = pipeline( 213 | prompt=prompts, 214 | generator=generators, 215 | **sample, 216 | **params, 217 | pipeline_config=temp_config_path, 218 | ).images 219 | return images 220 | 221 | batch_diff_images = get_images([first_prompt, second_prompt]) 222 | batch_same_images = get_images([second_prompt, second_prompt]) 223 | 224 | # Take the second image from both runs 225 | image2_not_same = batch_diff_images[1, :, 0, :, :] 226 | image2_same = batch_same_images[1, :, 0, :, :] 227 | 228 | # Compute mean absolute difference, should be 0 229 | mad = torch.mean(torch.abs(image2_not_same - image2_same)).item() 230 | print(f"Mean absolute difference: {mad}") 231 | 232 | assert torch.allclose(image2_not_same, image2_same) 233 | 234 | 235 | def test_prompt_enhancement(tmp_path, test_paths, monkeypatch): 236 | # Create pipeline with prompt enhancement enabled 237 | device = get_device() 238 | pipeline = create_ltx_video_pipeline( 239 | ckpt_path=test_paths["ckpt_path"], 240 | device=device, 241 | precision="bfloat16", 242 | text_encoder_model_name_or_path=test_paths["text_encoder_model_name_or_path"], 243 | enhance_prompt=True, 244 | prompt_enhancer_image_caption_model_name_or_path=test_paths[ 245 | "prompt_enhancer_image_caption_model_name_or_path" 246 | ], 247 | prompt_enhancer_llm_model_name_or_path=test_paths[ 248 | "prompt_enhancer_llm_model_name_or_path" 249 | ], 250 | ) 251 | 252 | original_prompt = "A cat sitting on a windowsill" 253 | 254 | # Mock the pipeline's _encode_prompt method to verify the prompt being used 255 | original_encode_prompt = pipeline.encode_prompt 256 | 257 | prompts_used = [] 258 | 259 | def mock_encode_prompt(prompt, *args, **kwargs): 260 | prompts_used.append(prompt[0] if isinstance(prompt, list) else prompt) 261 | return original_encode_prompt(prompt, *args, **kwargs) 262 | 263 | pipeline.encode_prompt = mock_encode_prompt 264 | 265 | # Set up minimal parameters for a quick test 266 | params = { 267 | "seed": 42, 268 | "image_cond_noise_scale": 0.15, 269 | "height": 512, 270 | "width": 768, 271 | "skip_layer_strategy": SkipLayerStrategy.AttentionValues, 272 | "num_frames": 1, 273 | "frame_rate": 25, 274 | "offload_to_cpu": False, 275 | "output_type": "pt", 276 | "is_video": False, 277 | "vae_per_channel_normalize": True, 278 | "mixed_precision": False, 279 | } 280 | 281 | config = { 282 | "pipeline_type": "base", 283 | "num_inference_steps": 1, 284 | "guidance_scale": 2.5, 285 | "stg_scale": 1, 286 | "rescaling_scale": 0.7, 287 | "skip_block_list": [1, 2], 288 | "decode_timestep": 0.05, 289 | "decode_noise_scale": 0.025, 290 | } 291 | 292 | temp_config_path = tmp_path / "config.yaml" 293 | with open(temp_config_path, "w") as f: 294 | yaml.dump(config, f) 295 | 296 | # Run pipeline with prompt enhancement enabled 297 | _ = pipeline( 298 | prompt=original_prompt, 299 | negative_prompt="worst quality", 300 | enhance_prompt=True, 301 | **params, 302 | pipeline_config=temp_config_path, 303 | ) 304 | 305 | # Verify that the enhanced prompt was used 306 | assert len(prompts_used) > 0 307 | assert ( 308 | prompts_used[0] != original_prompt 309 | ), f"Expected enhanced prompt to be different from original prompt, but got: {original_prompt}" 310 | 311 | # Run pipeline with prompt enhancement disabled 312 | prompts_used.clear() 313 | _ = pipeline( 314 | prompt=original_prompt, 315 | negative_prompt="worst quality", 316 | enhance_prompt=False, 317 | **params, 318 | pipeline_config=temp_config_path, 319 | ) 320 | 321 | # Verify that the original prompt was used 322 | assert len(prompts_used) > 0 323 | assert ( 324 | prompts_used[0] == original_prompt 325 | ), f"Expected original prompt to be used, but got: {prompts_used[0]}" 326 | -------------------------------------------------------------------------------- /tests/test_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ltx_video.schedulers.rf import RectifiedFlowScheduler 4 | 5 | 6 | def init_latents_and_scheduler(sampler): 7 | batch_size, n_tokens, n_channels = 2, 4096, 128 8 | num_steps = 20 9 | scheduler = RectifiedFlowScheduler( 10 | sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic") 11 | ) 12 | latents = torch.randn(size=(batch_size, n_tokens, n_channels)) 13 | scheduler.set_timesteps(num_inference_steps=num_steps, samples_shape=latents.shape) 14 | return scheduler, latents 15 | 16 | 17 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) 18 | def test_scheduler_default_behavior(sampler): 19 | """ 20 | Test the case of a single timestep from the list of timesteps. 21 | """ 22 | scheduler, latents = init_latents_and_scheduler(sampler) 23 | 24 | for i, t in enumerate(scheduler.timesteps): 25 | noise_pred = torch.randn_like(latents) 26 | denoised_latents = scheduler.step( 27 | noise_pred, 28 | t, 29 | latents, 30 | return_dict=False, 31 | )[0] 32 | 33 | # Verify the denoising 34 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 35 | dt = t - next_t 36 | expected_denoised_latents = latents - dt * noise_pred 37 | assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06) 38 | 39 | 40 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) 41 | def test_scheduler_per_token(sampler): 42 | """ 43 | Test the case of a timestep per token (from the list of timesteps). 44 | Some tokens are set with timestep of 0. 45 | """ 46 | scheduler, latents = init_latents_and_scheduler(sampler) 47 | batch_size, n_tokens = latents.shape[:2] 48 | for i, t in enumerate(scheduler.timesteps): 49 | timesteps = torch.full((batch_size, n_tokens), t) 50 | timesteps[:, 0] = 0.0 51 | noise_pred = torch.randn_like(latents) 52 | denoised_latents = scheduler.step( 53 | noise_pred, 54 | timesteps, 55 | latents, 56 | return_dict=False, 57 | )[0] 58 | 59 | # Verify the denoising 60 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 61 | next_timesteps = torch.full((batch_size, n_tokens), next_t) 62 | dt = timesteps - next_timesteps 63 | expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred 64 | assert torch.allclose( 65 | denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06 66 | ) 67 | assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06) 68 | 69 | 70 | @pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) 71 | def test_scheduler_t_not_in_list(sampler): 72 | """ 73 | Test the case of a timestep per token NOT from the list of timesteps. 74 | """ 75 | scheduler, latents = init_latents_and_scheduler(sampler) 76 | batch_size, n_tokens = latents.shape[:2] 77 | for i in range(len(scheduler.timesteps)): 78 | if i < len(scheduler.timesteps) - 1: 79 | t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2 80 | else: 81 | t = scheduler.timesteps[i] / 2 82 | timesteps = torch.full((batch_size, n_tokens), t) 83 | noise_pred = torch.randn_like(latents) 84 | denoised_latents = scheduler.step( 85 | noise_pred, 86 | timesteps, 87 | latents, 88 | return_dict=False, 89 | )[0] 90 | 91 | # Verify the denoising 92 | next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 93 | next_timesteps = torch.full((batch_size, n_tokens), next_t) 94 | dt = timesteps - next_timesteps 95 | expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred 96 | assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06) 97 | -------------------------------------------------------------------------------- /tests/test_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ltx_video.models.autoencoders.causal_video_autoencoder import ( 4 | CausalVideoAutoencoder, 5 | create_video_autoencoder_demo_config, 6 | ) 7 | 8 | 9 | def test_vae(): 10 | # create vae and run a forward and backward pass 11 | config = create_video_autoencoder_demo_config(latent_channels=16) 12 | video_autoencoder = CausalVideoAutoencoder.from_config(config) 13 | video_autoencoder.eval() 14 | input_videos = torch.randn(2, 3, 17, 64, 64) 15 | latent = video_autoencoder.encode(input_videos).latent_dist.mode() 16 | assert latent.shape == (2, 16, 3, 2, 2) 17 | timestep = torch.ones(input_videos.shape[0]) * 0.1 18 | reconstructed_videos = video_autoencoder.decode( 19 | latent, target_shape=input_videos.shape, timestep=timestep 20 | ).sample 21 | assert input_videos.shape == reconstructed_videos.shape 22 | loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) 23 | loss.backward() 24 | 25 | # validate temporal causality in encoder 26 | input_image = input_videos[:, :, :1, :, :] 27 | image_latent = video_autoencoder.encode(input_image).latent_dist.mode() 28 | assert torch.allclose(image_latent, latent[:, :, :1, :, :], atol=1e-6) 29 | 30 | input_sequence = input_videos[:, :, :9, :, :] 31 | sequence_latent = video_autoencoder.encode(input_sequence).latent_dist.mode() 32 | assert torch.allclose(sequence_latent, latent[:, :, :2, :, :], atol=1e-6) 33 | -------------------------------------------------------------------------------- /tests/utils/.gitattributes: -------------------------------------------------------------------------------- 1 | *.mp4 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /tests/utils/woman.jpeg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:acf0a6de9f7fb5551702fa8065ab423811e2011c36338f49636f84b033715df8 3 | size 33249 4 | -------------------------------------------------------------------------------- /tests/utils/woman.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:833c2a0afe03ad60cb0c6361433eb11acf2ed58890f6b411cd306c01069e199f 3 | size 73228 4 | --------------------------------------------------------------------------------