├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── TeaCache4CogVideoX1.5
├── README.md
└── teacache_sample_video.py
├── TeaCache4ConsisID
├── README.md
└── teacache_sample_video.py
├── TeaCache4Cosmos
├── README.md
├── teacache_sample_video_i2v.py
└── teacache_sample_video_t2v.py
├── TeaCache4FLUX
├── README.md
└── teacache_flux.py
├── TeaCache4HiDream-I1
├── README.md
└── teacache_hidream_i1.py
├── TeaCache4HunyuanVideo
├── README.md
└── teacache_sample_video.py
├── TeaCache4LTX-Video
├── README.md
└── teacache_ltx.py
├── TeaCache4Lumina-T2X
├── README.md
└── teacache_lumina_next.py
├── TeaCache4Lumina2
├── README.md
└── teacache_lumina2.py
├── TeaCache4Mochi
├── README.md
└── teacache_mochi.py
├── TeaCache4TangoFlux
├── README.md
└── teacache_tango_flux.py
├── TeaCache4Wan2.1
├── README.md
└── teacache_generate.py
├── assets
├── TeaCache4FLUX.png
├── TeaCache4HiDream-I1.png
├── TeaCache4LuminaT2X.png
└── tisser.png
├── eval
└── teacache
│ ├── README.md
│ ├── common_metrics
│ ├── README.md
│ ├── __init__.py
│ ├── batch_eval.py
│ ├── calculate_lpips.py
│ ├── calculate_psnr.py
│ ├── calculate_ssim.py
│ └── eval.py
│ ├── experiments
│ ├── __init__.py
│ ├── cogvideox.py
│ ├── latte.py
│ ├── opensora.py
│ ├── opensora_plan.py
│ └── utils.py
│ └── vbench
│ ├── VBench_full_info.json
│ ├── cal_vbench.py
│ └── run_vbench.py
├── requirements.txt
├── setup.py
└── videosys
├── __init__.py
├── core
├── __init__.py
├── comm.py
├── engine.py
├── mp_utils.py
├── pab_mgr.py
├── parallel_mgr.py
├── pipeline.py
└── shardformer
│ ├── __init__.py
│ └── t5
│ ├── __init__.py
│ ├── modeling.py
│ └── policy.py
├── models
├── __init__.py
├── autoencoders
│ ├── __init__.py
│ ├── autoencoder_kl_cogvideox.py
│ ├── autoencoder_kl_open_sora.py
│ ├── autoencoder_kl_open_sora_plan.py
│ ├── autoencoder_kl_open_sora_plan_v110.py
│ └── autoencoder_kl_open_sora_plan_v120.py
├── modules
│ ├── __init__.py
│ ├── activations.py
│ ├── attentions.py
│ ├── downsampling.py
│ ├── embeddings.py
│ ├── normalization.py
│ └── upsampling.py
└── transformers
│ ├── __init__.py
│ ├── cogvideox_transformer_3d.py
│ ├── latte_transformer_3d.py
│ ├── open_sora_plan_transformer_3d.py
│ ├── open_sora_plan_v110_transformer_3d.py
│ ├── open_sora_plan_v120_transformer_3d.py
│ ├── open_sora_transformer_3d.py
│ └── vchitect_transformer_3d.py
├── pipelines
├── __init__.py
├── cogvideox
│ ├── __init__.py
│ └── pipeline_cogvideox.py
├── latte
│ ├── __init__.py
│ └── pipeline_latte.py
├── open_sora
│ ├── __init__.py
│ ├── data_process.py
│ └── pipeline_open_sora.py
├── open_sora_plan
│ ├── __init__.py
│ └── pipeline_open_sora_plan.py
└── vchitect
│ ├── __init__.py
│ └── pipeline_vchitect.py
├── schedulers
├── __init__.py
├── scheduling_ddim_cogvideox.py
├── scheduling_dpm_cogvideox.py
└── scheduling_rflow_open_sora.py
└── utils
├── __init__.py
├── logging.py
├── test.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | processed/
3 | profile/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 | docs/.build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # IDE
137 | .idea/
138 | .vscode/
139 |
140 | # macos
141 | *.DS_Store
142 | #data/
143 |
144 | docs/.build
145 |
146 | # pytorch checkpoint
147 | *.pt
148 |
149 | # ignore any kernel build files
150 | .o
151 | .so
152 |
153 | # ignore python interface defition file
154 | .pyi
155 |
156 | # ignore coverage test file
157 | coverage.lcov
158 | coverage.xml
159 |
160 | # ignore testmon and coverage files
161 | .coverage
162 | .testmondata*
163 |
164 | pretrained
165 | samples
166 | cache_dir
167 | test_outputs
168 | datasets
169 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | line_length = 120
3 | multi_line_output=3
4 | include_trailing_comma = true
5 | ignore_comments = true
6 | profile = black
7 | honor_noqa = true
8 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 |
3 | - repo: https://github.com/PyCQA/autoflake
4 | rev: v2.2.1
5 | hooks:
6 | - id: autoflake
7 | name: autoflake (python)
8 | args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
9 |
10 | - repo: https://github.com/pycqa/isort
11 | rev: 5.12.0
12 | hooks:
13 | - id: isort
14 | name: sort all imports (python)
15 |
16 | - repo: https://github.com/psf/black-pre-commit-mirror
17 | rev: 23.9.1
18 | hooks:
19 | - id: black
20 | name: black formatter
21 | args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
22 |
23 | - repo: https://github.com/pre-commit/mirrors-clang-format
24 | rev: v13.0.1
25 | hooks:
26 | - id: clang-format
27 | name: clang formatter
28 | types_or: [c++, c]
29 |
30 | - repo: https://github.com/pre-commit/pre-commit-hooks
31 | rev: v4.3.0
32 | hooks:
33 | - id: check-yaml
34 | - id: check-merge-conflict
35 | - id: check-case-conflict
36 | - id: trailing-whitespace
37 | - id: end-of-file-fixer
38 | - id: mixed-line-ending
39 | args: ['--fix=lf']
40 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Coding Standards
2 |
3 | ### Unit Tests
4 | We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
5 |
6 | To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
7 | ```bash
8 | pip install -r requirements/requirements-test.txt
9 | ```
10 | If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
11 |
12 | If you only want to run CPU tests, you can run
13 |
14 | ```bash
15 | pytest -m cpu tests/
16 | ```
17 |
18 | If you have 8 GPUs on your machine, you can run the full test
19 |
20 | ```bash
21 | pytest tests/
22 | ```
23 |
24 | If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
25 |
26 |
27 | ### Code Style
28 |
29 | We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
30 |
31 | ```shell
32 | # these commands are executed under the Colossal-AI directory
33 | pip install pre-commit
34 | pre-commit install
35 | ```
36 |
37 | Code format checking will be automatically executed when you commit your changes.
38 |
--------------------------------------------------------------------------------
/TeaCache4CogVideoX1.5/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4CogVideoX1.5
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [CogVideoX1.5](https://github.com/THUDM/CogVideo) 1.8x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-CogVideoX1.5 with various `rel_l1_thresh` values: 0 (original), 0.1 (1.3x speedup), 0.2 (1.8x speedup), and 0.3(2.1x speedup).Additionally, the image-to-video (i2v) results are also demonstrated, with the following speedups: 0.1 (1.5x speedup), 0.2 (2.2x speedup), and 0.3 (2.7x speedup).
5 |
6 | https://github.com/user-attachments/assets/21261b03-71c6-47bf-9769-2a81c8dc452f
7 |
8 | https://github.com/user-attachments/assets/5e98e646-4034-4ae7-9680-a65ecd88dac9
9 |
10 | ## 📈 Inference Latency Comparisons on a Single H100 GPU
11 |
12 | | CogVideoX1.5-t2v | TeaCache (0.1) | TeaCache (0.2) | TeaCache (0.3) |
13 | | :--------------: | :------------: | :------------: | :------------: |
14 | | ~465 s | ~322 s | ~260 s | ~204 s |
15 |
16 | | CogVideoX1.5-i2v | TeaCache (0.1) | TeaCache (0.2) | TeaCache (0.3) |
17 | | :--------------: | :------------: | :------------: | :------------: |
18 | | ~475 s | ~316 s | ~239 s | ~204 s |
19 |
20 | ## Installation
21 |
22 | ```shell
23 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio imageio-ffmpeg
24 | ```
25 |
26 | ## Usage
27 |
28 | You can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image_path` to customize your identity-preserving video.
29 |
30 | For T2V inference, you can use the following command:
31 |
32 | ```bash
33 | cd TeaCache4CogVideoX1.5
34 |
35 | python3 teacache_sample_video.py \
36 | --rel_l1_thresh 0.2 \
37 | --ckpts_path THUDM/CogVideoX1.5-5B \
38 | --prompt "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom. The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility." \
39 | --seed 42 \
40 | --num_inference_steps 50 \
41 | --output_path ./teacache_results
42 | ```
43 |
44 | For I2V inference, you can use the following command:
45 |
46 | ```bash
47 | cd TeaCache4CogVideoX1.5
48 |
49 | python3 teacache_sample_video.py \
50 | --rel_l1_thresh 0.1 \
51 | --ckpts_path THUDM/CogVideoX1.5-5B-I2V \
52 | --prompt "A girl gazed at the camera and smiled, her hair drifting in the wind." \
53 | --seed 42 \
54 | --num_inference_steps 50 \
55 | --output_path ./teacache_results \
56 | --image_path ./image/path \
57 | ```
58 |
59 | ## Citation
60 |
61 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
62 |
63 | ```
64 | @article{liu2024timestep,
65 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
66 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
67 | journal={arXiv preprint arXiv:2411.19108},
68 | year={2024}
69 | }
70 | ```
71 |
72 |
73 | ## Acknowledgements
74 |
75 | We would like to thank the contributors to the [CogVideoX](https://github.com/THUDM/CogVideo) and [Diffusers](https://github.com/huggingface/diffusers).
76 |
--------------------------------------------------------------------------------
/TeaCache4ConsisID/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4ConsisID
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 2.1x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-ConsisID with various `rel_l1_thresh` values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup), and 0.2 (2.7x speedup).
5 |
6 | https://github.com/user-attachments/assets/501d71ef-0e71-4ae9-bceb-51cc18fa33d8
7 |
8 | ## 📈 Inference Latency Comparisons on a Single H100 GPU
9 |
10 | | ConsisID | TeaCache (0.1) | TeaCache (0.15) | TeaCache (0.20) |
11 | | :------: | :------------: | :-------------: | :-------------: |
12 | | ~110 s | ~70 s | ~53 s | ~41 s |
13 |
14 |
15 | ## Usage
16 |
17 | Follow [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image` to customize your identity-preserving video.
18 |
19 | For single-gpu inference, you can use the following command:
20 |
21 | ```bash
22 | cd TeaCache4ConsisID
23 |
24 | python3 teacache_sample_video.py \
25 | --rel_l1_thresh 0.1 \
26 | --ckpts_path BestWishYsh/ConsisID-preview \
27 | --image "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \
28 | --prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \
29 | --seed 42 \
30 | --num_infer_steps 50 \
31 | --output_path ./teacache_results
32 | ```
33 |
34 | To generate a video with 8 GPUs, you can following [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/tools).
35 |
36 | ## Resources
37 |
38 | Learn more about ConsisID with the following resources.
39 | - A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features.
40 | - The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details.
41 |
42 | ## Citation
43 |
44 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
45 |
46 | ```
47 | @article{liu2024timestep,
48 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
49 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
50 | journal={arXiv preprint arXiv:2411.19108},
51 | year={2024}
52 | }
53 | ```
54 |
55 |
56 | ## Acknowledgements
57 |
58 | We would like to thank the contributors to the [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
59 |
--------------------------------------------------------------------------------
/TeaCache4Cosmos/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4Cosmos
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Cosmos](https://github.com/NVIDIA/Cosmos) 2.0x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Cosmos with various `rel_l1_thresh` values: 0 (original), 0.3 (1.4x speedup), and 0.4(2.0x speedup).
5 |
6 | https://github.com/user-attachments/assets/28570179-0f22-42ee-8958-88bb48d209b4
7 |
8 | https://github.com/user-attachments/assets/21341bd4-c0d5-4b5a-8b7d-cb2103897d2c
9 |
10 | ## 📈 Inference Latency Comparisons on a Single H800 GPU
11 |
12 | | Cosmos-t2v | TeaCache (0.3) | TeaCache (0.4) |
13 | | :--------: | :------------: | :------------: |
14 | | ~449 s | ~327 s | ~227 s |
15 |
16 | | Cosmos-i2v | TeaCache (0.3) | TeaCache (0.4) |
17 | | :--------: | :------------: | :------------: |
18 | | ~453 s | ~330 s | ~229 s |
19 |
20 | ## Usage
21 |
22 | Follow [Cosmos](https://github.com/NVIDIA/Cosmos) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `checkpoint_dir`, `prompt`, `input_image_or_video_path` to customize your video.
23 |
24 | You need to copy the running script to the Cosmos project folder:
25 |
26 | ```bash
27 | cd TeaCache4Cosmos
28 | cp *.py /path/to/Cosmos/
29 | ```
30 |
31 | For T2V inference, you can use the following command:
32 |
33 | ```bash
34 | cd /path/to/Cosmos/
35 |
36 | python3 teacache_sample_video_t2v.py \
37 | --rel_l1_thresh 0.3 \
38 | --checkpoint_dir checkpoints \
39 | --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \
40 | --prompt Inside the cozy ambiance of a bustling coffee house, a young woman with wavy chestnut hair and wearing a rust-colored cozy sweater stands amid the chatter and clinking of cups. She smiles warmly at the camera, her green eyes glinting with joy and subtle hints of laughter. The camera frames her elegantly, emphasizing the soft glow of the lighting on her smooth, clear skin and the detailed textures of her woolen attire. Her genuine smile is the centerpiece of the shot, showcasing her enjoyment in the quaint café setting, with steaming mugs and blurred patrons in the background. \
41 | --disable_prompt_upsampler \
42 | --offload_prompt_upsampler
43 | ```
44 |
45 | For I2V inference, you can use the following command:
46 |
47 | ```bash
48 | cd /path/to/Cosmos/
49 |
50 | python3 teacache_sample_video_i2v.py \
51 | --rel_l1_thresh 0.4 \
52 | --checkpoint_dir checkpoints \
53 | --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Video2World \
54 | --prompt "A girl gazed at the camera and smiled, her hair drifting in the wind." \
55 | --input_image_or_video_path image/path \
56 | --num_input_frames 1 \
57 | --disable_prompt_upsampler \
58 | --offload_prompt_upsampler
59 | ```
60 |
61 | ## Citation
62 |
63 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
64 |
65 | ```
66 | @article{liu2024timestep,
67 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
68 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
69 | journal={arXiv preprint arXiv:2411.19108},
70 | year={2024}
71 | }
72 | ```
73 |
74 |
75 | ## Acknowledgements
76 |
77 | We would like to thank the contributors to the [Cosmos](https://github.com/NVIDIA/Cosmos).
78 |
--------------------------------------------------------------------------------
/TeaCache4FLUX/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4FLUX
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [FLUX](https://github.com/black-forest-labs/flux) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-FLUX with various `rel_l1_thresh` values: 0 (original), 0.25 (1.5x speedup), 0.4 (1.8x speedup), 0.6 (2.0x speedup), and 0.8 (2.25x speedup).
5 |
6 | 
7 |
8 | ## 📈 Inference Latency Comparisons on a Single A800
9 |
10 |
11 | | FLUX.1 [dev] | TeaCache (0.25) | TeaCache (0.4) | TeaCache (0.6) | TeaCache (0.8) |
12 | |:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:---------------------:|
13 | | ~18 s | ~12 s | ~10 s | ~9 s | ~8 s |
14 |
15 | ## Installation
16 |
17 | ```shell
18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
19 | ```
20 |
21 | ## Usage
22 |
23 | You can modify the `rel_l1_thresh` in line 320 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
24 |
25 | ```bash
26 | python teacache_flux.py
27 | ```
28 |
29 | ## Citation
30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
31 |
32 | ```
33 | @article{liu2024timestep,
34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
36 | journal={arXiv preprint arXiv:2411.19108},
37 | year={2024}
38 | }
39 | ```
40 |
41 | ## Acknowledgements
42 |
43 | We would like to thank the contributors to the [FLUX](https://github.com/black-forest-labs/flux) and [Diffusers](https://github.com/huggingface/diffusers).
--------------------------------------------------------------------------------
/TeaCache4HiDream-I1/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4HiDream-I1
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-HiDream-I1-Full with various `rel_l1_thresh` values: 0 (original), 0.17 (1.5x speedup), 0.25 (1.7x speedup), 0.3 (2.0x speedup), and 0.45 (2.6x speedup).
5 |
6 | 
7 |
8 | ## 📈 Inference Latency Comparisons on a Single A100
9 |
10 | | HiDream-I1-Full | TeaCache (0.17) | TeaCache (0.25) | TeaCache (0.3) | TeaCache (0.45) |
11 | |:-----------------------:|:----------------------------:|:--------------------:|:---------------------:|:--------------------:|
12 | | ~50 s | ~34 s | ~29 s | ~25 s | ~19 s |
13 |
14 | ## Installation
15 |
16 | ```shell
17 | pip install git+https://github.com/huggingface/diffusers
18 | pip install --upgrade transformers protobuf tiktoken tokenizers sentencepiece
19 | ```
20 |
21 | ## Usage
22 |
23 | You can modify the `rel_l1_thresh` in line 297 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
24 |
25 | ```bash
26 | python teacache_hidream_i1.py
27 | ```
28 |
29 | ## Citation
30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
31 |
32 | ```
33 | @article{liu2024timestep,
34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
36 | journal={arXiv preprint arXiv:2411.19108},
37 | year={2024}
38 | }
39 | ```
40 |
41 | ## Acknowledgements
42 |
43 | We would like to thank the contributors to the [HiDream-I1](https://github.com/HiDream-ai/HiDream-I1) and [Diffusers](https://github.com/huggingface/diffusers).
--------------------------------------------------------------------------------
/TeaCache4HunyuanVideo/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4HunyuanVideo
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-HunyuanVideo with various `rel_l1_thresh` values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
5 |
6 | https://github.com/user-attachments/assets/34b5dab0-5b0f-48a0-968d-88af18b84803
7 |
8 |
9 | ## 📈 Inference Latency Comparisons on a Single A800 GPU
10 |
11 |
12 | | Resolution | HunyuanVideo | TeaCache (0.1) | TeaCache (0.15) |
13 | |:---------------------:|:-------------------------:|:--------------------:|:----------------------:|
14 | | 540p | ~18 min | ~11 min | ~8 min |
15 | | 720p | ~50 min | ~30 min | ~23 min |
16 |
17 |
18 | ## Usage
19 |
20 | Follow [HunyuanVideo](https://github.com/Tencent/HunyuanVideo) to clone the repo and finish the installation, then copy 'teacache_sample_video.py' in this repo to the HunyuanVideo repo. You can modify the '`rel_l1_thresh`' in line 220 to obtain your desired trade-off between latency and visul quality.
21 |
22 | For single-gpu inference, you can use the following command:
23 |
24 | ```bash
25 | cd HunyuanVideo
26 |
27 | python3 teacache_sample_video.py \
28 | --video-size 720 1280 \
29 | --video-length 129 \
30 | --infer-steps 50 \
31 | --prompt "A cat walks on the grass, realistic style." \
32 | --flow-reverse \
33 | --use-cpu-offload \
34 | --save-path ./teacache_results
35 | ```
36 |
37 | To generate a video with 8 GPUs, you can use the following command:
38 |
39 | ```bash
40 | cd HunyuanVideo
41 |
42 | torchrun --nproc_per_node=8 teacache_sample_video.py \
43 | --video-size 1280 720 \
44 | --video-length 129 \
45 | --infer-steps 50 \
46 | --prompt "A cat walks on the grass, realistic style." \
47 | --flow-reverse \
48 | --seed 42 \
49 | --ulysses-degree 8 \
50 | --ring-degree 1 \
51 | --save-path ./teacache_results
52 | ```
53 |
54 | For FP8 inference, you must explicitly specify the FP8 weight path. For example, to generate a video with fp8 weights, you can use the following command:
55 |
56 | ```bash
57 | cd HunyuanVideo
58 |
59 | DIT_CKPT_PATH={PATH_TO_FP8_WEIGHTS}/{WEIGHT_NAME}_fp8.pt
60 |
61 | python3 teacache_sample_video.py \
62 | --dit-weight ${DIT_CKPT_PATH} \
63 | --video-size 1280 720 \
64 | --video-length 129 \
65 | --infer-steps 50 \
66 | --prompt "A cat walks on the grass, realistic style." \
67 | --seed 42 \
68 | --embedded-cfg-scale 6.0 \
69 | --flow-shift 7.0 \
70 | --flow-reverse \
71 | --use-cpu-offload \
72 | --use-fp8 \
73 | --save-path ./teacache_fp8_results
74 | ```
75 |
76 | ## Citation
77 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
78 |
79 | ```
80 | @article{liu2024timestep,
81 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
82 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
83 | journal={arXiv preprint arXiv:2411.19108},
84 | year={2024}
85 | }
86 | ```
87 |
88 |
89 | ## Acknowledgements
90 |
91 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo).
92 |
--------------------------------------------------------------------------------
/TeaCache4HunyuanVideo/teacache_sample_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from pathlib import Path
4 | from loguru import logger
5 | from datetime import datetime
6 |
7 | from hyvideo.utils.file_utils import save_videos_grid
8 | from hyvideo.config import parse_args
9 | from hyvideo.inference import HunyuanVideoSampler
10 |
11 | from hyvideo.modules.modulate_layers import modulate
12 | from hyvideo.modules.attenion import attention, parallel_attention, get_cu_seqlens
13 | from typing import Any, List, Tuple, Optional, Union, Dict
14 | import torch
15 | import json
16 | import numpy as np
17 |
18 |
19 |
20 |
21 |
22 | def teacache_forward(
23 | self,
24 | x: torch.Tensor,
25 | t: torch.Tensor, # Should be in range(0, 1000).
26 | text_states: torch.Tensor = None,
27 | text_mask: torch.Tensor = None, # Now we don't use it.
28 | text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
29 | freqs_cos: Optional[torch.Tensor] = None,
30 | freqs_sin: Optional[torch.Tensor] = None,
31 | guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
32 | return_dict: bool = True,
33 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
34 | out = {}
35 | img = x
36 | txt = text_states
37 | _, _, ot, oh, ow = x.shape
38 | tt, th, tw = (
39 | ot // self.patch_size[0],
40 | oh // self.patch_size[1],
41 | ow // self.patch_size[2],
42 | )
43 |
44 | # Prepare modulation vectors.
45 | vec = self.time_in(t)
46 |
47 | # text modulation
48 | vec = vec + self.vector_in(text_states_2)
49 |
50 | # guidance modulation
51 | if self.guidance_embed:
52 | if guidance is None:
53 | raise ValueError(
54 | "Didn't get guidance strength for guidance distilled model."
55 | )
56 |
57 | # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
58 | vec = vec + self.guidance_in(guidance)
59 |
60 | # Embed image and text.
61 | img = self.img_in(img)
62 | if self.text_projection == "linear":
63 | txt = self.txt_in(txt)
64 | elif self.text_projection == "single_refiner":
65 | txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
66 | else:
67 | raise NotImplementedError(
68 | f"Unsupported text_projection: {self.text_projection}"
69 | )
70 |
71 | txt_seq_len = txt.shape[1]
72 | img_seq_len = img.shape[1]
73 |
74 | # Compute cu_squlens and max_seqlen for flash attention
75 | cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
76 | cu_seqlens_kv = cu_seqlens_q
77 | max_seqlen_q = img_seq_len + txt_seq_len
78 | max_seqlen_kv = max_seqlen_q
79 |
80 | freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
81 |
82 | if self.enable_teacache:
83 | inp = img.clone()
84 | vec_ = vec.clone()
85 | txt_ = txt.clone()
86 | (
87 | img_mod1_shift,
88 | img_mod1_scale,
89 | img_mod1_gate,
90 | img_mod2_shift,
91 | img_mod2_scale,
92 | img_mod2_gate,
93 | ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
94 | normed_inp = self.double_blocks[0].img_norm1(inp)
95 | modulated_inp = modulate(
96 | normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
97 | )
98 | if self.cnt == 0 or self.cnt == self.num_steps-1:
99 | should_calc = True
100 | self.accumulated_rel_l1_distance = 0
101 | else:
102 | coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
103 | rescale_func = np.poly1d(coefficients)
104 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
105 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
106 | should_calc = False
107 | else:
108 | should_calc = True
109 | self.accumulated_rel_l1_distance = 0
110 | self.previous_modulated_input = modulated_inp
111 | self.cnt += 1
112 | if self.cnt == self.num_steps:
113 | self.cnt = 0
114 |
115 | if self.enable_teacache:
116 | if not should_calc:
117 | img += self.previous_residual
118 | else:
119 | ori_img = img.clone()
120 | # --------------------- Pass through DiT blocks ------------------------
121 | for _, block in enumerate(self.double_blocks):
122 | double_block_args = [
123 | img,
124 | txt,
125 | vec,
126 | cu_seqlens_q,
127 | cu_seqlens_kv,
128 | max_seqlen_q,
129 | max_seqlen_kv,
130 | freqs_cis,
131 | ]
132 |
133 | img, txt = block(*double_block_args)
134 |
135 | # Merge txt and img to pass through single stream blocks.
136 | x = torch.cat((img, txt), 1)
137 | if len(self.single_blocks) > 0:
138 | for _, block in enumerate(self.single_blocks):
139 | single_block_args = [
140 | x,
141 | vec,
142 | txt_seq_len,
143 | cu_seqlens_q,
144 | cu_seqlens_kv,
145 | max_seqlen_q,
146 | max_seqlen_kv,
147 | (freqs_cos, freqs_sin),
148 | ]
149 |
150 | x = block(*single_block_args)
151 |
152 | img = x[:, :img_seq_len, ...]
153 | self.previous_residual = img - ori_img
154 | else:
155 | # --------------------- Pass through DiT blocks ------------------------
156 | for _, block in enumerate(self.double_blocks):
157 | double_block_args = [
158 | img,
159 | txt,
160 | vec,
161 | cu_seqlens_q,
162 | cu_seqlens_kv,
163 | max_seqlen_q,
164 | max_seqlen_kv,
165 | freqs_cis,
166 | ]
167 |
168 | img, txt = block(*double_block_args)
169 |
170 | # Merge txt and img to pass through single stream blocks.
171 | x = torch.cat((img, txt), 1)
172 | if len(self.single_blocks) > 0:
173 | for _, block in enumerate(self.single_blocks):
174 | single_block_args = [
175 | x,
176 | vec,
177 | txt_seq_len,
178 | cu_seqlens_q,
179 | cu_seqlens_kv,
180 | max_seqlen_q,
181 | max_seqlen_kv,
182 | (freqs_cos, freqs_sin),
183 | ]
184 |
185 | x = block(*single_block_args)
186 |
187 | img = x[:, :img_seq_len, ...]
188 |
189 | # ---------------------------- Final layer ------------------------------
190 | img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
191 |
192 | img = self.unpatchify(img, tt, th, tw)
193 | if return_dict:
194 | out["x"] = img
195 | return out
196 | return img
197 |
198 |
199 | def main():
200 | args = parse_args()
201 | print(args)
202 | models_root_path = Path(args.model_base)
203 | if not models_root_path.exists():
204 | raise ValueError(f"`models_root` not exists: {models_root_path}")
205 |
206 | # Create save folder to save the samples
207 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
208 | if not os.path.exists(args.save_path):
209 | os.makedirs(save_path, exist_ok=True)
210 |
211 | # Load models
212 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
213 |
214 | # Get the updated args
215 | args = hunyuan_video_sampler.args
216 |
217 |
218 | # TeaCache
219 | hunyuan_video_sampler.pipeline.transformer.__class__.enable_teacache = True
220 | hunyuan_video_sampler.pipeline.transformer.__class__.cnt = 0
221 | hunyuan_video_sampler.pipeline.transformer.__class__.num_steps = args.infer_steps
222 | hunyuan_video_sampler.pipeline.transformer.__class__.rel_l1_thresh = 0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
223 | hunyuan_video_sampler.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
224 | hunyuan_video_sampler.pipeline.transformer.__class__.previous_modulated_input = None
225 | hunyuan_video_sampler.pipeline.transformer.__class__.previous_residual = None
226 | hunyuan_video_sampler.pipeline.transformer.__class__.forward = teacache_forward
227 |
228 | # Start sampling
229 | # TODO: batch inference check
230 | outputs = hunyuan_video_sampler.predict(
231 | prompt=args.prompt,
232 | height=args.video_size[0],
233 | width=args.video_size[1],
234 | video_length=args.video_length,
235 | seed=args.seed,
236 | negative_prompt=args.neg_prompt,
237 | infer_steps=args.infer_steps,
238 | guidance_scale=args.cfg_scale,
239 | num_videos_per_prompt=args.num_videos,
240 | flow_shift=args.flow_shift,
241 | batch_size=args.batch_size,
242 | embedded_guidance_scale=args.embedded_cfg_scale
243 | )
244 | samples = outputs['samples']
245 |
246 | # Save samples
247 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
248 | for i, sample in enumerate(samples):
249 | sample = samples[i].unsqueeze(0)
250 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
251 | save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
252 | save_videos_grid(sample, save_path, fps=24)
253 | logger.info(f'Sample save to: {save_path}')
254 |
255 | if __name__ == "__main__":
256 | main()
257 |
--------------------------------------------------------------------------------
/TeaCache4LTX-Video/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4LTX-Video
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [LTX-Video](https://github.com/Lightricks/LTX-Video) 2x without much visual quality degradation, in a training-free manner. The following video presents the videos generated by TeaCache-LTX-Video with various rel_l1_thresh values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
5 |
6 | https://github.com/user-attachments/assets/1f4cf26c-b8c6-45e3-b402-840bcd6ba00e
7 |
8 | ## 📈 Inference Latency Comparisons on a Single A800
9 |
10 |
11 | | LTX-Video-0.9.1 | TeaCache (0.03) | TeaCache (0.05) |
12 | |:--------------------------:|:----------------------------:|:---------------------:|
13 | | ~32 s | ~20 s | ~16 s |
14 |
15 | ## Installation
16 |
17 | ```shell
18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
19 | ```
20 |
21 | ## Usage
22 |
23 | You can modify the thresh in line 187 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
24 |
25 | ```bash
26 | python teacache_ltx.py
27 | ```
28 |
29 | ## Citation
30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
31 |
32 | ```
33 | @article{liu2024timestep,
34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
36 | journal={arXiv preprint arXiv:2411.19108},
37 | year={2024}
38 | }
39 | ```
40 |
41 | ## Acknowledgements
42 |
43 | We would like to thank the contributors to the [LTX-Video](https://github.com/Lightricks/LTX-Video) and [Diffusers](https://github.com/huggingface/diffusers).
--------------------------------------------------------------------------------
/TeaCache4LTX-Video/teacache_ltx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import LTXPipeline
3 | from diffusers.models.transformers import LTXVideoTransformer3DModel
4 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
5 | from diffusers.utils import export_to_video
6 | from typing import Any, Dict, Optional, Tuple
7 | import numpy as np
8 |
9 |
10 | def teacache_forward(
11 | self,
12 | hidden_states: torch.Tensor,
13 | encoder_hidden_states: torch.Tensor,
14 | timestep: torch.LongTensor,
15 | encoder_attention_mask: torch.Tensor,
16 | num_frames: int,
17 | height: int,
18 | width: int,
19 | rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
20 | attention_kwargs: Optional[Dict[str, Any]] = None,
21 | return_dict: bool = True,
22 | ) -> torch.Tensor:
23 | if attention_kwargs is not None:
24 | attention_kwargs = attention_kwargs.copy()
25 | lora_scale = attention_kwargs.pop("scale", 1.0)
26 | else:
27 | lora_scale = 1.0
28 |
29 | if USE_PEFT_BACKEND:
30 | # weight the lora layers by setting `lora_scale` for each PEFT layer
31 | scale_lora_layers(self, lora_scale)
32 | else:
33 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
34 | logger.warning(
35 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
36 | )
37 |
38 | image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
39 |
40 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
41 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
42 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
43 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
44 |
45 | batch_size = hidden_states.size(0)
46 | hidden_states = self.proj_in(hidden_states)
47 |
48 | temb, embedded_timestep = self.time_embed(
49 | timestep.flatten(),
50 | batch_size=batch_size,
51 | hidden_dtype=hidden_states.dtype,
52 | )
53 |
54 | temb = temb.view(batch_size, -1, temb.size(-1))
55 | embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
56 |
57 | encoder_hidden_states = self.caption_projection(encoder_hidden_states)
58 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
59 |
60 | if self.enable_teacache:
61 | inp = hidden_states.clone()
62 | temb_ = temb.clone()
63 | inp = self.transformer_blocks[0].norm1(inp)
64 | num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0]
65 | ada_values = self.transformer_blocks[0].scale_shift_table[None, None] + temb_.reshape(batch_size, temb_.size(1), num_ada_params, -1)
66 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
67 | modulated_inp = inp * (1 + scale_msa) + shift_msa
68 | if self.cnt == 0 or self.cnt == self.num_steps-1:
69 | should_calc = True
70 | self.accumulated_rel_l1_distance = 0
71 | else:
72 | coefficients = [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03]
73 | rescale_func = np.poly1d(coefficients)
74 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
75 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
76 | should_calc = False
77 | else:
78 | should_calc = True
79 | self.accumulated_rel_l1_distance = 0
80 | self.previous_modulated_input = modulated_inp
81 | self.cnt += 1
82 | if self.cnt == self.num_steps:
83 | self.cnt = 0
84 |
85 | if self.enable_teacache:
86 | if not should_calc:
87 | hidden_states += self.previous_residual
88 | else:
89 | ori_hidden_states = hidden_states.clone()
90 | for block in self.transformer_blocks:
91 | if torch.is_grad_enabled() and self.gradient_checkpointing:
92 |
93 | def create_custom_forward(module, return_dict=None):
94 | def custom_forward(*inputs):
95 | if return_dict is not None:
96 | return module(*inputs, return_dict=return_dict)
97 | else:
98 | return module(*inputs)
99 |
100 | return custom_forward
101 |
102 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
103 | hidden_states = torch.utils.checkpoint.checkpoint(
104 | create_custom_forward(block),
105 | hidden_states,
106 | encoder_hidden_states,
107 | temb,
108 | image_rotary_emb,
109 | encoder_attention_mask,
110 | **ckpt_kwargs,
111 | )
112 | else:
113 | hidden_states = block(
114 | hidden_states=hidden_states,
115 | encoder_hidden_states=encoder_hidden_states,
116 | temb=temb,
117 | image_rotary_emb=image_rotary_emb,
118 | encoder_attention_mask=encoder_attention_mask,
119 | )
120 |
121 | scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
122 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
123 |
124 | hidden_states = self.norm_out(hidden_states)
125 | hidden_states = hidden_states * (1 + scale) + shift
126 | self.previous_residual = hidden_states - ori_hidden_states
127 | else:
128 | for block in self.transformer_blocks:
129 | if torch.is_grad_enabled() and self.gradient_checkpointing:
130 |
131 | def create_custom_forward(module, return_dict=None):
132 | def custom_forward(*inputs):
133 | if return_dict is not None:
134 | return module(*inputs, return_dict=return_dict)
135 | else:
136 | return module(*inputs)
137 |
138 | return custom_forward
139 |
140 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
141 | hidden_states = torch.utils.checkpoint.checkpoint(
142 | create_custom_forward(block),
143 | hidden_states,
144 | encoder_hidden_states,
145 | temb,
146 | image_rotary_emb,
147 | encoder_attention_mask,
148 | **ckpt_kwargs,
149 | )
150 | else:
151 | hidden_states = block(
152 | hidden_states=hidden_states,
153 | encoder_hidden_states=encoder_hidden_states,
154 | temb=temb,
155 | image_rotary_emb=image_rotary_emb,
156 | encoder_attention_mask=encoder_attention_mask,
157 | )
158 |
159 | scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
160 | shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
161 |
162 | hidden_states = self.norm_out(hidden_states)
163 | hidden_states = hidden_states * (1 + scale) + shift
164 |
165 |
166 | output = self.proj_out(hidden_states)
167 |
168 | if USE_PEFT_BACKEND:
169 | # remove `lora_scale` from each PEFT layer
170 | unscale_lora_layers(self, lora_scale)
171 |
172 | if not return_dict:
173 | return (output,)
174 | return Transformer2DModelOutput(sample=output)
175 |
176 | LTXVideoTransformer3DModel.forward = teacache_forward
177 | prompt = "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility."
178 | negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
179 | seed = 42
180 | num_inference_steps = 50
181 | pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
182 |
183 | # TeaCache
184 | pipe.transformer.__class__.enable_teacache = True
185 | pipe.transformer.__class__.cnt = 0
186 | pipe.transformer.__class__.num_steps = num_inference_steps
187 | pipe.transformer.__class__.rel_l1_thresh = 0.05 # 0.03 for 1.6x speedup, 0.05 for 2.1x speedup
188 | pipe.transformer.__class__.accumulated_rel_l1_distance = 0
189 | pipe.transformer.__class__.previous_modulated_input = None
190 | pipe.transformer.__class__.previous_residual = None
191 |
192 | pipe.to("cuda")
193 | video = pipe(
194 | prompt=prompt,
195 | negative_prompt=negative_prompt,
196 | width=768,
197 | height=512,
198 | num_frames=161,
199 | decode_timestep=0.03,
200 | decode_noise_scale=0.025,
201 | num_inference_steps=num_inference_steps,
202 | generator=torch.Generator("cuda").manual_seed(seed)
203 | ).frames[0]
204 | export_to_video(video, "teacache_ltx_{}.mp4".format(prompt[:50]), fps=24)
--------------------------------------------------------------------------------
/TeaCache4Lumina-T2X/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4LuminaT2X
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) 2x without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Next with various rel_l1_thresh values: 0 (original), 0.2 (1.5x speedup), 0.3 (1.9x speedup), 0.4 (2.4x speedup), and 0.5 (2.8x speedup).
5 |
6 | 
7 |
8 | ## 📈 Inference Latency Comparisons on a Single A800
9 |
10 |
11 | | Lumina-Next-SFT | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) |
12 | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:|
13 | | ~17 s | ~11 s | ~9 s | ~7 s | ~6 s |
14 |
15 | ## Installation
16 |
17 | ```shell
18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
19 | pip install flash-attn --no-build-isolation
20 | ```
21 |
22 | ## Usage
23 |
24 | You can modify the thresh in line 113 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
25 |
26 | ```bash
27 | python teacache_lumina_next.py
28 | ```
29 |
30 | ## Citation
31 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
32 |
33 | ```
34 | @article{liu2024timestep,
35 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
36 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
37 | journal={arXiv preprint arXiv:2411.19108},
38 | year={2024}
39 | }
40 | ```
41 |
42 | ## Acknowledgements
43 |
44 | We would like to thank the contributors to the [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X) and [Diffusers](https://github.com/huggingface/diffusers).
--------------------------------------------------------------------------------
/TeaCache4Lumina-T2X/teacache_lumina_next.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Optional, Tuple, Union
3 | from diffusers import LuminaText2ImgPipeline
4 | from diffusers.models import LuminaNextDiT2DModel
5 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
6 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
7 | import numpy as np
8 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
9 |
10 | def teacache_forward(
11 | self,
12 | hidden_states: torch.Tensor,
13 | timestep: torch.Tensor,
14 | encoder_hidden_states: torch.Tensor,
15 | encoder_mask: torch.Tensor,
16 | image_rotary_emb: torch.Tensor,
17 | cross_attention_kwargs: Dict[str, Any] = None,
18 | return_dict=True,
19 | ) -> torch.Tensor:
20 | """
21 | Forward pass of LuminaNextDiT.
22 |
23 | Parameters:
24 | hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
25 | timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
26 | encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
27 | encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
28 | """
29 | hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
30 | image_rotary_emb = image_rotary_emb.to(hidden_states.device)
31 |
32 | temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
33 |
34 | encoder_mask = encoder_mask.bool()
35 | if self.enable_teacache:
36 | inp = hidden_states.clone()
37 | temb_ = temb.clone()
38 | modulated_inp, gate_msa, scale_mlp, gate_mlp = self.layers[0].norm1(inp, temb_)
39 | if self.cnt == 0 or self.cnt == self.num_steps-1:
40 | should_calc = True
41 | self.accumulated_rel_l1_distance = 0
42 | else:
43 | coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344]
44 | rescale_func = np.poly1d(coefficients)
45 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
46 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
47 | should_calc = False
48 | else:
49 | should_calc = True
50 | self.accumulated_rel_l1_distance = 0
51 | self.previous_modulated_input = modulated_inp
52 | self.cnt += 1
53 | if self.cnt == self.num_steps:
54 | self.cnt = 0
55 |
56 | if self.enable_teacache:
57 | if not should_calc:
58 | hidden_states += self.previous_residual
59 | else:
60 | ori_hidden_states = hidden_states.clone()
61 | for layer in self.layers:
62 | hidden_states = layer(
63 | hidden_states,
64 | mask,
65 | image_rotary_emb,
66 | encoder_hidden_states,
67 | encoder_mask,
68 | temb=temb,
69 | cross_attention_kwargs=cross_attention_kwargs,
70 | )
71 | self.previous_residual = hidden_states - ori_hidden_states
72 |
73 | else:
74 | for layer in self.layers:
75 | hidden_states = layer(
76 | hidden_states,
77 | mask,
78 | image_rotary_emb,
79 | encoder_hidden_states,
80 | encoder_mask,
81 | temb=temb,
82 | cross_attention_kwargs=cross_attention_kwargs,
83 | )
84 |
85 | hidden_states = self.norm_out(hidden_states, temb)
86 |
87 | # unpatchify
88 | height_tokens = width_tokens = self.patch_size
89 | height, width = img_size[0]
90 | batch_size = hidden_states.size(0)
91 | sequence_length = (height // height_tokens) * (width // width_tokens)
92 | hidden_states = hidden_states[:, :sequence_length].view(
93 | batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
94 | )
95 | output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
96 |
97 | if not return_dict:
98 | return (output,)
99 |
100 | return Transformer2DModelOutput(sample=output)
101 |
102 |
103 | LuminaNextDiT2DModel.forward = teacache_forward
104 | num_inference_steps = 30
105 | seed = 1024
106 | prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. "
107 | pipeline = LuminaText2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16).to("cuda")
108 |
109 | # TeaCache
110 | pipeline.transformer.__class__.enable_teacache = True
111 | pipeline.transformer.__class__.cnt = 0
112 | pipeline.transformer.__class__.num_steps = num_inference_steps
113 | pipeline.transformer.__class__.rel_l1_thresh = 0.3 # 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
114 | pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
115 | pipeline.transformer.__class__.previous_modulated_input = None
116 | pipeline.transformer.__class__.previous_residual = None
117 |
118 | image = pipeline(
119 | prompt=prompt,
120 | num_inference_steps=num_inference_steps,
121 | generator=torch.Generator("cpu").manual_seed(seed)
122 | ).images[0]
123 | image.save("teacache_lumina_{}.png".format(prompt))
--------------------------------------------------------------------------------
/TeaCache4Lumina2/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4Lumina2
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) without much visual quality degradation, in a training-free manner. The following image shows the results generated by TeaCache-Lumina-Image-2.0 with various rel_l1_thresh values: 0 (original), 0.2 (1.25x speedup), 0.3 (1.5625x speedup), 0.4 (2.0833x speedup), 0.5 (2.5x speedup).
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## 📈 Inference Latency Comparisons on a single 4090 (step 50)
15 |
16 |
17 | | Lumina-Image-2.0 | TeaCache (0.2) | TeaCache (0.3) | TeaCache (0.4) | TeaCache (0.5) |
18 | |:-------------------------:|:---------------------------:|:--------------------:|:---------------------:|:---------------------:|
19 | | ~25 s | ~20 s | ~16 s | ~12 s | ~10 s |
20 |
21 | ## Installation
22 |
23 | ```shell
24 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece
25 | pip install flash-attn --no-build-isolation
26 | ```
27 |
28 | ## Usage
29 |
30 | You can modify the thresh in line 154 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
31 |
32 | ```bash
33 | python teacache_lumina2.py
34 | ```
35 |
36 | ## Citation
37 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
38 |
39 | ```
40 | @article{liu2024timestep,
41 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
42 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
43 | journal={arXiv preprint arXiv:2411.19108},
44 | year={2024}
45 | }
46 | ```
47 |
48 | ## Acknowledgements
49 |
50 | We would like to thank the contributors to the [Lumina-Image-2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) and [Diffusers](https://github.com/huggingface/diffusers).
51 |
--------------------------------------------------------------------------------
/TeaCache4Lumina2/teacache_lumina2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from typing import Any, Dict, Optional, Tuple, Union, List
5 |
6 | from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
8 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
9 |
10 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11 |
12 | def teacache_forward_working(
13 | self,
14 | hidden_states: torch.Tensor,
15 | timestep: torch.Tensor,
16 | encoder_hidden_states: torch.Tensor,
17 | encoder_attention_mask: torch.Tensor,
18 | attention_kwargs: Optional[Dict[str, Any]] = None,
19 | return_dict: bool = True,
20 | ) -> Union[torch.Tensor, Transformer2DModelOutput]:
21 | if attention_kwargs is not None:
22 | attention_kwargs = attention_kwargs.copy()
23 | lora_scale = attention_kwargs.pop("scale", 1.0)
24 | else:
25 | lora_scale = 1.0
26 | if USE_PEFT_BACKEND:
27 | scale_lora_layers(self, lora_scale)
28 |
29 | batch_size, _, height, width = hidden_states.shape
30 | temb, encoder_hidden_states_processed = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
31 | (image_patch_embeddings, context_rotary_emb, noise_rotary_emb, joint_rotary_emb,
32 | encoder_seq_lengths, seq_lengths) = self.rope_embedder(hidden_states, encoder_attention_mask)
33 | image_patch_embeddings = self.x_embedder(image_patch_embeddings)
34 | for layer in self.context_refiner:
35 | encoder_hidden_states_processed = layer(encoder_hidden_states_processed, encoder_attention_mask, context_rotary_emb)
36 | for layer in self.noise_refiner:
37 | image_patch_embeddings = layer(image_patch_embeddings, None, noise_rotary_emb, temb)
38 |
39 | max_seq_len = max(seq_lengths)
40 | input_to_main_loop = image_patch_embeddings.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
41 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
42 | input_to_main_loop[i, :enc_len] = encoder_hidden_states_processed[i, :enc_len]
43 | input_to_main_loop[i, enc_len:seq_len_val] = image_patch_embeddings[i]
44 |
45 | use_mask = len(set(seq_lengths)) > 1
46 | attention_mask_for_main_loop_arg = None
47 | if use_mask:
48 | mask = input_to_main_loop.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
49 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
50 | mask[i, :seq_len_val] = True
51 | attention_mask_for_main_loop_arg = mask
52 |
53 | should_calc = True
54 | if self.enable_teacache:
55 | cache_key = max_seq_len
56 | if cache_key not in self.cache:
57 | self.cache[cache_key] = {
58 | "accumulated_rel_l1_distance": 0.0,
59 | "previous_modulated_input": None,
60 | "previous_residual": None,
61 | }
62 |
63 | current_cache = self.cache[cache_key]
64 | modulated_inp, _, _, _ = self.layers[0].norm1(input_to_main_loop.clone(), temb.clone())
65 |
66 | if self.cnt == 0 or self.cnt == self.num_steps - 1:
67 | should_calc = True
68 | current_cache["accumulated_rel_l1_distance"] = 0.0
69 | else:
70 | if current_cache["previous_modulated_input"] is not None:
71 | coefficients = [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344] # taken from teacache_lumina_next.py
72 | rescale_func = np.poly1d(coefficients)
73 | prev_mod_input = current_cache["previous_modulated_input"]
74 | prev_mean = prev_mod_input.abs().mean()
75 |
76 | if prev_mean.item() > 1e-9:
77 | rel_l1_change = ((modulated_inp - prev_mod_input).abs().mean() / prev_mean).cpu().item()
78 | else:
79 | rel_l1_change = 0.0 if modulated_inp.abs().mean().item() < 1e-9 else float('inf')
80 |
81 | current_cache["accumulated_rel_l1_distance"] += rescale_func(rel_l1_change)
82 |
83 | if current_cache["accumulated_rel_l1_distance"] < self.rel_l1_thresh:
84 | should_calc = False
85 | else:
86 | should_calc = True
87 | current_cache["accumulated_rel_l1_distance"] = 0.0
88 | else:
89 | should_calc = True
90 | current_cache["accumulated_rel_l1_distance"] = 0.0
91 |
92 | current_cache["previous_modulated_input"] = modulated_inp.clone()
93 |
94 | if not hasattr(self, 'uncond_seq_len'):
95 | self.uncond_seq_len = cache_key
96 | if cache_key != self.uncond_seq_len:
97 | self.cnt += 1
98 | if self.cnt >= self.num_steps:
99 | self.cnt = 0
100 |
101 | if self.enable_teacache and not should_calc:
102 | processed_hidden_states = input_to_main_loop + self.cache[max_seq_len]["previous_residual"]
103 | else:
104 | ori_input = input_to_main_loop.clone()
105 | current_processing_states = input_to_main_loop
106 | for layer in self.layers:
107 | current_processing_states = layer(current_processing_states, attention_mask_for_main_loop_arg, joint_rotary_emb, temb)
108 |
109 | if self.enable_teacache:
110 | self.cache[max_seq_len]["previous_residual"] = current_processing_states - ori_input
111 | processed_hidden_states = current_processing_states
112 |
113 | output_after_norm = self.norm_out(processed_hidden_states, temb)
114 | p = self.config.patch_size
115 | final_output_list = []
116 | for i, (enc_len, seq_len_val) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
117 | image_part = output_after_norm[i][enc_len:seq_len_val]
118 | h_p, w_p = height // p, width // p
119 | reconstructed_image = image_part.view(h_p, w_p, p, p, self.out_channels) \
120 | .permute(4, 0, 2, 1, 3) \
121 | .flatten(3, 4) \
122 | .flatten(1, 2)
123 | final_output_list.append(reconstructed_image)
124 |
125 | final_output_tensor = torch.stack(final_output_list, dim=0)
126 |
127 | if USE_PEFT_BACKEND:
128 | unscale_lora_layers(self, lora_scale)
129 |
130 | return Transformer2DModelOutput(sample=final_output_tensor)
131 |
132 |
133 | Lumina2Transformer2DModel.forward = teacache_forward_working
134 |
135 | ckpt_path = "NietaAniLumina_Alpha_full_round5_ep5_s182000.pth"
136 | transformer = Lumina2Transformer2DModel.from_single_file(
137 | ckpt_path, torch_dtype=torch.bfloat16
138 | )
139 | pipeline = Lumina2Pipeline.from_pretrained(
140 | "Alpha-VLLM/Lumina-Image-2.0",
141 | transformer=transformer,
142 | torch_dtype=torch.bfloat16
143 | ).to("cuda")
144 |
145 | num_inference_steps = 30
146 | seed = 1024
147 | prompt = "a cat holding a sign that says hello"
148 | output_filename = f"teacache_lumina2_output.png"
149 |
150 | # TeaCache
151 | pipeline.transformer.__class__.enable_teacache = True
152 | pipeline.transformer.__class__.cnt = 0
153 | pipeline.transformer.__class__.num_steps = num_inference_steps
154 | pipeline.transformer.__class__.rel_l1_thresh = 0.3 # taken from teacache_lumina_next.py, 0.2 for 1.5x speedup, 0.3 for 1.9x speedup, 0.4 for 2.4x speedup, 0.5 for 2.8x speedup
155 | pipeline.transformer.__class__.cache = {}
156 | pipeline.transformer.__class__.uncond_seq_len = None
157 |
158 |
159 | pipeline.enable_model_cpu_offload()
160 | image = pipeline(
161 | prompt=prompt,
162 | num_inference_steps=num_inference_steps,
163 | generator=torch.Generator("cuda").manual_seed(seed)
164 | ).images[0]
165 |
166 | image.save(output_filename)
--------------------------------------------------------------------------------
/TeaCache4Mochi/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4Mochi
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [Mochi](https://github.com/genmoai/mochi) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Mochi with various rel_l1_thresh values: 0 (original), 0.06 (1.5x speedup), 0.09 (2.1x speedup).
5 |
6 | https://github.com/user-attachments/assets/29a81380-46b3-414f-a96b-6e3acc71b6c4
7 |
8 | ## 📈 Inference Latency Comparisons on a Single A800
9 |
10 |
11 | | mochi-1-preview | TeaCache (0.06) | TeaCache (0.09) |
12 | |:--------------------------:|:----------------------------:|:--------------------:|
13 | | ~30 min | ~20 min | ~14 min |
14 |
15 | ## Installation
16 |
17 | ```shell
18 | pip install --upgrade diffusers[torch] transformers protobuf tokenizers sentencepiece imageio
19 | ```
20 |
21 | ## Usage
22 |
23 | You can modify the thresh in line 174 to obtain your desired trade-off between latency and visul quality. For single-gpu inference, you can use the following command:
24 |
25 | ```bash
26 | python teacache_mochi.py
27 | ```
28 |
29 | ## Citation
30 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
31 |
32 | ```
33 | @article{liu2024timestep,
34 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
35 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
36 | journal={arXiv preprint arXiv:2411.19108},
37 | year={2024}
38 | }
39 | ```
40 |
41 | ## Acknowledgements
42 |
43 | We would like to thank the contributors to the [Mochi](https://github.com/genmoai/mochi) and [Diffusers](https://github.com/huggingface/diffusers).
--------------------------------------------------------------------------------
/TeaCache4Mochi/teacache_mochi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.attention import SDPBackend, sdpa_kernel
3 | from diffusers import MochiPipeline
4 | from diffusers.models.transformers import MochiTransformer3DModel
5 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
6 | from diffusers.utils import export_to_video
7 | from diffusers.video_processor import VideoProcessor
8 | from typing import Any, Dict, Optional, Tuple
9 | import numpy as np
10 |
11 |
12 | def teacache_forward(
13 | self,
14 | hidden_states: torch.Tensor,
15 | encoder_hidden_states: torch.Tensor,
16 | timestep: torch.LongTensor,
17 | encoder_attention_mask: torch.Tensor,
18 | attention_kwargs: Optional[Dict[str, Any]] = None,
19 | return_dict: bool = True,
20 | ) -> torch.Tensor:
21 | if attention_kwargs is not None:
22 | attention_kwargs = attention_kwargs.copy()
23 | lora_scale = attention_kwargs.pop("scale", 1.0)
24 | else:
25 | lora_scale = 1.0
26 |
27 | if USE_PEFT_BACKEND:
28 | # weight the lora layers by setting `lora_scale` for each PEFT layer
29 | scale_lora_layers(self, lora_scale)
30 | else:
31 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
32 | logger.warning(
33 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
34 | )
35 |
36 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
37 | p = self.config.patch_size
38 |
39 | post_patch_height = height // p
40 | post_patch_width = width // p
41 |
42 | temb, encoder_hidden_states = self.time_embed(
43 | timestep,
44 | encoder_hidden_states,
45 | encoder_attention_mask,
46 | hidden_dtype=hidden_states.dtype,
47 | )
48 |
49 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
50 | hidden_states = self.patch_embed(hidden_states)
51 | hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
52 |
53 | image_rotary_emb = self.rope(
54 | self.pos_frequencies,
55 | num_frames,
56 | post_patch_height,
57 | post_patch_width,
58 | device=hidden_states.device,
59 | dtype=torch.float32,
60 | )
61 |
62 | if self.enable_teacache:
63 | inp = hidden_states.clone()
64 | temb_ = temb.clone()
65 | modulated_inp, gate_msa, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, temb_)
66 | if self.cnt == 0 or self.cnt == self.num_steps-1:
67 | should_calc = True
68 | self.accumulated_rel_l1_distance = 0
69 | else:
70 | coefficients = [-3.51241319e+03, 8.11675948e+02, -6.09400215e+01, 2.42429681e+00, 3.05291719e-03]
71 | rescale_func = np.poly1d(coefficients)
72 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
73 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
74 | should_calc = False
75 | else:
76 | should_calc = True
77 | self.accumulated_rel_l1_distance = 0
78 | self.previous_modulated_input = modulated_inp
79 | self.cnt += 1
80 | if self.cnt == self.num_steps:
81 | self.cnt = 0
82 |
83 | if self.enable_teacache:
84 | if not should_calc:
85 | hidden_states += self.previous_residual
86 | else:
87 | ori_hidden_states = hidden_states.clone()
88 | for i, block in enumerate(self.transformer_blocks):
89 | if torch.is_grad_enabled() and self.gradient_checkpointing:
90 |
91 | def create_custom_forward(module):
92 | def custom_forward(*inputs):
93 | return module(*inputs)
94 |
95 | return custom_forward
96 |
97 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
98 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
99 | create_custom_forward(block),
100 | hidden_states,
101 | encoder_hidden_states,
102 | temb,
103 | encoder_attention_mask,
104 | image_rotary_emb,
105 | **ckpt_kwargs,
106 | )
107 | else:
108 | hidden_states, encoder_hidden_states = block(
109 | hidden_states=hidden_states,
110 | encoder_hidden_states=encoder_hidden_states,
111 | temb=temb,
112 | encoder_attention_mask=encoder_attention_mask,
113 | image_rotary_emb=image_rotary_emb,
114 | )
115 | hidden_states = self.norm_out(hidden_states, temb)
116 | self.previous_residual = hidden_states - ori_hidden_states
117 | else:
118 | for i, block in enumerate(self.transformer_blocks):
119 | if torch.is_grad_enabled() and self.gradient_checkpointing:
120 | def create_custom_forward(module):
121 | def custom_forward(*inputs):
122 | return module(*inputs)
123 |
124 | return custom_forward
125 |
126 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
127 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
128 | create_custom_forward(block),
129 | hidden_states,
130 | encoder_hidden_states,
131 | temb,
132 | encoder_attention_mask,
133 | image_rotary_emb,
134 | **ckpt_kwargs,
135 | )
136 | else:
137 | hidden_states, encoder_hidden_states = block(
138 | hidden_states=hidden_states,
139 | encoder_hidden_states=encoder_hidden_states,
140 | temb=temb,
141 | encoder_attention_mask=encoder_attention_mask,
142 | image_rotary_emb=image_rotary_emb,
143 | )
144 | hidden_states = self.norm_out(hidden_states, temb)
145 |
146 | hidden_states = self.proj_out(hidden_states)
147 |
148 | hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
149 | hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
150 | output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
151 |
152 | if USE_PEFT_BACKEND:
153 | # remove `lora_scale` from each PEFT layer
154 | unscale_lora_layers(self, lora_scale)
155 |
156 | if not return_dict:
157 | return (output,)
158 | return Transformer2DModelOutput(sample=output)
159 |
160 |
161 | MochiTransformer3DModel.forward = teacache_forward
162 | prompt = "A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. \
163 | A beige string bag sits beside the bowl, adding a rustic touch to the scene. \
164 | Additional lemons, one halved, are scattered around the base of the bowl. \
165 | The even lighting enhances the vibrant colors and creates a fresh, \
166 | inviting atmosphere."
167 | num_inference_steps = 64
168 | pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
169 |
170 | # TeaCache
171 | pipe.transformer.__class__.enable_teacache = True
172 | pipe.transformer.__class__.cnt = 0
173 | pipe.transformer.__class__.num_steps = num_inference_steps
174 | pipe.transformer.__class__.rel_l1_thresh = 0.09 # 0.06 for 1.5x speedup, 0.09 for 2.1x speedup
175 | pipe.transformer.__class__.accumulated_rel_l1_distance = 0
176 | pipe.transformer.__class__.previous_modulated_input = None
177 | pipe.transformer.__class__.previous_residual = None
178 |
179 | # Enable memory savings
180 | pipe.enable_vae_tiling()
181 | pipe.enable_model_cpu_offload()
182 |
183 | with torch.no_grad():
184 | prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
185 | pipe.encode_prompt(prompt=prompt)
186 | )
187 |
188 | with torch.autocast("cuda", torch.bfloat16):
189 | with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
190 | frames = pipe(
191 | prompt_embeds=prompt_embeds,
192 | prompt_attention_mask=prompt_attention_mask,
193 | negative_prompt_embeds=negative_prompt_embeds,
194 | negative_prompt_attention_mask=negative_prompt_attention_mask,
195 | guidance_scale=4.5,
196 | num_inference_steps=num_inference_steps,
197 | height=480,
198 | width=848,
199 | num_frames=163,
200 | generator=torch.Generator("cuda").manual_seed(0),
201 | output_type="latent",
202 | return_dict=False,
203 | )[0]
204 |
205 | video_processor = VideoProcessor(vae_scale_factor=8)
206 | has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
207 | has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
208 | if has_latents_mean and has_latents_std:
209 | latents_mean = (
210 | torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
211 | )
212 | latents_std = (
213 | torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
214 | )
215 | frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
216 | else:
217 | frames = frames / pipe.vae.config.scaling_factor
218 |
219 | with torch.no_grad():
220 | video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
221 |
222 | video = video_processor.postprocess_video(video)[0]
223 | export_to_video(video, "teacache_mochi__{}.mp4".format(prompt[:50]), fps=30)
--------------------------------------------------------------------------------
/TeaCache4TangoFlux/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4TangoFlux
3 |
4 | [TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [TangoFlux](https://github.com/declare-lab/TangoFlux) 2x without much audio quality degradation, in a training-free manner.
5 |
6 | ## 📈 Inference Latency Comparisons on a Single A800
7 |
8 |
9 | | TangoFlux | TeaCache (0.25) | TeaCache (0.4) |
10 | |:-------------------:|:----------------------------:|:--------------------:|
11 | | ~4.08 s | ~2.42 s | ~1.95 s |
12 |
13 | ## Installation
14 |
15 | ```shell
16 | pip install git+https://github.com/declare-lab/TangoFlux
17 | ```
18 |
19 | ## Usage
20 |
21 | You can modify the thresh in line 266 to obtain your desired trade-off between latency and audio quality. For single-gpu inference, you can use the following command:
22 |
23 | ```bash
24 | python teacache_tango_flux.py
25 | ```
26 |
27 | ## Citation
28 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
29 |
30 | ```
31 | @article{liu2024timestep,
32 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
33 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
34 | journal={arXiv preprint arXiv:2411.19108},
35 | year={2024}
36 | }
37 | ```
38 |
39 | ## Acknowledgements
40 |
41 | We would like to thank the contributors to the [TangoFlux](https://github.com/declare-lab/TangoFlux).
--------------------------------------------------------------------------------
/TeaCache4Wan2.1/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TeaCache4Wan2.1
3 |
4 | [TeaCache](https://github.com/ali-vilab/TeaCache) can speedup [Wan2.1](https://github.com/Wan-Video/Wan2.1) 2x without much visual quality degradation, in a training-free manner. The following video shows the results generated by TeaCache-Wan2.1 with various teacache_thresh values. The corresponding teacache_thresh values are shown in the following table.
5 |
6 | https://github.com/user-attachments/assets/5ae5d6dd-bf87-4f8f-91b8-ccc5980c56ad
7 |
8 | https://github.com/user-attachments/assets/dfd047a9-e3ca-4a73-a282-4dadda8dbd43
9 |
10 | https://github.com/user-attachments/assets/7c20bd54-96a8-4bd7-b4fa-ea4c9da81562
11 |
12 | https://github.com/user-attachments/assets/72085f45-6b78-4fae-b58f-492360a6e55e
13 | ## 📈 Inference Latency Comparisons on a Single A800
14 |
15 |
16 | | Wan2.1 t2v 1.3B | TeaCache (0.05) | TeaCache (0.07) | TeaCache (0.08) |
17 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
18 | | ~175 s | ~117 s | ~110 s | ~88 s |
19 |
20 | | Wan2.1 t2v 14B | TeaCache (0.14) | TeaCache (0.15) | TeaCache (0.2) |
21 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
22 | | ~55 min | ~38 min | ~30 min | ~27 min |
23 |
24 | | Wan2.1 i2v 480P | TeaCache (0.13) | TeaCache (0.19) | TeaCache (0.26) |
25 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
26 | | ~735 s | ~464 s | ~372 s | ~300 s |
27 |
28 | | Wan2.1 i2v 720P | TeaCache (0.18) | TeaCache (0.2) | TeaCache (0.3) |
29 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
30 | | ~29 min | ~17 min | ~15 min | ~12 min |
31 |
32 | ## Usage
33 |
34 | Follow [Wan2.1](https://github.com/Wan-Video/Wan2.1) to clone the repo and finish the installation, then copy 'teacache_generate.py' in this repo to the Wan2.1 repo.
35 |
36 | For T2V with 1.3B model, you can use the following command:
37 |
38 | ```bash
39 | python teacache_generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.08
40 | ```
41 |
42 | For T2V with 14B model, you can use the following command:
43 |
44 | ```bash
45 | python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.2
46 | ```
47 |
48 | For I2V with 480P resolution, you can use the following command:
49 |
50 | ```bash
51 | python teacache_generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.26
52 | ```
53 |
54 | For I2V with 720P resolution, you can use the following command:
55 |
56 | ```bash
57 | python teacache_generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 42 --offload_model True --t5_cpu --frame_num 61 --teacache_thresh 0.3
58 | ```
59 |
60 | ## Faster Video Generation Using the `use_ret_steps` Parameter
61 |
62 | Using Retention Steps will result in faster generation speed and better generation quality (except for t2v-1.3B).
63 |
64 | https://github.com/user-attachments/assets/f241b5f5-1044-4223-b2a4-449dc6dc1ad7
65 |
66 | https://github.com/user-attachments/assets/01db60f9-4aaf-43c4-8f1b-6e050cfa1180
67 |
68 | https://github.com/user-attachments/assets/e03621f2-1085-4571-8eca-51889f47ce18
69 |
70 | https://github.com/user-attachments/assets/d1340197-20c1-4f9e-a780-31f789af0893
71 |
72 |
73 | | use_ref_steps | Wan2.1 t2v 1.3B (thresh) | Slow (thresh) | Fast (thresh) |
74 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
75 | | False | ~97 s (0.00) | ~64 s (0.05) | ~49 s (0.08) |
76 | | True | ~97 s (0.00) | ~61 s (0.05) | ~41 s (0.10) |
77 |
78 | | use_ref_steps | Wan2.1 t2v 14B (thresh) | Slow (thresh) | Fast (thresh) |
79 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
80 | | False | ~1829 s (0.00) | ~1234 s (0.14) | ~909 s (0.20) |
81 | | True | ~1829 s (0.00) | ~915 s (0.10) | ~578 s (0.20) |
82 |
83 | | use_ref_steps | Wan2.1 i2v 480p (thresh) | Slow (thresh) | Fast (thresh) |
84 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
85 | | False | ~385 s (0.00) | ~241 s (0.13) | ~156 s (0.26) |
86 | | True | ~385 s (0.00) | ~212 s (0.20) | ~164 s (0.30) |
87 |
88 | | use_ref_steps | Wan2.1 i2v 720p (thresh) | Slow (thresh) | Fast (thresh) |
89 | |:--------------------------:|:----------------------------:|:---------------------:|:---------------------:|
90 | | False | ~903 s (0.00) | ~476 s (0.20) | ~363 s (0.30) |
91 | | True | ~903 s (0.00) | ~430 s (0.20) | ~340 s (0.30) |
92 |
93 |
94 | You can refer to the previous video generation instructions and use the `use_ret_steps` parameter to speed up the video generation process, achieving results closer to [Wan2.1](https://github.com/Wan-Video/Wan2.1). Simply add the `--use_ret_steps` parameter to the original command and adjust the `--teacache_thresh` parameter to achieve more efficient video generation. The value of the `--teacache_thresh` parameter can be referenced from the table, allowing you to choose the appropriate value based on different models and settings.
95 |
96 | ### Example Command:
97 |
98 | ```bash
99 | python teacache_generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." --base_seed 42 --offload_model True --t5_cpu --teacache_thresh 0.3 --use_ret_steps
100 | ```
101 |
102 |
103 | ## Acknowledgements
104 |
105 | We would like to thank the contributors to the [Wan2.1](https://github.com/Wan-Video/Wan2.1).
--------------------------------------------------------------------------------
/assets/TeaCache4FLUX.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4FLUX.png
--------------------------------------------------------------------------------
/assets/TeaCache4HiDream-I1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4HiDream-I1.png
--------------------------------------------------------------------------------
/assets/TeaCache4LuminaT2X.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/TeaCache4LuminaT2X.png
--------------------------------------------------------------------------------
/assets/tisser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/assets/tisser.png
--------------------------------------------------------------------------------
/eval/teacache/README.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | Prerequisites:
4 |
5 | - Python >= 3.10
6 | - PyTorch >= 1.13 (We recommend to use a >2.0 version)
7 | - CUDA >= 11.6
8 |
9 | We strongly recommend using Anaconda to create a new environment (Python >= 3.10) to run our examples:
10 |
11 | ```shell
12 | conda create -n teacache python=3.10 -y
13 | conda activate teacache
14 | ```
15 |
16 | Install TeaCache:
17 |
18 | ```shell
19 | git clone https://github.com/LiewFeng/TeaCache
20 | cd TeaCache
21 | pip install -e .
22 | ```
23 |
24 |
25 | ## Evaluation of TeaCache
26 |
27 | We first generate videos according to VBench's prompts.
28 |
29 | And then calculate Vbench, PSNR, LPIPS and SSIM based on the video generated.
30 |
31 | 1. Generate video
32 | ```
33 | cd eval/teacache
34 | python experiments/latte.py
35 | python experiments/opensora.py
36 | python experiments/open_sora_plan.py
37 | python experiments/cogvideox.py
38 | ```
39 |
40 | 2. Calculate Vbench score
41 | ```
42 | # vbench is calculated independently
43 | # get scores for all metrics
44 | python vbench/run_vbench.py --video_path aaa --save_path bbb
45 | # calculate final score
46 | python vbench/cal_vbench.py --score_dir bbb
47 | ```
48 |
49 | 3. Calculate other metrics
50 | ```
51 | # these metrics are calculated compared with original model
52 | # gt video is the video of original model
53 | # generated video is our methods's results
54 | python common_metrics/eval.py --gt_video_dir aa --generated_video_dir bb
55 | ```
56 |
57 |
58 |
59 | ## Citation
60 | If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
61 |
62 | ```
63 | @article{liu2024timestep,
64 | title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model},
65 | author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang},
66 | journal={arXiv preprint arXiv:2411.19108},
67 | year={2024}
68 | }
69 | ```
70 |
71 | ## Acknowledgements
72 | We would like to thank the contributors to the [Open-Sora](https://github.com/hpcaitech/Open-Sora), [Open-Sora-Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan), [Latte](https://github.com/Vchitect/Latte), [CogVideoX](https://github.com/THUDM/CogVideo) and [VideoSys](https://github.com/NUS-HPC-AI-Lab/VideoSys).
73 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/README.md:
--------------------------------------------------------------------------------
1 | Common metrics
2 |
3 | Include LPIPS, PSNR and SSIM.
4 |
5 | The code is adapted from [common_metrics_on_video_quality
6 | ](https://github.com/JunyaoHu/common_metrics_on_video_quality).
7 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/eval/teacache/common_metrics/__init__.py
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/batch_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import imageio
5 | import torch
6 | import torchvision.transforms.functional as F
7 | import tqdm
8 | from calculate_lpips import calculate_lpips
9 | from calculate_psnr import calculate_psnr
10 | from calculate_ssim import calculate_ssim
11 |
12 |
13 | def load_video(video_path):
14 | """
15 | Load a video from the given path and convert it to a PyTorch tensor.
16 | """
17 | # Read the video using imageio
18 | reader = imageio.get_reader(video_path, "ffmpeg")
19 |
20 | # Extract frames and convert to a list of tensors
21 | frames = []
22 | for frame in reader:
23 | # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
24 | frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
25 | frames.append(frame_tensor)
26 |
27 | # Stack the list of tensors into a single tensor with shape (T, C, H, W)
28 | video_tensor = torch.stack(frames)
29 |
30 | return video_tensor
31 |
32 |
33 | def resize_video(video, target_height, target_width):
34 | resized_frames = []
35 | for frame in video:
36 | resized_frame = F.resize(frame, [target_height, target_width])
37 | resized_frames.append(resized_frame)
38 | return torch.stack(resized_frames)
39 |
40 |
41 | def resize_gt_video(gt_video, gen_video):
42 | gen_video_shape = gen_video.shape
43 | T_gen, _, H_gen, W_gen = gen_video_shape
44 | T_eval, _, H_eval, W_eval = gt_video.shape
45 |
46 | if T_eval < T_gen:
47 | raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
48 |
49 | if H_eval < H_gen or W_eval < W_gen:
50 | # Resize the video maintaining the aspect ratio
51 | resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
52 | resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
53 | gt_video = resize_video(gt_video, resize_height, resize_width)
54 | # Recalculate the dimensions
55 | T_eval, _, H_eval, W_eval = gt_video.shape
56 |
57 | # Center crop
58 | start_h = (H_eval - H_gen) // 2
59 | start_w = (W_eval - W_gen) // 2
60 | cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
61 |
62 | return cropped_video
63 |
64 |
65 | def get_video_ids(gt_video_dirs, gen_video_dirs):
66 | video_ids = []
67 | for f in os.listdir(gt_video_dirs[0]):
68 | if f.endswith(f".mp4"):
69 | video_ids.append(f.replace(f".mp4", ""))
70 | video_ids.sort()
71 |
72 | for video_dir in gt_video_dirs + gen_video_dirs:
73 | tmp_video_ids = []
74 | for f in os.listdir(video_dir):
75 | if f.endswith(f".mp4"):
76 | tmp_video_ids.append(f.replace(f".mp4", ""))
77 | tmp_video_ids.sort()
78 | if tmp_video_ids != video_ids:
79 | raise ValueError(f"Video IDs in {video_dir} are different.")
80 | return video_ids
81 |
82 |
83 | def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
84 | gt_videos = {}
85 | generated_videos = {}
86 |
87 | for gt_video_dir in gt_video_dirs:
88 | tmp_gt_videos_tensor = []
89 | for video_id in video_ids:
90 | gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
91 | tmp_gt_videos_tensor.append(gt_video)
92 | gt_videos[gt_video_dir] = tmp_gt_videos_tensor
93 |
94 | for generated_video_dir in gen_video_dirs:
95 | tmp_generated_videos_tensor = []
96 | for video_id in video_ids:
97 | generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
98 | tmp_generated_videos_tensor.append(generated_video)
99 | generated_videos[generated_video_dir] = tmp_generated_videos_tensor
100 |
101 | return gt_videos, generated_videos
102 |
103 |
104 | def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
105 | out_str = ""
106 |
107 | for gt_video_dir in gt_video_dirs:
108 | for generated_video_dir in gen_video_dirs:
109 | if gt_video_dir == generated_video_dir:
110 | continue
111 | lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
112 | lpips_results[gt_video_dir][generated_video_dir]
113 | )
114 | psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
115 | psnr_results[gt_video_dir][generated_video_dir]
116 | )
117 | ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
118 | ssim_results[gt_video_dir][generated_video_dir]
119 | )
120 | out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
121 |
122 | return out_str
123 |
124 |
125 | def main(args):
126 | device = "cuda"
127 | gt_video_dirs = args.gt_video_dirs
128 | gen_video_dirs = args.gen_video_dirs
129 |
130 | video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
131 | print(f"Find {len(video_ids)} videos")
132 |
133 | prompt_interval = 1
134 | batch_size = 8
135 | calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
136 |
137 | lpips_results = {}
138 | psnr_results = {}
139 | ssim_results = {}
140 | for gt_video_dir in gt_video_dirs:
141 | lpips_results[gt_video_dir] = {}
142 | psnr_results[gt_video_dir] = {}
143 | ssim_results[gt_video_dir] = {}
144 | for generated_video_dir in gen_video_dirs:
145 | lpips_results[gt_video_dir][generated_video_dir] = []
146 | psnr_results[gt_video_dir][generated_video_dir] = []
147 | ssim_results[gt_video_dir][generated_video_dir] = []
148 |
149 | total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
150 |
151 | for idx in tqdm.tqdm(range(total_len)):
152 | video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
153 | gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
154 |
155 | for gt_video_dir, gt_videos_tensor in gt_videos.items():
156 | for generated_video_dir, generated_videos_tensor in generated_videos.items():
157 | if gt_video_dir == generated_video_dir:
158 | continue
159 |
160 | if not isinstance(gt_videos_tensor, torch.Tensor):
161 | for i in range(len(gt_videos_tensor)):
162 | gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
163 | gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
164 |
165 | generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
166 |
167 | if calculate_lpips_flag:
168 | result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
169 | result = result["value"].values()
170 | result = float(sum(result) / len(result))
171 | lpips_results[gt_video_dir][generated_video_dir].append(result)
172 |
173 | if calculate_psnr_flag:
174 | result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
175 | result = result["value"].values()
176 | result = float(sum(result) / len(result))
177 | psnr_results[gt_video_dir][generated_video_dir].append(result)
178 |
179 | if calculate_ssim_flag:
180 | result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
181 | result = result["value"].values()
182 | result = float(sum(result) / len(result))
183 | ssim_results[gt_video_dir][generated_video_dir].append(result)
184 |
185 | if (idx + 1) % prompt_interval == 0:
186 | out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
187 | print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
188 |
189 | out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
190 |
191 | # save
192 | with open(f"./batch_eval.txt", "w+") as f:
193 | f.write(out_str)
194 |
195 | print(f"Processed all videos. {out_str}")
196 |
197 |
198 | if __name__ == "__main__":
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument("--gt_video_dirs", type=str, nargs="+")
201 | parser.add_argument("--gen_video_dirs", type=str, nargs="+")
202 |
203 | args = parser.parse_args()
204 |
205 | main(args)
206 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/calculate_lpips.py:
--------------------------------------------------------------------------------
1 | import lpips
2 | import numpy as np
3 | import torch
4 |
5 | spatial = True # Return a spatial map of perceptual distance.
6 |
7 | # Linearly calibrated models (LPIPS)
8 | loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
9 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
10 |
11 |
12 | def trans(x):
13 | # if greyscale images add channel
14 | if x.shape[-3] == 1:
15 | x = x.repeat(1, 1, 3, 1, 1)
16 |
17 | # value range [0, 1] -> [-1, 1]
18 | x = x * 2 - 1
19 |
20 | return x
21 |
22 |
23 | def calculate_lpips(videos1, videos2, device):
24 | # image should be RGB, IMPORTANT: normalized to [-1,1]
25 |
26 | assert videos1.shape == videos2.shape
27 |
28 | # videos [batch_size, timestamps, channel, h, w]
29 |
30 | # support grayscale input, if grayscale -> channel*3
31 | # value range [0, 1] -> [-1, 1]
32 | videos1 = trans(videos1)
33 | videos2 = trans(videos2)
34 |
35 | lpips_results = []
36 |
37 | for video_num in range(videos1.shape[0]):
38 | # get a video
39 | # video [timestamps, channel, h, w]
40 | video1 = videos1[video_num]
41 | video2 = videos2[video_num]
42 |
43 | lpips_results_of_a_video = []
44 | for clip_timestamp in range(len(video1)):
45 | # get a img
46 | # img [timestamps[x], channel, h, w]
47 | # img [channel, h, w] tensor
48 |
49 | img1 = video1[clip_timestamp].unsqueeze(0).to(device)
50 | img2 = video2[clip_timestamp].unsqueeze(0).to(device)
51 |
52 | loss_fn.to(device)
53 |
54 | # calculate lpips of a video
55 | lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
56 | lpips_results.append(lpips_results_of_a_video)
57 |
58 | lpips_results = np.array(lpips_results)
59 |
60 | lpips = {}
61 | lpips_std = {}
62 |
63 | for clip_timestamp in range(len(video1)):
64 | lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
65 | lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
66 |
67 | result = {
68 | "value": lpips,
69 | "value_std": lpips_std,
70 | "video_setting": video1.shape,
71 | "video_setting_name": "time, channel, heigth, width",
72 | }
73 |
74 | return result
75 |
76 |
77 | # test code / using example
78 |
79 |
80 | def main():
81 | NUMBER_OF_VIDEOS = 8
82 | VIDEO_LENGTH = 50
83 | CHANNEL = 3
84 | SIZE = 64
85 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
86 | videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
87 | device = torch.device("cuda")
88 | # device = torch.device("cpu")
89 |
90 | import json
91 |
92 | result = calculate_lpips(videos1, videos2, device)
93 | print(json.dumps(result, indent=4))
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/calculate_psnr.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def img_psnr(img1, img2):
8 | # [0,1]
9 | # compute mse
10 | # mse = np.mean((img1-img2)**2)
11 | mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
12 | # compute psnr
13 | if mse < 1e-10:
14 | return 100
15 | psnr = 20 * math.log10(1 / math.sqrt(mse))
16 | return psnr
17 |
18 |
19 | def trans(x):
20 | return x
21 |
22 |
23 | def calculate_psnr(videos1, videos2):
24 | # videos [batch_size, timestamps, channel, h, w]
25 |
26 | assert videos1.shape == videos2.shape
27 |
28 | videos1 = trans(videos1)
29 | videos2 = trans(videos2)
30 |
31 | psnr_results = []
32 |
33 | for video_num in range(videos1.shape[0]):
34 | # get a video
35 | # video [timestamps, channel, h, w]
36 | video1 = videos1[video_num]
37 | video2 = videos2[video_num]
38 |
39 | psnr_results_of_a_video = []
40 | for clip_timestamp in range(len(video1)):
41 | # get a img
42 | # img [timestamps[x], channel, h, w]
43 | # img [channel, h, w] numpy
44 |
45 | img1 = video1[clip_timestamp].numpy()
46 | img2 = video2[clip_timestamp].numpy()
47 |
48 | # calculate psnr of a video
49 | psnr_results_of_a_video.append(img_psnr(img1, img2))
50 |
51 | psnr_results.append(psnr_results_of_a_video)
52 |
53 | psnr_results = np.array(psnr_results)
54 |
55 | psnr = {}
56 | psnr_std = {}
57 |
58 | for clip_timestamp in range(len(video1)):
59 | psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
60 | psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
61 |
62 | result = {
63 | "value": psnr,
64 | "value_std": psnr_std,
65 | "video_setting": video1.shape,
66 | "video_setting_name": "time, channel, heigth, width",
67 | }
68 |
69 | return result
70 |
71 |
72 | # test code / using example
73 |
74 |
75 | def main():
76 | NUMBER_OF_VIDEOS = 8
77 | VIDEO_LENGTH = 50
78 | CHANNEL = 3
79 | SIZE = 64
80 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
81 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
82 |
83 | import json
84 |
85 | result = calculate_psnr(videos1, videos2)
86 | print(json.dumps(result, indent=4))
87 |
88 |
89 | if __name__ == "__main__":
90 | main()
91 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/calculate_ssim.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def ssim(img1, img2):
7 | C1 = 0.01**2
8 | C2 = 0.03**2
9 | img1 = img1.astype(np.float64)
10 | img2 = img2.astype(np.float64)
11 | kernel = cv2.getGaussianKernel(11, 1.5)
12 | window = np.outer(kernel, kernel.transpose())
13 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
14 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15 | mu1_sq = mu1**2
16 | mu2_sq = mu2**2
17 | mu1_mu2 = mu1 * mu2
18 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
19 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
20 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
22 | return ssim_map.mean()
23 |
24 |
25 | def calculate_ssim_function(img1, img2):
26 | # [0,1]
27 | # ssim is the only metric extremely sensitive to gray being compared to b/w
28 | if not img1.shape == img2.shape:
29 | raise ValueError("Input images must have the same dimensions.")
30 | if img1.ndim == 2:
31 | return ssim(img1, img2)
32 | elif img1.ndim == 3:
33 | if img1.shape[0] == 3:
34 | ssims = []
35 | for i in range(3):
36 | ssims.append(ssim(img1[i], img2[i]))
37 | return np.array(ssims).mean()
38 | elif img1.shape[0] == 1:
39 | return ssim(np.squeeze(img1), np.squeeze(img2))
40 | else:
41 | raise ValueError("Wrong input image dimensions.")
42 |
43 |
44 | def trans(x):
45 | return x
46 |
47 |
48 | def calculate_ssim(videos1, videos2):
49 | # videos [batch_size, timestamps, channel, h, w]
50 |
51 | assert videos1.shape == videos2.shape
52 |
53 | videos1 = trans(videos1)
54 | videos2 = trans(videos2)
55 |
56 | ssim_results = []
57 |
58 | for video_num in range(videos1.shape[0]):
59 | # get a video
60 | # video [timestamps, channel, h, w]
61 | video1 = videos1[video_num]
62 | video2 = videos2[video_num]
63 |
64 | ssim_results_of_a_video = []
65 | for clip_timestamp in range(len(video1)):
66 | # get a img
67 | # img [timestamps[x], channel, h, w]
68 | # img [channel, h, w] numpy
69 |
70 | img1 = video1[clip_timestamp].numpy()
71 | img2 = video2[clip_timestamp].numpy()
72 |
73 | # calculate ssim of a video
74 | ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
75 |
76 | ssim_results.append(ssim_results_of_a_video)
77 |
78 | ssim_results = np.array(ssim_results)
79 |
80 | ssim = {}
81 | ssim_std = {}
82 |
83 | for clip_timestamp in range(len(video1)):
84 | ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
85 | ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
86 |
87 | result = {
88 | "value": ssim,
89 | "value_std": ssim_std,
90 | "video_setting": video1.shape,
91 | "video_setting_name": "time, channel, heigth, width",
92 | }
93 |
94 | return result
95 |
96 |
97 | # test code / using example
98 |
99 |
100 | def main():
101 | NUMBER_OF_VIDEOS = 8
102 | VIDEO_LENGTH = 50
103 | CHANNEL = 3
104 | SIZE = 64
105 | videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
106 | videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
107 | torch.device("cuda")
108 |
109 | import json
110 |
111 | result = calculate_ssim(videos1, videos2)
112 | print(json.dumps(result, indent=4))
113 |
114 |
115 | if __name__ == "__main__":
116 | main()
117 |
--------------------------------------------------------------------------------
/eval/teacache/common_metrics/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import imageio
5 | import torch
6 | import torchvision.transforms.functional as F
7 | import tqdm
8 | from calculate_lpips import calculate_lpips
9 | from calculate_psnr import calculate_psnr
10 | from calculate_ssim import calculate_ssim
11 |
12 |
13 | def load_videos(directory, video_ids, file_extension):
14 | videos = []
15 | for video_id in video_ids:
16 | video_path = os.path.join(directory, f"{video_id}.{file_extension}")
17 | if os.path.exists(video_path):
18 | video = load_video(video_path) # Define load_video based on how videos are stored
19 | videos.append(video)
20 | else:
21 | raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
22 | return videos
23 |
24 |
25 | def load_video(video_path):
26 | """
27 | Load a video from the given path and convert it to a PyTorch tensor.
28 | """
29 | # Read the video using imageio
30 | reader = imageio.get_reader(video_path, "ffmpeg")
31 |
32 | # Extract frames and convert to a list of tensors
33 | frames = []
34 | for frame in reader:
35 | # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
36 | frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
37 | frames.append(frame_tensor)
38 |
39 | # Stack the list of tensors into a single tensor with shape (T, C, H, W)
40 | video_tensor = torch.stack(frames)
41 |
42 | return video_tensor
43 |
44 |
45 | def resize_video(video, target_height, target_width):
46 | resized_frames = []
47 | for frame in video:
48 | resized_frame = F.resize(frame, [target_height, target_width])
49 | resized_frames.append(resized_frame)
50 | return torch.stack(resized_frames)
51 |
52 |
53 | def preprocess_eval_video(eval_video, generated_video_shape):
54 | T_gen, _, H_gen, W_gen = generated_video_shape
55 | T_eval, _, H_eval, W_eval = eval_video.shape
56 |
57 | if T_eval < T_gen:
58 | raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
59 |
60 | if H_eval < H_gen or W_eval < W_gen:
61 | # Resize the video maintaining the aspect ratio
62 | resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
63 | resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
64 | eval_video = resize_video(eval_video, resize_height, resize_width)
65 | # Recalculate the dimensions
66 | T_eval, _, H_eval, W_eval = eval_video.shape
67 |
68 | # Center crop
69 | start_h = (H_eval - H_gen) // 2
70 | start_w = (W_eval - W_gen) // 2
71 | cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
72 |
73 | return cropped_video
74 |
75 |
76 | def main(args):
77 | device = "cuda"
78 | gt_video_dir = args.gt_video_dir
79 | generated_video_dir = args.generated_video_dir
80 |
81 | video_ids = []
82 | file_extension = "mp4"
83 | for f in os.listdir(generated_video_dir):
84 | if f.endswith(f".{file_extension}"):
85 | video_ids.append(f.replace(f".{file_extension}", ""))
86 | if not video_ids:
87 | raise ValueError("No videos found in the generated video dataset. Exiting.")
88 |
89 | print(f"Find {len(video_ids)} videos")
90 | prompt_interval = 1
91 | batch_size = 16
92 | calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
93 |
94 | lpips_results = []
95 | psnr_results = []
96 | ssim_results = []
97 |
98 | total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
99 |
100 | for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
101 | gt_videos_tensor = []
102 | generated_videos_tensor = []
103 | for i in range(batch_size):
104 | video_idx = idx * batch_size + i
105 | if video_idx >= len(video_ids):
106 | break
107 | video_id = video_ids[video_idx]
108 | generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
109 | generated_videos_tensor.append(generated_video)
110 | eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
111 | gt_videos_tensor.append(eval_video)
112 | gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
113 | generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
114 |
115 | if calculate_lpips_flag:
116 | result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
117 | result = result["value"].values()
118 | result = sum(result) / len(result)
119 | lpips_results.append(result)
120 |
121 | if calculate_psnr_flag:
122 | result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
123 | result = result["value"].values()
124 | result = sum(result) / len(result)
125 | psnr_results.append(result)
126 |
127 | if calculate_ssim_flag:
128 | result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
129 | result = result["value"].values()
130 | result = sum(result) / len(result)
131 | ssim_results.append(result)
132 |
133 | if (idx + 1) % prompt_interval == 0:
134 | out_str = ""
135 | for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
136 | result = sum(results) / len(results)
137 | out_str += f"{name}: {result:.4f}, "
138 | print(f"Processed {idx + 1} videos. {out_str[:-2]}")
139 |
140 | out_str = ""
141 | for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
142 | result = sum(results) / len(results)
143 | out_str += f"{name}: {result:.4f}, "
144 | out_str = out_str[:-2]
145 |
146 | # save
147 | with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
148 | f.write(out_str)
149 |
150 | print(f"Processed all videos. {out_str}")
151 |
152 |
153 | if __name__ == "__main__":
154 | parser = argparse.ArgumentParser()
155 | parser.add_argument("--gt_video_dir", type=str)
156 | parser.add_argument("--generated_video_dir", type=str)
157 |
158 | args = parser.parse_args()
159 |
160 | main(args)
161 |
--------------------------------------------------------------------------------
/eval/teacache/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/eval/teacache/experiments/__init__.py
--------------------------------------------------------------------------------
/eval/teacache/experiments/cogvideox.py:
--------------------------------------------------------------------------------
1 | from utils import generate_func, read_prompt_list
2 | from videosys import CogVideoXConfig, VideoSysEngine
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 | import numpy as np
7 | from typing import Any, Dict, Optional, Tuple, Union
8 | from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
9 | from videosys.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput
10 | from videosys.utils.utils import batch_func
11 | from functools import partial
12 | from diffusers.utils import is_torch_version
13 |
14 | def teacache_forward(
15 | self,
16 | hidden_states: torch.Tensor,
17 | encoder_hidden_states: torch.Tensor,
18 | timestep: Union[int, float, torch.LongTensor],
19 | timestep_cond: Optional[torch.Tensor] = None,
20 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
21 | return_dict: bool = True,
22 | all_timesteps=None
23 | ):
24 | if self.parallel_manager.cp_size > 1:
25 | (
26 | hidden_states,
27 | encoder_hidden_states,
28 | timestep,
29 | timestep_cond,
30 | image_rotary_emb,
31 | ) = batch_func(
32 | partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
33 | hidden_states,
34 | encoder_hidden_states,
35 | timestep,
36 | timestep_cond,
37 | image_rotary_emb,
38 | )
39 |
40 | batch_size, num_frames, channels, height, width = hidden_states.shape
41 |
42 | # 1. Time embedding
43 | timesteps = timestep
44 | org_timestep = timestep
45 | t_emb = self.time_proj(timesteps)
46 |
47 | # timesteps does not contain any weights and will always return f32 tensors
48 | # but time_embedding might actually be running in fp16. so we need to cast here.
49 | # there might be better ways to encapsulate this.
50 | t_emb = t_emb.to(dtype=hidden_states.dtype)
51 | emb = self.time_embedding(t_emb, timestep_cond)
52 |
53 | # 2. Patch embedding
54 | hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
55 |
56 | # 3. Position embedding
57 | text_seq_length = encoder_hidden_states.shape[1]
58 | if not self.config.use_rotary_positional_embeddings:
59 | seq_length = height * width * num_frames // (self.config.patch_size**2)
60 |
61 | pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
62 | hidden_states = hidden_states + pos_embeds
63 | hidden_states = self.embedding_dropout(hidden_states)
64 |
65 | encoder_hidden_states = hidden_states[:, :text_seq_length]
66 | hidden_states = hidden_states[:, text_seq_length:]
67 |
68 | if self.enable_teacache:
69 | if org_timestep[0] == all_timesteps[0] or org_timestep[0] == all_timesteps[-1]:
70 | should_calc = True
71 | self.accumulated_rel_l1_distance = 0
72 | else:
73 | if not self.config.use_rotary_positional_embeddings:
74 | # CogVideoX-2B
75 | coefficients = [-3.10658903e+01, 2.54732368e+01, -5.92380459e+00, 1.75769064e+00, -3.61568434e-03]
76 | else:
77 | # CogVideoX-5B
78 | coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02]
79 | rescale_func = np.poly1d(coefficients)
80 | self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
81 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
82 | should_calc = False
83 | else:
84 | should_calc = True
85 | self.accumulated_rel_l1_distance = 0
86 | self.previous_modulated_input = emb
87 |
88 | if self.enable_teacache:
89 | if not should_calc:
90 | hidden_states += self.previous_residual
91 | encoder_hidden_states += self.previous_residual_encoder
92 | else:
93 | if self.parallel_manager.sp_size > 1:
94 | set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
95 | hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
96 | ori_hidden_states = hidden_states.clone()
97 | ori_encoder_hidden_states = encoder_hidden_states.clone()
98 | # 4. Transformer blocks
99 | for i, block in enumerate(self.transformer_blocks):
100 | if self.training and self.gradient_checkpointing:
101 |
102 | def create_custom_forward(module):
103 | def custom_forward(*inputs):
104 | return module(*inputs)
105 |
106 | return custom_forward
107 |
108 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
109 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
110 | create_custom_forward(block),
111 | hidden_states,
112 | encoder_hidden_states,
113 | emb,
114 | image_rotary_emb,
115 | **ckpt_kwargs,
116 | )
117 | else:
118 | hidden_states, encoder_hidden_states = block(
119 | hidden_states=hidden_states,
120 | encoder_hidden_states=encoder_hidden_states,
121 | temb=emb,
122 | image_rotary_emb=image_rotary_emb,
123 | timestep=timesteps if False else None,
124 | )
125 | self.previous_residual = hidden_states - ori_hidden_states
126 | self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states
127 | else:
128 | if self.parallel_manager.sp_size > 1:
129 | set_pad("pad", hidden_states.shape[1], self.parallel_manager.sp_group)
130 | hidden_states = split_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
131 |
132 | # 4. Transformer blocks
133 | for i, block in enumerate(self.transformer_blocks):
134 | if self.training and self.gradient_checkpointing:
135 |
136 | def create_custom_forward(module):
137 | def custom_forward(*inputs):
138 | return module(*inputs)
139 |
140 | return custom_forward
141 |
142 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
143 | hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
144 | create_custom_forward(block),
145 | hidden_states,
146 | encoder_hidden_states,
147 | emb,
148 | image_rotary_emb,
149 | **ckpt_kwargs,
150 | )
151 | else:
152 | hidden_states, encoder_hidden_states = block(
153 | hidden_states=hidden_states,
154 | encoder_hidden_states=encoder_hidden_states,
155 | temb=emb,
156 | image_rotary_emb=image_rotary_emb,
157 | timestep=timesteps if False else None,
158 | )
159 |
160 | if self.parallel_manager.sp_size > 1:
161 | if self.enable_teacache:
162 | if should_calc:
163 | hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
164 | self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
165 | else:
166 | hidden_states = gather_sequence(hidden_states, self.parallel_manager.sp_group, dim=1, pad=get_pad("pad"))
167 |
168 | if not self.config.use_rotary_positional_embeddings:
169 | # CogVideoX-2B
170 | hidden_states = self.norm_final(hidden_states)
171 | else:
172 | # CogVideoX-5B
173 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
174 | hidden_states = self.norm_final(hidden_states)
175 | hidden_states = hidden_states[:, text_seq_length:]
176 |
177 | # 5. Final block
178 | hidden_states = self.norm_out(hidden_states, temb=emb)
179 | hidden_states = self.proj_out(hidden_states)
180 |
181 | # 6. Unpatchify
182 | p = self.config.patch_size
183 | output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
184 | output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
185 |
186 | if self.parallel_manager.cp_size > 1:
187 | output = gather_sequence(output, self.parallel_manager.cp_group, dim=0)
188 |
189 | if not return_dict:
190 | return (output,)
191 | return Transformer2DModelOutput(sample=output)
192 |
193 |
194 | def eval_teacache_slow(prompt_list):
195 | config = CogVideoXConfig()
196 | engine = VideoSysEngine(config)
197 | engine.driver_worker.transformer.__class__.enable_teacache = True
198 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
199 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
200 | engine.driver_worker.transformer.__class__.previous_modulated_input = None
201 | engine.driver_worker.transformer.__class__.previous_residual = None
202 | engine.driver_worker.transformer.__class__.previous_residual_encoder = None
203 | engine.driver_worker.transformer.__class__.forward = teacache_forward
204 | generate_func(engine, prompt_list, "./samples/cogvideox_teacache_slow", loop=5)
205 |
206 | def eval_teacache_fast(prompt_list):
207 | config = CogVideoXConfig()
208 | engine = VideoSysEngine(config)
209 | engine.driver_worker.transformer.__class__.enable_teacache = True
210 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
211 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
212 | engine.driver_worker.transformer.__class__.previous_modulated_input = None
213 | engine.driver_worker.transformer.__class__.previous_residual = None
214 | engine.driver_worker.transformer.__class__.previous_residual_encoder = None
215 | engine.driver_worker.transformer.__class__.forward = teacache_forward
216 | generate_func(engine, prompt_list, "./samples/cogvideox_teacache_fast", loop=5)
217 |
218 |
219 | def eval_base(prompt_list):
220 | config = CogVideoXConfig()
221 | engine = VideoSysEngine(config)
222 | generate_func(engine, prompt_list, "./samples/cogvideox_base", loop=5)
223 |
224 |
225 | if __name__ == "__main__":
226 | prompt_list = read_prompt_list("vbench/VBench_full_info.json")
227 | eval_base(prompt_list)
228 | eval_teacache_slow(prompt_list)
229 | eval_teacache_fast(prompt_list)
230 |
--------------------------------------------------------------------------------
/eval/teacache/experiments/opensora.py:
--------------------------------------------------------------------------------
1 | from utils import generate_func, read_prompt_list
2 | from videosys import OpenSoraConfig, VideoSysEngine
3 | import torch
4 | from einops import rearrange
5 | from videosys.models.transformers.open_sora_transformer_3d import t2i_modulate, auto_grad_checkpoint
6 | from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
7 | import numpy as np
8 | from videosys.utils.utils import batch_func
9 | from functools import partial
10 |
11 | def teacache_forward(
12 | self, x, timestep, all_timesteps, y, mask=None, x_mask=None, fps=None, height=None, width=None, **kwargs
13 | ):
14 | # === Split batch ===
15 | if self.parallel_manager.cp_size > 1:
16 | x, timestep, y, x_mask, mask = batch_func(
17 | partial(split_sequence, process_group=self.parallel_manager.cp_group, dim=0),
18 | x,
19 | timestep,
20 | y,
21 | x_mask,
22 | mask,
23 | )
24 |
25 | dtype = self.x_embedder.proj.weight.dtype
26 | B = x.size(0)
27 | x = x.to(dtype)
28 | timestep = timestep.to(dtype)
29 | y = y.to(dtype)
30 |
31 | # === get pos embed ===
32 | _, _, Tx, Hx, Wx = x.size()
33 | T, H, W = self.get_dynamic_size(x)
34 | S = H * W
35 | base_size = round(S**0.5)
36 | resolution_sq = (height[0].item() * width[0].item()) ** 0.5
37 | scale = resolution_sq / self.input_sq_size
38 | pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
39 |
40 | # === get timestep embed ===
41 | t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
42 | fps = self.fps_embedder(fps.unsqueeze(1), B)
43 | t = t + fps
44 | t_mlp = self.t_block(t)
45 | t0 = t0_mlp = None
46 | if x_mask is not None:
47 | t0_timestep = torch.zeros_like(timestep)
48 | t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
49 | t0 = t0 + fps
50 | t0_mlp = self.t_block(t0)
51 |
52 | # === get y embed ===
53 | if self.config.skip_y_embedder:
54 | y_lens = mask
55 | if isinstance(y_lens, torch.Tensor):
56 | y_lens = y_lens.long().tolist()
57 | else:
58 | y, y_lens = self.encode_text(y, mask)
59 |
60 | # === get x embed ===
61 | x = self.x_embedder(x) # [B, N, C]
62 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
63 | x = x + pos_emb
64 |
65 | if self.enable_teacache:
66 | inp = x.clone()
67 | inp = rearrange(inp, "B T S C -> B (T S) C", T=T, S=S)
68 | B, N, C = inp.shape
69 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
70 | self.spatial_blocks[0].scale_shift_table[None] + t_mlp.reshape(B, 6, -1)
71 | ).chunk(6, dim=1)
72 | modulated_inp = t2i_modulate(self.spatial_blocks[0].norm1(inp), shift_msa, scale_msa)
73 | if timestep[0] == all_timesteps[0] or timestep[0] == all_timesteps[-1]:
74 | should_calc = True
75 | self.accumulated_rel_l1_distance = 0
76 | else:
77 | coefficients = [2.17546007e+02, -1.18329252e+02, 2.68662585e+01, -4.59364272e-02, 4.84426240e-02]
78 | rescale_func = np.poly1d(coefficients)
79 | self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
80 | if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
81 | should_calc = False
82 | else:
83 | should_calc = True
84 | self.accumulated_rel_l1_distance = 0
85 | self.previous_modulated_input = modulated_inp
86 |
87 | # === blocks ===
88 | if self.enable_teacache:
89 | if not should_calc:
90 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
91 | x += self.previous_residual
92 | else:
93 | # shard over the sequence dim if sp is enabled
94 | if self.parallel_manager.sp_size > 1:
95 | set_pad("temporal", T, self.parallel_manager.sp_group)
96 | set_pad("spatial", S, self.parallel_manager.sp_group)
97 | x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
98 | T = x.shape[1]
99 | x_mask_org = x_mask
100 | x_mask = split_sequence(
101 | x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
102 | )
103 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
104 | origin_x = x.clone().detach()
105 | for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
106 | x = auto_grad_checkpoint(
107 | spatial_block,
108 | x,
109 | y,
110 | t_mlp,
111 | y_lens,
112 | x_mask,
113 | t0_mlp,
114 | T,
115 | S,
116 | timestep,
117 | all_timesteps=all_timesteps,
118 | )
119 |
120 | x = auto_grad_checkpoint(
121 | temporal_block,
122 | x,
123 | y,
124 | t_mlp,
125 | y_lens,
126 | x_mask,
127 | t0_mlp,
128 | T,
129 | S,
130 | timestep,
131 | all_timesteps=all_timesteps,
132 | )
133 | self.previous_residual = x - origin_x
134 | else:
135 | # shard over the sequence dim if sp is enabled
136 | if self.parallel_manager.sp_size > 1:
137 | set_pad("temporal", T, self.parallel_manager.sp_group)
138 | set_pad("spatial", S, self.parallel_manager.sp_group)
139 | x = split_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
140 | T = x.shape[1]
141 | x_mask_org = x_mask
142 | x_mask = split_sequence(
143 | x_mask, self.parallel_manager.sp_group, dim=1, grad_scale="down", pad=get_pad("temporal")
144 | )
145 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
146 |
147 | for spatial_block, temporal_block in zip(self.spatial_blocks, self.temporal_blocks):
148 | x = auto_grad_checkpoint(
149 | spatial_block,
150 | x,
151 | y,
152 | t_mlp,
153 | y_lens,
154 | x_mask,
155 | t0_mlp,
156 | T,
157 | S,
158 | timestep,
159 | all_timesteps=all_timesteps,
160 | )
161 |
162 | x = auto_grad_checkpoint(
163 | temporal_block,
164 | x,
165 | y,
166 | t_mlp,
167 | y_lens,
168 | x_mask,
169 | t0_mlp,
170 | T,
171 | S,
172 | timestep,
173 | all_timesteps=all_timesteps,
174 | )
175 |
176 | if self.parallel_manager.sp_size > 1:
177 | if self.enable_teacache:
178 | if should_calc:
179 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
180 | self.previous_residual = rearrange(self.previous_residual, "B (T S) C -> B T S C", T=T, S=S)
181 | x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
182 | self.previous_residual = gather_sequence(self.previous_residual, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
183 | T, S = x.shape[1], x.shape[2]
184 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
185 | self.previous_residual = rearrange(self.previous_residual, "B T S C -> B (T S) C", T=T, S=S)
186 | x_mask = x_mask_org
187 | else:
188 | x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
189 | x = gather_sequence(x, self.parallel_manager.sp_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
190 | T, S = x.shape[1], x.shape[2]
191 | x = rearrange(x, "B T S C -> B (T S) C", T=T, S=S)
192 | x_mask = x_mask_org
193 | # === final layer ===
194 | x = self.final_layer(x, t, x_mask, t0, T, S)
195 | x = self.unpatchify(x, T, H, W, Tx, Hx, Wx)
196 |
197 | # cast to float32 for better accuracy
198 | x = x.to(torch.float32)
199 |
200 | # === Gather Output ===
201 | if self.parallel_manager.cp_size > 1:
202 | x = gather_sequence(x, self.parallel_manager.cp_group, dim=0)
203 |
204 | return x
205 |
206 | def eval_base(prompt_list):
207 | config = OpenSoraConfig()
208 | engine = VideoSysEngine(config)
209 | generate_func(engine, prompt_list, "./samples/opensora_base", loop=5)
210 |
211 | def eval_teacache_slow(prompt_list):
212 | config = OpenSoraConfig()
213 | engine = VideoSysEngine(config)
214 | engine.driver_worker.transformer.__class__.enable_teacache = True
215 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.1
216 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
217 | engine.driver_worker.transformer.__class__.previous_modulated_input = None
218 | engine.driver_worker.transformer.__class__.previous_residual = None
219 | engine.driver_worker.transformer.__class__.forward = teacache_forward
220 | generate_func(engine, prompt_list, "./samples/opensora_teacache_slow", loop=5)
221 |
222 | def eval_teacache_fast(prompt_list):
223 | config = OpenSoraConfig()
224 | engine = VideoSysEngine(config)
225 | engine.driver_worker.transformer.__class__.enable_teacache = True
226 | engine.driver_worker.transformer.__class__.rel_l1_thresh = 0.2
227 | engine.driver_worker.transformer.__class__.accumulated_rel_l1_distance = 0
228 | engine.driver_worker.transformer.__class__.previous_modulated_input = None
229 | engine.driver_worker.transformer.__class__.previous_residual = None
230 | engine.driver_worker.transformer.__class__.forward = teacache_forward
231 | generate_func(engine, prompt_list, "./samples/opensora_teacache_fast", loop=5)
232 |
233 |
234 | if __name__ == "__main__":
235 | prompt_list = read_prompt_list("vbench/VBench_full_info.json")
236 | eval_base(prompt_list)
237 | eval_teacache_slow(prompt_list)
238 | eval_teacache_fast(prompt_list)
239 |
--------------------------------------------------------------------------------
/eval/teacache/experiments/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import tqdm
5 |
6 | from videosys.utils.utils import set_seed
7 |
8 |
9 | def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
10 | kwargs["verbose"] = False
11 | for prompt in tqdm.tqdm(prompt_list):
12 | for l in range(loop):
13 | video = pipeline.generate(prompt, seed=l, **kwargs).video[0]
14 | pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
15 |
16 |
17 | def read_prompt_list(prompt_list_path):
18 | with open(prompt_list_path, "r") as f:
19 | prompt_list = json.load(f)
20 | prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
21 | return prompt_list
22 |
--------------------------------------------------------------------------------
/eval/teacache/vbench/cal_vbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | SEMANTIC_WEIGHT = 1
6 | QUALITY_WEIGHT = 4
7 |
8 | QUALITY_LIST = [
9 | "subject consistency",
10 | "background consistency",
11 | "temporal flickering",
12 | "motion smoothness",
13 | "aesthetic quality",
14 | "imaging quality",
15 | "dynamic degree",
16 | ]
17 |
18 | SEMANTIC_LIST = [
19 | "object class",
20 | "multiple objects",
21 | "human action",
22 | "color",
23 | "spatial relationship",
24 | "scene",
25 | "appearance style",
26 | "temporal style",
27 | "overall consistency",
28 | ]
29 |
30 | NORMALIZE_DIC = {
31 | "subject consistency": {"Min": 0.1462, "Max": 1.0},
32 | "background consistency": {"Min": 0.2615, "Max": 1.0},
33 | "temporal flickering": {"Min": 0.6293, "Max": 1.0},
34 | "motion smoothness": {"Min": 0.706, "Max": 0.9975},
35 | "dynamic degree": {"Min": 0.0, "Max": 1.0},
36 | "aesthetic quality": {"Min": 0.0, "Max": 1.0},
37 | "imaging quality": {"Min": 0.0, "Max": 1.0},
38 | "object class": {"Min": 0.0, "Max": 1.0},
39 | "multiple objects": {"Min": 0.0, "Max": 1.0},
40 | "human action": {"Min": 0.0, "Max": 1.0},
41 | "color": {"Min": 0.0, "Max": 1.0},
42 | "spatial relationship": {"Min": 0.0, "Max": 1.0},
43 | "scene": {"Min": 0.0, "Max": 0.8222},
44 | "appearance style": {"Min": 0.0009, "Max": 0.2855},
45 | "temporal style": {"Min": 0.0, "Max": 0.364},
46 | "overall consistency": {"Min": 0.0, "Max": 0.364},
47 | }
48 |
49 | DIM_WEIGHT = {
50 | "subject consistency": 1,
51 | "background consistency": 1,
52 | "temporal flickering": 1,
53 | "motion smoothness": 1,
54 | "aesthetic quality": 1,
55 | "imaging quality": 1,
56 | "dynamic degree": 0.5,
57 | "object class": 1,
58 | "multiple objects": 1,
59 | "human action": 1,
60 | "color": 1,
61 | "spatial relationship": 1,
62 | "scene": 1,
63 | "appearance style": 1,
64 | "temporal style": 1,
65 | "overall consistency": 1,
66 | }
67 |
68 | ordered_scaled_res = [
69 | "total score",
70 | "quality score",
71 | "semantic score",
72 | "subject consistency",
73 | "background consistency",
74 | "temporal flickering",
75 | "motion smoothness",
76 | "dynamic degree",
77 | "aesthetic quality",
78 | "imaging quality",
79 | "object class",
80 | "multiple objects",
81 | "human action",
82 | "color",
83 | "spatial relationship",
84 | "scene",
85 | "appearance style",
86 | "temporal style",
87 | "overall consistency",
88 | ]
89 |
90 |
91 | def parse_args():
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument("--score_dir", required=True, type=str)
94 | args = parser.parse_args()
95 | return args
96 |
97 |
98 | if __name__ == "__main__":
99 | args = parse_args()
100 | res_postfix = "_eval_results.json"
101 | info_postfix = "_full_info.json"
102 | files = os.listdir(args.score_dir)
103 | res_files = [x for x in files if res_postfix in x]
104 | info_files = [x for x in files if info_postfix in x]
105 | assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
106 |
107 | full_results = {}
108 | for res_file in res_files:
109 | # first check if results is normal
110 | info_file = res_file.split(res_postfix)[0] + info_postfix
111 | with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
112 | info = json.load(f)
113 | assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
114 | # read results
115 | with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
116 | data = json.load(f)
117 | for key, val in data.items():
118 | full_results[key] = format(val[0], ".4f")
119 |
120 | scaled_results = {}
121 | dims = set()
122 | for key, val in full_results.items():
123 | dim = key.replace("_", " ") if "_" in key else key
124 | scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
125 | NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
126 | )
127 | scaled_score *= DIM_WEIGHT[dim]
128 | scaled_results[dim] = scaled_score
129 | dims.add(dim)
130 |
131 | assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
132 |
133 | quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
134 | semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
135 | scaled_results["quality score"] = quality_score
136 | scaled_results["semantic score"] = semantic_score
137 | scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
138 | QUALITY_WEIGHT + SEMANTIC_WEIGHT
139 | )
140 |
141 | formated_scaled_results = {"items": []}
142 | for key in ordered_scaled_res:
143 | formated_score = format(scaled_results[key] * 100, ".2f") + "%"
144 | formated_scaled_results["items"].append({key: formated_score})
145 |
146 | output_file_path = os.path.join(args.score_dir, "all_results.json")
147 | with open(output_file_path, "w") as outfile:
148 | json.dump(full_results, outfile, indent=4, sort_keys=True)
149 | print(f"results saved to: {output_file_path}")
150 |
151 | scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
152 | with open(scaled_file_path, "w") as outfile:
153 | json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
154 | print(f"results saved to: {scaled_file_path}")
155 |
--------------------------------------------------------------------------------
/eval/teacache/vbench/run_vbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from vbench import VBench
5 |
6 | full_info_path = "./vbench/VBench_full_info.json"
7 |
8 | dimensions = [
9 | "subject_consistency",
10 | "imaging_quality",
11 | "background_consistency",
12 | "motion_smoothness",
13 | "overall_consistency",
14 | "human_action",
15 | "multiple_objects",
16 | "spatial_relationship",
17 | "object_class",
18 | "color",
19 | "aesthetic_quality",
20 | "appearance_style",
21 | "temporal_flickering",
22 | "scene",
23 | "temporal_style",
24 | "dynamic_degree",
25 | ]
26 |
27 |
28 | def parse_args():
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument("--video_path", required=True, type=str)
31 | parser.add_argument("--save_path", required=True, type=str)
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | if __name__ == "__main__":
37 | args = parse_args()
38 |
39 | kwargs = {}
40 | kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
41 |
42 | for dimension in dimensions:
43 | my_VBench = VBench(torch.device("cuda"), full_info_path, args.save_path)
44 | my_VBench.evaluate(
45 | videos_path=args.video_path,
46 | name=dimension,
47 | local=False,
48 | read_frame=False,
49 | dimension_list=[dimension],
50 | mode="vbench_standard",
51 | **kwargs,
52 | )
53 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate>0.17.0
2 | bs4
3 | click
4 | colossalai==0.4.0
5 | diffusers==0.30.0
6 | einops
7 | fabric
8 | ftfy
9 | imageio
10 | imageio-ffmpeg
11 | matplotlib
12 | ninja
13 | numpy<2.0.0
14 | omegaconf
15 | packaging
16 | psutil
17 | pydantic
18 | ray
19 | rich
20 | safetensors
21 | sentencepiece
22 | timm
23 | torch>=1.13
24 | tqdm
25 | peft==0.13.2
26 | transformers==4.39.3
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from setuptools import find_packages, setup
4 | from setuptools.command.develop import develop
5 | from setuptools.command.egg_info import egg_info
6 | from setuptools.command.install import install
7 |
8 |
9 | def fetch_requirements(path) -> List[str]:
10 | """
11 | This function reads the requirements file.
12 |
13 | Args:
14 | path (str): the path to the requirements file.
15 |
16 | Returns:
17 | The lines in the requirements file.
18 | """
19 | with open(path, "r") as fd:
20 | requirements = [r.strip() for r in fd.readlines()]
21 | # requirements.remove("colossalai")
22 | return requirements
23 |
24 |
25 | def fetch_readme() -> str:
26 | """
27 | This function reads the README.md file in the current directory.
28 |
29 | Returns:
30 | The lines in the README file.
31 | """
32 | with open("README.md", encoding="utf-8") as f:
33 | return f.read()
34 |
35 |
36 | def custom_install():
37 | return ["pip", "install", "colossalai", "--no-deps"]
38 |
39 |
40 | class CustomInstallCommand(install):
41 | def run(self):
42 | install.run(self)
43 | self.spawn(custom_install())
44 |
45 |
46 | class CustomDevelopCommand(develop):
47 | def run(self):
48 | develop.run(self)
49 | self.spawn(custom_install())
50 |
51 |
52 | class CustomEggInfoCommand(egg_info):
53 | def run(self):
54 | egg_info.run(self)
55 | self.spawn(custom_install())
56 |
57 |
58 | setup(
59 | name="videosys",
60 | version="2.0.0",
61 | packages=find_packages(
62 | exclude=(
63 | "videos",
64 | "tests",
65 | "figure",
66 | "*.egg-info",
67 | )
68 | ),
69 | description="TeaCache",
70 | long_description=fetch_readme(),
71 | long_description_content_type="text/markdown",
72 | license="Apache Software License 2.0",
73 | install_requires=fetch_requirements("requirements.txt"),
74 | python_requires=">=3.7",
75 | # cmdclass={
76 | # "install": CustomInstallCommand,
77 | # "develop": CustomDevelopCommand,
78 | # "egg_info": CustomEggInfoCommand,
79 | # },
80 | classifiers=[
81 | "Programming Language :: Python :: 3",
82 | "License :: OSI Approved :: Apache Software License",
83 | "Environment :: GPU :: NVIDIA CUDA",
84 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
85 | "Topic :: System :: Distributed Computing",
86 | ],
87 | )
88 |
--------------------------------------------------------------------------------
/videosys/__init__.py:
--------------------------------------------------------------------------------
1 | from .core.engine import VideoSysEngine
2 | from .core.parallel_mgr import initialize
3 | from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
4 | from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
5 | from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
6 | from .pipelines.open_sora_plan import (
7 | OpenSoraPlanConfig,
8 | OpenSoraPlanPipeline,
9 | OpenSoraPlanV110PABConfig,
10 | OpenSoraPlanV120PABConfig,
11 | )
12 | from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
13 |
14 | __all__ = [
15 | "initialize",
16 | "VideoSysEngine",
17 | "LattePipeline", "LatteConfig", "LattePABConfig",
18 | "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig",
19 | "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
20 | "CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
21 | "VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"
22 | ] # fmt: skip
23 |
--------------------------------------------------------------------------------
/videosys/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/__init__.py
--------------------------------------------------------------------------------
/videosys/core/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from typing import Any, Optional
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | import videosys
9 |
10 | from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
11 |
12 |
13 | class VideoSysEngine:
14 | """
15 | this is partly inspired by vllm
16 | """
17 |
18 | def __init__(self, config):
19 | self.config = config
20 | self.parallel_worker_tasks = None
21 | self._init_worker(config.pipeline_cls)
22 |
23 | def _init_worker(self, pipeline_cls):
24 | world_size = self.config.num_gpus
25 |
26 | # Disable torch async compiling which won't work with daemonic processes
27 | os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
28 |
29 | # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
30 | # contention amongst the shards
31 | if "OMP_NUM_THREADS" not in os.environ:
32 | os.environ["OMP_NUM_THREADS"] = "1"
33 |
34 | # NOTE: The two following lines need adaption for multi-node
35 | assert world_size <= torch.cuda.device_count()
36 |
37 | # change addr for multi-node
38 | distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
39 |
40 | if world_size == 1:
41 | self.workers = []
42 | self.worker_monitor = None
43 | else:
44 | result_handler = ResultHandler()
45 | self.workers = [
46 | ProcessWorkerWrapper(
47 | result_handler,
48 | partial(
49 | self._create_pipeline,
50 | pipeline_cls=pipeline_cls,
51 | rank=rank,
52 | local_rank=rank,
53 | distributed_init_method=distributed_init_method,
54 | ),
55 | )
56 | for rank in range(1, world_size)
57 | ]
58 |
59 | self.worker_monitor = WorkerMonitor(self.workers, result_handler)
60 | result_handler.start()
61 | self.worker_monitor.start()
62 |
63 | self.driver_worker = self._create_pipeline(
64 | pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
65 | )
66 |
67 | # TODO: add more options here for pipeline, or wrap all options into config
68 | def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
69 | videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method)
70 |
71 | pipeline = pipeline_cls(self.config)
72 | return pipeline
73 |
74 | def _run_workers(
75 | self,
76 | method: str,
77 | *args,
78 | async_run_tensor_parallel_workers_only: bool = False,
79 | max_concurrent_workers: Optional[int] = None,
80 | **kwargs,
81 | ) -> Any:
82 | """Runs the given method on all workers."""
83 |
84 | # Start the workers first.
85 | worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
86 |
87 | if async_run_tensor_parallel_workers_only:
88 | # Just return futures
89 | return worker_outputs
90 |
91 | driver_worker_method = getattr(self.driver_worker, method)
92 | driver_worker_output = driver_worker_method(*args, **kwargs)
93 |
94 | # Get the results of the workers.
95 | return [driver_worker_output] + [output.get() for output in worker_outputs]
96 |
97 | def _driver_execute_model(self, *args, **kwargs):
98 | return self.driver_worker.generate(*args, **kwargs)
99 |
100 | def generate(self, *args, **kwargs):
101 | return self._run_workers("generate", *args, **kwargs)[0]
102 |
103 | def stop_remote_worker_execution_loop(self) -> None:
104 | if self.parallel_worker_tasks is None:
105 | return
106 |
107 | parallel_worker_tasks = self.parallel_worker_tasks
108 | self.parallel_worker_tasks = None
109 | # Ensure that workers exit model loop cleanly
110 | # (this will raise otherwise)
111 | self._wait_for_tasks_completion(parallel_worker_tasks)
112 |
113 | def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
114 | """Wait for futures returned from _run_workers() with
115 | async_run_remote_workers_only to complete."""
116 | for result in parallel_worker_tasks:
117 | result.get()
118 |
119 | def save_video(self, video, output_path):
120 | return self.driver_worker.save_video(video, output_path)
121 |
122 | def shutdown(self):
123 | if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
124 | worker_monitor.close()
125 | dist.destroy_process_group()
126 |
127 | def __del__(self):
128 | self.shutdown()
129 |
--------------------------------------------------------------------------------
/videosys/core/mp_utils.py:
--------------------------------------------------------------------------------
1 | # adapted from vllm
2 | # https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
3 |
4 | import asyncio
5 | import multiprocessing
6 | import os
7 | import socket
8 | import sys
9 | import threading
10 | import traceback
11 | import uuid
12 | from dataclasses import dataclass
13 | from multiprocessing import Queue
14 | from multiprocessing.connection import wait
15 | from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
16 |
17 | from videosys.utils.logging import create_logger
18 |
19 | T = TypeVar("T")
20 | _TERMINATE = "TERMINATE" # sentinel
21 | # ANSI color codes
22 | CYAN = "\033[1;36m"
23 | RESET = "\033[0;0m"
24 | JOIN_TIMEOUT_S = 2
25 |
26 | mp_method = "spawn" # fork cann't work
27 | mp = multiprocessing.get_context(mp_method)
28 |
29 | logger = create_logger()
30 |
31 |
32 | def get_distributed_init_method(ip: str, port: int) -> str:
33 | # Brackets are not permitted in ipv4 addresses,
34 | # see https://github.com/python/cpython/issues/103848
35 | return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
36 |
37 |
38 | def get_open_port() -> int:
39 | # try ipv4
40 | try:
41 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
42 | s.bind(("", 0))
43 | return s.getsockname()[1]
44 | except OSError:
45 | # try ipv6
46 | with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
47 | s.bind(("", 0))
48 | return s.getsockname()[1]
49 |
50 |
51 | @dataclass
52 | class Result(Generic[T]):
53 | """Result of task dispatched to worker"""
54 |
55 | task_id: uuid.UUID
56 | value: Optional[T] = None
57 | exception: Optional[BaseException] = None
58 |
59 |
60 | class ResultFuture(threading.Event, Generic[T]):
61 | """Synchronous future for non-async case"""
62 |
63 | def __init__(self):
64 | super().__init__()
65 | self.result: Optional[Result[T]] = None
66 |
67 | def set_result(self, result: Result[T]):
68 | self.result = result
69 | self.set()
70 |
71 | def get(self) -> T:
72 | self.wait()
73 | assert self.result is not None
74 | if self.result.exception is not None:
75 | raise self.result.exception
76 | return self.result.value # type: ignore[return-value]
77 |
78 |
79 | def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
80 | if isinstance(future, ResultFuture):
81 | future.set_result(result)
82 | return
83 | loop = future.get_loop()
84 | if not loop.is_closed():
85 | if result.exception is not None:
86 | loop.call_soon_threadsafe(future.set_exception, result.exception)
87 | else:
88 | loop.call_soon_threadsafe(future.set_result, result.value)
89 |
90 |
91 | class ResultHandler(threading.Thread):
92 | """Handle results from all workers (in background thread)"""
93 |
94 | def __init__(self) -> None:
95 | super().__init__(daemon=True)
96 | self.result_queue = mp.Queue()
97 | self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
98 |
99 | def run(self):
100 | for result in iter(self.result_queue.get, _TERMINATE):
101 | future = self.tasks.pop(result.task_id)
102 | _set_future_result(future, result)
103 | # Ensure that all waiters will receive an exception
104 | for task_id, future in self.tasks.items():
105 | _set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
106 |
107 | def close(self):
108 | self.result_queue.put(_TERMINATE)
109 |
110 |
111 | class WorkerMonitor(threading.Thread):
112 | """Monitor worker status (in background thread)"""
113 |
114 | def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
115 | super().__init__(daemon=True)
116 | self.workers = workers
117 | self.result_handler = result_handler
118 | self._close = False
119 |
120 | def run(self) -> None:
121 | # Blocks until any worker exits
122 | dead_sentinels = wait([w.process.sentinel for w in self.workers])
123 | if not self._close:
124 | self._close = True
125 |
126 | # Kill / cleanup all workers
127 | for worker in self.workers:
128 | process = worker.process
129 | if process.sentinel in dead_sentinels:
130 | process.join(JOIN_TIMEOUT_S)
131 | if process.exitcode is not None and process.exitcode != 0:
132 | logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
133 | # Cleanup any remaining workers
134 | logger.info("Killing local worker processes")
135 | for worker in self.workers:
136 | worker.kill_worker()
137 | # Must be done after worker task queues are all closed
138 | self.result_handler.close()
139 |
140 | for worker in self.workers:
141 | worker.process.join(JOIN_TIMEOUT_S)
142 |
143 | def close(self):
144 | if self._close:
145 | return
146 | self._close = True
147 | logger.info("Terminating local worker processes")
148 | for worker in self.workers:
149 | worker.terminate_worker()
150 | # Must be done after worker task queues are all closed
151 | self.result_handler.close()
152 |
153 |
154 | def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
155 | """Prepend each output line with process-specific prefix"""
156 |
157 | prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
158 | file_write = file.write
159 |
160 | def write_with_prefix(s: str):
161 | if not s:
162 | return
163 | if file.start_new_line: # type: ignore[attr-defined]
164 | file_write(prefix)
165 | idx = 0
166 | while (next_idx := s.find("\n", idx)) != -1:
167 | next_idx += 1
168 | file_write(s[idx:next_idx])
169 | if next_idx == len(s):
170 | file.start_new_line = True # type: ignore[attr-defined]
171 | return
172 | file_write(prefix)
173 | idx = next_idx
174 | file_write(s[idx:])
175 | file.start_new_line = False # type: ignore[attr-defined]
176 |
177 | file.start_new_line = True # type: ignore[attr-defined]
178 | file.write = write_with_prefix # type: ignore[method-assign]
179 |
180 |
181 | def _run_worker_process(
182 | worker_factory: Callable[[], Any],
183 | task_queue: Queue,
184 | result_queue: Queue,
185 | ) -> None:
186 | """Worker process event loop"""
187 |
188 | # Add process-specific prefix to stdout and stderr
189 | process_name = mp.current_process().name
190 | pid = os.getpid()
191 | _add_prefix(sys.stdout, process_name, pid)
192 | _add_prefix(sys.stderr, process_name, pid)
193 |
194 | # Initialize worker
195 | worker = worker_factory()
196 | del worker_factory
197 |
198 | # Accept tasks from the engine in task_queue
199 | # and return task output in result_queue
200 | logger.info("Worker ready; awaiting tasks")
201 | try:
202 | for items in iter(task_queue.get, _TERMINATE):
203 | output = None
204 | exception = None
205 | task_id, method, args, kwargs = items
206 | try:
207 | executor = getattr(worker, method)
208 | output = executor(*args, **kwargs)
209 | except BaseException as e:
210 | tb = traceback.format_exc()
211 | logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
212 | exception = e
213 | result_queue.put(Result(task_id=task_id, value=output, exception=exception))
214 | except KeyboardInterrupt:
215 | pass
216 | except Exception:
217 | logger.exception("Worker failed")
218 |
219 | logger.info("Worker exiting")
220 |
221 |
222 | class ProcessWorkerWrapper:
223 | """Local process wrapper for handling single-node multi-GPU."""
224 |
225 | def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
226 | self._task_queue = mp.Queue()
227 | self.result_queue = result_handler.result_queue
228 | self.tasks = result_handler.tasks
229 | self.process = mp.Process( # type: ignore[attr-defined]
230 | target=_run_worker_process,
231 | name="VideoSysWorkerProcess",
232 | kwargs=dict(
233 | worker_factory=worker_factory,
234 | task_queue=self._task_queue,
235 | result_queue=self.result_queue,
236 | ),
237 | daemon=True,
238 | )
239 |
240 | self.process.start()
241 |
242 | def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
243 | task_id = uuid.uuid4()
244 | self.tasks[task_id] = future
245 | try:
246 | self._task_queue.put((task_id, method, args, kwargs))
247 | except BaseException as e:
248 | del self.tasks[task_id]
249 | raise ChildProcessError("worker died") from e
250 |
251 | def execute_method(self, method: str, *args, **kwargs):
252 | future: ResultFuture = ResultFuture()
253 | self._enqueue_task(future, method, args, kwargs)
254 | return future
255 |
256 | async def execute_method_async(self, method: str, *args, **kwargs):
257 | future = asyncio.get_running_loop().create_future()
258 | self._enqueue_task(future, method, args, kwargs)
259 | return await future
260 |
261 | def terminate_worker(self):
262 | try:
263 | self._task_queue.put(_TERMINATE)
264 | except ValueError:
265 | self.process.kill()
266 | self._task_queue.close()
267 |
268 | def kill_worker(self):
269 | self._task_queue.close()
270 | self.process.kill()
271 |
--------------------------------------------------------------------------------
/videosys/core/pab_mgr.py:
--------------------------------------------------------------------------------
1 | from videosys.utils.logging import logger
2 |
3 | PAB_MANAGER = None
4 |
5 |
6 | class PABConfig:
7 | def __init__(
8 | self,
9 | cross_broadcast: bool = False,
10 | cross_threshold: list = None,
11 | cross_range: int = None,
12 | spatial_broadcast: bool = False,
13 | spatial_threshold: list = None,
14 | spatial_range: int = None,
15 | temporal_broadcast: bool = False,
16 | temporal_threshold: list = None,
17 | temporal_range: int = None,
18 | mlp_broadcast: bool = False,
19 | mlp_spatial_broadcast_config: dict = None,
20 | mlp_temporal_broadcast_config: dict = None,
21 | ):
22 | self.steps = None
23 |
24 | self.cross_broadcast = cross_broadcast
25 | self.cross_threshold = cross_threshold
26 | self.cross_range = cross_range
27 |
28 | self.spatial_broadcast = spatial_broadcast
29 | self.spatial_threshold = spatial_threshold
30 | self.spatial_range = spatial_range
31 |
32 | self.temporal_broadcast = temporal_broadcast
33 | self.temporal_threshold = temporal_threshold
34 | self.temporal_range = temporal_range
35 |
36 | self.mlp_broadcast = mlp_broadcast
37 | self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
38 | self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
39 | self.mlp_temporal_outputs = {}
40 | self.mlp_spatial_outputs = {}
41 |
42 |
43 | class PABManager:
44 | def __init__(self, config: PABConfig):
45 | self.config: PABConfig = config
46 |
47 | init_prompt = f"Init Pyramid Attention Broadcast."
48 | init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
49 | init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
50 | init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
51 | init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
52 | logger.info(init_prompt)
53 |
54 | def if_broadcast_cross(self, timestep: int, count: int):
55 | if (
56 | self.config.cross_broadcast
57 | and (timestep is not None)
58 | and (count % self.config.cross_range != 0)
59 | and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
60 | ):
61 | flag = True
62 | else:
63 | flag = False
64 | count = (count + 1) % self.config.steps
65 | return flag, count
66 |
67 | def if_broadcast_temporal(self, timestep: int, count: int):
68 | if (
69 | self.config.temporal_broadcast
70 | and (timestep is not None)
71 | and (count % self.config.temporal_range != 0)
72 | and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
73 | ):
74 | flag = True
75 | else:
76 | flag = False
77 | count = (count + 1) % self.config.steps
78 | return flag, count
79 |
80 | def if_broadcast_spatial(self, timestep: int, count: int):
81 | if (
82 | self.config.spatial_broadcast
83 | and (timestep is not None)
84 | and (count % self.config.spatial_range != 0)
85 | and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
86 | ):
87 | flag = True
88 | else:
89 | flag = False
90 | count = (count + 1) % self.config.steps
91 | return flag, count
92 |
93 | @staticmethod
94 | def _is_t_in_skip_config(all_timesteps, timestep, config):
95 | is_t_in_skip_config = False
96 | skip_range = None
97 | for key in config:
98 | if key not in all_timesteps:
99 | continue
100 | index = all_timesteps.index(key)
101 | skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
102 | if timestep in skip_range:
103 | is_t_in_skip_config = True
104 | skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
105 | break
106 | return is_t_in_skip_config, skip_range
107 |
108 | def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
109 | if not self.config.mlp_broadcast:
110 | return False, None, False, None
111 |
112 | if is_temporal:
113 | cur_config = self.config.mlp_temporal_broadcast_config
114 | else:
115 | cur_config = self.config.mlp_spatial_broadcast_config
116 |
117 | is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
118 | next_flag = False
119 | if (
120 | self.config.mlp_broadcast
121 | and (timestep is not None)
122 | and (timestep in cur_config)
123 | and (block_idx in cur_config[timestep]["block"])
124 | ):
125 | flag = False
126 | next_flag = True
127 | count = count + 1
128 | elif (
129 | self.config.mlp_broadcast
130 | and (timestep is not None)
131 | and (is_t_in_skip_config)
132 | and (block_idx in cur_config[skip_range[0]]["block"])
133 | ):
134 | flag = True
135 | count = 0
136 | else:
137 | flag = False
138 |
139 | return flag, count, next_flag, skip_range
140 |
141 | def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
142 | if is_temporal:
143 | self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
144 | else:
145 | self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
146 |
147 | def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
148 | skip_start_t = skip_range[0]
149 | if is_temporal:
150 | skip_output = (
151 | self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
152 | if self.config.mlp_temporal_outputs is not None
153 | else None
154 | )
155 | else:
156 | skip_output = (
157 | self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
158 | if self.config.mlp_spatial_outputs is not None
159 | else None
160 | )
161 |
162 | if skip_output is not None:
163 | if timestep == skip_range[-1]:
164 | # TODO: save memory
165 | if is_temporal:
166 | del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
167 | else:
168 | del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
169 | else:
170 | raise ValueError(
171 | f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
172 | )
173 |
174 | return skip_output
175 |
176 | def get_spatial_mlp_outputs(self):
177 | return self.config.mlp_spatial_outputs
178 |
179 | def get_temporal_mlp_outputs(self):
180 | return self.config.mlp_temporal_outputs
181 |
182 |
183 | def set_pab_manager(config: PABConfig):
184 | global PAB_MANAGER
185 | PAB_MANAGER = PABManager(config)
186 |
187 |
188 | def enable_pab():
189 | if PAB_MANAGER is None:
190 | return False
191 | return (
192 | PAB_MANAGER.config.cross_broadcast
193 | or PAB_MANAGER.config.spatial_broadcast
194 | or PAB_MANAGER.config.temporal_broadcast
195 | )
196 |
197 |
198 | def update_steps(steps: int):
199 | if PAB_MANAGER is not None:
200 | PAB_MANAGER.config.steps = steps
201 |
202 |
203 | def if_broadcast_cross(timestep: int, count: int):
204 | if not enable_pab():
205 | return False, count
206 | return PAB_MANAGER.if_broadcast_cross(timestep, count)
207 |
208 |
209 | def if_broadcast_temporal(timestep: int, count: int):
210 | if not enable_pab():
211 | return False, count
212 | return PAB_MANAGER.if_broadcast_temporal(timestep, count)
213 |
214 |
215 | def if_broadcast_spatial(timestep: int, count: int):
216 | if not enable_pab():
217 | return False, count
218 | return PAB_MANAGER.if_broadcast_spatial(timestep, count)
219 |
220 |
221 | def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
222 | if not enable_pab():
223 | return False, count
224 | return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
225 |
226 |
227 | def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
228 | return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
229 |
230 |
231 | def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
232 | return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
233 |
--------------------------------------------------------------------------------
/videosys/core/parallel_mgr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from colossalai.cluster.process_group_mesh import ProcessGroupMesh
4 | from torch.distributed import ProcessGroup
5 |
6 | from videosys.utils.logging import init_dist_logger, logger
7 |
8 |
9 | class ParallelManager(ProcessGroupMesh):
10 | def __init__(self, dp_size, cp_size, sp_size):
11 | super().__init__(dp_size, cp_size, sp_size)
12 | dp_axis, cp_axis, sp_axis = 0, 1, 2
13 |
14 | self.dp_size = dp_size
15 | self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
16 | self.dp_rank = dist.get_rank(self.dp_group)
17 |
18 | self.cp_size = cp_size
19 | if cp_size > 1:
20 | self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
21 | self.cp_rank = dist.get_rank(self.cp_group)
22 | else:
23 | self.cp_group = None
24 | self.cp_rank = None
25 |
26 | self.sp_size = sp_size
27 | if sp_size > 1:
28 | self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
29 | self.sp_rank = dist.get_rank(self.sp_group)
30 | else:
31 | self.sp_group = None
32 | self.sp_rank = None
33 |
34 | logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
35 |
36 |
37 | def initialize(
38 | rank=0,
39 | world_size=1,
40 | init_method=None,
41 | ):
42 | if not dist.is_initialized():
43 | try:
44 | dist.destroy_process_group()
45 | except Exception:
46 | pass
47 | dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
48 | torch.cuda.set_device(rank)
49 | init_dist_logger()
50 | torch.backends.cuda.matmul.allow_tf32 = True
51 | torch.backends.cudnn.allow_tf32 = True
52 |
--------------------------------------------------------------------------------
/videosys/core/pipeline.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from abc import abstractmethod
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
7 | from diffusers.utils import BaseOutput
8 |
9 |
10 | class VideoSysPipeline(DiffusionPipeline):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | @staticmethod
15 | def set_eval_and_device(device: torch.device, *modules):
16 | modules = list(modules)
17 | for i in range(len(modules)):
18 | modules[i] = modules[i].eval()
19 | modules[i] = modules[i].to(device)
20 |
21 | @abstractmethod
22 | def generate(self, *args, **kwargs):
23 | pass
24 |
25 | def __call__(self, *args, **kwargs):
26 | """
27 | In diffusers, it is a convention to call the pipeline object.
28 | But in VideoSys, we will use the generate method for better prompt.
29 | This is a wrapper for the generate method to support the diffusers usage.
30 | """
31 | return self.generate(*args, **kwargs)
32 |
33 | @classmethod
34 | def _get_signature_keys(cls, obj):
35 | parameters = inspect.signature(obj.__init__).parameters
36 | required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
37 | optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
38 | expected_modules = set(required_parameters.keys()) - {"self"}
39 | # modify: remove the config module from the expected modules
40 | expected_modules = expected_modules - {"config"}
41 |
42 | optional_names = list(optional_parameters)
43 | for name in optional_names:
44 | if name in cls._optional_components:
45 | expected_modules.add(name)
46 | optional_parameters.remove(name)
47 |
48 | return expected_modules, optional_parameters
49 |
50 |
51 | @dataclass
52 | class VideoSysPipelineOutput(BaseOutput):
53 | video: torch.Tensor
54 |
--------------------------------------------------------------------------------
/videosys/core/shardformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/shardformer/__init__.py
--------------------------------------------------------------------------------
/videosys/core/shardformer/t5/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/core/shardformer/t5/__init__.py
--------------------------------------------------------------------------------
/videosys/core/shardformer/t5/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class T5LayerNorm(nn.Module):
6 | def __init__(self, hidden_size, eps=1e-6):
7 | """
8 | Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
9 | """
10 | super().__init__()
11 | self.weight = nn.Parameter(torch.ones(hidden_size))
12 | self.variance_epsilon = eps
13 |
14 | def forward(self, hidden_states):
15 | # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
16 | # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
17 | # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
18 | # half-precision inputs is done in fp32
19 |
20 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
21 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
22 |
23 | # convert into half-precision if necessary
24 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
25 | hidden_states = hidden_states.to(self.weight.dtype)
26 |
27 | return self.weight * hidden_states
28 |
29 | @staticmethod
30 | def from_native_module(module, *args, **kwargs):
31 | assert module.__class__.__name__ == "FusedRMSNorm", (
32 | "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
33 | "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
34 | )
35 |
36 | layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
37 | layer_norm.weight.data.copy_(module.weight.data)
38 | layer_norm = layer_norm.to(module.weight.device)
39 | return layer_norm
40 |
--------------------------------------------------------------------------------
/videosys/core/shardformer/t5/policy.py:
--------------------------------------------------------------------------------
1 | from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
2 | from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
3 | from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
4 |
5 |
6 | class T5EncoderPolicy(Policy):
7 | def config_sanity_check(self):
8 | assert not self.shard_config.enable_tensor_parallelism
9 | assert not self.shard_config.enable_flash_attention
10 |
11 | def preprocess(self):
12 | return self.model
13 |
14 | def module_policy(self):
15 | from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
16 |
17 | policy = {}
18 |
19 | # check whether apex is installed
20 | try:
21 | from apex.normalization import FusedRMSNorm # noqa
22 | from videosys.core.shardformer.t5.modeling import T5LayerNorm
23 |
24 | # recover hf from fused rms norm to T5 norm which is faster
25 | self.append_or_create_submodule_replacement(
26 | description=SubModuleReplacementDescription(
27 | suffix="layer_norm",
28 | target_module=T5LayerNorm,
29 | ),
30 | policy=policy,
31 | target_key=T5LayerFF,
32 | )
33 | self.append_or_create_submodule_replacement(
34 | description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
35 | policy=policy,
36 | target_key=T5LayerSelfAttention,
37 | )
38 | self.append_or_create_submodule_replacement(
39 | description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
40 | policy=policy,
41 | target_key=T5Stack,
42 | )
43 | except (ImportError, ModuleNotFoundError):
44 | pass
45 |
46 | # use jit operator
47 | if self.shard_config.enable_jit_fused:
48 | self.append_or_create_method_replacement(
49 | description={
50 | "forward": get_jit_fused_T5_layer_ff_forward(),
51 | "dropout_add": get_jit_fused_dropout_add_func(),
52 | },
53 | policy=policy,
54 | target_key=T5LayerFF,
55 | )
56 | self.append_or_create_method_replacement(
57 | description={
58 | "forward": get_T5_layer_self_attention_forward(),
59 | "dropout_add": get_jit_fused_dropout_add_func(),
60 | },
61 | policy=policy,
62 | target_key=T5LayerSelfAttention,
63 | )
64 |
65 | return policy
66 |
67 | def postprocess(self):
68 | return self.model
69 |
--------------------------------------------------------------------------------
/videosys/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/__init__.py
--------------------------------------------------------------------------------
/videosys/models/autoencoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/autoencoders/__init__.py
--------------------------------------------------------------------------------
/videosys/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/modules/__init__.py
--------------------------------------------------------------------------------
/videosys/models/modules/activations.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | approx_gelu = lambda: nn.GELU(approximate="tanh")
4 |
--------------------------------------------------------------------------------
/videosys/models/modules/downsampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CogVideoXDownsample3D(nn.Module):
7 | # Todo: Wait for paper relase.
8 | r"""
9 | A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
10 |
11 | Args:
12 | in_channels (`int`):
13 | Number of channels in the input image.
14 | out_channels (`int`):
15 | Number of channels produced by the convolution.
16 | kernel_size (`int`, defaults to `3`):
17 | Size of the convolving kernel.
18 | stride (`int`, defaults to `2`):
19 | Stride of the convolution.
20 | padding (`int`, defaults to `0`):
21 | Padding added to all four sides of the input.
22 | compress_time (`bool`, defaults to `False`):
23 | Whether or not to compress the time dimension.
24 | """
25 |
26 | def __init__(
27 | self,
28 | in_channels: int,
29 | out_channels: int,
30 | kernel_size: int = 3,
31 | stride: int = 2,
32 | padding: int = 0,
33 | compress_time: bool = False,
34 | ):
35 | super().__init__()
36 |
37 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
38 | self.compress_time = compress_time
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | if self.compress_time:
42 | batch_size, channels, frames, height, width = x.shape
43 |
44 | # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
45 | x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
46 |
47 | if x.shape[-1] % 2 == 1:
48 | x_first, x_rest = x[..., 0], x[..., 1:]
49 | if x_rest.shape[-1] > 0:
50 | # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
51 | x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
52 |
53 | x = torch.cat([x_first[..., None], x_rest], dim=-1)
54 | # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
55 | x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
56 | else:
57 | # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
58 | x = F.avg_pool1d(x, kernel_size=2, stride=2)
59 | # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
60 | x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
61 |
62 | # Pad the tensor
63 | pad = (0, 1, 0, 1)
64 | x = F.pad(x, pad, mode="constant", value=0)
65 | batch_size, channels, frames, height, width = x.shape
66 | # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
67 | x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
68 | x = self.conv(x)
69 | # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
70 | x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
71 | return x
72 |
--------------------------------------------------------------------------------
/videosys/models/modules/normalization.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class LlamaRMSNorm(nn.Module):
9 | def __init__(self, hidden_size, eps=1e-6):
10 | """
11 | LlamaRMSNorm is equivalent to T5LayerNorm
12 | """
13 | super().__init__()
14 | self.weight = nn.Parameter(torch.ones(hidden_size))
15 | self.variance_epsilon = eps
16 |
17 | def forward(self, hidden_states):
18 | input_dtype = hidden_states.dtype
19 | hidden_states = hidden_states.to(torch.float32)
20 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
21 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
22 | return self.weight * hidden_states.to(input_dtype)
23 |
24 |
25 | class CogVideoXLayerNormZero(nn.Module):
26 | def __init__(
27 | self,
28 | conditioning_dim: int,
29 | embedding_dim: int,
30 | elementwise_affine: bool = True,
31 | eps: float = 1e-5,
32 | bias: bool = True,
33 | ) -> None:
34 | super().__init__()
35 |
36 | self.silu = nn.SiLU()
37 | self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
38 | self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
39 |
40 | def forward(
41 | self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
42 | ) -> Tuple[torch.Tensor, torch.Tensor]:
43 | shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
44 | hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
45 | encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
46 | return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
47 |
48 |
49 | class AdaLayerNorm(nn.Module):
50 | r"""
51 | Norm layer modified to incorporate timestep embeddings.
52 |
53 | Parameters:
54 | embedding_dim (`int`): The size of each embedding vector.
55 | num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
56 | output_dim (`int`, *optional*):
57 | norm_elementwise_affine (`bool`, defaults to `False):
58 | norm_eps (`bool`, defaults to `False`):
59 | chunk_dim (`int`, defaults to `0`):
60 | """
61 |
62 | def __init__(
63 | self,
64 | embedding_dim: int,
65 | num_embeddings: Optional[int] = None,
66 | output_dim: Optional[int] = None,
67 | norm_elementwise_affine: bool = False,
68 | norm_eps: float = 1e-5,
69 | chunk_dim: int = 0,
70 | ):
71 | super().__init__()
72 |
73 | self.chunk_dim = chunk_dim
74 | output_dim = output_dim or embedding_dim * 2
75 |
76 | if num_embeddings is not None:
77 | self.emb = nn.Embedding(num_embeddings, embedding_dim)
78 | else:
79 | self.emb = None
80 |
81 | self.silu = nn.SiLU()
82 | self.linear = nn.Linear(embedding_dim, output_dim)
83 | self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
84 |
85 | def forward(
86 | self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
87 | ) -> torch.Tensor:
88 | if self.emb is not None:
89 | temb = self.emb(timestep)
90 |
91 | temb = self.linear(self.silu(temb))
92 |
93 | if self.chunk_dim == 1:
94 | # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
95 | # other if-branch. This branch is specific to CogVideoX for now.
96 | shift, scale = temb.chunk(2, dim=1)
97 | shift = shift[:, None, :]
98 | scale = scale[:, None, :]
99 | else:
100 | scale, shift = temb.chunk(2, dim=0)
101 |
102 | x = self.norm(x) * (1 + scale) + shift
103 | return x
104 |
105 |
106 | class VchitectSpatialNorm(nn.Module):
107 | """
108 | Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
109 |
110 | Args:
111 | f_channels (`int`):
112 | The number of channels for input to group normalization layer, and output of the spatial norm layer.
113 | zq_channels (`int`):
114 | The number of channels for the quantized vector as described in the paper.
115 | """
116 |
117 | def __init__(
118 | self,
119 | f_channels: int,
120 | zq_channels: int,
121 | ):
122 | super().__init__()
123 | self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
124 | self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
125 | self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
126 |
127 | def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
128 | f_size = f.shape[-2:]
129 | zq = F.interpolate(zq, size=f_size, mode="nearest")
130 | norm_f = self.norm_layer(f)
131 | new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
132 | return new_f
133 |
--------------------------------------------------------------------------------
/videosys/models/modules/upsampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CogVideoXUpsample3D(nn.Module):
7 | r"""
8 | A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
9 |
10 | Args:
11 | in_channels (`int`):
12 | Number of channels in the input image.
13 | out_channels (`int`):
14 | Number of channels produced by the convolution.
15 | kernel_size (`int`, defaults to `3`):
16 | Size of the convolving kernel.
17 | stride (`int`, defaults to `1`):
18 | Stride of the convolution.
19 | padding (`int`, defaults to `1`):
20 | Padding added to all four sides of the input.
21 | compress_time (`bool`, defaults to `False`):
22 | Whether or not to compress the time dimension.
23 | """
24 |
25 | def __init__(
26 | self,
27 | in_channels: int,
28 | out_channels: int,
29 | kernel_size: int = 3,
30 | stride: int = 1,
31 | padding: int = 1,
32 | compress_time: bool = False,
33 | ) -> None:
34 | super().__init__()
35 |
36 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
37 | self.compress_time = compress_time
38 |
39 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
40 | if self.compress_time:
41 | if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
42 | # split first frame
43 | x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
44 |
45 | x_first = F.interpolate(x_first, scale_factor=2.0)
46 | x_rest = F.interpolate(x_rest, scale_factor=2.0)
47 | x_first = x_first[:, :, None, :, :]
48 | inputs = torch.cat([x_first, x_rest], dim=2)
49 | elif inputs.shape[2] > 1:
50 | inputs = F.interpolate(inputs, scale_factor=2.0)
51 | else:
52 | inputs = inputs.squeeze(2)
53 | inputs = F.interpolate(inputs, scale_factor=2.0)
54 | inputs = inputs[:, :, None, :, :]
55 | else:
56 | # only interpolate 2D
57 | b, c, t, h, w = inputs.shape
58 | inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
59 | inputs = F.interpolate(inputs, scale_factor=2.0)
60 | inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
61 |
62 | b, c, t, h, w = inputs.shape
63 | inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
64 | inputs = self.conv(inputs)
65 | inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
66 |
67 | return inputs
68 |
--------------------------------------------------------------------------------
/videosys/models/transformers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/models/transformers/__init__.py
--------------------------------------------------------------------------------
/videosys/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/pipelines/__init__.py
--------------------------------------------------------------------------------
/videosys/pipelines/cogvideox/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
2 |
3 | __all__ = ["CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"]
4 |
--------------------------------------------------------------------------------
/videosys/pipelines/latte/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_latte import LatteConfig, LattePABConfig, LattePipeline
2 |
3 | __all__ = ["LatteConfig", "LattePipeline", "LattePABConfig"]
4 |
--------------------------------------------------------------------------------
/videosys/pipelines/open_sora/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
2 |
3 | __all__ = ["OpenSoraConfig", "OpenSoraPipeline", "OpenSoraPABConfig"]
4 |
--------------------------------------------------------------------------------
/videosys/pipelines/open_sora_plan/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_open_sora_plan import (
2 | OpenSoraPlanConfig,
3 | OpenSoraPlanPipeline,
4 | OpenSoraPlanV110PABConfig,
5 | OpenSoraPlanV120PABConfig,
6 | )
7 |
8 | __all__ = ["OpenSoraPlanConfig", "OpenSoraPlanPipeline", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig"]
9 |
--------------------------------------------------------------------------------
/videosys/pipelines/vchitect/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
2 |
3 | __all__ = ["VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"]
4 |
--------------------------------------------------------------------------------
/videosys/schedulers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/schedulers/__init__.py
--------------------------------------------------------------------------------
/videosys/schedulers/scheduling_rflow_open_sora.py:
--------------------------------------------------------------------------------
1 | # Adapted from OpenSora
2 |
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | # --------------------------------------------------------
6 | # References:
7 | # OpenSora: https://github.com/hpcaitech/Open-Sora
8 | # --------------------------------------------------------
9 |
10 | import torch
11 | import torch.distributed as dist
12 | from einops import rearrange
13 | from torch.distributions import LogisticNormal
14 | from tqdm import tqdm
15 |
16 |
17 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
18 | """
19 | Extract values from a 1-D numpy array for a batch of indices.
20 | :param arr: the 1-D numpy array.
21 | :param timesteps: a tensor of indices into the array to extract.
22 | :param broadcast_shape: a larger shape of K dimensions with the batch
23 | dimension equal to the length of timesteps.
24 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
25 | """
26 | res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
27 | while len(res.shape) < len(broadcast_shape):
28 | res = res[..., None]
29 | return res + torch.zeros(broadcast_shape, device=timesteps.device)
30 |
31 |
32 | def mean_flat(tensor: torch.Tensor, mask=None):
33 | """
34 | Take the mean over all non-batch dimensions.
35 | """
36 | if mask is None:
37 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
38 | else:
39 | assert tensor.dim() == 5
40 | assert tensor.shape[2] == mask.shape[1]
41 | tensor = rearrange(tensor, "b c t h w -> b t (c h w)")
42 | denom = mask.sum(dim=1) * tensor.shape[-1]
43 | loss = (tensor * mask.unsqueeze(2)).sum(dim=1).sum(dim=1) / denom
44 | return loss
45 |
46 |
47 | def timestep_transform(
48 | t,
49 | model_kwargs,
50 | base_resolution=512 * 512,
51 | base_num_frames=1,
52 | scale=1.0,
53 | num_timesteps=1,
54 | ):
55 | t = t / num_timesteps
56 | resolution = model_kwargs["height"] * model_kwargs["width"]
57 | ratio_space = (resolution / base_resolution).sqrt()
58 | # NOTE: currently, we do not take fps into account
59 | # NOTE: temporal_reduction is hardcoded, this should be equal to the temporal reduction factor of the vae
60 | if model_kwargs["num_frames"][0] == 1:
61 | num_frames = torch.ones_like(model_kwargs["num_frames"])
62 | else:
63 | num_frames = model_kwargs["num_frames"] // 17 * 5
64 | ratio_time = (num_frames / base_num_frames).sqrt()
65 |
66 | ratio = ratio_space * ratio_time * scale
67 | new_t = ratio * t / (1 + (ratio - 1) * t)
68 |
69 | new_t = new_t * num_timesteps
70 | return new_t
71 |
72 |
73 | class RFlowScheduler:
74 | def __init__(
75 | self,
76 | num_timesteps=1000,
77 | num_sampling_steps=10,
78 | use_discrete_timesteps=False,
79 | sample_method="uniform",
80 | loc=0.0,
81 | scale=1.0,
82 | use_timestep_transform=False,
83 | transform_scale=1.0,
84 | ):
85 | self.num_timesteps = num_timesteps
86 | self.num_sampling_steps = num_sampling_steps
87 | self.use_discrete_timesteps = use_discrete_timesteps
88 |
89 | # sample method
90 | assert sample_method in ["uniform", "logit-normal"]
91 | assert (
92 | sample_method == "uniform" or not use_discrete_timesteps
93 | ), "Only uniform sampling is supported for discrete timesteps"
94 | self.sample_method = sample_method
95 | if sample_method == "logit-normal":
96 | self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale]))
97 | self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device)
98 |
99 | # timestep transform
100 | self.use_timestep_transform = use_timestep_transform
101 | self.transform_scale = transform_scale
102 |
103 | def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
104 | """
105 | Compute training losses for a single timestep.
106 | Arguments format copied from opensora/schedulers/iddpm/gaussian_diffusion.py/training_losses
107 | Note: t is int tensor and should be rescaled from [0, num_timesteps-1] to [1,0]
108 | """
109 | if t is None:
110 | if self.use_discrete_timesteps:
111 | t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device)
112 | elif self.sample_method == "uniform":
113 | t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_timesteps
114 | elif self.sample_method == "logit-normal":
115 | t = self.sample_t(x_start) * self.num_timesteps
116 |
117 | if self.use_timestep_transform:
118 | t = timestep_transform(t, model_kwargs, scale=self.transform_scale, num_timesteps=self.num_timesteps)
119 |
120 | if model_kwargs is None:
121 | model_kwargs = {}
122 | if noise is None:
123 | noise = torch.randn_like(x_start)
124 | assert noise.shape == x_start.shape
125 |
126 | x_t = self.add_noise(x_start, noise, t)
127 | if mask is not None:
128 | t0 = torch.zeros_like(t)
129 | x_t0 = self.add_noise(x_start, noise, t0)
130 | x_t = torch.where(mask[:, None, :, None, None], x_t, x_t0)
131 |
132 | terms = {}
133 | model_output = model(x_t, t, **model_kwargs)
134 | velocity_pred = model_output.chunk(2, dim=1)[0]
135 | if weights is None:
136 | loss = mean_flat((velocity_pred - (x_start - noise)).pow(2), mask=mask)
137 | else:
138 | weight = _extract_into_tensor(weights, t, x_start.shape)
139 | loss = mean_flat(weight * (velocity_pred - (x_start - noise)).pow(2), mask=mask)
140 | terms["loss"] = loss
141 |
142 | return terms
143 |
144 | def add_noise(
145 | self,
146 | original_samples: torch.FloatTensor,
147 | noise: torch.FloatTensor,
148 | timesteps: torch.IntTensor,
149 | ) -> torch.FloatTensor:
150 | """
151 | compatible with diffusers add_noise()
152 | """
153 | timepoints = timesteps.float() / self.num_timesteps
154 | timepoints = 1 - timepoints # [1,1/1000]
155 |
156 | # timepoint (bsz) noise: (bsz, 4, frame, w ,h)
157 | # expand timepoint to noise shape
158 | timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
159 | timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4])
160 |
161 | return timepoints * original_samples + (1 - timepoints) * noise
162 |
163 |
164 | class RFLOW:
165 | def __init__(
166 | self,
167 | num_sampling_steps=10,
168 | num_timesteps=1000,
169 | cfg_scale=4.0,
170 | use_discrete_timesteps=False,
171 | use_timestep_transform=False,
172 | **kwargs,
173 | ):
174 | self.num_sampling_steps = num_sampling_steps
175 | self.num_timesteps = num_timesteps
176 | self.cfg_scale = cfg_scale
177 | self.use_discrete_timesteps = use_discrete_timesteps
178 | self.use_timestep_transform = use_timestep_transform
179 |
180 | self.scheduler = RFlowScheduler(
181 | num_timesteps=num_timesteps,
182 | num_sampling_steps=num_sampling_steps,
183 | use_discrete_timesteps=use_discrete_timesteps,
184 | use_timestep_transform=use_timestep_transform,
185 | **kwargs,
186 | )
187 |
188 | def sample(
189 | self,
190 | model,
191 | z,
192 | model_args,
193 | y_null,
194 | device,
195 | mask=None,
196 | guidance_scale=None,
197 | progress=True,
198 | verbose=False,
199 | ):
200 | # if no specific guidance scale is provided, use the default scale when initializing the scheduler
201 | if guidance_scale is None:
202 | guidance_scale = self.cfg_scale
203 |
204 | # text encoding
205 | model_args["y"] = torch.cat([model_args["y"], y_null], 0)
206 |
207 | # prepare timesteps
208 | timesteps = [(1.0 - i / self.num_sampling_steps) * self.num_timesteps for i in range(self.num_sampling_steps)]
209 | if self.use_discrete_timesteps:
210 | timesteps = [int(round(t)) for t in timesteps]
211 | timesteps = [torch.tensor([t] * z.shape[0], device=device) for t in timesteps]
212 | if self.use_timestep_transform:
213 | timesteps = [timestep_transform(t, model_args, num_timesteps=self.num_timesteps) for t in timesteps]
214 |
215 | if mask is not None:
216 | noise_added = torch.zeros_like(mask, dtype=torch.bool)
217 | noise_added = noise_added | (mask == 1)
218 |
219 | progress_wrap = tqdm if progress and dist.get_rank() == 0 else (lambda x: x)
220 |
221 | dtype = model.x_embedder.proj.weight.dtype
222 | all_timesteps = [int(t.to(dtype).item()) for t in timesteps]
223 | for i, t in progress_wrap(list(enumerate(timesteps))):
224 | # mask for adding noise
225 | if mask is not None:
226 | mask_t = mask * self.num_timesteps
227 | x0 = z.clone()
228 | x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t)
229 |
230 | mask_t_upper = mask_t >= t.unsqueeze(1)
231 | model_args["x_mask"] = mask_t_upper.repeat(2, 1)
232 | mask_add_noise = mask_t_upper & ~noise_added
233 |
234 | z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0)
235 | noise_added = mask_t_upper
236 |
237 | # classifier-free guidance
238 | z_in = torch.cat([z, z], 0)
239 | t = torch.cat([t, t], 0)
240 |
241 | # pred = model(z_in, t, **model_args).chunk(2, dim=1)[0]
242 | output = model(z_in, t, all_timesteps, **model_args)
243 |
244 | pred = output.chunk(2, dim=1)[0]
245 | pred_cond, pred_uncond = pred.chunk(2, dim=0)
246 | v_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
247 |
248 | # update z
249 | dt = timesteps[i] - timesteps[i + 1] if i < len(timesteps) - 1 else timesteps[i]
250 | dt = dt / self.num_timesteps
251 | z = z + v_pred * dt[:, None, None, None, None]
252 |
253 | if mask is not None:
254 | z = torch.where(mask_t_upper[:, None, :, None, None], z, x0)
255 |
256 | return z
257 |
258 | def training_losses(self, model, x_start, model_kwargs=None, noise=None, mask=None, weights=None, t=None):
259 | return self.scheduler.training_losses(model, x_start, model_kwargs, noise, mask, weights, t)
260 |
--------------------------------------------------------------------------------
/videosys/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/TeaCache/3dd7c3ffa2bb7487498f2e2b0898e0a9b9be51ac/videosys/utils/__init__.py
--------------------------------------------------------------------------------
/videosys/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.distributed as dist
4 | from rich.logging import RichHandler
5 |
6 |
7 | def create_logger():
8 | """
9 | Create a logger that writes to a log file and stdout.
10 | """
11 | logger = logging.getLogger(__name__)
12 | return logger
13 |
14 |
15 | def init_dist_logger():
16 | """
17 | Update the logger to write to a log file.
18 | """
19 | global logger
20 | if dist.get_rank() == 0:
21 | logger = logging.getLogger(__name__)
22 | handler = RichHandler(show_path=False, markup=True, rich_tracebacks=True)
23 | formatter = logging.Formatter("VideoSys - %(levelname)s: %(message)s")
24 | handler.setFormatter(formatter)
25 | logger.addHandler(handler)
26 | logger.setLevel(logging.INFO)
27 | else: # dummy logger (does nothing)
28 | logger = logging.getLogger(__name__)
29 | logger.addHandler(logging.NullHandler())
30 |
31 |
32 | logger = create_logger()
33 |
--------------------------------------------------------------------------------
/videosys/utils/test.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 |
5 |
6 | def empty_cache(func):
7 | @functools.wraps(func)
8 | def wrapper(*args, **kwargs):
9 | torch.cuda.empty_cache()
10 | return func(*args, **kwargs)
11 |
12 | return wrapper
13 |
--------------------------------------------------------------------------------
/videosys/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import imageio
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 | from omegaconf import DictConfig, ListConfig, OmegaConf
9 |
10 |
11 | def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
12 | """
13 | Set requires_grad flag for all parameters in a model.
14 | """
15 | for p in model.parameters():
16 | p.requires_grad = flag
17 |
18 |
19 | def set_seed(seed, dp_rank=None):
20 | if seed == -1:
21 | seed = random.randint(0, 1000000)
22 |
23 | if dp_rank is not None:
24 | seed = torch.tensor(seed, dtype=torch.int64).cuda()
25 | if dist.get_world_size() > 1:
26 | dist.broadcast(seed, 0)
27 | seed = seed + dp_rank
28 |
29 | seed = int(seed)
30 | random.seed(seed)
31 | os.environ["PYTHONHASHSEED"] = str(seed)
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 | torch.cuda.manual_seed(seed)
35 |
36 |
37 | def str_to_dtype(x: str):
38 | if x == "fp32":
39 | return torch.float32
40 | elif x == "fp16":
41 | return torch.float16
42 | elif x == "bf16":
43 | return torch.bfloat16
44 | else:
45 | raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
46 |
47 |
48 | def batch_func(func, *args):
49 | """
50 | Apply a function to each element of a batch.
51 | """
52 | batch = []
53 | for arg in args:
54 | if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
55 | batch.append(func(arg))
56 | else:
57 | batch.append(arg)
58 |
59 | return batch
60 |
61 |
62 | def merge_args(args1, args2):
63 | """
64 | Merge two argparse Namespace objects.
65 | """
66 | if args2 is None:
67 | return args1
68 |
69 | for k in args2._content.keys():
70 | if k in args1.__dict__:
71 | v = getattr(args2, k)
72 | if isinstance(v, ListConfig) or isinstance(v, DictConfig):
73 | v = OmegaConf.to_object(v)
74 | setattr(args1, k, v)
75 | else:
76 | raise RuntimeError(f"Unknown argument {k}")
77 |
78 | return args1
79 |
80 |
81 | def all_exists(paths):
82 | return all(os.path.exists(path) for path in paths)
83 |
84 |
85 | def save_video(video, output_path, fps):
86 | """
87 | Save a video to disk.
88 | """
89 | if dist.is_initialized() and dist.get_rank() != 0:
90 | return
91 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
92 | imageio.mimwrite(output_path, video, fps=fps)
93 |
--------------------------------------------------------------------------------