├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── TeaCache4CogVideoX1.5 ├── README.md └── teacache_sample_video.py ├── TeaCache4ConsisID ├── README.md └── teacache_sample_video.py ├── TeaCache4Cosmos ├── README.md ├── teacache_sample_video_i2v.py └── teacache_sample_video_t2v.py ├── TeaCache4FLUX ├── README.md └── teacache_flux.py ├── TeaCache4HiDream-I1 ├── README.md └── teacache_hidream_i1.py ├── TeaCache4HunyuanVideo ├── README.md └── teacache_sample_video.py ├── TeaCache4LTX-Video ├── README.md └── teacache_ltx.py ├── TeaCache4Lumina-T2X ├── README.md └── teacache_lumina_next.py ├── TeaCache4Lumina2 ├── README.md └── teacache_lumina2.py ├── TeaCache4Mochi ├── README.md └── teacache_mochi.py ├── TeaCache4TangoFlux ├── README.md └── teacache_tango_flux.py ├── TeaCache4Wan2.1 ├── README.md └── teacache_generate.py ├── assets ├── TeaCache4FLUX.png ├── TeaCache4HiDream-I1.png ├── TeaCache4LuminaT2X.png └── tisser.png ├── eval └── teacache │ ├── README.md │ ├── common_metrics │ ├── README.md │ ├── __init__.py │ ├── batch_eval.py │ ├── calculate_lpips.py │ ├── calculate_psnr.py │ ├── calculate_ssim.py │ └── eval.py │ ├── experiments │ ├── __init__.py │ ├── cogvideox.py │ ├── latte.py │ ├── opensora.py │ ├── opensora_plan.py │ └── utils.py │ └── vbench │ ├── VBench_full_info.json │ ├── cal_vbench.py │ └── run_vbench.py ├── requirements.txt ├── setup.py └── videosys ├── __init__.py ├── core ├── __init__.py ├── comm.py ├── engine.py ├── mp_utils.py ├── pab_mgr.py ├── parallel_mgr.py ├── pipeline.py └── shardformer │ ├── __init__.py │ └── t5 │ ├── __init__.py │ ├── modeling.py │ └── policy.py ├── models ├── __init__.py ├── autoencoders │ ├── __init__.py │ ├── autoencoder_kl_cogvideox.py │ ├── autoencoder_kl_open_sora.py │ ├── autoencoder_kl_open_sora_plan.py │ ├── autoencoder_kl_open_sora_plan_v110.py │ └── autoencoder_kl_open_sora_plan_v120.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── attentions.py │ ├── downsampling.py │ ├── embeddings.py │ ├── normalization.py │ └── upsampling.py └── transformers │ ├── __init__.py │ ├── cogvideox_transformer_3d.py │ ├── latte_transformer_3d.py │ ├── open_sora_plan_transformer_3d.py │ ├── open_sora_plan_v110_transformer_3d.py │ ├── open_sora_plan_v120_transformer_3d.py │ ├── open_sora_transformer_3d.py │ └── vchitect_transformer_3d.py ├── pipelines ├── __init__.py ├── cogvideox │ ├── __init__.py │ └── pipeline_cogvideox.py ├── latte │ ├── __init__.py │ └── pipeline_latte.py ├── open_sora │ ├── __init__.py │ ├── data_process.py │ └── pipeline_open_sora.py ├── open_sora_plan │ ├── __init__.py │ └── pipeline_open_sora_plan.py └── vchitect │ ├── __init__.py │ └── pipeline_vchitect.py ├── schedulers ├── __init__.py ├── scheduling_ddim_cogvideox.py ├── scheduling_dpm_cogvideox.py └── scheduling_rflow_open_sora.py └── utils ├── __init__.py ├── logging.py ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | processed/ 3 | profile/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | docs/.build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # IDE 137 | .idea/ 138 | .vscode/ 139 | 140 | # macos 141 | *.DS_Store 142 | #data/ 143 | 144 | docs/.build 145 | 146 | # pytorch checkpoint 147 | *.pt 148 | 149 | # ignore any kernel build files 150 | .o 151 | .so 152 | 153 | # ignore python interface defition file 154 | .pyi 155 | 156 | # ignore coverage test file 157 | coverage.lcov 158 | coverage.xml 159 | 160 | # ignore testmon and coverage files 161 | .coverage 162 | .testmondata* 163 | 164 | pretrained 165 | samples 166 | cache_dir 167 | test_outputs 168 | datasets 169 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length = 120 3 | multi_line_output=3 4 | include_trailing_comma = true 5 | ignore_comments = true 6 | profile = black 7 | honor_noqa = true 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/PyCQA/autoflake 4 | rev: v2.2.1 5 | hooks: 6 | - id: autoflake 7 | name: autoflake (python) 8 | args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] 9 | 10 | - repo: https://github.com/pycqa/isort 11 | rev: 5.12.0 12 | hooks: 13 | - id: isort 14 | name: sort all imports (python) 15 | 16 | - repo: https://github.com/psf/black-pre-commit-mirror 17 | rev: 23.9.1 18 | hooks: 19 | - id: black 20 | name: black formatter 21 | args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] 22 | 23 | - repo: https://github.com/pre-commit/mirrors-clang-format 24 | rev: v13.0.1 25 | hooks: 26 | - id: clang-format 27 | name: clang formatter 28 | types_or: [c++, c] 29 | 30 | - repo: https://github.com/pre-commit/pre-commit-hooks 31 | rev: v4.3.0 32 | hooks: 33 | - id: check-yaml 34 | - id: check-merge-conflict 35 | - id: check-case-conflict 36 | - id: trailing-whitespace 37 | - id: end-of-file-fixer 38 | - id: mixed-line-ending 39 | args: ['--fix=lf'] 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Coding Standards 2 | 3 | ### Unit Tests 4 | We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests. 5 | 6 | To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run 7 | ```bash 8 | pip install -r requirements/requirements-test.txt 9 | ``` 10 | If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again. 11 | 12 | If you only want to run CPU tests, you can run 13 | 14 | ```bash 15 | pytest -m cpu tests/ 16 | ``` 17 | 18 | If you have 8 GPUs on your machine, you can run the full test 19 | 20 | ```bash 21 | pytest tests/ 22 | ``` 23 | 24 | If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch. 25 | 26 | 27 | ### Code Style 28 | 29 | We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. 30 | 31 | ```shell 32 | # these commands are executed under the Colossal-AI directory 33 | pip install pre-commit 34 | pre-commit install 35 | ``` 36 | 37 | Code format checking will be automatically executed when you commit your changes. 38 | -------------------------------------------------------------------------------- /TeaCache4CogVideoX1.5/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4CogVideoX1.5 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [CogVideoX1.5](https://github.com/THUDM/CogVideo) 1.8x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-CogVideoX1.5 with various `rel_l1_thresh` values: 0 (original), 0.1 (1.3x speedup), 0.2 (1.8x speedup), and 0.3(2.1x speedup).Additionally, the image-to-video (i2v) results are also demonstrated, with the following speedups: 0.1 (1.5x speedup), 0.2 (2.2x speedup), and 0.3 (2.7x speedup). 5 | 6 | https://github.com/user-attachments/assets/21261b03-71c6-47bf-9769-2a81c8dc452f 7 | 8 | https://github.com/user-attachments/assets/5e98e646-4034-4ae7-9680-a65ecd88dac9 9 | 10 | ## 📈 Inference Latency Comparisons on a Single H100 GPU 11 | 12 | | CogVideoX1.5-t2v | TeaCache (0.1) | TeaCache (0.2) | TeaCache (0.3) | 13 | | :--------------: | :------------: | :------------: | :------------: | 14 | | ~465 s | ~322 s | ~260 s | ~204 s | 15 | 16 | | CogVideoX1.5-i2v | TeaCache (0.1) | TeaCache (0.2) | TeaCache (0.3) | 17 | | :--------------: | :------------: | :------------: | :------------: | 18 | | ~475 s | ~316 s | ~239 s | ~204 s | 19 | 20 | ## Installation 21 | 22 | ```shell 23 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio imageio-ffmpeg 24 | ``` 25 | 26 | ## Usage 27 | 28 | You can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image_path` to customize your identity-preserving video. 29 | 30 | For T2V inference, you can use the following command: 31 | 32 | ```bash 33 | cd TeaCache4CogVideoX1.5 34 | 35 | python3 teacache_sample_video.py \ 36 | --rel_l1_thresh 0.2 \ 37 | --ckpts_path THUDM/CogVideoX1.5-5B \ 38 | --prompt "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom. The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility." \ 39 | --seed 42 \ 40 | --num_inference_steps 50 \ 41 | --output_path ./teacache_results 42 | ``` 43 | 44 | For I2V inference, you can use the following command: 45 | 46 | ```bash 47 | cd TeaCache4CogVideoX1.5 48 | 49 | python3 teacache_sample_video.py \ 50 | --rel_l1_thresh 0.1 \ 51 | --ckpts_path THUDM/CogVideoX1.5-5B-I2V \ 52 | --prompt "A girl gazed at the camera and smiled, her hair drifting in the wind." \ 53 | --seed 42 \ 54 | --num_inference_steps 50 \ 55 | --output_path ./teacache_results \ 56 | --image_path ./image/path \ 57 | ``` 58 | 59 | ## Citation 60 | 61 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 62 | 63 | ``` 64 | @article{liu2024timestep, 65 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 66 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 67 | journal={arXiv preprint arXiv:2411.19108}, 68 | year={2024} 69 | } 70 | ``` 71 | 72 | 73 | ## Acknowledgements 74 | 75 | We would like to thank the contributors to the [CogVideoX](https://github.com/THUDM/CogVideo) and [Diffusers](https://github.com/huggingface/diffusers). 76 | -------------------------------------------------------------------------------- /TeaCache4ConsisID/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4ConsisID 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 2.1x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-ConsisID with various `rel_l1_thresh` values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup), and 0.2 (2.7x speedup). 5 | 6 | https://github.com/user-attachments/assets/501d71ef-0e71-4ae9-bceb-51cc18fa33d8 7 | 8 | ## 📈 Inference Latency Comparisons on a Single H100 GPU 9 | 10 | | ConsisID | TeaCache (0.1) | TeaCache (0.15) | TeaCache (0.20) | 11 | | :------: | :------------: | :-------------: | :-------------: | 12 | | ~110 s | ~70 s | ~53 s | ~41 s | 13 | 14 | 15 | ## Usage 16 | 17 | Follow [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image` to customize your identity-preserving video. 18 | 19 | For single-gpu inference, you can use the following command: 20 | 21 | ```bash 22 | cd TeaCache4ConsisID 23 | 24 | python3 teacache_sample_video.py \ 25 | --rel_l1_thresh 0.1 \ 26 | --ckpts_path BestWishYsh/ConsisID-preview \ 27 | --image "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ 28 | --prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ 29 | --seed 42 \ 30 | --num_infer_steps 50 \ 31 | --output_path ./teacache_results 32 | ``` 33 | 34 | To generate a video with 8 GPUs, you can following [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/tools). 35 | 36 | ## Resources 37 | 38 | Learn more about ConsisID with the following resources. 39 | - A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features. 40 | - The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details. 41 | 42 | ## Citation 43 | 44 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 45 | 46 | ``` 47 | @article{liu2024timestep, 48 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 49 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 50 | journal={arXiv preprint arXiv:2411.19108}, 51 | year={2024} 52 | } 53 | ``` 54 | 55 | 56 | ## Acknowledgements 57 | 58 | We would like to thank the contributors to the [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). 59 | -------------------------------------------------------------------------------- /TeaCache4Cosmos/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4Cosmos 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Cosmos](https://github.com/NVIDIA/Cosmos) 2.0x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Cosmos with various `rel_l1_thresh` values: 0 (original), 0.3 (1.4x speedup), and 0.4(2.0x speedup). 5 | 6 | https://github.com/user-attachments/assets/28570179-0f22-42ee-8958-88bb48d209b4 7 | 8 | https://github.com/user-attachments/assets/21341bd4-c0d5-4b5a-8b7d-cb2103897d2c 9 | 10 | ## 📈 Inference Latency Comparisons on a Single H800 GPU 11 | 12 | | Cosmos-t2v | TeaCache (0.3) | TeaCache (0.4) | 13 | | :--------: | :------------: | :------------: | 14 | | ~449 s | ~327 s | ~227 s | 15 | 16 | | Cosmos-i2v | TeaCache (0.3) | TeaCache (0.4) | 17 | | :--------: | :------------: | :------------: | 18 | | ~453 s | ~330 s | ~229 s | 19 | 20 | ## Usage 21 | 22 | Follow [Cosmos](https://github.com/NVIDIA/Cosmos) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `checkpoint_dir`, `prompt`, `input_image_or_video_path` to customize your video. 23 | 24 | You need to copy the running script to the Cosmos project folder: 25 | 26 | ```bash 27 | cd TeaCache4Cosmos 28 | cp *.py /path/to/Cosmos/ 29 | ``` 30 | 31 | For T2V inference, you can use the following command: 32 | 33 | ```bash 34 | cd /path/to/Cosmos/ 35 | 36 | python3 teacache_sample_video_t2v.py \ 37 | --rel_l1_thresh 0.3 \ 38 | --checkpoint_dir checkpoints \ 39 | --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \ 40 | --prompt Inside the cozy ambiance of a bustling coffee house, a young woman with wavy chestnut hair and wearing a rust-colored cozy sweater stands amid the chatter and clinking of cups. She smiles warmly at the camera, her green eyes glinting with joy and subtle hints of laughter. The camera frames her elegantly, emphasizing the soft glow of the lighting on her smooth, clear skin and the detailed textures of her woolen attire. Her genuine smile is the centerpiece of the shot, showcasing her enjoyment in the quaint café setting, with steaming mugs and blurred patrons in the background. \ 41 | --disable_prompt_upsampler \ 42 | --offload_prompt_upsampler 43 | ``` 44 | 45 | For I2V inference, you can use the following command: 46 | 47 | ```bash 48 | cd /path/to/Cosmos/ 49 | 50 | python3 teacache_sample_video_i2v.py \ 51 | --rel_l1_thresh 0.4 \ 52 | --checkpoint_dir checkpoints \ 53 | --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \ 54 | --prompt "A girl gazed at the camera and smiled, her hair drifting in the wind." \ 55 | --input_image_or_video_path image/path \ 56 | --num_input_frames 1 \ 57 | --disable_prompt_upsampler \ 58 | --offload_prompt_upsampler 59 | ``` 60 | 61 | ## Citation 62 | 63 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 64 | 65 | ``` 66 | @article{liu2024timestep, 67 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 68 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 69 | journal={arXiv preprint arXiv:2411.19108}, 70 | year={2024} 71 | } 72 | ``` 73 | 74 | 75 | ## Acknowledgements 76 | 77 | We would like to thank the contributors to the [Cosmos](https://github.com/NVIDIA/Cosmos). 78 | -------------------------------------------------------------------------------- /TeaCache4FLUX/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4FLUX 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [FLUX](https://github.com/black-forest-labs/flux) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-FLUX with various `rel_l1_thresh` values: 0 (original), 0.25 (1.5x speedup), 0.4 (1.8x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup). 5 | 6 | ![visualization](../assets/TeaCache4FLUX.png) 7 | 8 | ## 📈 Inference Latency Comparisons on a Single A800 9 | 10 | 11 | | FLUX.1 [dev] | TeaCache (0.25) | TeaCache (0.4) | TeaCache (0.6) | TeaCache (0.8) | 12 | |:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:---------------------:| 13 | | ~18 s | ~12 s | ~10 s | ~9 s | ~8 s | 14 | 15 | ## Installation 16 | 17 | ```shell 18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece 19 | ``` 20 | 21 | ## Usage 22 | 23 | You can modify the `rel_l1_thresh` in line 320 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 24 | 25 | ```bash 26 | python teacache_flux.py 27 | ``` 28 | 29 | ## Citation 30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 31 | 32 | ``` 33 | @article{liu2024timestep, 34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 36 | journal={arXiv preprint arXiv:2411.19108}, 37 | year={2024} 38 | } 39 | ``` 40 | 41 | ## Acknowledgements 42 | 43 | We would like to thank the contributors to the [FLUX](https://github.com/black-forest-labs/flux) and [Diffusers](https://github.com/huggingface/diffusers). -------------------------------------------------------------------------------- /TeaCache4HiDream-I1/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4HiDream-I1 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-HiDream-I1-Full with various `rel_l1_thresh` values: 0 (original), 0.17 (1.5x speedup), 0.25 (1.7x speedup), 0.3 (2.0x speedup), and 0.45 (2.6x speedup). 5 | 6 | ![visualization](../assets/TeaCache4HiDream-I1.png) 7 | 8 | ## 📈 Inference Latency Comparisons on a Single A100 9 | 10 | | HiDream-I1-Full | TeaCache (0.17) | TeaCache (0.25) | TeaCache (0.3) | TeaCache (0.45) | 11 | |:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:--------------------:| 12 | | ~50 s | ~34 s | ~29 s | ~25 s | ~19 s | 13 | 14 | ## Installation 15 | 16 | ```shell 17 | pip install git+https://github.com/huggingface/diffusers 18 | pip install --upgrade transformers protobuf tiktoken tokenizers sentencepiece 19 | ``` 20 | 21 | ## Usage 22 | 23 | You can modify the `rel_l1_thresh` in line 297 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 24 | 25 | ```bash 26 | python teacache_hidream_i1.py 27 | ``` 28 | 29 | ## Citation 30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 31 | 32 | ``` 33 | @article{liu2024timestep, 34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 36 | journal={arXiv preprint arXiv:2411.19108}, 37 | year={2024} 38 | } 39 | ``` 40 | 41 | ## Acknowledgements 42 | 43 | We would like to thank the contributors to the [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Diffusers](https://github.com/huggingface/diffusers). -------------------------------------------------------------------------------- /TeaCache4HunyuanVideo/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4HunyuanVideo 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-HunyuanVideo with various `rel_l1_thresh` values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup). 5 | 6 | https://github.com/user-attachments/assets/34b5dab0-5b0f-48a0-968d-88af18b84803 7 | 8 | 9 | ## 📈 Inference Latency Comparisons on a Single A800 GPU 10 | 11 | 12 | | Resolution | HunyuanVideo | TeaCache (0.1) | TeaCache (0.15) | 13 | |:---------------------:|:-------------------------:|:--------------------:|:----------------------:| 14 | | 540p | ~18 min | ~11 min | ~8 min | 15 | | 720p | ~50 min | ~30 min | ~23 min | 16 | 17 | 18 | ## Usage 19 | 20 | Follow [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) to clone the repo and finish the installation, then copy 'teacache_sample_video.py' in this repo to the HunyuanVideo repo. You can modify the '`rel_l1_thresh`' in line 220 to obtain your desired trade-off between latency and visul quality. 21 | 22 | For single-gpu inference, you can use the following command: 23 | 24 | ```bash 25 | cd HunyuanVideo 26 | 27 | python3 teacache_sample_video.py \ 28 | --video-size 720 1280 \ 29 | --video-length 129 \ 30 | --infer-steps 50 \ 31 | --prompt "A cat walks on the grass, realistic style." \ 32 | --flow-reverse \ 33 | --use-cpu-offload \ 34 | --save-path ./teacache_results 35 | ``` 36 | 37 | To generate a video with 8 GPUs, you can use the following command: 38 | 39 | ```bash 40 | cd HunyuanVideo 41 | 42 | torchrun --nproc_per_node=8 teacache_sample_video.py \ 43 | --video-size 1280 720 \ 44 | --video-length 129 \ 45 | --infer-steps 50 \ 46 | --prompt "A cat walks on the grass, realistic style." \ 47 | --flow-reverse \ 48 | --seed 42 \ 49 | --ulysses-degree 8 \ 50 | --ring-degree 1 \ 51 | --save-path ./teacache_results 52 | ``` 53 | 54 | For FP8 inference, you must explicitly specify the FP8 weight path. For example, to generate a video with fp8 weights, you can use the following command: 55 | 56 | ```bash 57 | cd HunyuanVideo 58 | 59 | DIT_CKPT_PATH={PATH_TO_FP8_WEIGHTS}/{WEIGHT_NAME}_fp8.pt 60 | 61 | python3 teacache_sample_video.py \ 62 | --dit-weight ${DIT_CKPT_PATH} \ 63 | --video-size 1280 720 \ 64 | --video-length 129 \ 65 | --infer-steps 50 \ 66 | --prompt "A cat walks on the grass, realistic style." \ 67 | --seed 42 \ 68 | --embedded-cfg-scale 6.0 \ 69 | --flow-shift 7.0 \ 70 | --flow-reverse \ 71 | --use-cpu-offload \ 72 | --use-fp8 \ 73 | --save-path ./teacache_fp8_results 74 | ``` 75 | 76 | ## Citation 77 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 78 | 79 | ``` 80 | @article{liu2024timestep, 81 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 82 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 83 | journal={arXiv preprint arXiv:2411.19108}, 84 | year={2024} 85 | } 86 | ``` 87 | 88 | 89 | ## Acknowledgements 90 | 91 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo). 92 | -------------------------------------------------------------------------------- /TeaCache4HunyuanVideo/teacache_sample_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | from loguru import logger 5 | from datetime import datetime 6 | 7 | from hyvideo.utils.file_utils import save_videos_grid 8 | from hyvideo.config import parse_args 9 | from hyvideo.inference import HunyuanVideoSampler 10 | 11 | from hyvideo.modules.modulate_layers import modulate 12 | from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens 13 | from typing import Any, List, Tuple, Optional, Union, Dict 14 | import torch 15 | import json 16 | import numpy as np 17 | 18 | 19 | 20 | 21 | 22 | def teacache_forward( 23 | self, 24 | x: torch.Tensor, 25 | t: torch.Tensor, # Should be in range(0, 1000). 26 | text_states: torch.Tensor = None, 27 | text_mask: torch.Tensor = None, # Now we don't use it. 28 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. 29 | freqs_cos: Optional[torch.Tensor] = None, 30 | freqs_sin: Optional[torch.Tensor] = None, 31 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. 32 | return_dict: bool = True, 33 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 34 | out = {} 35 | img = x 36 | txt = text_states 37 | _, _, ot, oh, ow = x.shape 38 | tt, th, tw = ( 39 | ot // self.patch_size[0], 40 | oh // self.patch_size[1], 41 | ow // self.patch_size[2], 42 | ) 43 | 44 | # Prepare modulation vectors. 45 | vec = self.time_in(t) 46 | 47 | # text modulation 48 | vec = vec + self.vector_in(text_states_2) 49 | 50 | # guidance modulation 51 | if self.guidance_embed: 52 | if guidance is None: 53 | raise ValueError( 54 | "Didn't get guidance strength for guidance distilled model." 55 | ) 56 | 57 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder) 58 | vec = vec + self.guidance_in(guidance) 59 | 60 | # Embed image and text. 61 | img = self.img_in(img) 62 | if self.text_projection == "linear": 63 | txt = self.txt_in(txt) 64 | elif self.text_projection == "single_refiner": 65 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) 66 | else: 67 | raise NotImplementedError( 68 | f"Unsupported text_projection: {self.text_projection}" 69 | ) 70 | 71 | txt_seq_len = txt.shape[1] 72 | img_seq_len = img.shape[1] 73 | 74 | # Compute cu_squlens and max_seqlen for flash attention 75 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) 76 | cu_seqlens_kv = cu_seqlens_q 77 | max_seqlen_q = img_seq_len + txt_seq_len 78 | max_seqlen_kv = max_seqlen_q 79 | 80 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None 81 | 82 | if self.enable_teacache: 83 | inp = img.clone() 84 | vec_ = vec.clone() 85 | txt_ = txt.clone() 86 | ( 87 | img_mod1_shift, 88 | img_mod1_scale, 89 | img_mod1_gate, 90 | img_mod2_shift, 91 | img_mod2_scale, 92 | img_mod2_gate, 93 | ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) 94 | normed_inp = self.double_blocks[0].img_norm1(inp) 95 | modulated_inp = modulate( 96 | normed_inp, shift=img_mod1_shift, scale=img_mod1_scale 97 | ) 98 | if self.cnt == 0 or self.cnt == self.num_steps-1: 99 | should_calc = True 100 | self.accumulated_rel_l1_distance = 0 101 | else: 102 | coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] 103 | rescale_func = np.poly1d(coefficients) 104 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 105 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 106 | should_calc = False 107 | else: 108 | should_calc = True 109 | self.accumulated_rel_l1_distance = 0 110 | self.previous_modulated_input = modulated_inp 111 | self.cnt += 1 112 | if self.cnt == self.num_steps: 113 | self.cnt = 0 114 | 115 | if self.enable_teacache: 116 | if not should_calc: 117 | img += self.previous_residual 118 | else: 119 | ori_img = img.clone() 120 | # --------------------- Pass through DiT blocks ------------------------ 121 | for _, block in enumerate(self.double_blocks): 122 | double_block_args = [ 123 | img, 124 | txt, 125 | vec, 126 | cu_seqlens_q, 127 | cu_seqlens_kv, 128 | max_seqlen_q, 129 | max_seqlen_kv, 130 | freqs_cis, 131 | ] 132 | 133 | img, txt = block(*double_block_args) 134 | 135 | # Merge txt and img to pass through single stream blocks. 136 | x = torch.cat((img, txt), 1) 137 | if len(self.single_blocks) > 0: 138 | for _, block in enumerate(self.single_blocks): 139 | single_block_args = [ 140 | x, 141 | vec, 142 | txt_seq_len, 143 | cu_seqlens_q, 144 | cu_seqlens_kv, 145 | max_seqlen_q, 146 | max_seqlen_kv, 147 | (freqs_cos, freqs_sin), 148 | ] 149 | 150 | x = block(*single_block_args) 151 | 152 | img = x[:, :img_seq_len, ...] 153 | self.previous_residual = img - ori_img 154 | else: 155 | # --------------------- Pass through DiT blocks ------------------------ 156 | for _, block in enumerate(self.double_blocks): 157 | double_block_args = [ 158 | img, 159 | txt, 160 | vec, 161 | cu_seqlens_q, 162 | cu_seqlens_kv, 163 | max_seqlen_q, 164 | max_seqlen_kv, 165 | freqs_cis, 166 | ] 167 | 168 | img, txt = block(*double_block_args) 169 | 170 | # Merge txt and img to pass through single stream blocks. 171 | x = torch.cat((img, txt), 1) 172 | if len(self.single_blocks) > 0: 173 | for _, block in enumerate(self.single_blocks): 174 | single_block_args = [ 175 | x, 176 | vec, 177 | txt_seq_len, 178 | cu_seqlens_q, 179 | cu_seqlens_kv, 180 | max_seqlen_q, 181 | max_seqlen_kv, 182 | (freqs_cos, freqs_sin), 183 | ] 184 | 185 | x = block(*single_block_args) 186 | 187 | img = x[:, :img_seq_len, ...] 188 | 189 | # ---------------------------- Final layer ------------------------------ 190 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 191 | 192 | img = self.unpatchify(img, tt, th, tw) 193 | if return_dict: 194 | out["x"] = img 195 | return out 196 | return img 197 | 198 | 199 | def main(): 200 | args = parse_args() 201 | print(args) 202 | models_root_path = Path(args.model_base) 203 | if not models_root_path.exists(): 204 | raise ValueError(f"`models_root` not exists: {models_root_path}") 205 | 206 | # Create save folder to save the samples 207 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' 208 | if not os.path.exists(args.save_path): 209 | os.makedirs(save_path, exist_ok=True) 210 | 211 | # Load models 212 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args) 213 | 214 | # Get the updated args 215 | args = hunyuan_video_sampler.args 216 | 217 | 218 | # TeaCache 219 | hunyuan_video_sampler.pipeline.transformer.__class__.enable_teacache = True 220 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0 221 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps 222 | hunyuan_video_sampler.pipeline.transformer.__class__.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup 223 | hunyuan_video_sampler.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0 224 | hunyuan_video_sampler.pipeline.transformer.__class__.previous_modulated_input = None 225 | hunyuan_video_sampler.pipeline.transformer.__class__.previous_residual = None 226 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = teacache_forward 227 | 228 | # Start sampling 229 | # TODO: batch inference check 230 | outputs = hunyuan_video_sampler.predict( 231 | prompt=args.prompt, 232 | height=args.video_size[0], 233 | width=args.video_size[1], 234 | video_length=args.video_length, 235 | seed=args.seed, 236 | negative_prompt=args.neg_prompt, 237 | infer_steps=args.infer_steps, 238 | guidance_scale=args.cfg_scale, 239 | num_videos_per_prompt=args.num_videos, 240 | flow_shift=args.flow_shift, 241 | batch_size=args.batch_size, 242 | embedded_guidance_scale=args.embedded_cfg_scale 243 | ) 244 | samples = outputs['samples'] 245 | 246 | # Save samples 247 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 248 | for i, sample in enumerate(samples): 249 | sample = samples[i].unsqueeze(0) 250 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") 251 | save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" 252 | save_videos_grid(sample, save_path, fps=24) 253 | logger.info(f'Sample save to: {save_path}') 254 | 255 | if __name__ == "__main__": 256 | main() 257 | -------------------------------------------------------------------------------- /TeaCache4LTX-Video/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4LTX-Video 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [LTX-Video](https://github.com/Lightricks/LTX-Video) 2x without much visual quality degradation, in a training-free manner. The following video presents the videos generated by TeaCache-LTX-Video with various rel_l1_thresh values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup). 5 | 6 | https://github.com/user-attachments/assets/1f4cf26c-b8c6-45e3-b402-840bcd6ba00e 7 | 8 | ## 📈 Inference Latency Comparisons on a Single A800 9 | 10 | 11 | | LTX-Video-0.9.1 | TeaCache (0.03) | TeaCache (0.05) | 12 | |:--------------------------:|:----------------------------:|:---------------------:| 13 | | ~32 s | ~20 s | ~16 s | 14 | 15 | ## Installation 16 | 17 | ```shell 18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio 19 | ``` 20 | 21 | ## Usage 22 | 23 | You can modify the thresh in line 187 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 24 | 25 | ```bash 26 | python teacache_ltx.py 27 | ``` 28 | 29 | ## Citation 30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 31 | 32 | ``` 33 | @article{liu2024timestep, 34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 36 | journal={arXiv preprint arXiv:2411.19108}, 37 | year={2024} 38 | } 39 | ``` 40 | 41 | ## Acknowledgements 42 | 43 | We would like to thank the contributors to the [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Diffusers](https://github.com/huggingface/diffusers). -------------------------------------------------------------------------------- /TeaCache4LTX-Video/teacache_ltx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import LTXPipeline 3 | from diffusers.models.transformers import LTXVideoTransformer3DModel 4 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 5 | from diffusers.utils import export_to_video 6 | from typing import Any, Dict, Optional, Tuple 7 | import numpy as np 8 | 9 | 10 | def teacache_forward( 11 | self, 12 | hidden_states: torch.Tensor, 13 | encoder_hidden_states: torch.Tensor, 14 | timestep: torch.LongTensor, 15 | encoder_attention_mask: torch.Tensor, 16 | num_frames: int, 17 | height: int, 18 | width: int, 19 | rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, 20 | attention_kwargs: Optional[Dict[str, Any]] = None, 21 | return_dict: bool = True, 22 | ) -> torch.Tensor: 23 | if attention_kwargs is not None: 24 | attention_kwargs = attention_kwargs.copy() 25 | lora_scale = attention_kwargs.pop("scale", 1.0) 26 | else: 27 | lora_scale = 1.0 28 | 29 | if USE_PEFT_BACKEND: 30 | # weight the lora layers by setting `lora_scale` for each PEFT layer 31 | scale_lora_layers(self, lora_scale) 32 | else: 33 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 34 | logger.warning( 35 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 36 | ) 37 | 38 | image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) 39 | 40 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 41 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 42 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 43 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 44 | 45 | batch_size = hidden_states.size(0) 46 | hidden_states = self.proj_in(hidden_states) 47 | 48 | temb, embedded_timestep = self.time_embed( 49 | timestep.flatten(), 50 | batch_size=batch_size, 51 | hidden_dtype=hidden_states.dtype, 52 | ) 53 | 54 | temb = temb.view(batch_size, -1, temb.size(-1)) 55 | embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) 56 | 57 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 58 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) 59 | 60 | if self.enable_teacache: 61 | inp = hidden_states.clone() 62 | temb_ = temb.clone() 63 | inp = self.transformer_blocks[0].norm1(inp) 64 | num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] 65 | ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1) 66 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) 67 | modulated_inp = inp * (1 + scale_msa) + shift_msa 68 | if self.cnt == 0 or self.cnt == self.num_steps-1: 69 | should_calc = True 70 | self.accumulated_rel_l1_distance = 0 71 | else: 72 | coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03] 73 | rescale_func = np.poly1d(coefficients) 74 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 75 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 76 | should_calc = False 77 | else: 78 | should_calc = True 79 | self.accumulated_rel_l1_distance = 0 80 | self.previous_modulated_input = modulated_inp 81 | self.cnt += 1 82 | if self.cnt == self.num_steps: 83 | self.cnt = 0 84 | 85 | if self.enable_teacache: 86 | if not should_calc: 87 | hidden_states += self.previous_residual 88 | else: 89 | ori_hidden_states = hidden_states.clone() 90 | for block in self.transformer_blocks: 91 | if torch.is_grad_enabled() and self.gradient_checkpointing: 92 | 93 | def create_custom_forward(module, return_dict=None): 94 | def custom_forward(*inputs): 95 | if return_dict is not None: 96 | return module(*inputs, return_dict=return_dict) 97 | else: 98 | return module(*inputs) 99 | 100 | return custom_forward 101 | 102 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 103 | hidden_states = torch.utils.checkpoint.checkpoint( 104 | create_custom_forward(block), 105 | hidden_states, 106 | encoder_hidden_states, 107 | temb, 108 | image_rotary_emb, 109 | encoder_attention_mask, 110 | **ckpt_kwargs, 111 | ) 112 | else: 113 | hidden_states = block( 114 | hidden_states=hidden_states, 115 | encoder_hidden_states=encoder_hidden_states, 116 | temb=temb, 117 | image_rotary_emb=image_rotary_emb, 118 | encoder_attention_mask=encoder_attention_mask, 119 | ) 120 | 121 | scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] 122 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 123 | 124 | hidden_states = self.norm_out(hidden_states) 125 | hidden_states = hidden_states * (1 + scale) + shift 126 | self.previous_residual = hidden_states - ori_hidden_states 127 | else: 128 | for block in self.transformer_blocks: 129 | if torch.is_grad_enabled() and self.gradient_checkpointing: 130 | 131 | def create_custom_forward(module, return_dict=None): 132 | def custom_forward(*inputs): 133 | if return_dict is not None: 134 | return module(*inputs, return_dict=return_dict) 135 | else: 136 | return module(*inputs) 137 | 138 | return custom_forward 139 | 140 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 141 | hidden_states = torch.utils.checkpoint.checkpoint( 142 | create_custom_forward(block), 143 | hidden_states, 144 | encoder_hidden_states, 145 | temb, 146 | image_rotary_emb, 147 | encoder_attention_mask, 148 | **ckpt_kwargs, 149 | ) 150 | else: 151 | hidden_states = block( 152 | hidden_states=hidden_states, 153 | encoder_hidden_states=encoder_hidden_states, 154 | temb=temb, 155 | image_rotary_emb=image_rotary_emb, 156 | encoder_attention_mask=encoder_attention_mask, 157 | ) 158 | 159 | scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] 160 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] 161 | 162 | hidden_states = self.norm_out(hidden_states) 163 | hidden_states = hidden_states * (1 + scale) + shift 164 | 165 | 166 | output = self.proj_out(hidden_states) 167 | 168 | if USE_PEFT_BACKEND: 169 | # remove `lora_scale` from each PEFT layer 170 | unscale_lora_layers(self, lora_scale) 171 | 172 | if not return_dict: 173 | return (output,) 174 | return Transformer2DModelOutput(sample=output) 175 | 176 | LTXVideoTransformer3DModel.forward = teacache_forward 177 | prompt = "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility." 178 | negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" 179 | seed = 42 180 | num_inference_steps = 50 181 | pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16) 182 | 183 | # TeaCache 184 | pipe.transformer.__class__.enable_teacache = True 185 | pipe.transformer.__class__.cnt = 0 186 | pipe.transformer.__class__.num_steps = num_inference_steps 187 | pipe.transformer.__class__.rel_l1_thresh = 0.05 # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup 188 | pipe.transformer.__class__.accumulated_rel_l1_distance = 0 189 | pipe.transformer.__class__.previous_modulated_input = None 190 | pipe.transformer.__class__.previous_residual = None 191 | 192 | pipe.to("cuda") 193 | video = pipe( 194 | prompt=prompt, 195 | negative_prompt=negative_prompt, 196 | width=768, 197 | height=512, 198 | num_frames=161, 199 | decode_timestep=0.03, 200 | decode_noise_scale=0.025, 201 | num_inference_steps=num_inference_steps, 202 | generator=torch.Generator("cuda").manual_seed(seed) 203 | ).frames[0] 204 | export_to_video(video, "teacache_ltx_{}.mp4".format(prompt[:50]), fps=24) -------------------------------------------------------------------------------- /TeaCache4Lumina-T2X/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4LuminaT2X 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Next with various rel_l1_thresh values: 0 (original), 0.2 (1.5x speedup), 0.3 (1.9x speedup), 0.4 (2.4x speedup), and 0.5 (2.8x speedup). 5 | 6 | ![visualization](../assets/TeaCache4LuminaT2X.png) 7 | 8 | ## 📈 Inference Latency Comparisons on a Single A800 9 | 10 | 11 | | Lumina-Next-SFT | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) | 12 | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:| 13 | | ~17 s | ~11 s | ~9 s | ~7 s | ~6 s | 14 | 15 | ## Installation 16 | 17 | ```shell 18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece 19 | pip install flash-attn --no-build-isolation 20 | ``` 21 | 22 | ## Usage 23 | 24 | You can modify the thresh in line 113 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 25 | 26 | ```bash 27 | python teacache_lumina_next.py 28 | ``` 29 | 30 | ## Citation 31 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 32 | 33 | ``` 34 | @article{liu2024timestep, 35 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 36 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 37 | journal={arXiv preprint arXiv:2411.19108}, 38 | year={2024} 39 | } 40 | ``` 41 | 42 | ## Acknowledgements 43 | 44 | We would like to thank the contributors to the [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) and [Diffusers](https://github.com/huggingface/diffusers). -------------------------------------------------------------------------------- /TeaCache4Lumina-T2X/teacache_lumina_next.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Optional, Tuple, Union 3 | from diffusers import LuminaText2ImgPipeline 4 | from diffusers.models import LuminaNextDiT2DModel 5 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 6 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 7 | import numpy as np 8 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 9 | 10 | def teacache_forward( 11 | self, 12 | hidden_states: torch.Tensor, 13 | timestep: torch.Tensor, 14 | encoder_hidden_states: torch.Tensor, 15 | encoder_mask: torch.Tensor, 16 | image_rotary_emb: torch.Tensor, 17 | cross_attention_kwargs: Dict[str, Any] = None, 18 | return_dict=True, 19 | ) -> torch.Tensor: 20 | """ 21 | Forward pass of LuminaNextDiT. 22 | 23 | Parameters: 24 | hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). 25 | timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). 26 | encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). 27 | encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). 28 | """ 29 | hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb) 30 | image_rotary_emb = image_rotary_emb.to(hidden_states.device) 31 | 32 | temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask) 33 | 34 | encoder_mask = encoder_mask.bool() 35 | if self.enable_teacache: 36 | inp = hidden_states.clone() 37 | temb_ = temb.clone() 38 | modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_) 39 | if self.cnt == 0 or self.cnt == self.num_steps-1: 40 | should_calc = True 41 | self.accumulated_rel_l1_distance = 0 42 | else: 43 | coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] 44 | rescale_func = np.poly1d(coefficients) 45 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 46 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 47 | should_calc = False 48 | else: 49 | should_calc = True 50 | self.accumulated_rel_l1_distance = 0 51 | self.previous_modulated_input = modulated_inp 52 | self.cnt += 1 53 | if self.cnt == self.num_steps: 54 | self.cnt = 0 55 | 56 | if self.enable_teacache: 57 | if not should_calc: 58 | hidden_states += self.previous_residual 59 | else: 60 | ori_hidden_states = hidden_states.clone() 61 | for layer in self.layers: 62 | hidden_states = layer( 63 | hidden_states, 64 | mask, 65 | image_rotary_emb, 66 | encoder_hidden_states, 67 | encoder_mask, 68 | temb=temb, 69 | cross_attention_kwargs=cross_attention_kwargs, 70 | ) 71 | self.previous_residual = hidden_states - ori_hidden_states 72 | 73 | else: 74 | for layer in self.layers: 75 | hidden_states = layer( 76 | hidden_states, 77 | mask, 78 | image_rotary_emb, 79 | encoder_hidden_states, 80 | encoder_mask, 81 | temb=temb, 82 | cross_attention_kwargs=cross_attention_kwargs, 83 | ) 84 | 85 | hidden_states = self.norm_out(hidden_states, temb) 86 | 87 | # unpatchify 88 | height_tokens = width_tokens = self.patch_size 89 | height, width = img_size[0] 90 | batch_size = hidden_states.size(0) 91 | sequence_length = (height // height_tokens) * (width // width_tokens) 92 | hidden_states = hidden_states[:, :sequence_length].view( 93 | batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels 94 | ) 95 | output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) 96 | 97 | if not return_dict: 98 | return (output,) 99 | 100 | return Transformer2DModelOutput(sample=output) 101 | 102 | 103 | LuminaNextDiT2DModel.forward = teacache_forward 104 | num_inference_steps = 30 105 | seed = 1024 106 | prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. " 107 | pipeline = LuminaText2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16).to("cuda") 108 | 109 | # TeaCache 110 | pipeline.transformer.__class__.enable_teacache = True 111 | pipeline.transformer.__class__.cnt = 0 112 | pipeline.transformer.__class__.num_steps = num_inference_steps 113 | pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup 114 | pipeline.transformer.__class__.accumulated_rel_l1_distance = 0 115 | pipeline.transformer.__class__.previous_modulated_input = None 116 | pipeline.transformer.__class__.previous_residual = None 117 | 118 | image = pipeline( 119 | prompt=prompt, 120 | num_inference_steps=num_inference_steps, 121 | generator=torch.Generator("cpu").manual_seed(seed) 122 | ).images[0] 123 | image.save("teacache_lumina_{}.png".format(prompt)) -------------------------------------------------------------------------------- /TeaCache4Lumina2/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4Lumina2 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Image-2.0 with various rel_l1_thresh values: 0 (original), 0.2 (1.25x speedup), 0.3 (1.5625x speedup), 0.4 (2.0833x speedup), 0.5 (2.5x speedup). 5 | 6 |

7 | 8 | 9 | 10 | 11 | 12 |

13 | 14 | ## 📈 Inference Latency Comparisons on a single 4090 (step 50) 15 | 16 | 17 | | Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) | 18 | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:| 19 | | ~25 s | ~20 s | ~16 s | ~12 s | ~10 s | 20 | 21 | ## Installation 22 | 23 | ```shell 24 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece 25 | pip install flash-attn --no-build-isolation 26 | ``` 27 | 28 | ## Usage 29 | 30 | You can modify the thresh in line 154 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 31 | 32 | ```bash 33 | python teacache_lumina2.py 34 | ``` 35 | 36 | ## Citation 37 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 38 | 39 | ``` 40 | @article{liu2024timestep, 41 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 42 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 43 | journal={arXiv preprint arXiv:2411.19108}, 44 | year={2024} 45 | } 46 | ``` 47 | 48 | ## Acknowledgements 49 | 50 | We would like to thank the contributors to the [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) and [Diffusers](https://github.com/huggingface/diffusers). 51 | -------------------------------------------------------------------------------- /TeaCache4Lumina2/teacache_lumina2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import Any, Dict, Optional, Tuple, Union, List 5 | 6 | from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline 7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 8 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 9 | 10 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 11 | 12 | def teacache_forward_working( 13 | self, 14 | hidden_states: torch.Tensor, 15 | timestep: torch.Tensor, 16 | encoder_hidden_states: torch.Tensor, 17 | encoder_attention_mask: torch.Tensor, 18 | attention_kwargs: Optional[Dict[str, Any]] = None, 19 | return_dict: bool = True, 20 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 21 | if attention_kwargs is not None: 22 | attention_kwargs = attention_kwargs.copy() 23 | lora_scale = attention_kwargs.pop("scale", 1.0) 24 | else: 25 | lora_scale = 1.0 26 | if USE_PEFT_BACKEND: 27 | scale_lora_layers(self, lora_scale) 28 | 29 | batch_size, _, height, width = hidden_states.shape 30 | temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) 31 | (image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb, 32 | encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask) 33 | image_patch_embeddings = self.x_embedder(image_patch_embeddings) 34 | for layer in self.context_refiner: 35 | encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb) 36 | for layer in self.noise_refiner: 37 | image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb) 38 | 39 | max_seq_len = max(seq_lengths) 40 | input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size) 41 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): 42 | input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len] 43 | input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i] 44 | 45 | use_mask = len(set(seq_lengths)) > 1 46 | attention_mask_for_main_loop_arg = None 47 | if use_mask: 48 | mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool) 49 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): 50 | mask[i, :seq_len_val] = True 51 | attention_mask_for_main_loop_arg = mask 52 | 53 | should_calc = True 54 | if self.enable_teacache: 55 | cache_key = max_seq_len 56 | if cache_key not in self.cache: 57 | self.cache[cache_key] = { 58 | "accumulated_rel_l1_distance": 0.0, 59 | "previous_modulated_input": None, 60 | "previous_residual": None, 61 | } 62 | 63 | current_cache = self.cache[cache_key] 64 | modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone()) 65 | 66 | if self.cnt == 0 or self.cnt == self.num_steps - 1: 67 | should_calc = True 68 | current_cache["accumulated_rel_l1_distance"] = 0.0 69 | else: 70 | if current_cache["previous_modulated_input"] is not None: 71 | coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py 72 | rescale_func = np.poly1d(coefficients) 73 | prev_mod_input = current_cache["previous_modulated_input"] 74 | prev_mean = prev_mod_input.abs().mean() 75 | 76 | if prev_mean.item() > 1e-9: 77 | rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item() 78 | else: 79 | rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf') 80 | 81 | current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change) 82 | 83 | if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh: 84 | should_calc = False 85 | else: 86 | should_calc = True 87 | current_cache["accumulated_rel_l1_distance"] = 0.0 88 | else: 89 | should_calc = True 90 | current_cache["accumulated_rel_l1_distance"] = 0.0 91 | 92 | current_cache["previous_modulated_input"] = modulated_inp.clone() 93 | 94 | if not hasattr(self, 'uncond_seq_len'): 95 | self.uncond_seq_len = cache_key 96 | if cache_key != self.uncond_seq_len: 97 | self.cnt += 1 98 | if self.cnt >= self.num_steps: 99 | self.cnt = 0 100 | 101 | if self.enable_teacache and not should_calc: 102 | processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"] 103 | else: 104 | ori_input = input_to_main_loop.clone() 105 | current_processing_states = input_to_main_loop 106 | for layer in self.layers: 107 | current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb) 108 | 109 | if self.enable_teacache: 110 | self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input 111 | processed_hidden_states = current_processing_states 112 | 113 | output_after_norm = self.norm_out(processed_hidden_states, temb) 114 | p = self.config.patch_size 115 | final_output_list = [] 116 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)): 117 | image_part = output_after_norm[i][enc_len:seq_len_val] 118 | h_p, w_p = height // p, width // p 119 | reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \ 120 | .permute(4, 0, 2, 1, 3) \ 121 | .flatten(3, 4) \ 122 | .flatten(1, 2) 123 | final_output_list.append(reconstructed_image) 124 | 125 | final_output_tensor = torch.stack(final_output_list, dim=0) 126 | 127 | if USE_PEFT_BACKEND: 128 | unscale_lora_layers(self, lora_scale) 129 | 130 | return Transformer2DModelOutput(sample=final_output_tensor) 131 | 132 | 133 | Lumina2Transformer2DModel.forward = teacache_forward_working 134 | 135 | ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth" 136 | transformer = Lumina2Transformer2DModel.from_single_file( 137 | ckpt_path, torch_dtype=torch.bfloat16 138 | ) 139 | pipeline = Lumina2Pipeline.from_pretrained( 140 | "Alpha-VLLM/Lumina-Image-2.0", 141 | transformer=transformer, 142 | torch_dtype=torch.bfloat16 143 | ).to("cuda") 144 | 145 | num_inference_steps = 30 146 | seed = 1024 147 | prompt = "a cat holding a sign that says hello" 148 | output_filename = f"teacache_lumina2_output.png" 149 | 150 | # TeaCache 151 | pipeline.transformer.__class__.enable_teacache = True 152 | pipeline.transformer.__class__.cnt = 0 153 | pipeline.transformer.__class__.num_steps = num_inference_steps 154 | pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup 155 | pipeline.transformer.__class__.cache = {} 156 | pipeline.transformer.__class__.uncond_seq_len = None 157 | 158 | 159 | pipeline.enable_model_cpu_offload() 160 | image = pipeline( 161 | prompt=prompt, 162 | num_inference_steps=num_inference_steps, 163 | generator=torch.Generator("cuda").manual_seed(seed) 164 | ).images[0] 165 | 166 | image.save(output_filename) -------------------------------------------------------------------------------- /TeaCache4Mochi/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4Mochi 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Mochi](https://github.com/genmoai/mochi) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Mochi with various rel_l1_thresh values: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup). 5 | 6 | https://github.com/user-attachments/assets/29a81380-46b3-414f-a96b-6e3acc71b6c4 7 | 8 | ## 📈 Inference Latency Comparisons on a Single A800 9 | 10 | 11 | | mochi-1-preview | TeaCache (0.06) | TeaCache (0.09) | 12 | |:--------------------------:|:----------------------------:|:--------------------:| 13 | | ~30 min | ~20 min | ~14 min | 14 | 15 | ## Installation 16 | 17 | ```shell 18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio 19 | ``` 20 | 21 | ## Usage 22 | 23 | You can modify the thresh in line 174 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command: 24 | 25 | ```bash 26 | python teacache_mochi.py 27 | ``` 28 | 29 | ## Citation 30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 31 | 32 | ``` 33 | @article{liu2024timestep, 34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 36 | journal={arXiv preprint arXiv:2411.19108}, 37 | year={2024} 38 | } 39 | ``` 40 | 41 | ## Acknowledgements 42 | 43 | We would like to thank the contributors to the [Mochi](https://github.com/genmoai/mochi) and [Diffusers](https://github.com/huggingface/diffusers). -------------------------------------------------------------------------------- /TeaCache4Mochi/teacache_mochi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.attention import SDPBackend, sdpa_kernel 3 | from diffusers import MochiPipeline 4 | from diffusers.models.transformers import MochiTransformer3DModel 5 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 6 | from diffusers.utils import export_to_video 7 | from diffusers.video_processor import VideoProcessor 8 | from typing import Any, Dict, Optional, Tuple 9 | import numpy as np 10 | 11 | 12 | def teacache_forward( 13 | self, 14 | hidden_states: torch.Tensor, 15 | encoder_hidden_states: torch.Tensor, 16 | timestep: torch.LongTensor, 17 | encoder_attention_mask: torch.Tensor, 18 | attention_kwargs: Optional[Dict[str, Any]] = None, 19 | return_dict: bool = True, 20 | ) -> torch.Tensor: 21 | if attention_kwargs is not None: 22 | attention_kwargs = attention_kwargs.copy() 23 | lora_scale = attention_kwargs.pop("scale", 1.0) 24 | else: 25 | lora_scale = 1.0 26 | 27 | if USE_PEFT_BACKEND: 28 | # weight the lora layers by setting `lora_scale` for each PEFT layer 29 | scale_lora_layers(self, lora_scale) 30 | else: 31 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 32 | logger.warning( 33 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 34 | ) 35 | 36 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 37 | p = self.config.patch_size 38 | 39 | post_patch_height = height // p 40 | post_patch_width = width // p 41 | 42 | temb, encoder_hidden_states = self.time_embed( 43 | timestep, 44 | encoder_hidden_states, 45 | encoder_attention_mask, 46 | hidden_dtype=hidden_states.dtype, 47 | ) 48 | 49 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) 50 | hidden_states = self.patch_embed(hidden_states) 51 | hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) 52 | 53 | image_rotary_emb = self.rope( 54 | self.pos_frequencies, 55 | num_frames, 56 | post_patch_height, 57 | post_patch_width, 58 | device=hidden_states.device, 59 | dtype=torch.float32, 60 | ) 61 | 62 | if self.enable_teacache: 63 | inp = hidden_states.clone() 64 | temb_ = temb.clone() 65 | modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_) 66 | if self.cnt == 0 or self.cnt == self.num_steps-1: 67 | should_calc = True 68 | self.accumulated_rel_l1_distance = 0 69 | else: 70 | coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03] 71 | rescale_func = np.poly1d(coefficients) 72 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 73 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 74 | should_calc = False 75 | else: 76 | should_calc = True 77 | self.accumulated_rel_l1_distance = 0 78 | self.previous_modulated_input = modulated_inp 79 | self.cnt += 1 80 | if self.cnt == self.num_steps: 81 | self.cnt = 0 82 | 83 | if self.enable_teacache: 84 | if not should_calc: 85 | hidden_states += self.previous_residual 86 | else: 87 | ori_hidden_states = hidden_states.clone() 88 | for i, block in enumerate(self.transformer_blocks): 89 | if torch.is_grad_enabled() and self.gradient_checkpointing: 90 | 91 | def create_custom_forward(module): 92 | def custom_forward(*inputs): 93 | return module(*inputs) 94 | 95 | return custom_forward 96 | 97 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 98 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 99 | create_custom_forward(block), 100 | hidden_states, 101 | encoder_hidden_states, 102 | temb, 103 | encoder_attention_mask, 104 | image_rotary_emb, 105 | **ckpt_kwargs, 106 | ) 107 | else: 108 | hidden_states, encoder_hidden_states = block( 109 | hidden_states=hidden_states, 110 | encoder_hidden_states=encoder_hidden_states, 111 | temb=temb, 112 | encoder_attention_mask=encoder_attention_mask, 113 | image_rotary_emb=image_rotary_emb, 114 | ) 115 | hidden_states = self.norm_out(hidden_states, temb) 116 | self.previous_residual = hidden_states - ori_hidden_states 117 | else: 118 | for i, block in enumerate(self.transformer_blocks): 119 | if torch.is_grad_enabled() and self.gradient_checkpointing: 120 | def create_custom_forward(module): 121 | def custom_forward(*inputs): 122 | return module(*inputs) 123 | 124 | return custom_forward 125 | 126 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 127 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 128 | create_custom_forward(block), 129 | hidden_states, 130 | encoder_hidden_states, 131 | temb, 132 | encoder_attention_mask, 133 | image_rotary_emb, 134 | **ckpt_kwargs, 135 | ) 136 | else: 137 | hidden_states, encoder_hidden_states = block( 138 | hidden_states=hidden_states, 139 | encoder_hidden_states=encoder_hidden_states, 140 | temb=temb, 141 | encoder_attention_mask=encoder_attention_mask, 142 | image_rotary_emb=image_rotary_emb, 143 | ) 144 | hidden_states = self.norm_out(hidden_states, temb) 145 | 146 | hidden_states = self.proj_out(hidden_states) 147 | 148 | hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) 149 | hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) 150 | output = hidden_states.reshape(batch_size, -1, num_frames, height, width) 151 | 152 | if USE_PEFT_BACKEND: 153 | # remove `lora_scale` from each PEFT layer 154 | unscale_lora_layers(self, lora_scale) 155 | 156 | if not return_dict: 157 | return (output,) 158 | return Transformer2DModelOutput(sample=output) 159 | 160 | 161 | MochiTransformer3DModel.forward = teacache_forward 162 | prompt = "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. \ 163 | A beige string bag sits beside the bowl, adding a rustic touch to the scene. \ 164 | Additional lemons, one halved, are scattered around the base of the bowl. \ 165 | The even lighting enhances the vibrant colors and creates a fresh, \ 166 | inviting atmosphere." 167 | num_inference_steps = 64 168 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True) 169 | 170 | # TeaCache 171 | pipe.transformer.__class__.enable_teacache = True 172 | pipe.transformer.__class__.cnt = 0 173 | pipe.transformer.__class__.num_steps = num_inference_steps 174 | pipe.transformer.__class__.rel_l1_thresh = 0.09 # 0.06 for 1.5x speedup, 0.09 for 2.1x speedup 175 | pipe.transformer.__class__.accumulated_rel_l1_distance = 0 176 | pipe.transformer.__class__.previous_modulated_input = None 177 | pipe.transformer.__class__.previous_residual = None 178 | 179 | # Enable memory savings 180 | pipe.enable_vae_tiling() 181 | pipe.enable_model_cpu_offload() 182 | 183 | with torch.no_grad(): 184 | prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( 185 | pipe.encode_prompt(prompt=prompt) 186 | ) 187 | 188 | with torch.autocast("cuda", torch.bfloat16): 189 | with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): 190 | frames = pipe( 191 | prompt_embeds=prompt_embeds, 192 | prompt_attention_mask=prompt_attention_mask, 193 | negative_prompt_embeds=negative_prompt_embeds, 194 | negative_prompt_attention_mask=negative_prompt_attention_mask, 195 | guidance_scale=4.5, 196 | num_inference_steps=num_inference_steps, 197 | height=480, 198 | width=848, 199 | num_frames=163, 200 | generator=torch.Generator("cuda").manual_seed(0), 201 | output_type="latent", 202 | return_dict=False, 203 | )[0] 204 | 205 | video_processor = VideoProcessor(vae_scale_factor=8) 206 | has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None 207 | has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None 208 | if has_latents_mean and has_latents_std: 209 | latents_mean = ( 210 | torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) 211 | ) 212 | latents_std = ( 213 | torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype) 214 | ) 215 | frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean 216 | else: 217 | frames = frames / pipe.vae.config.scaling_factor 218 | 219 | with torch.no_grad(): 220 | video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0] 221 | 222 | video = video_processor.postprocess_video(video)[0] 223 | export_to_video(video, "teacache_mochi__{}.mp4".format(prompt[:50]), fps=30) -------------------------------------------------------------------------------- /TeaCache4TangoFlux/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4TangoFlux 3 | 4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [TangoFlux](https://github.com/declare-lab/TangoFlux) 2x without much audio quality degradation, in a training-free manner. 5 | 6 | ## 📈 Inference Latency Comparisons on a Single A800 7 | 8 | 9 | | TangoFlux | TeaCache (0.25) | TeaCache (0.4) | 10 | |:-------------------:|:----------------------------:|:--------------------:| 11 | | ~4.08 s | ~2.42 s | ~1.95 s | 12 | 13 | ## Installation 14 | 15 | ```shell 16 | pip install git+https://github.com/declare-lab/TangoFlux 17 | ``` 18 | 19 | ## Usage 20 | 21 | You can modify the thresh in line 266 to obtain your desired trade-off between latency and audio quality. For single-gpu inference, you can use the following command: 22 | 23 | ```bash 24 | python teacache_tango_flux.py 25 | ``` 26 | 27 | ## Citation 28 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 29 | 30 | ``` 31 | @article{liu2024timestep, 32 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 33 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 34 | journal={arXiv preprint arXiv:2411.19108}, 35 | year={2024} 36 | } 37 | ``` 38 | 39 | ## Acknowledgements 40 | 41 | We would like to thank the contributors to the [TangoFlux](https://github.com/declare-lab/TangoFlux). -------------------------------------------------------------------------------- /TeaCache4Wan2.1/README.md: -------------------------------------------------------------------------------- 1 | 2 | # TeaCache4Wan2.1 3 | 4 | [TeaCache](https://github.com/ali-vilab/TeaCache) can speedup [Wan2.1](https://github.com/Wan-Video/Wan2.1) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Wan2.1 with various teacache_thresh values. The corresponding teacache_thresh values are shown in the following table. 5 | 6 | https://github.com/user-attachments/assets/5ae5d6dd-bf87-4f8f-91b8-ccc5980c56ad 7 | 8 | https://github.com/user-attachments/assets/dfd047a9-e3ca-4a73-a282-4dadda8dbd43 9 | 10 | https://github.com/user-attachments/assets/7c20bd54-96a8-4bd7-b4fa-ea4c9da81562 11 | 12 | https://github.com/user-attachments/assets/72085f45-6b78-4fae-b58f-492360a6e55e 13 | ## 📈 Inference Latency Comparisons on a Single A800 14 | 15 | 16 | | Wan2.1 t2v 1.3B | TeaCache (0.05) | TeaCache (0.07) | TeaCache (0.08) | 17 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 18 | | ~175 s | ~117 s | ~110 s | ~88 s | 19 | 20 | | Wan2.1 t2v 14B | TeaCache (0.14) | TeaCache (0.15) | TeaCache (0.2) | 21 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 22 | | ~55 min | ~38 min | ~30 min | ~27 min | 23 | 24 | | Wan2.1 i2v 480P | TeaCache (0.13) | TeaCache (0.19) | TeaCache (0.26) | 25 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 26 | | ~735 s | ~464 s | ~372 s | ~300 s | 27 | 28 | | Wan2.1 i2v 720P | TeaCache (0.18) | TeaCache (0.2) | TeaCache (0.3) | 29 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 30 | | ~29 min | ~17 min | ~15 min | ~12 min | 31 | 32 | ## Usage 33 | 34 | Follow [Wan2.1](https://github.com/Wan-Video/Wan2.1) to clone the repo and finish the installation, then copy 'teacache_generate.py' in this repo to the Wan2.1 repo. 35 | 36 | For T2V with 1.3B model, you can use the following command: 37 | 38 | ```bash 39 | python teacache_generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.08 40 | ``` 41 | 42 | For T2V with 14B model, you can use the following command: 43 | 44 | ```bash 45 | python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.2 46 | ``` 47 | 48 | For I2V with 480P resolution, you can use the following command: 49 | 50 | ```bash 51 | python teacache_generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.26 52 | ``` 53 | 54 | For I2V with 720P resolution, you can use the following command: 55 | 56 | ```bash 57 | python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3 58 | ``` 59 | 60 | ## Faster Video Generation Using the `use_ret_steps` Parameter 61 | 62 | Using Retention Steps will result in faster generation speed and better generation quality (except for t2v-1.3B). 63 | 64 | https://github.com/user-attachments/assets/f241b5f5-1044-4223-b2a4-449dc6dc1ad7 65 | 66 | https://github.com/user-attachments/assets/01db60f9-4aaf-43c4-8f1b-6e050cfa1180 67 | 68 | https://github.com/user-attachments/assets/e03621f2-1085-4571-8eca-51889f47ce18 69 | 70 | https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893 71 | 72 | 73 | | use_ref_steps | Wan2.1 t2v 1.3B (thresh) | Slow (thresh) | Fast (thresh) | 74 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 75 | | False | ~97 s (0.00) | ~64 s (0.05) | ~49 s (0.08) | 76 | | True | ~97 s (0.00) | ~61 s (0.05) | ~41 s (0.10) | 77 | 78 | | use_ref_steps | Wan2.1 t2v 14B (thresh) | Slow (thresh) | Fast (thresh) | 79 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 80 | | False | ~1829 s (0.00) | ~1234 s (0.14) | ~909 s (0.20) | 81 | | True | ~1829 s (0.00) | ~915 s (0.10) | ~578 s (0.20) | 82 | 83 | | use_ref_steps | Wan2.1 i2v 480p (thresh) | Slow (thresh) | Fast (thresh) | 84 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 85 | | False | ~385 s (0.00) | ~241 s (0.13) | ~156 s (0.26) | 86 | | True | ~385 s (0.00) | ~212 s (0.20) | ~164 s (0.30) | 87 | 88 | | use_ref_steps | Wan2.1 i2v 720p (thresh) | Slow (thresh) | Fast (thresh) | 89 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:| 90 | | False | ~903 s (0.00) | ~476 s (0.20) | ~363 s (0.30) | 91 | | True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) | 92 | 93 | 94 | You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to [Wan2.1](https://github.com/Wan-Video/Wan2.1). Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings. 95 | 96 | ### Example Command: 97 | 98 | ```bash 99 | python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.3 --use_ret_steps 100 | ``` 101 | 102 | 103 | ## Acknowledgements 104 | 105 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1). -------------------------------------------------------------------------------- /assets/TeaCache4FLUX.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4FLUX.png -------------------------------------------------------------------------------- /assets/TeaCache4HiDream-I1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4HiDream-I1.png -------------------------------------------------------------------------------- /assets/TeaCache4LuminaT2X.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4LuminaT2X.png -------------------------------------------------------------------------------- /assets/tisser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/tisser.png -------------------------------------------------------------------------------- /eval/teacache/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Prerequisites: 4 | 5 | - Python >= 3.10 6 | - PyTorch >= 1.13 (We recommend to use a >2.0 version) 7 | - CUDA >= 11.6 8 | 9 | We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples: 10 | 11 | ```shell 12 | conda create -n teacache python=3.10 -y 13 | conda activate teacache 14 | ``` 15 | 16 | Install TeaCache: 17 | 18 | ```shell 19 | git clone https://github.com/LiewFeng/TeaCache 20 | cd TeaCache 21 | pip install -e . 22 | ``` 23 | 24 | 25 | ## Evaluation of TeaCache 26 | 27 | We first generate videos according to VBench's prompts. 28 | 29 | And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated. 30 | 31 | 1. Generate video 32 | ``` 33 | cd eval/teacache 34 | python experiments/latte.py 35 | python experiments/opensora.py 36 | python experiments/open_sora_plan.py 37 | python experiments/cogvideox.py 38 | ``` 39 | 40 | 2. Calculate Vbench score 41 | ``` 42 | # vbench is calculated independently 43 | # get scores for all metrics 44 | python vbench/run_vbench.py --video_path aaa --save_path bbb 45 | # calculate final score 46 | python vbench/cal_vbench.py --score_dir bbb 47 | ``` 48 | 49 | 3. Calculate other metrics 50 | ``` 51 | # these metrics are calculated compared with original model 52 | # gt video is the video of original model 53 | # generated video is our methods's results 54 | python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb 55 | ``` 56 | 57 | 58 | 59 | ## Citation 60 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 61 | 62 | ``` 63 | @article{liu2024timestep, 64 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, 65 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, 66 | journal={arXiv preprint arXiv:2411.19108}, 67 | year={2024} 68 | } 69 | ``` 70 | 71 | ## Acknowledgements 72 | We would like to thank the contributors to the [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo) and [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys). 73 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/README.md: -------------------------------------------------------------------------------- 1 | Common metrics 2 | 3 | Include LPIPS, PSNR and SSIM. 4 | 5 | The code is adapted from [common_metrics_on_video_quality 6 | ](https://github.com/JunyaoHu/common_metrics_on_video_quality). 7 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/eval/teacache/common_metrics/__init__.py -------------------------------------------------------------------------------- /eval/teacache/common_metrics/batch_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import imageio 5 | import torch 6 | import torchvision.transforms.functional as F 7 | import tqdm 8 | from calculate_lpips import calculate_lpips 9 | from calculate_psnr import calculate_psnr 10 | from calculate_ssim import calculate_ssim 11 | 12 | 13 | def load_video(video_path): 14 | """ 15 | Load a video from the given path and convert it to a PyTorch tensor. 16 | """ 17 | # Read the video using imageio 18 | reader = imageio.get_reader(video_path, "ffmpeg") 19 | 20 | # Extract frames and convert to a list of tensors 21 | frames = [] 22 | for frame in reader: 23 | # Convert the frame to a tensor and permute the dimensions to match (C, H, W) 24 | frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1) 25 | frames.append(frame_tensor) 26 | 27 | # Stack the list of tensors into a single tensor with shape (T, C, H, W) 28 | video_tensor = torch.stack(frames) 29 | 30 | return video_tensor 31 | 32 | 33 | def resize_video(video, target_height, target_width): 34 | resized_frames = [] 35 | for frame in video: 36 | resized_frame = F.resize(frame, [target_height, target_width]) 37 | resized_frames.append(resized_frame) 38 | return torch.stack(resized_frames) 39 | 40 | 41 | def resize_gt_video(gt_video, gen_video): 42 | gen_video_shape = gen_video.shape 43 | T_gen, _, H_gen, W_gen = gen_video_shape 44 | T_eval, _, H_eval, W_eval = gt_video.shape 45 | 46 | if T_eval < T_gen: 47 | raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).") 48 | 49 | if H_eval < H_gen or W_eval < W_gen: 50 | # Resize the video maintaining the aspect ratio 51 | resize_height = max(H_gen, int(H_gen * (H_eval / W_eval))) 52 | resize_width = max(W_gen, int(W_gen * (W_eval / H_eval))) 53 | gt_video = resize_video(gt_video, resize_height, resize_width) 54 | # Recalculate the dimensions 55 | T_eval, _, H_eval, W_eval = gt_video.shape 56 | 57 | # Center crop 58 | start_h = (H_eval - H_gen) // 2 59 | start_w = (W_eval - W_gen) // 2 60 | cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen] 61 | 62 | return cropped_video 63 | 64 | 65 | def get_video_ids(gt_video_dirs, gen_video_dirs): 66 | video_ids = [] 67 | for f in os.listdir(gt_video_dirs[0]): 68 | if f.endswith(f".mp4"): 69 | video_ids.append(f.replace(f".mp4", "")) 70 | video_ids.sort() 71 | 72 | for video_dir in gt_video_dirs + gen_video_dirs: 73 | tmp_video_ids = [] 74 | for f in os.listdir(video_dir): 75 | if f.endswith(f".mp4"): 76 | tmp_video_ids.append(f.replace(f".mp4", "")) 77 | tmp_video_ids.sort() 78 | if tmp_video_ids != video_ids: 79 | raise ValueError(f"Video IDs in {video_dir} are different.") 80 | return video_ids 81 | 82 | 83 | def get_videos(video_ids, gt_video_dirs, gen_video_dirs): 84 | gt_videos = {} 85 | generated_videos = {} 86 | 87 | for gt_video_dir in gt_video_dirs: 88 | tmp_gt_videos_tensor = [] 89 | for video_id in video_ids: 90 | gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4")) 91 | tmp_gt_videos_tensor.append(gt_video) 92 | gt_videos[gt_video_dir] = tmp_gt_videos_tensor 93 | 94 | for generated_video_dir in gen_video_dirs: 95 | tmp_generated_videos_tensor = [] 96 | for video_id in video_ids: 97 | generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4")) 98 | tmp_generated_videos_tensor.append(generated_video) 99 | generated_videos[generated_video_dir] = tmp_generated_videos_tensor 100 | 101 | return gt_videos, generated_videos 102 | 103 | 104 | def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs): 105 | out_str = "" 106 | 107 | for gt_video_dir in gt_video_dirs: 108 | for generated_video_dir in gen_video_dirs: 109 | if gt_video_dir == generated_video_dir: 110 | continue 111 | lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len( 112 | lpips_results[gt_video_dir][generated_video_dir] 113 | ) 114 | psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len( 115 | psnr_results[gt_video_dir][generated_video_dir] 116 | ) 117 | ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len( 118 | ssim_results[gt_video_dir][generated_video_dir] 119 | ) 120 | out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}" 121 | 122 | return out_str 123 | 124 | 125 | def main(args): 126 | device = "cuda" 127 | gt_video_dirs = args.gt_video_dirs 128 | gen_video_dirs = args.gen_video_dirs 129 | 130 | video_ids = get_video_ids(gt_video_dirs, gen_video_dirs) 131 | print(f"Find {len(video_ids)} videos") 132 | 133 | prompt_interval = 1 134 | batch_size = 8 135 | calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True 136 | 137 | lpips_results = {} 138 | psnr_results = {} 139 | ssim_results = {} 140 | for gt_video_dir in gt_video_dirs: 141 | lpips_results[gt_video_dir] = {} 142 | psnr_results[gt_video_dir] = {} 143 | ssim_results[gt_video_dir] = {} 144 | for generated_video_dir in gen_video_dirs: 145 | lpips_results[gt_video_dir][generated_video_dir] = [] 146 | psnr_results[gt_video_dir][generated_video_dir] = [] 147 | ssim_results[gt_video_dir][generated_video_dir] = [] 148 | 149 | total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0) 150 | 151 | for idx in tqdm.tqdm(range(total_len)): 152 | video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size] 153 | gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs) 154 | 155 | for gt_video_dir, gt_videos_tensor in gt_videos.items(): 156 | for generated_video_dir, generated_videos_tensor in generated_videos.items(): 157 | if gt_video_dir == generated_video_dir: 158 | continue 159 | 160 | if not isinstance(gt_videos_tensor, torch.Tensor): 161 | for i in range(len(gt_videos_tensor)): 162 | gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0]) 163 | gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu() 164 | 165 | generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu() 166 | 167 | if calculate_lpips_flag: 168 | result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device) 169 | result = result["value"].values() 170 | result = float(sum(result) / len(result)) 171 | lpips_results[gt_video_dir][generated_video_dir].append(result) 172 | 173 | if calculate_psnr_flag: 174 | result = calculate_psnr(gt_videos_tensor, generated_videos_tensor) 175 | result = result["value"].values() 176 | result = float(sum(result) / len(result)) 177 | psnr_results[gt_video_dir][generated_video_dir].append(result) 178 | 179 | if calculate_ssim_flag: 180 | result = calculate_ssim(gt_videos_tensor, generated_videos_tensor) 181 | result = result["value"].values() 182 | result = float(sum(result) / len(result)) 183 | ssim_results[gt_video_dir][generated_video_dir].append(result) 184 | 185 | if (idx + 1) % prompt_interval == 0: 186 | out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) 187 | print(f"Processed {idx + 1} / {total_len} videos. {out_str}") 188 | 189 | out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs) 190 | 191 | # save 192 | with open(f"./batch_eval.txt", "w+") as f: 193 | f.write(out_str) 194 | 195 | print(f"Processed all videos. {out_str}") 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--gt_video_dirs", type=str, nargs="+") 201 | parser.add_argument("--gen_video_dirs", type=str, nargs="+") 202 | 203 | args = parser.parse_args() 204 | 205 | main(args) 206 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/calculate_lpips.py: -------------------------------------------------------------------------------- 1 | import lpips 2 | import numpy as np 3 | import torch 4 | 5 | spatial = True # Return a spatial map of perceptual distance. 6 | 7 | # Linearly calibrated models (LPIPS) 8 | loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg' 9 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg' 10 | 11 | 12 | def trans(x): 13 | # if greyscale images add channel 14 | if x.shape[-3] == 1: 15 | x = x.repeat(1, 1, 3, 1, 1) 16 | 17 | # value range [0, 1] -> [-1, 1] 18 | x = x * 2 - 1 19 | 20 | return x 21 | 22 | 23 | def calculate_lpips(videos1, videos2, device): 24 | # image should be RGB, IMPORTANT: normalized to [-1,1] 25 | 26 | assert videos1.shape == videos2.shape 27 | 28 | # videos [batch_size, timestamps, channel, h, w] 29 | 30 | # support grayscale input, if grayscale -> channel*3 31 | # value range [0, 1] -> [-1, 1] 32 | videos1 = trans(videos1) 33 | videos2 = trans(videos2) 34 | 35 | lpips_results = [] 36 | 37 | for video_num in range(videos1.shape[0]): 38 | # get a video 39 | # video [timestamps, channel, h, w] 40 | video1 = videos1[video_num] 41 | video2 = videos2[video_num] 42 | 43 | lpips_results_of_a_video = [] 44 | for clip_timestamp in range(len(video1)): 45 | # get a img 46 | # img [timestamps[x], channel, h, w] 47 | # img [channel, h, w] tensor 48 | 49 | img1 = video1[clip_timestamp].unsqueeze(0).to(device) 50 | img2 = video2[clip_timestamp].unsqueeze(0).to(device) 51 | 52 | loss_fn.to(device) 53 | 54 | # calculate lpips of a video 55 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist()) 56 | lpips_results.append(lpips_results_of_a_video) 57 | 58 | lpips_results = np.array(lpips_results) 59 | 60 | lpips = {} 61 | lpips_std = {} 62 | 63 | for clip_timestamp in range(len(video1)): 64 | lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp]) 65 | lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp]) 66 | 67 | result = { 68 | "value": lpips, 69 | "value_std": lpips_std, 70 | "video_setting": video1.shape, 71 | "video_setting_name": "time, channel, heigth, width", 72 | } 73 | 74 | return result 75 | 76 | 77 | # test code / using example 78 | 79 | 80 | def main(): 81 | NUMBER_OF_VIDEOS = 8 82 | VIDEO_LENGTH = 50 83 | CHANNEL = 3 84 | SIZE = 64 85 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 86 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 87 | device = torch.device("cuda") 88 | # device = torch.device("cpu") 89 | 90 | import json 91 | 92 | result = calculate_lpips(videos1, videos2, device) 93 | print(json.dumps(result, indent=4)) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/calculate_psnr.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def img_psnr(img1, img2): 8 | # [0,1] 9 | # compute mse 10 | # mse = np.mean((img1-img2)**2) 11 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2) 12 | # compute psnr 13 | if mse < 1e-10: 14 | return 100 15 | psnr = 20 * math.log10(1 / math.sqrt(mse)) 16 | return psnr 17 | 18 | 19 | def trans(x): 20 | return x 21 | 22 | 23 | def calculate_psnr(videos1, videos2): 24 | # videos [batch_size, timestamps, channel, h, w] 25 | 26 | assert videos1.shape == videos2.shape 27 | 28 | videos1 = trans(videos1) 29 | videos2 = trans(videos2) 30 | 31 | psnr_results = [] 32 | 33 | for video_num in range(videos1.shape[0]): 34 | # get a video 35 | # video [timestamps, channel, h, w] 36 | video1 = videos1[video_num] 37 | video2 = videos2[video_num] 38 | 39 | psnr_results_of_a_video = [] 40 | for clip_timestamp in range(len(video1)): 41 | # get a img 42 | # img [timestamps[x], channel, h, w] 43 | # img [channel, h, w] numpy 44 | 45 | img1 = video1[clip_timestamp].numpy() 46 | img2 = video2[clip_timestamp].numpy() 47 | 48 | # calculate psnr of a video 49 | psnr_results_of_a_video.append(img_psnr(img1, img2)) 50 | 51 | psnr_results.append(psnr_results_of_a_video) 52 | 53 | psnr_results = np.array(psnr_results) 54 | 55 | psnr = {} 56 | psnr_std = {} 57 | 58 | for clip_timestamp in range(len(video1)): 59 | psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp]) 60 | psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp]) 61 | 62 | result = { 63 | "value": psnr, 64 | "value_std": psnr_std, 65 | "video_setting": video1.shape, 66 | "video_setting_name": "time, channel, heigth, width", 67 | } 68 | 69 | return result 70 | 71 | 72 | # test code / using example 73 | 74 | 75 | def main(): 76 | NUMBER_OF_VIDEOS = 8 77 | VIDEO_LENGTH = 50 78 | CHANNEL = 3 79 | SIZE = 64 80 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 81 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 82 | 83 | import json 84 | 85 | result = calculate_psnr(videos1, videos2) 86 | print(json.dumps(result, indent=4)) 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/calculate_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def ssim(img1, img2): 7 | C1 = 0.01**2 8 | C2 = 0.03**2 9 | img1 = img1.astype(np.float64) 10 | img2 = img2.astype(np.float64) 11 | kernel = cv2.getGaussianKernel(11, 1.5) 12 | window = np.outer(kernel, kernel.transpose()) 13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 15 | mu1_sq = mu1**2 16 | mu2_sq = mu2**2 17 | mu1_mu2 = mu1 * mu2 18 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 19 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 22 | return ssim_map.mean() 23 | 24 | 25 | def calculate_ssim_function(img1, img2): 26 | # [0,1] 27 | # ssim is the only metric extremely sensitive to gray being compared to b/w 28 | if not img1.shape == img2.shape: 29 | raise ValueError("Input images must have the same dimensions.") 30 | if img1.ndim == 2: 31 | return ssim(img1, img2) 32 | elif img1.ndim == 3: 33 | if img1.shape[0] == 3: 34 | ssims = [] 35 | for i in range(3): 36 | ssims.append(ssim(img1[i], img2[i])) 37 | return np.array(ssims).mean() 38 | elif img1.shape[0] == 1: 39 | return ssim(np.squeeze(img1), np.squeeze(img2)) 40 | else: 41 | raise ValueError("Wrong input image dimensions.") 42 | 43 | 44 | def trans(x): 45 | return x 46 | 47 | 48 | def calculate_ssim(videos1, videos2): 49 | # videos [batch_size, timestamps, channel, h, w] 50 | 51 | assert videos1.shape == videos2.shape 52 | 53 | videos1 = trans(videos1) 54 | videos2 = trans(videos2) 55 | 56 | ssim_results = [] 57 | 58 | for video_num in range(videos1.shape[0]): 59 | # get a video 60 | # video [timestamps, channel, h, w] 61 | video1 = videos1[video_num] 62 | video2 = videos2[video_num] 63 | 64 | ssim_results_of_a_video = [] 65 | for clip_timestamp in range(len(video1)): 66 | # get a img 67 | # img [timestamps[x], channel, h, w] 68 | # img [channel, h, w] numpy 69 | 70 | img1 = video1[clip_timestamp].numpy() 71 | img2 = video2[clip_timestamp].numpy() 72 | 73 | # calculate ssim of a video 74 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2)) 75 | 76 | ssim_results.append(ssim_results_of_a_video) 77 | 78 | ssim_results = np.array(ssim_results) 79 | 80 | ssim = {} 81 | ssim_std = {} 82 | 83 | for clip_timestamp in range(len(video1)): 84 | ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp]) 85 | ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp]) 86 | 87 | result = { 88 | "value": ssim, 89 | "value_std": ssim_std, 90 | "video_setting": video1.shape, 91 | "video_setting_name": "time, channel, heigth, width", 92 | } 93 | 94 | return result 95 | 96 | 97 | # test code / using example 98 | 99 | 100 | def main(): 101 | NUMBER_OF_VIDEOS = 8 102 | VIDEO_LENGTH = 50 103 | CHANNEL = 3 104 | SIZE = 64 105 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 106 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False) 107 | torch.device("cuda") 108 | 109 | import json 110 | 111 | result = calculate_ssim(videos1, videos2) 112 | print(json.dumps(result, indent=4)) 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /eval/teacache/common_metrics/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import imageio 5 | import torch 6 | import torchvision.transforms.functional as F 7 | import tqdm 8 | from calculate_lpips import calculate_lpips 9 | from calculate_psnr import calculate_psnr 10 | from calculate_ssim import calculate_ssim 11 | 12 | 13 | def load_videos(directory, video_ids, file_extension): 14 | videos = [] 15 | for video_id in video_ids: 16 | video_path = os.path.join(directory, f"{video_id}.{file_extension}") 17 | if os.path.exists(video_path): 18 | video = load_video(video_path) # Define load_video based on how videos are stored 19 | videos.append(video) 20 | else: 21 | raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}") 22 | return videos 23 | 24 | 25 | def load_video(video_path): 26 | """ 27 | Load a video from the given path and convert it to a PyTorch tensor. 28 | """ 29 | # Read the video using imageio 30 | reader = imageio.get_reader(video_path, "ffmpeg") 31 | 32 | # Extract frames and convert to a list of tensors 33 | frames = [] 34 | for frame in reader: 35 | # Convert the frame to a tensor and permute the dimensions to match (C, H, W) 36 | frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1) 37 | frames.append(frame_tensor) 38 | 39 | # Stack the list of tensors into a single tensor with shape (T, C, H, W) 40 | video_tensor = torch.stack(frames) 41 | 42 | return video_tensor 43 | 44 | 45 | def resize_video(video, target_height, target_width): 46 | resized_frames = [] 47 | for frame in video: 48 | resized_frame = F.resize(frame, [target_height, target_width]) 49 | resized_frames.append(resized_frame) 50 | return torch.stack(resized_frames) 51 | 52 | 53 | def preprocess_eval_video(eval_video, generated_video_shape): 54 | T_gen, _, H_gen, W_gen = generated_video_shape 55 | T_eval, _, H_eval, W_eval = eval_video.shape 56 | 57 | if T_eval < T_gen: 58 | raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).") 59 | 60 | if H_eval < H_gen or W_eval < W_gen: 61 | # Resize the video maintaining the aspect ratio 62 | resize_height = max(H_gen, int(H_gen * (H_eval / W_eval))) 63 | resize_width = max(W_gen, int(W_gen * (W_eval / H_eval))) 64 | eval_video = resize_video(eval_video, resize_height, resize_width) 65 | # Recalculate the dimensions 66 | T_eval, _, H_eval, W_eval = eval_video.shape 67 | 68 | # Center crop 69 | start_h = (H_eval - H_gen) // 2 70 | start_w = (W_eval - W_gen) // 2 71 | cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen] 72 | 73 | return cropped_video 74 | 75 | 76 | def main(args): 77 | device = "cuda" 78 | gt_video_dir = args.gt_video_dir 79 | generated_video_dir = args.generated_video_dir 80 | 81 | video_ids = [] 82 | file_extension = "mp4" 83 | for f in os.listdir(generated_video_dir): 84 | if f.endswith(f".{file_extension}"): 85 | video_ids.append(f.replace(f".{file_extension}", "")) 86 | if not video_ids: 87 | raise ValueError("No videos found in the generated video dataset. Exiting.") 88 | 89 | print(f"Find {len(video_ids)} videos") 90 | prompt_interval = 1 91 | batch_size = 16 92 | calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True 93 | 94 | lpips_results = [] 95 | psnr_results = [] 96 | ssim_results = [] 97 | 98 | total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0) 99 | 100 | for idx, video_id in enumerate(tqdm.tqdm(range(total_len))): 101 | gt_videos_tensor = [] 102 | generated_videos_tensor = [] 103 | for i in range(batch_size): 104 | video_idx = idx * batch_size + i 105 | if video_idx >= len(video_ids): 106 | break 107 | video_id = video_ids[video_idx] 108 | generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}")) 109 | generated_videos_tensor.append(generated_video) 110 | eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}")) 111 | gt_videos_tensor.append(eval_video) 112 | gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu() 113 | generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu() 114 | 115 | if calculate_lpips_flag: 116 | result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device) 117 | result = result["value"].values() 118 | result = sum(result) / len(result) 119 | lpips_results.append(result) 120 | 121 | if calculate_psnr_flag: 122 | result = calculate_psnr(gt_videos_tensor, generated_videos_tensor) 123 | result = result["value"].values() 124 | result = sum(result) / len(result) 125 | psnr_results.append(result) 126 | 127 | if calculate_ssim_flag: 128 | result = calculate_ssim(gt_videos_tensor, generated_videos_tensor) 129 | result = result["value"].values() 130 | result = sum(result) / len(result) 131 | ssim_results.append(result) 132 | 133 | if (idx + 1) % prompt_interval == 0: 134 | out_str = "" 135 | for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]): 136 | result = sum(results) / len(results) 137 | out_str += f"{name}: {result:.4f}, " 138 | print(f"Processed {idx + 1} videos. {out_str[:-2]}") 139 | 140 | out_str = "" 141 | for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]): 142 | result = sum(results) / len(results) 143 | out_str += f"{name}: {result:.4f}, " 144 | out_str = out_str[:-2] 145 | 146 | # save 147 | with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f: 148 | f.write(out_str) 149 | 150 | print(f"Processed all videos. {out_str}") 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument("--gt_video_dir", type=str) 156 | parser.add_argument("--generated_video_dir", type=str) 157 | 158 | args = parser.parse_args() 159 | 160 | main(args) 161 | -------------------------------------------------------------------------------- /eval/teacache/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/eval/teacache/experiments/__init__.py -------------------------------------------------------------------------------- /eval/teacache/experiments/cogvideox.py: -------------------------------------------------------------------------------- 1 | from utils import generate_func, read_prompt_list 2 | from videosys import CogVideoXConfig, VideoSysEngine 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | import numpy as np 7 | from typing import Any, Dict, Optional, Tuple, Union 8 | from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence 9 | from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput 10 | from videosys.utils.utils import batch_func 11 | from functools import partial 12 | from diffusers.utils import is_torch_version 13 | 14 | def teacache_forward( 15 | self, 16 | hidden_states: torch.Tensor, 17 | encoder_hidden_states: torch.Tensor, 18 | timestep: Union[int, float, torch.LongTensor], 19 | timestep_cond: Optional[torch.Tensor] = None, 20 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 21 | return_dict: bool = True, 22 | all_timesteps=None 23 | ): 24 | if self.parallel_manager.cp_size > 1: 25 | ( 26 | hidden_states, 27 | encoder_hidden_states, 28 | timestep, 29 | timestep_cond, 30 | image_rotary_emb, 31 | ) = batch_func( 32 | partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), 33 | hidden_states, 34 | encoder_hidden_states, 35 | timestep, 36 | timestep_cond, 37 | image_rotary_emb, 38 | ) 39 | 40 | batch_size, num_frames, channels, height, width = hidden_states.shape 41 | 42 | # 1. Time embedding 43 | timesteps = timestep 44 | org_timestep = timestep 45 | t_emb = self.time_proj(timesteps) 46 | 47 | # timesteps does not contain any weights and will always return f32 tensors 48 | # but time_embedding might actually be running in fp16. so we need to cast here. 49 | # there might be better ways to encapsulate this. 50 | t_emb = t_emb.to(dtype=hidden_states.dtype) 51 | emb = self.time_embedding(t_emb, timestep_cond) 52 | 53 | # 2. Patch embedding 54 | hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) 55 | 56 | # 3. Position embedding 57 | text_seq_length = encoder_hidden_states.shape[1] 58 | if not self.config.use_rotary_positional_embeddings: 59 | seq_length = height * width * num_frames // (self.config.patch_size**2) 60 | 61 | pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length] 62 | hidden_states = hidden_states + pos_embeds 63 | hidden_states = self.embedding_dropout(hidden_states) 64 | 65 | encoder_hidden_states = hidden_states[:, :text_seq_length] 66 | hidden_states = hidden_states[:, text_seq_length:] 67 | 68 | if self.enable_teacache: 69 | if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]: 70 | should_calc = True 71 | self.accumulated_rel_l1_distance = 0 72 | else: 73 | if not self.config.use_rotary_positional_embeddings: 74 | # CogVideoX-2B 75 | coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03] 76 | else: 77 | # CogVideoX-5B 78 | coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] 79 | rescale_func = np.poly1d(coefficients) 80 | self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 81 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 82 | should_calc = False 83 | else: 84 | should_calc = True 85 | self.accumulated_rel_l1_distance = 0 86 | self.previous_modulated_input = emb 87 | 88 | if self.enable_teacache: 89 | if not should_calc: 90 | hidden_states += self.previous_residual 91 | encoder_hidden_states += self.previous_residual_encoder 92 | else: 93 | if self.parallel_manager.sp_size > 1: 94 | set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) 95 | hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) 96 | ori_hidden_states = hidden_states.clone() 97 | ori_encoder_hidden_states = encoder_hidden_states.clone() 98 | # 4. Transformer blocks 99 | for i, block in enumerate(self.transformer_blocks): 100 | if self.training and self.gradient_checkpointing: 101 | 102 | def create_custom_forward(module): 103 | def custom_forward(*inputs): 104 | return module(*inputs) 105 | 106 | return custom_forward 107 | 108 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 109 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 110 | create_custom_forward(block), 111 | hidden_states, 112 | encoder_hidden_states, 113 | emb, 114 | image_rotary_emb, 115 | **ckpt_kwargs, 116 | ) 117 | else: 118 | hidden_states, encoder_hidden_states = block( 119 | hidden_states=hidden_states, 120 | encoder_hidden_states=encoder_hidden_states, 121 | temb=emb, 122 | image_rotary_emb=image_rotary_emb, 123 | timestep=timesteps if False else None, 124 | ) 125 | self.previous_residual = hidden_states - ori_hidden_states 126 | self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states 127 | else: 128 | if self.parallel_manager.sp_size > 1: 129 | set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group) 130 | hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) 131 | 132 | # 4. Transformer blocks 133 | for i, block in enumerate(self.transformer_blocks): 134 | if self.training and self.gradient_checkpointing: 135 | 136 | def create_custom_forward(module): 137 | def custom_forward(*inputs): 138 | return module(*inputs) 139 | 140 | return custom_forward 141 | 142 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 143 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( 144 | create_custom_forward(block), 145 | hidden_states, 146 | encoder_hidden_states, 147 | emb, 148 | image_rotary_emb, 149 | **ckpt_kwargs, 150 | ) 151 | else: 152 | hidden_states, encoder_hidden_states = block( 153 | hidden_states=hidden_states, 154 | encoder_hidden_states=encoder_hidden_states, 155 | temb=emb, 156 | image_rotary_emb=image_rotary_emb, 157 | timestep=timesteps if False else None, 158 | ) 159 | 160 | if self.parallel_manager.sp_size > 1: 161 | if self.enable_teacache: 162 | if should_calc: 163 | hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) 164 | self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) 165 | else: 166 | hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad")) 167 | 168 | if not self.config.use_rotary_positional_embeddings: 169 | # CogVideoX-2B 170 | hidden_states = self.norm_final(hidden_states) 171 | else: 172 | # CogVideoX-5B 173 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 174 | hidden_states = self.norm_final(hidden_states) 175 | hidden_states = hidden_states[:, text_seq_length:] 176 | 177 | # 5. Final block 178 | hidden_states = self.norm_out(hidden_states, temb=emb) 179 | hidden_states = self.proj_out(hidden_states) 180 | 181 | # 6. Unpatchify 182 | p = self.config.patch_size 183 | output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) 184 | output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) 185 | 186 | if self.parallel_manager.cp_size > 1: 187 | output = gather_sequence(output, self.parallel_manager.cp_group, dim=0) 188 | 189 | if not return_dict: 190 | return (output,) 191 | return Transformer2DModelOutput(sample=output) 192 | 193 | 194 | def eval_teacache_slow(prompt_list): 195 | config = CogVideoXConfig() 196 | engine = VideoSysEngine(config) 197 | engine.driver_worker.transformer.__class__.enable_teacache = True 198 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1 199 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 200 | engine.driver_worker.transformer.__class__.previous_modulated_input = None 201 | engine.driver_worker.transformer.__class__.previous_residual = None 202 | engine.driver_worker.transformer.__class__.previous_residual_encoder = None 203 | engine.driver_worker.transformer.__class__.forward = teacache_forward 204 | generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5) 205 | 206 | def eval_teacache_fast(prompt_list): 207 | config = CogVideoXConfig() 208 | engine = VideoSysEngine(config) 209 | engine.driver_worker.transformer.__class__.enable_teacache = True 210 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2 211 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 212 | engine.driver_worker.transformer.__class__.previous_modulated_input = None 213 | engine.driver_worker.transformer.__class__.previous_residual = None 214 | engine.driver_worker.transformer.__class__.previous_residual_encoder = None 215 | engine.driver_worker.transformer.__class__.forward = teacache_forward 216 | generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5) 217 | 218 | 219 | def eval_base(prompt_list): 220 | config = CogVideoXConfig() 221 | engine = VideoSysEngine(config) 222 | generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5) 223 | 224 | 225 | if __name__ == "__main__": 226 | prompt_list = read_prompt_list("vbench/VBench_full_info.json") 227 | eval_base(prompt_list) 228 | eval_teacache_slow(prompt_list) 229 | eval_teacache_fast(prompt_list) 230 | -------------------------------------------------------------------------------- /eval/teacache/experiments/opensora.py: -------------------------------------------------------------------------------- 1 | from utils import generate_func, read_prompt_list 2 | from videosys import OpenSoraConfig, VideoSysEngine 3 | import torch 4 | from einops import rearrange 5 | from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint 6 | from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence 7 | import numpy as np 8 | from videosys.utils.utils import batch_func 9 | from functools import partial 10 | 11 | def teacache_forward( 12 | self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs 13 | ): 14 | # === Split batch === 15 | if self.parallel_manager.cp_size > 1: 16 | x, timestep, y, x_mask, mask = batch_func( 17 | partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0), 18 | x, 19 | timestep, 20 | y, 21 | x_mask, 22 | mask, 23 | ) 24 | 25 | dtype = self.x_embedder.proj.weight.dtype 26 | B = x.size(0) 27 | x = x.to(dtype) 28 | timestep = timestep.to(dtype) 29 | y = y.to(dtype) 30 | 31 | # === get pos embed === 32 | _, _, Tx, Hx, Wx = x.size() 33 | T, H, W = self.get_dynamic_size(x) 34 | S = H * W 35 | base_size = round(S**0.5) 36 | resolution_sq = (height[0].item() * width[0].item()) ** 0.5 37 | scale = resolution_sq / self.input_sq_size 38 | pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size) 39 | 40 | # === get timestep embed === 41 | t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] 42 | fps = self.fps_embedder(fps.unsqueeze(1), B) 43 | t = t + fps 44 | t_mlp = self.t_block(t) 45 | t0 = t0_mlp = None 46 | if x_mask is not None: 47 | t0_timestep = torch.zeros_like(timestep) 48 | t0 = self.t_embedder(t0_timestep, dtype=x.dtype) 49 | t0 = t0 + fps 50 | t0_mlp = self.t_block(t0) 51 | 52 | # === get y embed === 53 | if self.config.skip_y_embedder: 54 | y_lens = mask 55 | if isinstance(y_lens, torch.Tensor): 56 | y_lens = y_lens.long().tolist() 57 | else: 58 | y, y_lens = self.encode_text(y, mask) 59 | 60 | # === get x embed === 61 | x = self.x_embedder(x) # [B, N, C] 62 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) 63 | x = x + pos_emb 64 | 65 | if self.enable_teacache: 66 | inp = x.clone() 67 | inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S) 68 | B, N, C = inp.shape 69 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 70 | self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1) 71 | ).chunk(6, dim=1) 72 | modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa) 73 | if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]: 74 | should_calc = True 75 | self.accumulated_rel_l1_distance = 0 76 | else: 77 | coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02] 78 | rescale_func = np.poly1d(coefficients) 79 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) 80 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh: 81 | should_calc = False 82 | else: 83 | should_calc = True 84 | self.accumulated_rel_l1_distance = 0 85 | self.previous_modulated_input = modulated_inp 86 | 87 | # === blocks === 88 | if self.enable_teacache: 89 | if not should_calc: 90 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) 91 | x += self.previous_residual 92 | else: 93 | # shard over the sequence dim if sp is enabled 94 | if self.parallel_manager.sp_size > 1: 95 | set_pad("temporal", T, self.parallel_manager.sp_group) 96 | set_pad("spatial", S, self.parallel_manager.sp_group) 97 | x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) 98 | T = x.shape[1] 99 | x_mask_org = x_mask 100 | x_mask = split_sequence( 101 | x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") 102 | ) 103 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) 104 | origin_x = x.clone().detach() 105 | for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks): 106 | x = auto_grad_checkpoint( 107 | spatial_block, 108 | x, 109 | y, 110 | t_mlp, 111 | y_lens, 112 | x_mask, 113 | t0_mlp, 114 | T, 115 | S, 116 | timestep, 117 | all_timesteps=all_timesteps, 118 | ) 119 | 120 | x = auto_grad_checkpoint( 121 | temporal_block, 122 | x, 123 | y, 124 | t_mlp, 125 | y_lens, 126 | x_mask, 127 | t0_mlp, 128 | T, 129 | S, 130 | timestep, 131 | all_timesteps=all_timesteps, 132 | ) 133 | self.previous_residual = x - origin_x 134 | else: 135 | # shard over the sequence dim if sp is enabled 136 | if self.parallel_manager.sp_size > 1: 137 | set_pad("temporal", T, self.parallel_manager.sp_group) 138 | set_pad("spatial", S, self.parallel_manager.sp_group) 139 | x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")) 140 | T = x.shape[1] 141 | x_mask_org = x_mask 142 | x_mask = split_sequence( 143 | x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal") 144 | ) 145 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) 146 | 147 | for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks): 148 | x = auto_grad_checkpoint( 149 | spatial_block, 150 | x, 151 | y, 152 | t_mlp, 153 | y_lens, 154 | x_mask, 155 | t0_mlp, 156 | T, 157 | S, 158 | timestep, 159 | all_timesteps=all_timesteps, 160 | ) 161 | 162 | x = auto_grad_checkpoint( 163 | temporal_block, 164 | x, 165 | y, 166 | t_mlp, 167 | y_lens, 168 | x_mask, 169 | t0_mlp, 170 | T, 171 | S, 172 | timestep, 173 | all_timesteps=all_timesteps, 174 | ) 175 | 176 | if self.parallel_manager.sp_size > 1: 177 | if self.enable_teacache: 178 | if should_calc: 179 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) 180 | self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S) 181 | x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) 182 | self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) 183 | T, S = x.shape[1], x.shape[2] 184 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) 185 | self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S) 186 | x_mask = x_mask_org 187 | else: 188 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S) 189 | x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal")) 190 | T, S = x.shape[1], x.shape[2] 191 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S) 192 | x_mask = x_mask_org 193 | # === final layer === 194 | x = self.final_layer(x, t, x_mask, t0, T, S) 195 | x = self.unpatchify(x, T, H, W, Tx, Hx, Wx) 196 | 197 | # cast to float32 for better accuracy 198 | x = x.to(torch.float32) 199 | 200 | # === Gather Output === 201 | if self.parallel_manager.cp_size > 1: 202 | x = gather_sequence(x, self.parallel_manager.cp_group, dim=0) 203 | 204 | return x 205 | 206 | def eval_base(prompt_list): 207 | config = OpenSoraConfig() 208 | engine = VideoSysEngine(config) 209 | generate_func(engine, prompt_list, "./samples/opensora_base", loop=5) 210 | 211 | def eval_teacache_slow(prompt_list): 212 | config = OpenSoraConfig() 213 | engine = VideoSysEngine(config) 214 | engine.driver_worker.transformer.__class__.enable_teacache = True 215 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1 216 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 217 | engine.driver_worker.transformer.__class__.previous_modulated_input = None 218 | engine.driver_worker.transformer.__class__.previous_residual = None 219 | engine.driver_worker.transformer.__class__.forward = teacache_forward 220 | generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5) 221 | 222 | def eval_teacache_fast(prompt_list): 223 | config = OpenSoraConfig() 224 | engine = VideoSysEngine(config) 225 | engine.driver_worker.transformer.__class__.enable_teacache = True 226 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2 227 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0 228 | engine.driver_worker.transformer.__class__.previous_modulated_input = None 229 | engine.driver_worker.transformer.__class__.previous_residual = None 230 | engine.driver_worker.transformer.__class__.forward = teacache_forward 231 | generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5) 232 | 233 | 234 | if __name__ == "__main__": 235 | prompt_list = read_prompt_list("vbench/VBench_full_info.json") 236 | eval_base(prompt_list) 237 | eval_teacache_slow(prompt_list) 238 | eval_teacache_fast(prompt_list) 239 | -------------------------------------------------------------------------------- /eval/teacache/experiments/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import tqdm 5 | 6 | from videosys.utils.utils import set_seed 7 | 8 | 9 | def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}): 10 | kwargs["verbose"] = False 11 | for prompt in tqdm.tqdm(prompt_list): 12 | for l in range(loop): 13 | video = pipeline.generate(prompt, seed=l, **kwargs).video[0] 14 | pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4")) 15 | 16 | 17 | def read_prompt_list(prompt_list_path): 18 | with open(prompt_list_path, "r") as f: 19 | prompt_list = json.load(f) 20 | prompt_list = [prompt["prompt_en"] for prompt in prompt_list] 21 | return prompt_list 22 | -------------------------------------------------------------------------------- /eval/teacache/vbench/cal_vbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | SEMANTIC_WEIGHT = 1 6 | QUALITY_WEIGHT = 4 7 | 8 | QUALITY_LIST = [ 9 | "subject consistency", 10 | "background consistency", 11 | "temporal flickering", 12 | "motion smoothness", 13 | "aesthetic quality", 14 | "imaging quality", 15 | "dynamic degree", 16 | ] 17 | 18 | SEMANTIC_LIST = [ 19 | "object class", 20 | "multiple objects", 21 | "human action", 22 | "color", 23 | "spatial relationship", 24 | "scene", 25 | "appearance style", 26 | "temporal style", 27 | "overall consistency", 28 | ] 29 | 30 | NORMALIZE_DIC = { 31 | "subject consistency": {"Min": 0.1462, "Max": 1.0}, 32 | "background consistency": {"Min": 0.2615, "Max": 1.0}, 33 | "temporal flickering": {"Min": 0.6293, "Max": 1.0}, 34 | "motion smoothness": {"Min": 0.706, "Max": 0.9975}, 35 | "dynamic degree": {"Min": 0.0, "Max": 1.0}, 36 | "aesthetic quality": {"Min": 0.0, "Max": 1.0}, 37 | "imaging quality": {"Min": 0.0, "Max": 1.0}, 38 | "object class": {"Min": 0.0, "Max": 1.0}, 39 | "multiple objects": {"Min": 0.0, "Max": 1.0}, 40 | "human action": {"Min": 0.0, "Max": 1.0}, 41 | "color": {"Min": 0.0, "Max": 1.0}, 42 | "spatial relationship": {"Min": 0.0, "Max": 1.0}, 43 | "scene": {"Min": 0.0, "Max": 0.8222}, 44 | "appearance style": {"Min": 0.0009, "Max": 0.2855}, 45 | "temporal style": {"Min": 0.0, "Max": 0.364}, 46 | "overall consistency": {"Min": 0.0, "Max": 0.364}, 47 | } 48 | 49 | DIM_WEIGHT = { 50 | "subject consistency": 1, 51 | "background consistency": 1, 52 | "temporal flickering": 1, 53 | "motion smoothness": 1, 54 | "aesthetic quality": 1, 55 | "imaging quality": 1, 56 | "dynamic degree": 0.5, 57 | "object class": 1, 58 | "multiple objects": 1, 59 | "human action": 1, 60 | "color": 1, 61 | "spatial relationship": 1, 62 | "scene": 1, 63 | "appearance style": 1, 64 | "temporal style": 1, 65 | "overall consistency": 1, 66 | } 67 | 68 | ordered_scaled_res = [ 69 | "total score", 70 | "quality score", 71 | "semantic score", 72 | "subject consistency", 73 | "background consistency", 74 | "temporal flickering", 75 | "motion smoothness", 76 | "dynamic degree", 77 | "aesthetic quality", 78 | "imaging quality", 79 | "object class", 80 | "multiple objects", 81 | "human action", 82 | "color", 83 | "spatial relationship", 84 | "scene", 85 | "appearance style", 86 | "temporal style", 87 | "overall consistency", 88 | ] 89 | 90 | 91 | def parse_args(): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--score_dir", required=True, type=str) 94 | args = parser.parse_args() 95 | return args 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_args() 100 | res_postfix = "_eval_results.json" 101 | info_postfix = "_full_info.json" 102 | files = os.listdir(args.score_dir) 103 | res_files = [x for x in files if res_postfix in x] 104 | info_files = [x for x in files if info_postfix in x] 105 | assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files" 106 | 107 | full_results = {} 108 | for res_file in res_files: 109 | # first check if results is normal 110 | info_file = res_file.split(res_postfix)[0] + info_postfix 111 | with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f: 112 | info = json.load(f) 113 | assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list" 114 | # read results 115 | with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f: 116 | data = json.load(f) 117 | for key, val in data.items(): 118 | full_results[key] = format(val[0], ".4f") 119 | 120 | scaled_results = {} 121 | dims = set() 122 | for key, val in full_results.items(): 123 | dim = key.replace("_", " ") if "_" in key else key 124 | scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / ( 125 | NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"] 126 | ) 127 | scaled_score *= DIM_WEIGHT[dim] 128 | scaled_results[dim] = scaled_score 129 | dims.add(dim) 130 | 131 | assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet" 132 | 133 | quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST]) 134 | semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST]) 135 | scaled_results["quality score"] = quality_score 136 | scaled_results["semantic score"] = semantic_score 137 | scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / ( 138 | QUALITY_WEIGHT + SEMANTIC_WEIGHT 139 | ) 140 | 141 | formated_scaled_results = {"items": []} 142 | for key in ordered_scaled_res: 143 | formated_score = format(scaled_results[key] * 100, ".2f") + "%" 144 | formated_scaled_results["items"].append({key: formated_score}) 145 | 146 | output_file_path = os.path.join(args.score_dir, "all_results.json") 147 | with open(output_file_path, "w") as outfile: 148 | json.dump(full_results, outfile, indent=4, sort_keys=True) 149 | print(f"results saved to: {output_file_path}") 150 | 151 | scaled_file_path = os.path.join(args.score_dir, "scaled_results.json") 152 | with open(scaled_file_path, "w") as outfile: 153 | json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True) 154 | print(f"results saved to: {scaled_file_path}") 155 | -------------------------------------------------------------------------------- /eval/teacache/vbench/run_vbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from vbench import VBench 5 | 6 | full_info_path = "./vbench/VBench_full_info.json" 7 | 8 | dimensions = [ 9 | "subject_consistency", 10 | "imaging_quality", 11 | "background_consistency", 12 | "motion_smoothness", 13 | "overall_consistency", 14 | "human_action", 15 | "multiple_objects", 16 | "spatial_relationship", 17 | "object_class", 18 | "color", 19 | "aesthetic_quality", 20 | "appearance_style", 21 | "temporal_flickering", 22 | "scene", 23 | "temporal_style", 24 | "dynamic_degree", 25 | ] 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--video_path", required=True, type=str) 31 | parser.add_argument("--save_path", required=True, type=str) 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | if __name__ == "__main__": 37 | args = parse_args() 38 | 39 | kwargs = {} 40 | kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default 41 | 42 | for dimension in dimensions: 43 | my_VBench = VBench(torch.device("cuda"), full_info_path, args.save_path) 44 | my_VBench.evaluate( 45 | videos_path=args.video_path, 46 | name=dimension, 47 | local=False, 48 | read_frame=False, 49 | dimension_list=[dimension], 50 | mode="vbench_standard", 51 | **kwargs, 52 | ) 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>0.17.0 2 | bs4 3 | click 4 | colossalai==0.4.0 5 | diffusers==0.30.0 6 | einops 7 | fabric 8 | ftfy 9 | imageio 10 | imageio-ffmpeg 11 | matplotlib 12 | ninja 13 | numpy<2.0.0 14 | omegaconf 15 | packaging 16 | psutil 17 | pydantic 18 | ray 19 | rich 20 | safetensors 21 | sentencepiece 22 | timm 23 | torch>=1.13 24 | tqdm 25 | peft==0.13.2 26 | transformers==4.39.3 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from setuptools import find_packages, setup 4 | from setuptools.command.develop import develop 5 | from setuptools.command.egg_info import egg_info 6 | from setuptools.command.install import install 7 | 8 | 9 | def fetch_requirements(path) -> List[str]: 10 | """ 11 | This function reads the requirements file. 12 | 13 | Args: 14 | path (str): the path to the requirements file. 15 | 16 | Returns: 17 | The lines in the requirements file. 18 | """ 19 | with open(path, "r") as fd: 20 | requirements = [r.strip() for r in fd.readlines()] 21 | # requirements.remove("colossalai") 22 | return requirements 23 | 24 | 25 | def fetch_readme() -> str: 26 | """ 27 | This function reads the README.md file in the current directory. 28 | 29 | Returns: 30 | The lines in the README file. 31 | """ 32 | with open("README.md", encoding="utf-8") as f: 33 | return f.read() 34 | 35 | 36 | def custom_install(): 37 | return ["pip", "install", "colossalai", "--no-deps"] 38 | 39 | 40 | class CustomInstallCommand(install): 41 | def run(self): 42 | install.run(self) 43 | self.spawn(custom_install()) 44 | 45 | 46 | class CustomDevelopCommand(develop): 47 | def run(self): 48 | develop.run(self) 49 | self.spawn(custom_install()) 50 | 51 | 52 | class CustomEggInfoCommand(egg_info): 53 | def run(self): 54 | egg_info.run(self) 55 | self.spawn(custom_install()) 56 | 57 | 58 | setup( 59 | name="videosys", 60 | version="2.0.0", 61 | packages=find_packages( 62 | exclude=( 63 | "videos", 64 | "tests", 65 | "figure", 66 | "*.egg-info", 67 | ) 68 | ), 69 | description="TeaCache", 70 | long_description=fetch_readme(), 71 | long_description_content_type="text/markdown", 72 | license="Apache Software License 2.0", 73 | install_requires=fetch_requirements("requirements.txt"), 74 | python_requires=">=3.7", 75 | # cmdclass={ 76 | # "install": CustomInstallCommand, 77 | # "develop": CustomDevelopCommand, 78 | # "egg_info": CustomEggInfoCommand, 79 | # }, 80 | classifiers=[ 81 | "Programming Language :: Python :: 3", 82 | "License :: OSI Approved :: Apache Software License", 83 | "Environment :: GPU :: NVIDIA CUDA", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | "Topic :: System :: Distributed Computing", 86 | ], 87 | ) 88 | -------------------------------------------------------------------------------- /videosys/__init__.py: -------------------------------------------------------------------------------- 1 | from .core.engine import VideoSysEngine 2 | from .core.parallel_mgr import initialize 3 | from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline 4 | from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline 5 | from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline 6 | from .pipelines.open_sora_plan import ( 7 | OpenSoraPlanConfig, 8 | OpenSoraPlanPipeline, 9 | OpenSoraPlanV110PABConfig, 10 | OpenSoraPlanV120PABConfig, 11 | ) 12 | from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline 13 | 14 | __all__ = [ 15 | "initialize", 16 | "VideoSysEngine", 17 | "LattePipeline", "LatteConfig", "LattePABConfig", 18 | "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig", 19 | "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig", 20 | "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig", 21 | "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig" 22 | ] # fmt: skip 23 | -------------------------------------------------------------------------------- /videosys/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/__init__.py -------------------------------------------------------------------------------- /videosys/core/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import videosys 9 | 10 | from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port 11 | 12 | 13 | class VideoSysEngine: 14 | """ 15 | this is partly inspired by vllm 16 | """ 17 | 18 | def __init__(self, config): 19 | self.config = config 20 | self.parallel_worker_tasks = None 21 | self._init_worker(config.pipeline_cls) 22 | 23 | def _init_worker(self, pipeline_cls): 24 | world_size = self.config.num_gpus 25 | 26 | # Disable torch async compiling which won't work with daemonic processes 27 | os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" 28 | 29 | # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU 30 | # contention amongst the shards 31 | if "OMP_NUM_THREADS" not in os.environ: 32 | os.environ["OMP_NUM_THREADS"] = "1" 33 | 34 | # NOTE: The two following lines need adaption for multi-node 35 | assert world_size <= torch.cuda.device_count() 36 | 37 | # change addr for multi-node 38 | distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port()) 39 | 40 | if world_size == 1: 41 | self.workers = [] 42 | self.worker_monitor = None 43 | else: 44 | result_handler = ResultHandler() 45 | self.workers = [ 46 | ProcessWorkerWrapper( 47 | result_handler, 48 | partial( 49 | self._create_pipeline, 50 | pipeline_cls=pipeline_cls, 51 | rank=rank, 52 | local_rank=rank, 53 | distributed_init_method=distributed_init_method, 54 | ), 55 | ) 56 | for rank in range(1, world_size) 57 | ] 58 | 59 | self.worker_monitor = WorkerMonitor(self.workers, result_handler) 60 | result_handler.start() 61 | self.worker_monitor.start() 62 | 63 | self.driver_worker = self._create_pipeline( 64 | pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method 65 | ) 66 | 67 | # TODO: add more options here for pipeline, or wrap all options into config 68 | def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None): 69 | videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method) 70 | 71 | pipeline = pipeline_cls(self.config) 72 | return pipeline 73 | 74 | def _run_workers( 75 | self, 76 | method: str, 77 | *args, 78 | async_run_tensor_parallel_workers_only: bool = False, 79 | max_concurrent_workers: Optional[int] = None, 80 | **kwargs, 81 | ) -> Any: 82 | """Runs the given method on all workers.""" 83 | 84 | # Start the workers first. 85 | worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers] 86 | 87 | if async_run_tensor_parallel_workers_only: 88 | # Just return futures 89 | return worker_outputs 90 | 91 | driver_worker_method = getattr(self.driver_worker, method) 92 | driver_worker_output = driver_worker_method(*args, **kwargs) 93 | 94 | # Get the results of the workers. 95 | return [driver_worker_output] + [output.get() for output in worker_outputs] 96 | 97 | def _driver_execute_model(self, *args, **kwargs): 98 | return self.driver_worker.generate(*args, **kwargs) 99 | 100 | def generate(self, *args, **kwargs): 101 | return self._run_workers("generate", *args, **kwargs)[0] 102 | 103 | def stop_remote_worker_execution_loop(self) -> None: 104 | if self.parallel_worker_tasks is None: 105 | return 106 | 107 | parallel_worker_tasks = self.parallel_worker_tasks 108 | self.parallel_worker_tasks = None 109 | # Ensure that workers exit model loop cleanly 110 | # (this will raise otherwise) 111 | self._wait_for_tasks_completion(parallel_worker_tasks) 112 | 113 | def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: 114 | """Wait for futures returned from _run_workers() with 115 | async_run_remote_workers_only to complete.""" 116 | for result in parallel_worker_tasks: 117 | result.get() 118 | 119 | def save_video(self, video, output_path): 120 | return self.driver_worker.save_video(video, output_path) 121 | 122 | def shutdown(self): 123 | if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: 124 | worker_monitor.close() 125 | dist.destroy_process_group() 126 | 127 | def __del__(self): 128 | self.shutdown() 129 | -------------------------------------------------------------------------------- /videosys/core/mp_utils.py: -------------------------------------------------------------------------------- 1 | # adapted from vllm 2 | # https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py 3 | 4 | import asyncio 5 | import multiprocessing 6 | import os 7 | import socket 8 | import sys 9 | import threading 10 | import traceback 11 | import uuid 12 | from dataclasses import dataclass 13 | from multiprocessing import Queue 14 | from multiprocessing.connection import wait 15 | from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union 16 | 17 | from videosys.utils.logging import create_logger 18 | 19 | T = TypeVar("T") 20 | _TERMINATE = "TERMINATE" # sentinel 21 | # ANSI color codes 22 | CYAN = "\033[1;36m" 23 | RESET = "\033[0;0m" 24 | JOIN_TIMEOUT_S = 2 25 | 26 | mp_method = "spawn" # fork cann't work 27 | mp = multiprocessing.get_context(mp_method) 28 | 29 | logger = create_logger() 30 | 31 | 32 | def get_distributed_init_method(ip: str, port: int) -> str: 33 | # Brackets are not permitted in ipv4 addresses, 34 | # see https://github.com/python/cpython/issues/103848 35 | return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" 36 | 37 | 38 | def get_open_port() -> int: 39 | # try ipv4 40 | try: 41 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 42 | s.bind(("", 0)) 43 | return s.getsockname()[1] 44 | except OSError: 45 | # try ipv6 46 | with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: 47 | s.bind(("", 0)) 48 | return s.getsockname()[1] 49 | 50 | 51 | @dataclass 52 | class Result(Generic[T]): 53 | """Result of task dispatched to worker""" 54 | 55 | task_id: uuid.UUID 56 | value: Optional[T] = None 57 | exception: Optional[BaseException] = None 58 | 59 | 60 | class ResultFuture(threading.Event, Generic[T]): 61 | """Synchronous future for non-async case""" 62 | 63 | def __init__(self): 64 | super().__init__() 65 | self.result: Optional[Result[T]] = None 66 | 67 | def set_result(self, result: Result[T]): 68 | self.result = result 69 | self.set() 70 | 71 | def get(self) -> T: 72 | self.wait() 73 | assert self.result is not None 74 | if self.result.exception is not None: 75 | raise self.result.exception 76 | return self.result.value # type: ignore[return-value] 77 | 78 | 79 | def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result): 80 | if isinstance(future, ResultFuture): 81 | future.set_result(result) 82 | return 83 | loop = future.get_loop() 84 | if not loop.is_closed(): 85 | if result.exception is not None: 86 | loop.call_soon_threadsafe(future.set_exception, result.exception) 87 | else: 88 | loop.call_soon_threadsafe(future.set_result, result.value) 89 | 90 | 91 | class ResultHandler(threading.Thread): 92 | """Handle results from all workers (in background thread)""" 93 | 94 | def __init__(self) -> None: 95 | super().__init__(daemon=True) 96 | self.result_queue = mp.Queue() 97 | self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} 98 | 99 | def run(self): 100 | for result in iter(self.result_queue.get, _TERMINATE): 101 | future = self.tasks.pop(result.task_id) 102 | _set_future_result(future, result) 103 | # Ensure that all waiters will receive an exception 104 | for task_id, future in self.tasks.items(): 105 | _set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died"))) 106 | 107 | def close(self): 108 | self.result_queue.put(_TERMINATE) 109 | 110 | 111 | class WorkerMonitor(threading.Thread): 112 | """Monitor worker status (in background thread)""" 113 | 114 | def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler): 115 | super().__init__(daemon=True) 116 | self.workers = workers 117 | self.result_handler = result_handler 118 | self._close = False 119 | 120 | def run(self) -> None: 121 | # Blocks until any worker exits 122 | dead_sentinels = wait([w.process.sentinel for w in self.workers]) 123 | if not self._close: 124 | self._close = True 125 | 126 | # Kill / cleanup all workers 127 | for worker in self.workers: 128 | process = worker.process 129 | if process.sentinel in dead_sentinels: 130 | process.join(JOIN_TIMEOUT_S) 131 | if process.exitcode is not None and process.exitcode != 0: 132 | logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) 133 | # Cleanup any remaining workers 134 | logger.info("Killing local worker processes") 135 | for worker in self.workers: 136 | worker.kill_worker() 137 | # Must be done after worker task queues are all closed 138 | self.result_handler.close() 139 | 140 | for worker in self.workers: 141 | worker.process.join(JOIN_TIMEOUT_S) 142 | 143 | def close(self): 144 | if self._close: 145 | return 146 | self._close = True 147 | logger.info("Terminating local worker processes") 148 | for worker in self.workers: 149 | worker.terminate_worker() 150 | # Must be done after worker task queues are all closed 151 | self.result_handler.close() 152 | 153 | 154 | def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: 155 | """Prepend each output line with process-specific prefix""" 156 | 157 | prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " 158 | file_write = file.write 159 | 160 | def write_with_prefix(s: str): 161 | if not s: 162 | return 163 | if file.start_new_line: # type: ignore[attr-defined] 164 | file_write(prefix) 165 | idx = 0 166 | while (next_idx := s.find("\n", idx)) != -1: 167 | next_idx += 1 168 | file_write(s[idx:next_idx]) 169 | if next_idx == len(s): 170 | file.start_new_line = True # type: ignore[attr-defined] 171 | return 172 | file_write(prefix) 173 | idx = next_idx 174 | file_write(s[idx:]) 175 | file.start_new_line = False # type: ignore[attr-defined] 176 | 177 | file.start_new_line = True # type: ignore[attr-defined] 178 | file.write = write_with_prefix # type: ignore[method-assign] 179 | 180 | 181 | def _run_worker_process( 182 | worker_factory: Callable[[], Any], 183 | task_queue: Queue, 184 | result_queue: Queue, 185 | ) -> None: 186 | """Worker process event loop""" 187 | 188 | # Add process-specific prefix to stdout and stderr 189 | process_name = mp.current_process().name 190 | pid = os.getpid() 191 | _add_prefix(sys.stdout, process_name, pid) 192 | _add_prefix(sys.stderr, process_name, pid) 193 | 194 | # Initialize worker 195 | worker = worker_factory() 196 | del worker_factory 197 | 198 | # Accept tasks from the engine in task_queue 199 | # and return task output in result_queue 200 | logger.info("Worker ready; awaiting tasks") 201 | try: 202 | for items in iter(task_queue.get, _TERMINATE): 203 | output = None 204 | exception = None 205 | task_id, method, args, kwargs = items 206 | try: 207 | executor = getattr(worker, method) 208 | output = executor(*args, **kwargs) 209 | except BaseException as e: 210 | tb = traceback.format_exc() 211 | logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb) 212 | exception = e 213 | result_queue.put(Result(task_id=task_id, value=output, exception=exception)) 214 | except KeyboardInterrupt: 215 | pass 216 | except Exception: 217 | logger.exception("Worker failed") 218 | 219 | logger.info("Worker exiting") 220 | 221 | 222 | class ProcessWorkerWrapper: 223 | """Local process wrapper for handling single-node multi-GPU.""" 224 | 225 | def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None: 226 | self._task_queue = mp.Queue() 227 | self.result_queue = result_handler.result_queue 228 | self.tasks = result_handler.tasks 229 | self.process = mp.Process( # type: ignore[attr-defined] 230 | target=_run_worker_process, 231 | name="VideoSysWorkerProcess", 232 | kwargs=dict( 233 | worker_factory=worker_factory, 234 | task_queue=self._task_queue, 235 | result_queue=self.result_queue, 236 | ), 237 | daemon=True, 238 | ) 239 | 240 | self.process.start() 241 | 242 | def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs): 243 | task_id = uuid.uuid4() 244 | self.tasks[task_id] = future 245 | try: 246 | self._task_queue.put((task_id, method, args, kwargs)) 247 | except BaseException as e: 248 | del self.tasks[task_id] 249 | raise ChildProcessError("worker died") from e 250 | 251 | def execute_method(self, method: str, *args, **kwargs): 252 | future: ResultFuture = ResultFuture() 253 | self._enqueue_task(future, method, args, kwargs) 254 | return future 255 | 256 | async def execute_method_async(self, method: str, *args, **kwargs): 257 | future = asyncio.get_running_loop().create_future() 258 | self._enqueue_task(future, method, args, kwargs) 259 | return await future 260 | 261 | def terminate_worker(self): 262 | try: 263 | self._task_queue.put(_TERMINATE) 264 | except ValueError: 265 | self.process.kill() 266 | self._task_queue.close() 267 | 268 | def kill_worker(self): 269 | self._task_queue.close() 270 | self.process.kill() 271 | -------------------------------------------------------------------------------- /videosys/core/pab_mgr.py: -------------------------------------------------------------------------------- 1 | from videosys.utils.logging import logger 2 | 3 | PAB_MANAGER = None 4 | 5 | 6 | class PABConfig: 7 | def __init__( 8 | self, 9 | cross_broadcast: bool = False, 10 | cross_threshold: list = None, 11 | cross_range: int = None, 12 | spatial_broadcast: bool = False, 13 | spatial_threshold: list = None, 14 | spatial_range: int = None, 15 | temporal_broadcast: bool = False, 16 | temporal_threshold: list = None, 17 | temporal_range: int = None, 18 | mlp_broadcast: bool = False, 19 | mlp_spatial_broadcast_config: dict = None, 20 | mlp_temporal_broadcast_config: dict = None, 21 | ): 22 | self.steps = None 23 | 24 | self.cross_broadcast = cross_broadcast 25 | self.cross_threshold = cross_threshold 26 | self.cross_range = cross_range 27 | 28 | self.spatial_broadcast = spatial_broadcast 29 | self.spatial_threshold = spatial_threshold 30 | self.spatial_range = spatial_range 31 | 32 | self.temporal_broadcast = temporal_broadcast 33 | self.temporal_threshold = temporal_threshold 34 | self.temporal_range = temporal_range 35 | 36 | self.mlp_broadcast = mlp_broadcast 37 | self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config 38 | self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config 39 | self.mlp_temporal_outputs = {} 40 | self.mlp_spatial_outputs = {} 41 | 42 | 43 | class PABManager: 44 | def __init__(self, config: PABConfig): 45 | self.config: PABConfig = config 46 | 47 | init_prompt = f"Init Pyramid Attention Broadcast." 48 | init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}." 49 | init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}." 50 | init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}." 51 | init_prompt += f" mlp broadcast: {config.mlp_broadcast}." 52 | logger.info(init_prompt) 53 | 54 | def if_broadcast_cross(self, timestep: int, count: int): 55 | if ( 56 | self.config.cross_broadcast 57 | and (timestep is not None) 58 | and (count % self.config.cross_range != 0) 59 | and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1]) 60 | ): 61 | flag = True 62 | else: 63 | flag = False 64 | count = (count + 1) % self.config.steps 65 | return flag, count 66 | 67 | def if_broadcast_temporal(self, timestep: int, count: int): 68 | if ( 69 | self.config.temporal_broadcast 70 | and (timestep is not None) 71 | and (count % self.config.temporal_range != 0) 72 | and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1]) 73 | ): 74 | flag = True 75 | else: 76 | flag = False 77 | count = (count + 1) % self.config.steps 78 | return flag, count 79 | 80 | def if_broadcast_spatial(self, timestep: int, count: int): 81 | if ( 82 | self.config.spatial_broadcast 83 | and (timestep is not None) 84 | and (count % self.config.spatial_range != 0) 85 | and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1]) 86 | ): 87 | flag = True 88 | else: 89 | flag = False 90 | count = (count + 1) % self.config.steps 91 | return flag, count 92 | 93 | @staticmethod 94 | def _is_t_in_skip_config(all_timesteps, timestep, config): 95 | is_t_in_skip_config = False 96 | skip_range = None 97 | for key in config: 98 | if key not in all_timesteps: 99 | continue 100 | index = all_timesteps.index(key) 101 | skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])] 102 | if timestep in skip_range: 103 | is_t_in_skip_config = True 104 | skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]] 105 | break 106 | return is_t_in_skip_config, skip_range 107 | 108 | def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): 109 | if not self.config.mlp_broadcast: 110 | return False, None, False, None 111 | 112 | if is_temporal: 113 | cur_config = self.config.mlp_temporal_broadcast_config 114 | else: 115 | cur_config = self.config.mlp_spatial_broadcast_config 116 | 117 | is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config) 118 | next_flag = False 119 | if ( 120 | self.config.mlp_broadcast 121 | and (timestep is not None) 122 | and (timestep in cur_config) 123 | and (block_idx in cur_config[timestep]["block"]) 124 | ): 125 | flag = False 126 | next_flag = True 127 | count = count + 1 128 | elif ( 129 | self.config.mlp_broadcast 130 | and (timestep is not None) 131 | and (is_t_in_skip_config) 132 | and (block_idx in cur_config[skip_range[0]]["block"]) 133 | ): 134 | flag = True 135 | count = 0 136 | else: 137 | flag = False 138 | 139 | return flag, count, next_flag, skip_range 140 | 141 | def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False): 142 | if is_temporal: 143 | self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output 144 | else: 145 | self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output 146 | 147 | def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False): 148 | skip_start_t = skip_range[0] 149 | if is_temporal: 150 | skip_output = ( 151 | self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None) 152 | if self.config.mlp_temporal_outputs is not None 153 | else None 154 | ) 155 | else: 156 | skip_output = ( 157 | self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None) 158 | if self.config.mlp_spatial_outputs is not None 159 | else None 160 | ) 161 | 162 | if skip_output is not None: 163 | if timestep == skip_range[-1]: 164 | # TODO: save memory 165 | if is_temporal: 166 | del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)] 167 | else: 168 | del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)] 169 | else: 170 | raise ValueError( 171 | f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}" 172 | ) 173 | 174 | return skip_output 175 | 176 | def get_spatial_mlp_outputs(self): 177 | return self.config.mlp_spatial_outputs 178 | 179 | def get_temporal_mlp_outputs(self): 180 | return self.config.mlp_temporal_outputs 181 | 182 | 183 | def set_pab_manager(config: PABConfig): 184 | global PAB_MANAGER 185 | PAB_MANAGER = PABManager(config) 186 | 187 | 188 | def enable_pab(): 189 | if PAB_MANAGER is None: 190 | return False 191 | return ( 192 | PAB_MANAGER.config.cross_broadcast 193 | or PAB_MANAGER.config.spatial_broadcast 194 | or PAB_MANAGER.config.temporal_broadcast 195 | ) 196 | 197 | 198 | def update_steps(steps: int): 199 | if PAB_MANAGER is not None: 200 | PAB_MANAGER.config.steps = steps 201 | 202 | 203 | def if_broadcast_cross(timestep: int, count: int): 204 | if not enable_pab(): 205 | return False, count 206 | return PAB_MANAGER.if_broadcast_cross(timestep, count) 207 | 208 | 209 | def if_broadcast_temporal(timestep: int, count: int): 210 | if not enable_pab(): 211 | return False, count 212 | return PAB_MANAGER.if_broadcast_temporal(timestep, count) 213 | 214 | 215 | def if_broadcast_spatial(timestep: int, count: int): 216 | if not enable_pab(): 217 | return False, count 218 | return PAB_MANAGER.if_broadcast_spatial(timestep, count) 219 | 220 | 221 | def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False): 222 | if not enable_pab(): 223 | return False, count 224 | return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal) 225 | 226 | 227 | def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False): 228 | return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal) 229 | 230 | 231 | def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False): 232 | return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal) 233 | -------------------------------------------------------------------------------- /videosys/core/parallel_mgr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from colossalai.cluster.process_group_mesh import ProcessGroupMesh 4 | from torch.distributed import ProcessGroup 5 | 6 | from videosys.utils.logging import init_dist_logger, logger 7 | 8 | 9 | class ParallelManager(ProcessGroupMesh): 10 | def __init__(self, dp_size, cp_size, sp_size): 11 | super().__init__(dp_size, cp_size, sp_size) 12 | dp_axis, cp_axis, sp_axis = 0, 1, 2 13 | 14 | self.dp_size = dp_size 15 | self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis) 16 | self.dp_rank = dist.get_rank(self.dp_group) 17 | 18 | self.cp_size = cp_size 19 | if cp_size > 1: 20 | self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis) 21 | self.cp_rank = dist.get_rank(self.cp_group) 22 | else: 23 | self.cp_group = None 24 | self.cp_rank = None 25 | 26 | self.sp_size = sp_size 27 | if sp_size > 1: 28 | self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) 29 | self.sp_rank = dist.get_rank(self.sp_group) 30 | else: 31 | self.sp_group = None 32 | self.sp_rank = None 33 | 34 | logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}") 35 | 36 | 37 | def initialize( 38 | rank=0, 39 | world_size=1, 40 | init_method=None, 41 | ): 42 | if not dist.is_initialized(): 43 | try: 44 | dist.destroy_process_group() 45 | except Exception: 46 | pass 47 | dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank) 48 | torch.cuda.set_device(rank) 49 | init_dist_logger() 50 | torch.backends.cuda.matmul.allow_tf32 = True 51 | torch.backends.cudnn.allow_tf32 = True 52 | -------------------------------------------------------------------------------- /videosys/core/pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from abc import abstractmethod 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 7 | from diffusers.utils import BaseOutput 8 | 9 | 10 | class VideoSysPipeline(DiffusionPipeline): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | @staticmethod 15 | def set_eval_and_device(device: torch.device, *modules): 16 | modules = list(modules) 17 | for i in range(len(modules)): 18 | modules[i] = modules[i].eval() 19 | modules[i] = modules[i].to(device) 20 | 21 | @abstractmethod 22 | def generate(self, *args, **kwargs): 23 | pass 24 | 25 | def __call__(self, *args, **kwargs): 26 | """ 27 | In diffusers, it is a convention to call the pipeline object. 28 | But in VideoSys, we will use the generate method for better prompt. 29 | This is a wrapper for the generate method to support the diffusers usage. 30 | """ 31 | return self.generate(*args, **kwargs) 32 | 33 | @classmethod 34 | def _get_signature_keys(cls, obj): 35 | parameters = inspect.signature(obj.__init__).parameters 36 | required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} 37 | optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) 38 | expected_modules = set(required_parameters.keys()) - {"self"} 39 | # modify: remove the config module from the expected modules 40 | expected_modules = expected_modules - {"config"} 41 | 42 | optional_names = list(optional_parameters) 43 | for name in optional_names: 44 | if name in cls._optional_components: 45 | expected_modules.add(name) 46 | optional_parameters.remove(name) 47 | 48 | return expected_modules, optional_parameters 49 | 50 | 51 | @dataclass 52 | class VideoSysPipelineOutput(BaseOutput): 53 | video: torch.Tensor 54 | -------------------------------------------------------------------------------- /videosys/core/shardformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/shardformer/__init__.py -------------------------------------------------------------------------------- /videosys/core/shardformer/t5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/shardformer/t5/__init__.py -------------------------------------------------------------------------------- /videosys/core/shardformer/t5/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class T5LayerNorm(nn.Module): 6 | def __init__(self, hidden_size, eps=1e-6): 7 | """ 8 | Construct a layernorm module in the T5 style. No bias and no subtraction of mean. 9 | """ 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(hidden_size)) 12 | self.variance_epsilon = eps 13 | 14 | def forward(self, hidden_states): 15 | # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean 16 | # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated 17 | # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for 18 | # half-precision inputs is done in fp32 19 | 20 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 21 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 22 | 23 | # convert into half-precision if necessary 24 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 25 | hidden_states = hidden_states.to(self.weight.dtype) 26 | 27 | return self.weight * hidden_states 28 | 29 | @staticmethod 30 | def from_native_module(module, *args, **kwargs): 31 | assert module.__class__.__name__ == "FusedRMSNorm", ( 32 | "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm." 33 | "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48" 34 | ) 35 | 36 | layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps) 37 | layer_norm.weight.data.copy_(module.weight.data) 38 | layer_norm = layer_norm.to(module.weight.device) 39 | return layer_norm 40 | -------------------------------------------------------------------------------- /videosys/core/shardformer/t5/policy.py: -------------------------------------------------------------------------------- 1 | from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func 2 | from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward 3 | from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription 4 | 5 | 6 | class T5EncoderPolicy(Policy): 7 | def config_sanity_check(self): 8 | assert not self.shard_config.enable_tensor_parallelism 9 | assert not self.shard_config.enable_flash_attention 10 | 11 | def preprocess(self): 12 | return self.model 13 | 14 | def module_policy(self): 15 | from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack 16 | 17 | policy = {} 18 | 19 | # check whether apex is installed 20 | try: 21 | from apex.normalization import FusedRMSNorm # noqa 22 | from videosys.core.shardformer.t5.modeling import T5LayerNorm 23 | 24 | # recover hf from fused rms norm to T5 norm which is faster 25 | self.append_or_create_submodule_replacement( 26 | description=SubModuleReplacementDescription( 27 | suffix="layer_norm", 28 | target_module=T5LayerNorm, 29 | ), 30 | policy=policy, 31 | target_key=T5LayerFF, 32 | ) 33 | self.append_or_create_submodule_replacement( 34 | description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm), 35 | policy=policy, 36 | target_key=T5LayerSelfAttention, 37 | ) 38 | self.append_or_create_submodule_replacement( 39 | description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm), 40 | policy=policy, 41 | target_key=T5Stack, 42 | ) 43 | except (ImportError, ModuleNotFoundError): 44 | pass 45 | 46 | # use jit operator 47 | if self.shard_config.enable_jit_fused: 48 | self.append_or_create_method_replacement( 49 | description={ 50 | "forward": get_jit_fused_T5_layer_ff_forward(), 51 | "dropout_add": get_jit_fused_dropout_add_func(), 52 | }, 53 | policy=policy, 54 | target_key=T5LayerFF, 55 | ) 56 | self.append_or_create_method_replacement( 57 | description={ 58 | "forward": get_T5_layer_self_attention_forward(), 59 | "dropout_add": get_jit_fused_dropout_add_func(), 60 | }, 61 | policy=policy, 62 | target_key=T5LayerSelfAttention, 63 | ) 64 | 65 | return policy 66 | 67 | def postprocess(self): 68 | return self.model 69 | -------------------------------------------------------------------------------- /videosys/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/__init__.py -------------------------------------------------------------------------------- /videosys/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/autoencoders/__init__.py -------------------------------------------------------------------------------- /videosys/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/modules/__init__.py -------------------------------------------------------------------------------- /videosys/models/modules/activations.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | approx_gelu = lambda: nn.GELU(approximate="tanh") 4 | -------------------------------------------------------------------------------- /videosys/models/modules/downsampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CogVideoXDownsample3D(nn.Module): 7 | # Todo: Wait for paper relase. 8 | r""" 9 | A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI 10 | 11 | Args: 12 | in_channels (`int`): 13 | Number of channels in the input image. 14 | out_channels (`int`): 15 | Number of channels produced by the convolution. 16 | kernel_size (`int`, defaults to `3`): 17 | Size of the convolving kernel. 18 | stride (`int`, defaults to `2`): 19 | Stride of the convolution. 20 | padding (`int`, defaults to `0`): 21 | Padding added to all four sides of the input. 22 | compress_time (`bool`, defaults to `False`): 23 | Whether or not to compress the time dimension. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | in_channels: int, 29 | out_channels: int, 30 | kernel_size: int = 3, 31 | stride: int = 2, 32 | padding: int = 0, 33 | compress_time: bool = False, 34 | ): 35 | super().__init__() 36 | 37 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 38 | self.compress_time = compress_time 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | if self.compress_time: 42 | batch_size, channels, frames, height, width = x.shape 43 | 44 | # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) 45 | x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) 46 | 47 | if x.shape[-1] % 2 == 1: 48 | x_first, x_rest = x[..., 0], x[..., 1:] 49 | if x_rest.shape[-1] > 0: 50 | # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) 51 | x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) 52 | 53 | x = torch.cat([x_first[..., None], x_rest], dim=-1) 54 | # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) 55 | x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) 56 | else: 57 | # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) 58 | x = F.avg_pool1d(x, kernel_size=2, stride=2) 59 | # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) 60 | x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) 61 | 62 | # Pad the tensor 63 | pad = (0, 1, 0, 1) 64 | x = F.pad(x, pad, mode="constant", value=0) 65 | batch_size, channels, frames, height, width = x.shape 66 | # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) 67 | x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) 68 | x = self.conv(x) 69 | # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) 70 | x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) 71 | return x 72 | -------------------------------------------------------------------------------- /videosys/models/modules/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class LlamaRMSNorm(nn.Module): 9 | def __init__(self, hidden_size, eps=1e-6): 10 | """ 11 | LlamaRMSNorm is equivalent to T5LayerNorm 12 | """ 13 | super().__init__() 14 | self.weight = nn.Parameter(torch.ones(hidden_size)) 15 | self.variance_epsilon = eps 16 | 17 | def forward(self, hidden_states): 18 | input_dtype = hidden_states.dtype 19 | hidden_states = hidden_states.to(torch.float32) 20 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 21 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 22 | return self.weight * hidden_states.to(input_dtype) 23 | 24 | 25 | class CogVideoXLayerNormZero(nn.Module): 26 | def __init__( 27 | self, 28 | conditioning_dim: int, 29 | embedding_dim: int, 30 | elementwise_affine: bool = True, 31 | eps: float = 1e-5, 32 | bias: bool = True, 33 | ) -> None: 34 | super().__init__() 35 | 36 | self.silu = nn.SiLU() 37 | self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) 38 | self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) 39 | 40 | def forward( 41 | self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor 42 | ) -> Tuple[torch.Tensor, torch.Tensor]: 43 | shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) 44 | hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] 45 | encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] 46 | return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] 47 | 48 | 49 | class AdaLayerNorm(nn.Module): 50 | r""" 51 | Norm layer modified to incorporate timestep embeddings. 52 | 53 | Parameters: 54 | embedding_dim (`int`): The size of each embedding vector. 55 | num_embeddings (`int`, *optional*): The size of the embeddings dictionary. 56 | output_dim (`int`, *optional*): 57 | norm_elementwise_affine (`bool`, defaults to `False): 58 | norm_eps (`bool`, defaults to `False`): 59 | chunk_dim (`int`, defaults to `0`): 60 | """ 61 | 62 | def __init__( 63 | self, 64 | embedding_dim: int, 65 | num_embeddings: Optional[int] = None, 66 | output_dim: Optional[int] = None, 67 | norm_elementwise_affine: bool = False, 68 | norm_eps: float = 1e-5, 69 | chunk_dim: int = 0, 70 | ): 71 | super().__init__() 72 | 73 | self.chunk_dim = chunk_dim 74 | output_dim = output_dim or embedding_dim * 2 75 | 76 | if num_embeddings is not None: 77 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 78 | else: 79 | self.emb = None 80 | 81 | self.silu = nn.SiLU() 82 | self.linear = nn.Linear(embedding_dim, output_dim) 83 | self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) 84 | 85 | def forward( 86 | self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None 87 | ) -> torch.Tensor: 88 | if self.emb is not None: 89 | temb = self.emb(timestep) 90 | 91 | temb = self.linear(self.silu(temb)) 92 | 93 | if self.chunk_dim == 1: 94 | # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the 95 | # other if-branch. This branch is specific to CogVideoX for now. 96 | shift, scale = temb.chunk(2, dim=1) 97 | shift = shift[:, None, :] 98 | scale = scale[:, None, :] 99 | else: 100 | scale, shift = temb.chunk(2, dim=0) 101 | 102 | x = self.norm(x) * (1 + scale) + shift 103 | return x 104 | 105 | 106 | class VchitectSpatialNorm(nn.Module): 107 | """ 108 | Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. 109 | 110 | Args: 111 | f_channels (`int`): 112 | The number of channels for input to group normalization layer, and output of the spatial norm layer. 113 | zq_channels (`int`): 114 | The number of channels for the quantized vector as described in the paper. 115 | """ 116 | 117 | def __init__( 118 | self, 119 | f_channels: int, 120 | zq_channels: int, 121 | ): 122 | super().__init__() 123 | self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) 124 | self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) 125 | self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) 126 | 127 | def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: 128 | f_size = f.shape[-2:] 129 | zq = F.interpolate(zq, size=f_size, mode="nearest") 130 | norm_f = self.norm_layer(f) 131 | new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) 132 | return new_f 133 | -------------------------------------------------------------------------------- /videosys/models/modules/upsampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CogVideoXUpsample3D(nn.Module): 7 | r""" 8 | A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. 9 | 10 | Args: 11 | in_channels (`int`): 12 | Number of channels in the input image. 13 | out_channels (`int`): 14 | Number of channels produced by the convolution. 15 | kernel_size (`int`, defaults to `3`): 16 | Size of the convolving kernel. 17 | stride (`int`, defaults to `1`): 18 | Stride of the convolution. 19 | padding (`int`, defaults to `1`): 20 | Padding added to all four sides of the input. 21 | compress_time (`bool`, defaults to `False`): 22 | Whether or not to compress the time dimension. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | kernel_size: int = 3, 30 | stride: int = 1, 31 | padding: int = 1, 32 | compress_time: bool = False, 33 | ) -> None: 34 | super().__init__() 35 | 36 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 37 | self.compress_time = compress_time 38 | 39 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 40 | if self.compress_time: 41 | if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: 42 | # split first frame 43 | x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] 44 | 45 | x_first = F.interpolate(x_first, scale_factor=2.0) 46 | x_rest = F.interpolate(x_rest, scale_factor=2.0) 47 | x_first = x_first[:, :, None, :, :] 48 | inputs = torch.cat([x_first, x_rest], dim=2) 49 | elif inputs.shape[2] > 1: 50 | inputs = F.interpolate(inputs, scale_factor=2.0) 51 | else: 52 | inputs = inputs.squeeze(2) 53 | inputs = F.interpolate(inputs, scale_factor=2.0) 54 | inputs = inputs[:, :, None, :, :] 55 | else: 56 | # only interpolate 2D 57 | b, c, t, h, w = inputs.shape 58 | inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) 59 | inputs = F.interpolate(inputs, scale_factor=2.0) 60 | inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) 61 | 62 | b, c, t, h, w = inputs.shape 63 | inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) 64 | inputs = self.conv(inputs) 65 | inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) 66 | 67 | return inputs 68 | -------------------------------------------------------------------------------- /videosys/models/transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/transformers/__init__.py -------------------------------------------------------------------------------- /videosys/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/pipelines/__init__.py -------------------------------------------------------------------------------- /videosys/pipelines/cogvideox/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline 2 | 3 | __all__ = ["CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"] 4 | -------------------------------------------------------------------------------- /videosys/pipelines/latte/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_latte import LatteConfig, LattePABConfig, LattePipeline 2 | 3 | __all__ = ["LatteConfig", "LattePipeline", "LattePABConfig"] 4 | -------------------------------------------------------------------------------- /videosys/pipelines/open_sora/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline 2 | 3 | __all__ = ["OpenSoraConfig", "OpenSoraPipeline", "OpenSoraPABConfig"] 4 | -------------------------------------------------------------------------------- /videosys/pipelines/open_sora_plan/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_open_sora_plan import ( 2 | OpenSoraPlanConfig, 3 | OpenSoraPlanPipeline, 4 | OpenSoraPlanV110PABConfig, 5 | OpenSoraPlanV120PABConfig, 6 | ) 7 | 8 | __all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"] 9 | -------------------------------------------------------------------------------- /videosys/pipelines/vchitect/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline 2 | 3 | __all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"] 4 | -------------------------------------------------------------------------------- /videosys/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/schedulers/__init__.py -------------------------------------------------------------------------------- /videosys/schedulers/scheduling_rflow_open_sora.py: -------------------------------------------------------------------------------- 1 | # Adapted from OpenSora 2 | 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # -------------------------------------------------------- 6 | # References: 7 | # OpenSora: https://github.com/hpcaitech/Open-Sora 8 | # -------------------------------------------------------- 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from einops import rearrange 13 | from torch.distributions import LogisticNormal 14 | from tqdm import tqdm 15 | 16 | 17 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 18 | """ 19 | Extract values from a 1-D numpy array for a batch of indices. 20 | :param arr: the 1-D numpy array. 21 | :param timesteps: a tensor of indices into the array to extract. 22 | :param broadcast_shape: a larger shape of K dimensions with the batch 23 | dimension equal to the length of timesteps. 24 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 25 | """ 26 | res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 27 | while len(res.shape) < len(broadcast_shape): 28 | res = res[..., None] 29 | return res + torch.zeros(broadcast_shape, device=timesteps.device) 30 | 31 | 32 | def mean_flat(tensor: torch.Tensor, mask=None): 33 | """ 34 | Take the mean over all non-batch dimensions. 35 | """ 36 | if mask is None: 37 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 38 | else: 39 | assert tensor.dim() == 5 40 | assert tensor.shape[2] == mask.shape[1] 41 | tensor = rearrange(tensor, "b c t h w -> b t (c h w)") 42 | denom = mask.sum(dim=1) * tensor.shape[-1] 43 | loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom 44 | return loss 45 | 46 | 47 | def timestep_transform( 48 | t, 49 | model_kwargs, 50 | base_resolution=512 * 512, 51 | base_num_frames=1, 52 | scale=1.0, 53 | num_timesteps=1, 54 | ): 55 | t = t / num_timesteps 56 | resolution = model_kwargs["height"] * model_kwargs["width"] 57 | ratio_space = (resolution / base_resolution).sqrt() 58 | # NOTE: currently, we do not take fps into account 59 | # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae 60 | if model_kwargs["num_frames"][0] == 1: 61 | num_frames = torch.ones_like(model_kwargs["num_frames"]) 62 | else: 63 | num_frames = model_kwargs["num_frames"] // 17 * 5 64 | ratio_time = (num_frames / base_num_frames).sqrt() 65 | 66 | ratio = ratio_space * ratio_time * scale 67 | new_t = ratio * t / (1 + (ratio - 1) * t) 68 | 69 | new_t = new_t * num_timesteps 70 | return new_t 71 | 72 | 73 | class RFlowScheduler: 74 | def __init__( 75 | self, 76 | num_timesteps=1000, 77 | num_sampling_steps=10, 78 | use_discrete_timesteps=False, 79 | sample_method="uniform", 80 | loc=0.0, 81 | scale=1.0, 82 | use_timestep_transform=False, 83 | transform_scale=1.0, 84 | ): 85 | self.num_timesteps = num_timesteps 86 | self.num_sampling_steps = num_sampling_steps 87 | self.use_discrete_timesteps = use_discrete_timesteps 88 | 89 | # sample method 90 | assert sample_method in ["uniform", "logit-normal"] 91 | assert ( 92 | sample_method == "uniform" or not use_discrete_timesteps 93 | ), "Only uniform sampling is supported for discrete timesteps" 94 | self.sample_method = sample_method 95 | if sample_method == "logit-normal": 96 | self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) 97 | self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) 98 | 99 | # timestep transform 100 | self.use_timestep_transform = use_timestep_transform 101 | self.transform_scale = transform_scale 102 | 103 | def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): 104 | """ 105 | Compute training losses for a single timestep. 106 | Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses 107 | Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0] 108 | """ 109 | if t is None: 110 | if self.use_discrete_timesteps: 111 | t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) 112 | elif self.sample_method == "uniform": 113 | t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps 114 | elif self.sample_method == "logit-normal": 115 | t = self.sample_t(x_start) * self.num_timesteps 116 | 117 | if self.use_timestep_transform: 118 | t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps) 119 | 120 | if model_kwargs is None: 121 | model_kwargs = {} 122 | if noise is None: 123 | noise = torch.randn_like(x_start) 124 | assert noise.shape == x_start.shape 125 | 126 | x_t = self.add_noise(x_start, noise, t) 127 | if mask is not None: 128 | t0 = torch.zeros_like(t) 129 | x_t0 = self.add_noise(x_start, noise, t0) 130 | x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0) 131 | 132 | terms = {} 133 | model_output = model(x_t, t, **model_kwargs) 134 | velocity_pred = model_output.chunk(2, dim=1)[0] 135 | if weights is None: 136 | loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask) 137 | else: 138 | weight = _extract_into_tensor(weights, t, x_start.shape) 139 | loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask) 140 | terms["loss"] = loss 141 | 142 | return terms 143 | 144 | def add_noise( 145 | self, 146 | original_samples: torch.FloatTensor, 147 | noise: torch.FloatTensor, 148 | timesteps: torch.IntTensor, 149 | ) -> torch.FloatTensor: 150 | """ 151 | compatible with diffusers add_noise() 152 | """ 153 | timepoints = timesteps.float() / self.num_timesteps 154 | timepoints = 1 - timepoints # [1,1/1000] 155 | 156 | # timepoint (bsz) noise: (bsz, 4, frame, w ,h) 157 | # expand timepoint to noise shape 158 | timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) 159 | timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) 160 | 161 | return timepoints * original_samples + (1 - timepoints) * noise 162 | 163 | 164 | class RFLOW: 165 | def __init__( 166 | self, 167 | num_sampling_steps=10, 168 | num_timesteps=1000, 169 | cfg_scale=4.0, 170 | use_discrete_timesteps=False, 171 | use_timestep_transform=False, 172 | **kwargs, 173 | ): 174 | self.num_sampling_steps = num_sampling_steps 175 | self.num_timesteps = num_timesteps 176 | self.cfg_scale = cfg_scale 177 | self.use_discrete_timesteps = use_discrete_timesteps 178 | self.use_timestep_transform = use_timestep_transform 179 | 180 | self.scheduler = RFlowScheduler( 181 | num_timesteps=num_timesteps, 182 | num_sampling_steps=num_sampling_steps, 183 | use_discrete_timesteps=use_discrete_timesteps, 184 | use_timestep_transform=use_timestep_transform, 185 | **kwargs, 186 | ) 187 | 188 | def sample( 189 | self, 190 | model, 191 | z, 192 | model_args, 193 | y_null, 194 | device, 195 | mask=None, 196 | guidance_scale=None, 197 | progress=True, 198 | verbose=False, 199 | ): 200 | # if no specific guidance scale is provided, use the default scale when initializing the scheduler 201 | if guidance_scale is None: 202 | guidance_scale = self.cfg_scale 203 | 204 | # text encoding 205 | model_args["y"] = torch.cat([model_args["y"], y_null], 0) 206 | 207 | # prepare timesteps 208 | timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)] 209 | if self.use_discrete_timesteps: 210 | timesteps = [int(round(t)) for t in timesteps] 211 | timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps] 212 | if self.use_timestep_transform: 213 | timesteps = [timestep_transform(t, model_args, num_timesteps=self.num_timesteps) for t in timesteps] 214 | 215 | if mask is not None: 216 | noise_added = torch.zeros_like(mask, dtype=torch.bool) 217 | noise_added = noise_added | (mask == 1) 218 | 219 | progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x) 220 | 221 | dtype = model.x_embedder.proj.weight.dtype 222 | all_timesteps = [int(t.to(dtype).item()) for t in timesteps] 223 | for i, t in progress_wrap(list(enumerate(timesteps))): 224 | # mask for adding noise 225 | if mask is not None: 226 | mask_t = mask * self.num_timesteps 227 | x0 = z.clone() 228 | x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t) 229 | 230 | mask_t_upper = mask_t >= t.unsqueeze(1) 231 | model_args["x_mask"] = mask_t_upper.repeat(2, 1) 232 | mask_add_noise = mask_t_upper & ~noise_added 233 | 234 | z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0) 235 | noise_added = mask_t_upper 236 | 237 | # classifier-free guidance 238 | z_in = torch.cat([z, z], 0) 239 | t = torch.cat([t, t], 0) 240 | 241 | # pred = model(z_in, t, **model_args).chunk(2, dim=1)[0] 242 | output = model(z_in, t, all_timesteps, **model_args) 243 | 244 | pred = output.chunk(2, dim=1)[0] 245 | pred_cond, pred_uncond = pred.chunk(2, dim=0) 246 | v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) 247 | 248 | # update z 249 | dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i] 250 | dt = dt / self.num_timesteps 251 | z = z + v_pred * dt[:, None, None, None, None] 252 | 253 | if mask is not None: 254 | z = torch.where(mask_t_upper[:, None, :, None, None], z, x0) 255 | 256 | return z 257 | 258 | def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None): 259 | return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t) 260 | -------------------------------------------------------------------------------- /videosys/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/utils/__init__.py -------------------------------------------------------------------------------- /videosys/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed as dist 4 | from rich.logging import RichHandler 5 | 6 | 7 | def create_logger(): 8 | """ 9 | Create a logger that writes to a log file and stdout. 10 | """ 11 | logger = logging.getLogger(__name__) 12 | return logger 13 | 14 | 15 | def init_dist_logger(): 16 | """ 17 | Update the logger to write to a log file. 18 | """ 19 | global logger 20 | if dist.get_rank() == 0: 21 | logger = logging.getLogger(__name__) 22 | handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True) 23 | formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s") 24 | handler.setFormatter(formatter) 25 | logger.addHandler(handler) 26 | logger.setLevel(logging.INFO) 27 | else: # dummy logger (does nothing) 28 | logger = logging.getLogger(__name__) 29 | logger.addHandler(logging.NullHandler()) 30 | 31 | 32 | logger = create_logger() 33 | -------------------------------------------------------------------------------- /videosys/utils/test.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | 5 | 6 | def empty_cache(func): 7 | @functools.wraps(func) 8 | def wrapper(*args, **kwargs): 9 | torch.cuda.empty_cache() 10 | return func(*args, **kwargs) 11 | 12 | return wrapper 13 | -------------------------------------------------------------------------------- /videosys/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import imageio 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from omegaconf import DictConfig, ListConfig, OmegaConf 9 | 10 | 11 | def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: 12 | """ 13 | Set requires_grad flag for all parameters in a model. 14 | """ 15 | for p in model.parameters(): 16 | p.requires_grad = flag 17 | 18 | 19 | def set_seed(seed, dp_rank=None): 20 | if seed == -1: 21 | seed = random.randint(0, 1000000) 22 | 23 | if dp_rank is not None: 24 | seed = torch.tensor(seed, dtype=torch.int64).cuda() 25 | if dist.get_world_size() > 1: 26 | dist.broadcast(seed, 0) 27 | seed = seed + dp_rank 28 | 29 | seed = int(seed) 30 | random.seed(seed) 31 | os.environ["PYTHONHASHSEED"] = str(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | 36 | 37 | def str_to_dtype(x: str): 38 | if x == "fp32": 39 | return torch.float32 40 | elif x == "fp16": 41 | return torch.float16 42 | elif x == "bf16": 43 | return torch.bfloat16 44 | else: 45 | raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}") 46 | 47 | 48 | def batch_func(func, *args): 49 | """ 50 | Apply a function to each element of a batch. 51 | """ 52 | batch = [] 53 | for arg in args: 54 | if isinstance(arg, torch.Tensor) and arg.shape[0] == 2: 55 | batch.append(func(arg)) 56 | else: 57 | batch.append(arg) 58 | 59 | return batch 60 | 61 | 62 | def merge_args(args1, args2): 63 | """ 64 | Merge two argparse Namespace objects. 65 | """ 66 | if args2 is None: 67 | return args1 68 | 69 | for k in args2._content.keys(): 70 | if k in args1.__dict__: 71 | v = getattr(args2, k) 72 | if isinstance(v, ListConfig) or isinstance(v, DictConfig): 73 | v = OmegaConf.to_object(v) 74 | setattr(args1, k, v) 75 | else: 76 | raise RuntimeError(f"Unknown argument {k}") 77 | 78 | return args1 79 | 80 | 81 | def all_exists(paths): 82 | return all(os.path.exists(path) for path in paths) 83 | 84 | 85 | def save_video(video, output_path, fps): 86 | """ 87 | Save a video to disk. 88 | """ 89 | if dist.is_initialized() and dist.get_rank() != 0: 90 | return 91 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 92 | imageio.mimwrite(output_path, video, fps=fps) 93 | --------------------------------------------------------------------------------