├── .gitignore
├── README.md
├── assets
├── draft-attention.png
└── video
│ ├── demo-bluedress-dense.gif
│ ├── demo-bluedress-sp0.9-ours.gif
│ ├── demo-bluedress-sp0.9-svg.gif
│ ├── demo-building-dense.gif
│ ├── demo-building-sp0.9-ours.gif
│ ├── demo-building-sp0.9-svg.gif
│ ├── demo-hunyuan_custom-768p-dense.gif
│ ├── demo-hunyuan_custom-768p-sp0.9.gif
│ ├── demo-pisa-dense.gif
│ ├── demo-pisa-sp0.9-ours.gif
│ └── demo-pisa-sp0.9-svg.gif
├── draft_attention.py
├── draft_attention_classifier_free_guidance.py
├── hunyuan
├── hyvideo
│ ├── __init__.py
│ ├── config.py
│ ├── constants.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── pipelines
│ │ │ ├── __init__.py
│ │ │ └── pipeline_hunyuan_video.py
│ │ └── schedulers
│ │ │ ├── __init__.py
│ │ │ └── scheduling_flow_match_discrete.py
│ ├── inference.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── activation_layers.py
│ │ ├── attenion.py
│ │ ├── embed_layers.py
│ │ ├── fp8_optimization.py
│ │ ├── mlp_layers.py
│ │ ├── models.py
│ │ ├── modulate_layers.py
│ │ ├── norm_layers.py
│ │ ├── posemb_layers.py
│ │ └── token_refiner.py
│ ├── prompt_rewrite.py
│ ├── text_encoder
│ │ └── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── data_utils.py
│ │ ├── file_utils.py
│ │ ├── helpers.py
│ │ └── preprocess_text_encoder_tokenizer_utils.py
│ └── vae
│ │ ├── __init__.py
│ │ ├── autoencoder_kl_causal_3d.py
│ │ ├── unet_causal_3d_blocks.py
│ │ └── vae.py
├── run-single-sample_video-fp8.sh
├── run-single-sample_video.sh
└── sample_video.py
├── hunyuan_custom
├── assets
│ ├── images
│ │ ├── method.png
│ │ ├── poodle.png
│ │ ├── seg_boy.png
│ │ ├── seg_man_01.png
│ │ ├── seg_man_02.png
│ │ ├── seg_man_03.png
│ │ ├── seg_poodle.png
│ │ ├── seg_woman_01.png
│ │ ├── seg_woman_02.png
│ │ └── seg_woman_03.png
│ ├── material
│ │ ├── application.png
│ │ ├── logo.png
│ │ ├── method.png
│ │ └── teaser.png
│ ├── meta_files.list
│ ├── meta_files
│ │ └── poodle.json
│ └── videos
│ │ ├── seg_man_01.mp4
│ │ ├── seg_man_02.mp4
│ │ ├── seg_woman_01.mp4
│ │ └── seg_woman_03.mp4
├── hymm_gradio
│ ├── flask_ref2v.py
│ ├── gradio_ref2v.py
│ └── tool_for_end2end.py
├── hymm_sp
│ ├── __init__.py
│ ├── config.py
│ ├── constants.py
│ ├── data_kits
│ │ ├── data_tools.py
│ │ └── video_dataset.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── pipelines
│ │ │ ├── __init__.py
│ │ │ └── pipeline_hunyuan_video_custom.py
│ │ └── schedulers
│ │ │ ├── __init__.py
│ │ │ └── scheduling_flow_match_discrete.py
│ ├── helpers.py
│ ├── inference.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── activation_layers.py
│ │ ├── attn_layers.py
│ │ ├── draft_attention_classifier_free_guidance.py
│ │ ├── embed_layers.py
│ │ ├── fp8_optimization.py
│ │ ├── mlp_layers.py
│ │ ├── models.py
│ │ ├── modulate_layers.py
│ │ ├── norm_layers.py
│ │ ├── parallel_states.py
│ │ ├── posemb_layers.py
│ │ └── token_refiner.py
│ ├── sample_batch.py
│ ├── sample_gpu_poor.py
│ ├── sample_inference.py
│ ├── text_encoder
│ │ └── __init__.py
│ └── vae
│ │ ├── __init__.py
│ │ ├── autoencoder_kl_causal_3d.py
│ │ ├── unet_causal_3d_blocks.py
│ │ └── vae.py
├── models
│ └── README.md
└── run-single-video-8xA100.sh
└── wan
├── generate.py
├── run-single-inference.sh
└── wan
├── __init__.py
├── configs
├── __init__.py
├── shared_config.py
├── wan_i2v_14B.py
├── wan_t2v_14B.py
└── wan_t2v_1_3B.py
├── distributed
├── __init__.py
├── fsdp.py
└── xdit_context_parallel.py
├── first_last_frame2video.py
├── image2video.py
├── modules
├── __init__.py
├── attention.py
├── clip.py
├── model.py
├── t5.py
├── tokenizers.py
├── vae.py
└── xlm_roberta.py
├── text2video.py
└── utils
├── __init__.py
├── fm_solvers.py
├── fm_solvers_unipc.py
├── prompt_extend.py
├── qwen_vl_utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### PythonVanilla template
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | z-helloworld/
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .nox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *.cover
46 | *.py,cover
47 | .hypothesis/
48 | .pytest_cache/
49 | cover/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # pyenv
56 | # For a library or package, you might want to ignore these files since the code is
57 | # intended to run in multiple environments; otherwise, check them in:
58 | # .python-version
59 |
60 | # pipenv
61 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
62 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
63 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
64 | # install all needed dependencies.
65 | #Pipfile.lock
66 |
67 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
68 | __pypackages__/
69 |
70 |
71 | ### CircuitPython template
72 | .Trashes
73 | .metadata_never_index
74 | .fseventsd/
75 | boot_out.txt
76 |
77 | ### Python template
78 | # Byte-compiled / optimized / DLL files
79 | __pycache__/
80 | *.py[cod]
81 | *$py.class
82 |
83 | # C extensions
84 | *.so
85 |
86 | # Distribution / packaging
87 | .Python
88 | build/
89 | develop-eggs/
90 | dist/
91 | downloads/
92 | eggs/
93 | .eggs/
94 | lib/
95 | lib64/
96 | parts/
97 | sdist/
98 | var/
99 | wheels/
100 | share/python-wheels/
101 | *.egg-info/
102 | .installed.cfg
103 | *.egg
104 | MANIFEST
105 |
106 | # PyInstaller
107 | # Usually these files are written by a python script from a template
108 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
109 | *.manifest
110 | *.spec
111 |
112 | # Installer logs
113 | pip-log.txt
114 | pip-delete-this-directory.txt
115 |
116 | # Unit test / coverage reports
117 | htmlcov/
118 | .tox/
119 | .nox/
120 | .coverage
121 | .coverage.*
122 | .cache
123 | nosetests.xml
124 | coverage.xml
125 | *.cover
126 | *.py,cover
127 | .hypothesis/
128 | .pytest_cache/
129 | cover/
130 |
131 | # Translations
132 | *.mo
133 | *.pot
134 |
135 | # Django stuff:
136 | *.log
137 | local_settings.py
138 | db.sqlite3
139 | db.sqlite3-journal
140 |
141 | # Flask stuff:
142 | instance/
143 | .webassets-cache
144 |
145 | # Scrapy stuff:
146 | .scrapy
147 |
148 | # Sphinx documentation
149 | docs/_build/
150 |
151 | # PyBuilder
152 | .pybuilder/
153 | target/
154 |
155 | # Jupyter Notebook
156 | .ipynb_checkpoints
157 |
158 | # IPython
159 | profile_default/
160 | ipython_config.py
161 |
162 | # pyenv
163 | # For a library or package, you might want to ignore these files since the code is
164 | # intended to run in multiple environments; otherwise, check them in:
165 | # .python-version
166 |
167 | # pipenv
168 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
169 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
170 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
171 | # install all needed dependencies.
172 | #Pipfile.lock
173 |
174 | # poetry
175 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
176 | # This is especially recommended for binary packages to ensure reproducibility, and is more
177 | # commonly ignored for libraries.
178 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
179 | #poetry.lock
180 |
181 | # pdm
182 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
183 | #pdm.lock
184 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
185 | # in version control.
186 | # https://pdm.fming.dev/#use-with-ide
187 | .pdm.toml
188 |
189 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
190 | __pypackages__/
191 |
192 | # Celery stuff
193 | celerybeat-schedule
194 | celerybeat.pid
195 |
196 | # SageMath parsed files
197 | *.sage.py
198 |
199 | # Environments
200 | .env
201 | .venv
202 | env/
203 | venv/
204 | ENV/
205 | env.bak/
206 | venv.bak/
207 |
208 | # Spyder project settings
209 | .spyderproject
210 | .spyproject
211 |
212 | # Rope project settings
213 | .ropeproject
214 |
215 | # mkdocs documentation
216 | /site
217 |
218 | # mypy
219 | .mypy_cache/
220 | .dmypy.json
221 | dmypy.json
222 |
223 | # Pyre type checker
224 | .pyre/
225 |
226 | # pytype static type analyzer
227 | .pytype/
228 |
229 | # Cython debug symbols
230 | cython_debug/
231 |
232 | # PyCharm
233 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
234 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
235 | # and can be added to the global gitignore or merged into this file. For a more nuclear
236 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
237 | #.idea/
238 |
239 | /z-helloworld/
240 | /z-helloworld/
241 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | # Draft Attention
7 |
8 | This repository provides an overview of all resources for the paper
9 | ["DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance"](https://arxiv.org/abs/2505.14708).
10 |
11 |
12 | Draft Attention is a plug-and-play acceleration method for video diffusion transformers.
13 |
14 | Draft Attention reshapes long queries and keys into frame-wise feature maps and applying 2D average pooling to downsample them.
15 |
16 | Draft Attention provides the reference for the sparse attention in full length.
17 |
18 | Draft Attention introduces minimal overhead by compressing the number of tokens 128x or larger.
19 |
20 |
21 | ## 🔥 News
22 | - [2025/05] We support [HunyuanCustom](https://github.com/Tencent/HunyuanCustom) with classifier free guidance.
23 |
24 |
25 |
26 | ## 🎥 Demo
27 |
28 | ### Hunyuan
29 |
45 |
46 | Prompt:
47 | "The banks of the Thames, as the camera moves vertically from low to high."
48 |
49 |
50 |
51 |
52 |
68 |
69 | Prompt:
70 | "On the green grass, the white-walled Leaning Tower of Pisa stands tall. The camera moves vertically from top to bottom during filming."
71 |
72 |
73 |
74 |
75 |
91 |
92 | Prompt:
93 | "A blue long dress fell from the balcony clothes rack and dropped into the water on the ground."
94 |
95 |
96 | Prompts are all from the Penguin Video Benchmark.
97 |
98 | Videos are generated with sparsity 90%, seed 42, using Hunyuan model in 768p on A100 GPU.
99 |
100 | ### HunyuanCustom
101 |
102 |
103 |
104 |
105 | 
106 | Input Image
107 | |
108 |
109 | 
110 | Dense Attention
111 | |
112 |
113 | 
114 | Draft Attention (Ours)
115 | |
116 |
117 |
118 |
119 | Prompt:
120 | "Realistic, High-quality. A woman is drinking coffee at a café."
121 |
122 |
123 | Videos are generated with seed 42 in 768p resolution on 8xA100 GPUs, with either dense attention or 90% sparse attention.
124 |
125 |
126 |
127 | ## 🚀 Quick Start
128 |
129 | ### Model Preparation
130 | Please follow the instruction of environment setup and download the checkpoint from [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [Wan2.1](https://github.com/Wan-Video/Wan2.1), and [HunyuanCustom](https://github.com/Tencent/HunyuanCustom).
131 |
132 | ### Sparse Attention
133 | We mainly adopt the [block sparse attention](https://github.com/mit-han-lab/Block-Sparse-Attention) for draft attention.
134 |
135 | ### Video Generation
136 | Simply run video generation with scripts in `hunyuan/`, `wan/` or `hunyuan_custom/`.
137 |
138 | Evaluation results in the paper are mainly achieved with [VBench](https://github.com/Vchitect/VBench) on [Penguin Video Benchmark](https://github.com/Tencent/HunyuanVideo/blob/main/assets/PenguinVideoBenchmark.csv) using HunyuanVideo and Wan2.1.
139 |
140 | ### Use for Your Own
141 | You can simply use the draft attention similar as the flash attention through the `Draft_Attention` defined in `draft_attention.py` or `draft_attention_classifier_free_guidance.py`.
142 |
143 | Here is the example for hunyuan model:
144 | ```python3
145 | from draft_attention import Draft_Attention
146 |
147 | draft_attention = Draft_Attention(
148 | pool_h=8,
149 | pool_w=16,
150 | latent_h=48,
151 | latent_w=80,
152 | visual_len=126_720,
153 | text_len=256,
154 | sparsity_ratio=0.9,
155 | )
156 |
157 | x = draft_attention(
158 | q,
159 | k,
160 | v,
161 | attn_mask=attn_mask,
162 | causal=causal,
163 | drop_rate=drop_rate,
164 | cu_seqlens_q=cu_seqlens_q,
165 | cu_seqlens_kv=cu_seqlens_kv,
166 | max_seqlen_q=max_seqlen_q,
167 | max_seqlen_kv=max_seqlen_kv,
168 | batch_size=batch_size,
169 | )
170 | ```
171 |
172 | ## ✏️ TODO
173 | - [ ] Support any-resolution video generation with padding.
174 | - [ ] Support reordering of further block sparse grouping for faster hardware execution.
175 |
176 | ## 📑 Acknowledgement
177 | This work is mainly contributed by [Xuan](https://shawnricecake.github.io) and [Chenxia](https://cxhan.com/).
178 |
179 |
180 | ## 🔗 BibTeX
181 | If you find Draft Attention is interesting, please cite through BibTeX:
182 | ```bibtex
183 | @article{shen2025draft,
184 | title={DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance},
185 | author={Shen, Xuan and Han, Chenxia and Zhou, Yufa and Xie, Yanyue and Gong, Yifan and Wang, Quanyi and Wang, Yiwei and Wang, Yanzhi and Zhao, Pu and Gu, Jiuxiang},
186 | journal={arXiv preprint arXiv:2505.14708},
187 | year={2025}
188 | }
189 | ```
190 |
--------------------------------------------------------------------------------
/assets/draft-attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/draft-attention.png
--------------------------------------------------------------------------------
/assets/video/demo-bluedress-dense.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-dense.gif
--------------------------------------------------------------------------------
/assets/video/demo-bluedress-sp0.9-ours.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-sp0.9-ours.gif
--------------------------------------------------------------------------------
/assets/video/demo-bluedress-sp0.9-svg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-sp0.9-svg.gif
--------------------------------------------------------------------------------
/assets/video/demo-building-dense.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-dense.gif
--------------------------------------------------------------------------------
/assets/video/demo-building-sp0.9-ours.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-sp0.9-ours.gif
--------------------------------------------------------------------------------
/assets/video/demo-building-sp0.9-svg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-sp0.9-svg.gif
--------------------------------------------------------------------------------
/assets/video/demo-hunyuan_custom-768p-dense.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-hunyuan_custom-768p-dense.gif
--------------------------------------------------------------------------------
/assets/video/demo-hunyuan_custom-768p-sp0.9.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-hunyuan_custom-768p-sp0.9.gif
--------------------------------------------------------------------------------
/assets/video/demo-pisa-dense.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-dense.gif
--------------------------------------------------------------------------------
/assets/video/demo-pisa-sp0.9-ours.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-sp0.9-ours.gif
--------------------------------------------------------------------------------
/assets/video/demo-pisa-sp0.9-svg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-sp0.9-svg.gif
--------------------------------------------------------------------------------
/hunyuan/hyvideo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan/hyvideo/__init__.py
--------------------------------------------------------------------------------
/hunyuan/hyvideo/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | __all__ = [
5 | "C_SCALE",
6 | "PROMPT_TEMPLATE",
7 | "MODEL_BASE",
8 | "PRECISIONS",
9 | "NORMALIZATION_TYPE",
10 | "ACTIVATION_TYPE",
11 | "VAE_PATH",
12 | "TEXT_ENCODER_PATH",
13 | "TOKENIZER_PATH",
14 | "TEXT_PROJECTION",
15 | "DATA_TYPE",
16 | "NEGATIVE_PROMPT",
17 | ]
18 |
19 | PRECISION_TO_TYPE = {
20 | 'fp32': torch.float32,
21 | 'fp16': torch.float16,
22 | 'bf16': torch.bfloat16,
23 | }
24 |
25 | # =================== Constant Values =====================
26 | # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
27 | # overflow error when tensorboard logging values.
28 | C_SCALE = 1_000_000_000_000_000
29 |
30 | # When using decoder-only models, we must provide a prompt template to instruct the text encoder
31 | # on how to generate the text.
32 | # --------------------------------------------------------------------
33 | PROMPT_TEMPLATE_ENCODE = (
34 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
35 | "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
36 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
37 | )
38 | PROMPT_TEMPLATE_ENCODE_VIDEO = (
39 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
40 | "1. The main content and theme of the video."
41 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
42 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
43 | "4. background environment, light, style and atmosphere."
44 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
45 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
46 | )
47 |
48 | NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
49 |
50 | PROMPT_TEMPLATE = {
51 | "dit-llm-encode": {
52 | "template": PROMPT_TEMPLATE_ENCODE,
53 | "crop_start": 36,
54 | },
55 | "dit-llm-encode-video": {
56 | "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
57 | "crop_start": 95,
58 | },
59 | }
60 |
61 | # ======================= Model ======================
62 | PRECISIONS = {"fp32", "fp16", "bf16"}
63 | NORMALIZATION_TYPE = {"layer", "rms"}
64 | ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
65 |
66 | # =================== Model Path =====================
67 | MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
68 |
69 | # =================== Data =======================
70 | DATA_TYPE = {"image", "video", "image_video"}
71 |
72 | # 3D VAE
73 | VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
74 |
75 | # Text Encoder
76 | TEXT_ENCODER_PATH = {
77 | "clipL": f"{MODEL_BASE}/text_encoder_2",
78 | "llm": f"{MODEL_BASE}/text_encoder",
79 | }
80 |
81 | # Tokenizer
82 | TOKENIZER_PATH = {
83 | "clipL": f"{MODEL_BASE}/text_encoder_2",
84 | "llm": f"{MODEL_BASE}/text_encoder",
85 | }
86 |
87 | TEXT_PROJECTION = {
88 | "linear", # Default, an nn.Linear() layer
89 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT
90 | }
91 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import HunyuanVideoPipeline
2 | from .schedulers import FlowMatchDiscreteScheduler
3 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/diffusion/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_hunyuan_video import HunyuanVideoPipeline
2 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/diffusion/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
2 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | #
16 | # Modified from diffusers==0.29.2
17 | #
18 | # ==============================================================================
19 |
20 | from dataclasses import dataclass
21 | from typing import Optional, Tuple, Union
22 |
23 | import numpy as np
24 | import torch
25 |
26 | from diffusers.configuration_utils import ConfigMixin, register_to_config
27 | from diffusers.utils import BaseOutput, logging
28 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
29 |
30 |
31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32 |
33 |
34 | @dataclass
35 | class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36 | """
37 | Output class for the scheduler's `step` function output.
38 |
39 | Args:
40 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42 | denoising loop.
43 | """
44 |
45 | prev_sample: torch.FloatTensor
46 |
47 |
48 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49 | """
50 | Euler scheduler.
51 |
52 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53 | methods the library implements for all schedulers such as loading and saving.
54 |
55 | Args:
56 | num_train_timesteps (`int`, defaults to 1000):
57 | The number of diffusion steps to train the model.
58 | timestep_spacing (`str`, defaults to `"linspace"`):
59 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61 | shift (`float`, defaults to 1.0):
62 | The shift value for the timestep schedule.
63 | reverse (`bool`, defaults to `True`):
64 | Whether to reverse the timestep schedule.
65 | """
66 |
67 | _compatibles = []
68 | order = 1
69 |
70 | @register_to_config
71 | def __init__(
72 | self,
73 | num_train_timesteps: int = 1000,
74 | shift: float = 1.0,
75 | reverse: bool = True,
76 | solver: str = "euler",
77 | n_tokens: Optional[int] = None,
78 | ):
79 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80 |
81 | if not reverse:
82 | sigmas = sigmas.flip(0)
83 |
84 | self.sigmas = sigmas
85 | # the value fed to model
86 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87 |
88 | self._step_index = None
89 | self._begin_index = None
90 |
91 | self.supported_solver = ["euler"]
92 | if solver not in self.supported_solver:
93 | raise ValueError(
94 | f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95 | )
96 |
97 | @property
98 | def step_index(self):
99 | """
100 | The index counter for current timestep. It will increase 1 after each scheduler step.
101 | """
102 | return self._step_index
103 |
104 | @property
105 | def begin_index(self):
106 | """
107 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108 | """
109 | return self._begin_index
110 |
111 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112 | def set_begin_index(self, begin_index: int = 0):
113 | """
114 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115 |
116 | Args:
117 | begin_index (`int`):
118 | The begin index for the scheduler.
119 | """
120 | self._begin_index = begin_index
121 |
122 | def _sigma_to_t(self, sigma):
123 | return sigma * self.config.num_train_timesteps
124 |
125 | def set_timesteps(
126 | self,
127 | num_inference_steps: int,
128 | device: Union[str, torch.device] = None,
129 | n_tokens: int = None,
130 | ):
131 | """
132 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133 |
134 | Args:
135 | num_inference_steps (`int`):
136 | The number of diffusion steps used when generating samples with a pre-trained model.
137 | device (`str` or `torch.device`, *optional*):
138 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139 | n_tokens (`int`, *optional*):
140 | Number of tokens in the input sequence.
141 | """
142 | self.num_inference_steps = num_inference_steps
143 |
144 | sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145 | sigmas = self.sd3_time_shift(sigmas)
146 |
147 | if not self.config.reverse:
148 | sigmas = 1 - sigmas
149 |
150 | self.sigmas = sigmas
151 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152 | dtype=torch.float32, device=device
153 | )
154 |
155 | # Reset step index
156 | self._step_index = None
157 |
158 | def index_for_timestep(self, timestep, schedule_timesteps=None):
159 | if schedule_timesteps is None:
160 | schedule_timesteps = self.timesteps
161 |
162 | indices = (schedule_timesteps == timestep).nonzero()
163 |
164 | # The sigma index that is taken for the **very** first `step`
165 | # is always the second index (or the last index if there is only 1)
166 | # This way we can ensure we don't accidentally skip a sigma in
167 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168 | pos = 1 if len(indices) > 1 else 0
169 |
170 | return indices[pos].item()
171 |
172 | def _init_step_index(self, timestep):
173 | if self.begin_index is None:
174 | if isinstance(timestep, torch.Tensor):
175 | timestep = timestep.to(self.timesteps.device)
176 | self._step_index = self.index_for_timestep(timestep)
177 | else:
178 | self._step_index = self._begin_index
179 |
180 | def scale_model_input(
181 | self, sample: torch.Tensor, timestep: Optional[int] = None
182 | ) -> torch.Tensor:
183 | return sample
184 |
185 | def sd3_time_shift(self, t: torch.Tensor):
186 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187 |
188 | def step(
189 | self,
190 | model_output: torch.FloatTensor,
191 | timestep: Union[float, torch.FloatTensor],
192 | sample: torch.FloatTensor,
193 | return_dict: bool = True,
194 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195 | """
196 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197 | process from the learned model outputs (most often the predicted noise).
198 |
199 | Args:
200 | model_output (`torch.FloatTensor`):
201 | The direct output from learned diffusion model.
202 | timestep (`float`):
203 | The current discrete timestep in the diffusion chain.
204 | sample (`torch.FloatTensor`):
205 | A current instance of a sample created by the diffusion process.
206 | generator (`torch.Generator`, *optional*):
207 | A random number generator.
208 | n_tokens (`int`, *optional*):
209 | Number of tokens in the input sequence.
210 | return_dict (`bool`):
211 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212 | tuple.
213 |
214 | Returns:
215 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217 | returned, otherwise a tuple is returned where the first element is the sample tensor.
218 | """
219 |
220 | if (
221 | isinstance(timestep, int)
222 | or isinstance(timestep, torch.IntTensor)
223 | or isinstance(timestep, torch.LongTensor)
224 | ):
225 | raise ValueError(
226 | (
227 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229 | " one of the `scheduler.timesteps` as a timestep."
230 | ),
231 | )
232 |
233 | if self.step_index is None:
234 | self._init_step_index(timestep)
235 |
236 | # Upcast to avoid precision issues when computing prev_sample
237 | sample = sample.to(torch.float32)
238 |
239 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240 |
241 | if self.config.solver == "euler":
242 | prev_sample = sample + model_output.to(torch.float32) * dt
243 | else:
244 | raise ValueError(
245 | f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246 | )
247 |
248 | # upon completion increase step index by one
249 | self._step_index += 1
250 |
251 | if not return_dict:
252 | return (prev_sample,)
253 |
254 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255 |
256 | def __len__(self):
257 | return self.config.num_train_timesteps
258 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2 |
3 |
4 | def load_model(args, in_channels, out_channels, factor_kwargs):
5 | """load hunyuan video model
6 |
7 | Args:
8 | args (dict): model args
9 | in_channels (int): input channels number
10 | out_channels (int): output channels number
11 | factor_kwargs (dict): factor kwargs
12 |
13 | Returns:
14 | model (nn.Module): The hunyuan video model
15 | """
16 | if args.model in HUNYUAN_VIDEO_CONFIG.keys():
17 | model = HYVideoDiffusionTransformer(
18 | args,
19 | in_channels=in_channels,
20 | out_channels=out_channels,
21 | **HUNYUAN_VIDEO_CONFIG[args.model],
22 | **factor_kwargs,
23 | )
24 | return model
25 | else:
26 | raise NotImplementedError()
27 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def get_activation_layer(act_type):
5 | """get activation layer
6 |
7 | Args:
8 | act_type (str): the activation type
9 |
10 | Returns:
11 | torch.nn.functional: the activation layer
12 | """
13 | if act_type == "gelu":
14 | return lambda: nn.GELU()
15 | elif act_type == "gelu_tanh":
16 | # Approximate `tanh` requires torch >= 1.13
17 | return lambda: nn.GELU(approximate="tanh")
18 | elif act_type == "relu":
19 | return nn.ReLU
20 | elif act_type == "silu":
21 | return nn.SiLU
22 | else:
23 | raise ValueError(f"Unknown activation type: {act_type}")
24 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange, repeat
5 |
6 | from ..utils.helpers import to_2tuple
7 |
8 |
9 | class PatchEmbed(nn.Module):
10 | """2D Image to Patch Embedding
11 |
12 | Image to Patch Embedding using Conv2d
13 |
14 | A convolution based approach to patchifying a 2D image w/ embedding projection.
15 |
16 | Based on the impl in https://github.com/google-research/vision_transformer
17 |
18 | Hacked together by / Copyright 2020 Ross Wightman
19 |
20 | Remove the _assert function in forward function to be compatible with multi-resolution images.
21 | """
22 |
23 | def __init__(
24 | self,
25 | patch_size=16,
26 | in_chans=3,
27 | embed_dim=768,
28 | norm_layer=None,
29 | flatten=True,
30 | bias=True,
31 | dtype=None,
32 | device=None,
33 | ):
34 | factory_kwargs = {"dtype": dtype, "device": device}
35 | super().__init__()
36 | patch_size = to_2tuple(patch_size)
37 | self.patch_size = patch_size
38 | self.flatten = flatten
39 |
40 | self.proj = nn.Conv3d(
41 | in_chans,
42 | embed_dim,
43 | kernel_size=patch_size,
44 | stride=patch_size,
45 | bias=bias,
46 | **factory_kwargs
47 | )
48 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
49 | if bias:
50 | nn.init.zeros_(self.proj.bias)
51 |
52 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
53 |
54 | def forward(self, x):
55 | x = self.proj(x)
56 | if self.flatten:
57 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
58 | x = self.norm(x)
59 | return x
60 |
61 |
62 | class TextProjection(nn.Module):
63 | """
64 | Projects text embeddings. Also handles dropout for classifier-free guidance.
65 |
66 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
67 | """
68 |
69 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
70 | factory_kwargs = {"dtype": dtype, "device": device}
71 | super().__init__()
72 | self.linear_1 = nn.Linear(
73 | in_features=in_channels,
74 | out_features=hidden_size,
75 | bias=True,
76 | **factory_kwargs
77 | )
78 | self.act_1 = act_layer()
79 | self.linear_2 = nn.Linear(
80 | in_features=hidden_size,
81 | out_features=hidden_size,
82 | bias=True,
83 | **factory_kwargs
84 | )
85 |
86 | def forward(self, caption):
87 | hidden_states = self.linear_1(caption)
88 | hidden_states = self.act_1(hidden_states)
89 | hidden_states = self.linear_2(hidden_states)
90 | return hidden_states
91 |
92 |
93 | def timestep_embedding(t, dim, max_period=10000):
94 | """
95 | Create sinusoidal timestep embeddings.
96 |
97 | Args:
98 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
99 | dim (int): the dimension of the output.
100 | max_period (int): controls the minimum frequency of the embeddings.
101 |
102 | Returns:
103 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
104 |
105 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
106 | """
107 | half = dim // 2
108 | freqs = torch.exp(
109 | -math.log(max_period)
110 | * torch.arange(start=0, end=half, dtype=torch.float32)
111 | / half
112 | ).to(device=t.device)
113 | args = t[:, None].float() * freqs[None]
114 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
115 | if dim % 2:
116 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
117 | return embedding
118 |
119 |
120 | class TimestepEmbedder(nn.Module):
121 | """
122 | Embeds scalar timesteps into vector representations.
123 | """
124 |
125 | def __init__(
126 | self,
127 | hidden_size,
128 | act_layer,
129 | frequency_embedding_size=256,
130 | max_period=10000,
131 | out_size=None,
132 | dtype=None,
133 | device=None,
134 | ):
135 | factory_kwargs = {"dtype": dtype, "device": device}
136 | super().__init__()
137 | self.frequency_embedding_size = frequency_embedding_size
138 | self.max_period = max_period
139 | if out_size is None:
140 | out_size = hidden_size
141 |
142 | self.mlp = nn.Sequential(
143 | nn.Linear(
144 | frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
145 | ),
146 | act_layer(),
147 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
148 | )
149 | nn.init.normal_(self.mlp[0].weight, std=0.02)
150 | nn.init.normal_(self.mlp[2].weight, std=0.02)
151 |
152 | def forward(self, t):
153 | t_freq = timestep_embedding(
154 | t, self.frequency_embedding_size, self.max_period
155 | ).type(self.mlp[0].weight.dtype)
156 | t_emb = self.mlp(t_freq)
157 | return t_emb
158 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/fp8_optimization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8 | _bits = torch.tensor(bits)
9 | _mantissa_bit = torch.tensor(mantissa_bit)
10 | _sign_bits = torch.tensor(sign_bits)
11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12 | E = _bits - _sign_bits - M
13 | bias = 2 ** (E - 1) - 1
14 | mantissa = 1
15 | for i in range(mantissa_bit - 1):
16 | mantissa += 1 / (2 ** (i+1))
17 | maxval = mantissa * 2 ** (2**E - 1 - bias)
18 | return maxval
19 |
20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21 | """
22 | Default is E4M3.
23 | """
24 | bits = torch.tensor(bits)
25 | mantissa_bit = torch.tensor(mantissa_bit)
26 | sign_bits = torch.tensor(sign_bits)
27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28 | E = bits - sign_bits - M
29 | bias = 2 ** (E - 1) - 1
30 | mantissa = 1
31 | for i in range(mantissa_bit - 1):
32 | mantissa += 1 / (2 ** (i+1))
33 | maxval = mantissa * 2 ** (2**E - 1 - bias)
34 | minval = - maxval
35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36 | input_clamp = torch.min(torch.max(x, minval), maxval)
37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39 | # dequant
40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales
41 | return qdq_out, log_scales
42 |
43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44 | for i in range(len(x.shape) - 1):
45 | scale = scale.unsqueeze(-1)
46 | new_x = x / scale
47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48 | return quant_dequant_x, scale, log_scales
49 |
50 | def fp8_activation_dequant(qdq_out, scale, dtype):
51 | qdq_out = qdq_out.type(dtype)
52 | quant_dequant_x = qdq_out * scale.to(dtype)
53 | return quant_dequant_x
54 |
55 | def fp8_linear_forward(cls, original_dtype, input):
56 | weight_dtype = cls.weight.dtype
57 | #####
58 | if cls.weight.dtype != torch.float8_e4m3fn:
59 | maxval = get_fp_maxval()
60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62 | linear_weight = linear_weight.to(torch.float8_e4m3fn)
63 | weight_dtype = linear_weight.dtype
64 | else:
65 | scale = cls.fp8_scale.to(cls.weight.device)
66 | linear_weight = cls.weight
67 | #####
68 |
69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70 | if True or len(input.shape) == 3:
71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72 | if cls.bias != None:
73 | output = F.linear(input, cls_dequant, cls.bias)
74 | else:
75 | output = F.linear(input, cls_dequant)
76 | return output
77 | else:
78 | return cls.original_forward(input.to(original_dtype))
79 | else:
80 | return cls.original_forward(input)
81 |
82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83 | setattr(module, "fp8_matmul_enabled", True)
84 |
85 | # loading fp8 mapping file
86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87 | if os.path.exists(fp8_map_path):
88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89 | else:
90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91 |
92 | fp8_layers = []
93 | for key, layer in module.named_modules():
94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95 | fp8_layers.append(key)
96 | original_forward = layer.forward
97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99 | setattr(layer, "original_forward", original_forward)
100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101 |
102 |
103 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .modulate_layers import modulate
10 | from ..utils.helpers import to_2tuple
11 |
12 |
13 | class MLP(nn.Module):
14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15 |
16 | def __init__(
17 | self,
18 | in_channels,
19 | hidden_channels=None,
20 | out_features=None,
21 | act_layer=nn.GELU,
22 | norm_layer=None,
23 | bias=True,
24 | drop=0.0,
25 | use_conv=False,
26 | device=None,
27 | dtype=None,
28 | ):
29 | factory_kwargs = {"device": device, "dtype": dtype}
30 | super().__init__()
31 | out_features = out_features or in_channels
32 | hidden_channels = hidden_channels or in_channels
33 | bias = to_2tuple(bias)
34 | drop_probs = to_2tuple(drop)
35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36 |
37 | self.fc1 = linear_layer(
38 | in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39 | )
40 | self.act = act_layer()
41 | self.drop1 = nn.Dropout(drop_probs[0])
42 | self.norm = (
43 | norm_layer(hidden_channels, **factory_kwargs)
44 | if norm_layer is not None
45 | else nn.Identity()
46 | )
47 | self.fc2 = linear_layer(
48 | hidden_channels, out_features, bias=bias[1], **factory_kwargs
49 | )
50 | self.drop2 = nn.Dropout(drop_probs[1])
51 |
52 | def forward(self, x):
53 | x = self.fc1(x)
54 | x = self.act(x)
55 | x = self.drop1(x)
56 | x = self.norm(x)
57 | x = self.fc2(x)
58 | x = self.drop2(x)
59 | return x
60 |
61 |
62 | #
63 | class MLPEmbedder(nn.Module):
64 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66 | factory_kwargs = {"device": device, "dtype": dtype}
67 | super().__init__()
68 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69 | self.silu = nn.SiLU()
70 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71 |
72 | def forward(self, x: torch.Tensor) -> torch.Tensor:
73 | return self.out_layer(self.silu(self.in_layer(x)))
74 |
75 |
76 | class FinalLayer(nn.Module):
77 | """The final layer of DiT."""
78 |
79 | def __init__(
80 | self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81 | ):
82 | factory_kwargs = {"device": device, "dtype": dtype}
83 | super().__init__()
84 |
85 | # Just use LayerNorm for the final layer
86 | self.norm_final = nn.LayerNorm(
87 | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88 | )
89 | if isinstance(patch_size, int):
90 | self.linear = nn.Linear(
91 | hidden_size,
92 | patch_size * patch_size * out_channels,
93 | bias=True,
94 | **factory_kwargs
95 | )
96 | else:
97 | self.linear = nn.Linear(
98 | hidden_size,
99 | patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100 | bias=True,
101 | )
102 | nn.init.zeros_(self.linear.weight)
103 | nn.init.zeros_(self.linear.bias)
104 |
105 | # Here we don't distinguish between the modulate types. Just use the simple one.
106 | self.adaLN_modulation = nn.Sequential(
107 | act_layer(),
108 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109 | )
110 | # Zero-initialize the modulation
111 | nn.init.zeros_(self.adaLN_modulation[1].weight)
112 | nn.init.zeros_(self.adaLN_modulation[1].bias)
113 |
114 | def forward(self, x, c):
115 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
117 | x = self.linear(x)
118 | return x
119 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ModulateDiT(nn.Module):
8 | """Modulation layer for DiT."""
9 | def __init__(
10 | self,
11 | hidden_size: int,
12 | factor: int,
13 | act_layer: Callable,
14 | dtype=None,
15 | device=None,
16 | ):
17 | factory_kwargs = {"dtype": dtype, "device": device}
18 | super().__init__()
19 | self.act = act_layer()
20 | self.linear = nn.Linear(
21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22 | )
23 | # Zero-initialize the modulation
24 | nn.init.zeros_(self.linear.weight)
25 | nn.init.zeros_(self.linear.bias)
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | return self.linear(self.act(x))
29 |
30 |
31 | def modulate(x, shift=None, scale=None):
32 | """modulate by shift and scale
33 |
34 | Args:
35 | x (torch.Tensor): input tensor.
36 | shift (torch.Tensor, optional): shift tensor. Defaults to None.
37 | scale (torch.Tensor, optional): scale tensor. Defaults to None.
38 |
39 | Returns:
40 | torch.Tensor: the output tensor after modulate.
41 | """
42 | if scale is None and shift is None:
43 | return x
44 | elif shift is None:
45 | return x * (1 + scale.unsqueeze(1))
46 | elif scale is None:
47 | return x + shift.unsqueeze(1)
48 | else:
49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50 |
51 |
52 | def apply_gate(x, gate=None, tanh=False):
53 | """AI is creating summary for apply_gate
54 |
55 | Args:
56 | x (torch.Tensor): input tensor.
57 | gate (torch.Tensor, optional): gate tensor. Defaults to None.
58 | tanh (bool, optional): whether to use tanh function. Defaults to False.
59 |
60 | Returns:
61 | torch.Tensor: the output tensor after apply gate.
62 | """
63 | if gate is None:
64 | return x
65 | if tanh:
66 | return x * gate.unsqueeze(1).tanh()
67 | else:
68 | return x * gate.unsqueeze(1)
69 |
70 |
71 | def ckpt_wrapper(module):
72 | def ckpt_forward(*inputs):
73 | outputs = module(*inputs)
74 | return outputs
75 |
76 | return ckpt_forward
77 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(
7 | self,
8 | dim: int,
9 | elementwise_affine=True,
10 | eps: float = 1e-6,
11 | device=None,
12 | dtype=None,
13 | ):
14 | """
15 | Initialize the RMSNorm normalization layer.
16 |
17 | Args:
18 | dim (int): The dimension of the input tensor.
19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20 |
21 | Attributes:
22 | eps (float): A small value added to the denominator for numerical stability.
23 | weight (nn.Parameter): Learnable scaling parameter.
24 |
25 | """
26 | factory_kwargs = {"device": device, "dtype": dtype}
27 | super().__init__()
28 | self.eps = eps
29 | if elementwise_affine:
30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31 |
32 | def _norm(self, x):
33 | """
34 | Apply the RMSNorm normalization to the input tensor.
35 |
36 | Args:
37 | x (torch.Tensor): The input tensor.
38 |
39 | Returns:
40 | torch.Tensor: The normalized tensor.
41 |
42 | """
43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44 |
45 | def forward(self, x):
46 | """
47 | Forward pass through the RMSNorm layer.
48 |
49 | Args:
50 | x (torch.Tensor): The input tensor.
51 |
52 | Returns:
53 | torch.Tensor: The output tensor after applying RMSNorm.
54 |
55 | """
56 | output = self._norm(x.float()).type_as(x)
57 | if hasattr(self, "weight"):
58 | output = output * self.weight
59 | return output
60 |
61 |
62 | def get_norm_layer(norm_layer):
63 | """
64 | Get the normalization layer.
65 |
66 | Args:
67 | norm_layer (str): The type of normalization layer.
68 |
69 | Returns:
70 | norm_layer (nn.Module): The normalization layer.
71 | """
72 | if norm_layer == "layer":
73 | return nn.LayerNorm
74 | elif norm_layer == "rms":
75 | return RMSNorm
76 | else:
77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
78 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/modules/token_refiner.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from einops import rearrange
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .activation_layers import get_activation_layer
8 | from .attenion import attention
9 | from .norm_layers import get_norm_layer
10 | from .embed_layers import TimestepEmbedder, TextProjection
11 | from .attenion import attention
12 | from .mlp_layers import MLP
13 | from .modulate_layers import modulate, apply_gate
14 |
15 |
16 | class IndividualTokenRefinerBlock(nn.Module):
17 | def __init__(
18 | self,
19 | hidden_size,
20 | heads_num,
21 | mlp_width_ratio: str = 4.0,
22 | mlp_drop_rate: float = 0.0,
23 | act_type: str = "silu",
24 | qk_norm: bool = False,
25 | qk_norm_type: str = "layer",
26 | qkv_bias: bool = True,
27 | dtype: Optional[torch.dtype] = None,
28 | device: Optional[torch.device] = None,
29 | ):
30 | factory_kwargs = {"device": device, "dtype": dtype}
31 | super().__init__()
32 | self.heads_num = heads_num
33 | head_dim = hidden_size // heads_num
34 | mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35 |
36 | self.norm1 = nn.LayerNorm(
37 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
38 | )
39 | self.self_attn_qkv = nn.Linear(
40 | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
41 | )
42 | qk_norm_layer = get_norm_layer(qk_norm_type)
43 | self.self_attn_q_norm = (
44 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
45 | if qk_norm
46 | else nn.Identity()
47 | )
48 | self.self_attn_k_norm = (
49 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
50 | if qk_norm
51 | else nn.Identity()
52 | )
53 | self.self_attn_proj = nn.Linear(
54 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
55 | )
56 |
57 | self.norm2 = nn.LayerNorm(
58 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
59 | )
60 | act_layer = get_activation_layer(act_type)
61 | self.mlp = MLP(
62 | in_channels=hidden_size,
63 | hidden_channels=mlp_hidden_dim,
64 | act_layer=act_layer,
65 | drop=mlp_drop_rate,
66 | **factory_kwargs,
67 | )
68 |
69 | self.adaLN_modulation = nn.Sequential(
70 | act_layer(),
71 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
72 | )
73 | # Zero-initialize the modulation
74 | nn.init.zeros_(self.adaLN_modulation[1].weight)
75 | nn.init.zeros_(self.adaLN_modulation[1].bias)
76 |
77 | def forward(
78 | self,
79 | x: torch.Tensor,
80 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations
81 | attn_mask: torch.Tensor = None,
82 | ):
83 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
84 |
85 | norm_x = self.norm1(x)
86 | qkv = self.self_attn_qkv(norm_x)
87 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
88 | # Apply QK-Norm if needed
89 | q = self.self_attn_q_norm(q).to(v)
90 | k = self.self_attn_k_norm(k).to(v)
91 |
92 | # Self-Attention
93 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
94 |
95 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
96 |
97 | # FFN Layer
98 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
99 |
100 | return x
101 |
102 |
103 | class IndividualTokenRefiner(nn.Module):
104 | def __init__(
105 | self,
106 | hidden_size,
107 | heads_num,
108 | depth,
109 | mlp_width_ratio: float = 4.0,
110 | mlp_drop_rate: float = 0.0,
111 | act_type: str = "silu",
112 | qk_norm: bool = False,
113 | qk_norm_type: str = "layer",
114 | qkv_bias: bool = True,
115 | dtype: Optional[torch.dtype] = None,
116 | device: Optional[torch.device] = None,
117 | ):
118 | factory_kwargs = {"device": device, "dtype": dtype}
119 | super().__init__()
120 | self.blocks = nn.ModuleList(
121 | [
122 | IndividualTokenRefinerBlock(
123 | hidden_size=hidden_size,
124 | heads_num=heads_num,
125 | mlp_width_ratio=mlp_width_ratio,
126 | mlp_drop_rate=mlp_drop_rate,
127 | act_type=act_type,
128 | qk_norm=qk_norm,
129 | qk_norm_type=qk_norm_type,
130 | qkv_bias=qkv_bias,
131 | **factory_kwargs,
132 | )
133 | for _ in range(depth)
134 | ]
135 | )
136 |
137 | def forward(
138 | self,
139 | x: torch.Tensor,
140 | c: torch.LongTensor,
141 | mask: Optional[torch.Tensor] = None,
142 | ):
143 | self_attn_mask = None
144 | if mask is not None:
145 | batch_size = mask.shape[0]
146 | seq_len = mask.shape[1]
147 | mask = mask.to(x.device)
148 | # batch_size x 1 x seq_len x seq_len
149 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
150 | 1, 1, seq_len, 1
151 | )
152 | # batch_size x 1 x seq_len x seq_len
153 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
154 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
155 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
156 | # avoids self-attention weight being NaN for padding tokens
157 | self_attn_mask[:, :, :, 0] = True
158 |
159 | for block in self.blocks:
160 | x = block(x, c, self_attn_mask)
161 | return x
162 |
163 |
164 | class SingleTokenRefiner(nn.Module):
165 | """
166 | A single token refiner block for llm text embedding refine.
167 | """
168 | def __init__(
169 | self,
170 | in_channels,
171 | hidden_size,
172 | heads_num,
173 | depth,
174 | mlp_width_ratio: float = 4.0,
175 | mlp_drop_rate: float = 0.0,
176 | act_type: str = "silu",
177 | qk_norm: bool = False,
178 | qk_norm_type: str = "layer",
179 | qkv_bias: bool = True,
180 | attn_mode: str = "torch",
181 | dtype: Optional[torch.dtype] = None,
182 | device: Optional[torch.device] = None,
183 | ):
184 | factory_kwargs = {"device": device, "dtype": dtype}
185 | super().__init__()
186 | self.attn_mode = attn_mode
187 | assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
188 |
189 | self.input_embedder = nn.Linear(
190 | in_channels, hidden_size, bias=True, **factory_kwargs
191 | )
192 |
193 | act_layer = get_activation_layer(act_type)
194 | # Build timestep embedding layer
195 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
196 | # Build context embedding layer
197 | self.c_embedder = TextProjection(
198 | in_channels, hidden_size, act_layer, **factory_kwargs
199 | )
200 |
201 | self.individual_token_refiner = IndividualTokenRefiner(
202 | hidden_size=hidden_size,
203 | heads_num=heads_num,
204 | depth=depth,
205 | mlp_width_ratio=mlp_width_ratio,
206 | mlp_drop_rate=mlp_drop_rate,
207 | act_type=act_type,
208 | qk_norm=qk_norm,
209 | qk_norm_type=qk_norm_type,
210 | qkv_bias=qkv_bias,
211 | **factory_kwargs,
212 | )
213 |
214 | def forward(
215 | self,
216 | x: torch.Tensor,
217 | t: torch.LongTensor,
218 | mask: Optional[torch.LongTensor] = None,
219 | ):
220 | timestep_aware_representations = self.t_embedder(t)
221 |
222 | if mask is None:
223 | context_aware_representations = x.mean(dim=1)
224 | else:
225 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
226 | context_aware_representations = (x * mask_float).sum(
227 | dim=1
228 | ) / mask_float.sum(dim=1)
229 | context_aware_representations = self.c_embedder(context_aware_representations)
230 | c = timestep_aware_representations + context_aware_representations
231 |
232 | x = self.input_embedder(x)
233 |
234 | x = self.individual_token_refiner(x, c, mask)
235 |
236 | return x
237 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/prompt_rewrite.py:
--------------------------------------------------------------------------------
1 | normal_mode_prompt = """Normal mode - Video Recaption Task:
2 |
3 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4 |
5 | 0. Preserve ALL information, including style words and technical terms.
6 |
7 | 1. If the input is in Chinese, translate the entire description to English.
8 |
9 | 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10 |
11 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12 |
13 | 4. Output ALL must be in English.
14 |
15 | Given Input:
16 | input: "{input}"
17 | """
18 |
19 |
20 | master_mode_prompt = """Master mode - Video Recaption Task:
21 |
22 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
23 |
24 | 0. Preserve ALL information, including style words and technical terms.
25 |
26 | 1. If the input is in Chinese, translate the entire description to English.
27 |
28 | 2. To generate high-quality visual scenes with aesthetic appeal, it is necessary to carefully depict each visual element to create a unique aesthetic.
29 |
30 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
31 |
32 | 4. Output ALL must be in English.
33 |
34 | Given Input:
35 | input: "{input}"
36 | """
37 |
38 | def get_rewrite_prompt(ori_prompt, mode="Normal"):
39 | if mode == "Normal":
40 | prompt = normal_mode_prompt.format(input=ori_prompt)
41 | elif mode == "Master":
42 | prompt = master_mode_prompt.format(input=ori_prompt)
43 | else:
44 | raise Exception("Only supports Normal and Normal", mode)
45 | return prompt
46 |
47 | ori_prompt = "一只小狗在草地上奔跑。"
48 | normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
49 | master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
50 |
51 | # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
52 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan/hyvideo/utils/__init__.py
--------------------------------------------------------------------------------
/hunyuan/hyvideo/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 |
5 | def align_to(value, alignment):
6 | """align hight, width according to alignment
7 |
8 | Args:
9 | value (int): height or width
10 | alignment (int): target alignment factor
11 |
12 | Returns:
13 | int: the aligned value
14 | """
15 | return int(math.ceil(value / alignment) * alignment)
16 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from einops import rearrange
4 |
5 | import torch
6 | import torchvision
7 | import numpy as np
8 | import imageio
9 |
10 | CODE_SUFFIXES = {
11 | ".py", # Python codes
12 | ".sh", # Shell scripts
13 | ".yaml",
14 | ".yml", # Configuration files
15 | }
16 |
17 |
18 | def safe_dir(path):
19 | """
20 | Create a directory (or the parent directory of a file) if it does not exist.
21 |
22 | Args:
23 | path (str or Path): Path to the directory.
24 |
25 | Returns:
26 | path (Path): Path object of the directory.
27 | """
28 | path = Path(path)
29 | path.mkdir(exist_ok=True, parents=True)
30 | return path
31 |
32 |
33 | def safe_file(path):
34 | """
35 | Create the parent directory of a file if it does not exist.
36 |
37 | Args:
38 | path (str or Path): Path to the file.
39 |
40 | Returns:
41 | path (Path): Path object of the file.
42 | """
43 | path = Path(path)
44 | path.parent.mkdir(exist_ok=True, parents=True)
45 | return path
46 |
47 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
48 | """save videos by video tensor
49 | copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
50 |
51 | Args:
52 | videos (torch.Tensor): video tensor predicted by the model
53 | path (str): path to save video
54 | rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
55 | n_rows (int, optional): Defaults to 1.
56 | fps (int, optional): video save fps. Defaults to 8.
57 | """
58 | videos = rearrange(videos, "b c t h w -> t b c h w")
59 | outputs = []
60 | for x in videos:
61 | x = torchvision.utils.make_grid(x, nrow=n_rows)
62 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
63 | if rescale:
64 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
65 | x = torch.clamp(x, 0, 1)
66 | x = (x * 255).numpy().astype(np.uint8)
67 | outputs.append(x)
68 |
69 | os.makedirs(os.path.dirname(path), exist_ok=True)
70 | imageio.mimsave(path, outputs, fps=fps)
71 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 |
3 | from itertools import repeat
4 |
5 |
6 | def _ntuple(n):
7 | def parse(x):
8 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9 | x = tuple(x)
10 | if len(x) == 1:
11 | x = tuple(repeat(x[0], n))
12 | return x
13 | return tuple(repeat(x, n))
14 | return parse
15 |
16 |
17 | to_1tuple = _ntuple(1)
18 | to_2tuple = _ntuple(2)
19 | to_3tuple = _ntuple(3)
20 | to_4tuple = _ntuple(4)
21 |
22 |
23 | def as_tuple(x):
24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25 | return tuple(x)
26 | if x is None or isinstance(x, (int, float, str)):
27 | return (x,)
28 | else:
29 | raise ValueError(f"Unknown type {type(x)}")
30 |
31 |
32 | def as_list_of_2tuple(x):
33 | x = as_tuple(x)
34 | if len(x) == 1:
35 | x = (x[0], x[0])
36 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37 | lst = []
38 | for i in range(0, len(x), 2):
39 | lst.append((x[i], x[i + 1]))
40 | return lst
41 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from transformers import (
4 | AutoProcessor,
5 | LlavaForConditionalGeneration,
6 | )
7 |
8 |
9 | def preprocess_text_encoder_tokenizer(args):
10 |
11 | processor = AutoProcessor.from_pretrained(args.input_dir)
12 | model = LlavaForConditionalGeneration.from_pretrained(
13 | args.input_dir,
14 | torch_dtype=torch.float16,
15 | low_cpu_mem_usage=True,
16 | ).to(0)
17 |
18 | model.language_model.save_pretrained(
19 | f"{args.output_dir}"
20 | )
21 | processor.tokenizer.save_pretrained(
22 | f"{args.output_dir}"
23 | )
24 |
25 | if __name__ == "__main__":
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument(
29 | "--input_dir",
30 | type=str,
31 | required=True,
32 | help="The path to the llava-llama-3-8b-v1_1-transformers.",
33 | )
34 | parser.add_argument(
35 | "--output_dir",
36 | type=str,
37 | default="",
38 | help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
39 | "if '', the parent dir of output will be the same as input dir.",
40 | )
41 | args = parser.parse_args()
42 |
43 | if len(args.output_dir) == 0:
44 | args.output_dir = "/".join(args.input_dir.split("/")[:-1])
45 |
46 | preprocess_text_encoder_tokenizer(args)
47 |
--------------------------------------------------------------------------------
/hunyuan/hyvideo/vae/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 |
5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6 | from ..constants import VAE_PATH, PRECISION_TO_TYPE
7 |
8 | def load_vae(vae_type: str="884-16c-hy",
9 | vae_precision: str=None,
10 | sample_size: tuple=None,
11 | vae_path: str=None,
12 | logger=None,
13 | device=None
14 | ):
15 | """the fucntion to load the 3D VAE model
16 |
17 | Args:
18 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19 | vae_precision (str, optional): the precision to load vae. Defaults to None.
20 | sample_size (tuple, optional): the tiling size. Defaults to None.
21 | vae_path (str, optional): the path to vae. Defaults to None.
22 | logger (_type_, optional): logger. Defaults to None.
23 | device (_type_, optional): device to load vae. Defaults to None.
24 | """
25 | if vae_path is None:
26 | vae_path = VAE_PATH[vae_type]
27 |
28 | if logger is not None:
29 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30 | config = AutoencoderKLCausal3D.load_config(vae_path)
31 | if sample_size:
32 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33 | else:
34 | vae = AutoencoderKLCausal3D.from_config(config)
35 |
36 | vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38 |
39 | ckpt = torch.load(vae_ckpt, map_location=vae.device)
40 | if "state_dict" in ckpt:
41 | ckpt = ckpt["state_dict"]
42 | if any(k.startswith("vae.") for k in ckpt.keys()):
43 | ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44 | vae.load_state_dict(ckpt)
45 |
46 | spatial_compression_ratio = vae.config.spatial_compression_ratio
47 | time_compression_ratio = vae.config.time_compression_ratio
48 |
49 | if vae_precision is not None:
50 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51 |
52 | vae.requires_grad_(False)
53 |
54 | if logger is not None:
55 | logger.info(f"VAE to dtype: {vae.dtype}")
56 |
57 | if device is not None:
58 | vae = vae.to(device)
59 |
60 | vae.eval()
61 |
62 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio
63 |
--------------------------------------------------------------------------------
/hunyuan/run-single-sample_video-fp8.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | export MODEL_BASE="/mnt/localssd/hunyuan-video/ckpts/"
4 | #dit_weight="/mnt/localssd/hunyuan-video/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
5 | seed=42
6 |
7 | dit_weight="/mnt/localssd/hunyuan-video/fp8/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"
8 |
9 | python3 sample_video.py \
10 | --model-base $MODEL_BASE \
11 | --dit-weight $dit_weight \
12 | --seed $seed \
13 | --video-size 768 1280 \
14 | --video-length 129 \
15 | --infer-steps 50 \
16 | --prompt "A cat walks on the grass, realistic style." \
17 | --flow-reverse \
18 | --use-cpu-offload \
19 | --use-fp8 \
20 | --save-path ./z-results
21 |
--------------------------------------------------------------------------------
/hunyuan/run-single-sample_video.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | export MODEL_BASE="/mnt/localssd/hunyuan-video/ckpts/"
4 | dit_weight="/mnt/localssd/hunyuan-video/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
5 | seed=42
6 |
7 | python3 sample_video.py \
8 | --model-base $MODEL_BASE \
9 | --dit-weight $dit_weight \
10 | --seed $seed \
11 | --video-size 768 1280 \
12 | --video-length 129 \
13 | --infer-steps 50 \
14 | --prompt "A cat walks on the grass, realistic style." \
15 | --flow-reverse \
16 | --use-cpu-offload \
17 | --save-path ./z-results
18 |
--------------------------------------------------------------------------------
/hunyuan/sample_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from pathlib import Path
4 | from loguru import logger
5 | from datetime import datetime
6 |
7 | from hyvideo.utils.file_utils import save_videos_grid
8 | from hyvideo.config import parse_args
9 | from hyvideo.inference import HunyuanVideoSampler
10 |
11 |
12 | def main():
13 | args = parse_args()
14 | print(args)
15 | models_root_path = Path(args.model_base)
16 | if not models_root_path.exists():
17 | raise ValueError(f"`models_root` not exists: {models_root_path}")
18 |
19 | # Create save folder to save the samples
20 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
21 | if not os.path.exists(save_path):
22 | os.makedirs(save_path, exist_ok=True)
23 |
24 | # Load models
25 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
26 |
27 | # Get the updated args
28 | args = hunyuan_video_sampler.args
29 |
30 | # Start sampling
31 | # TODO: batch inference check
32 | outputs = hunyuan_video_sampler.predict(
33 | prompt=args.prompt,
34 | height=args.video_size[0],
35 | width=args.video_size[1],
36 | video_length=args.video_length,
37 | seed=args.seed,
38 | negative_prompt=args.neg_prompt,
39 | infer_steps=args.infer_steps,
40 | guidance_scale=args.cfg_scale,
41 | num_videos_per_prompt=1,#args.num_videos,
42 | flow_shift=args.flow_shift,
43 | batch_size=args.batch_size,
44 | embedded_guidance_scale=args.embedded_cfg_scale
45 | )
46 | samples = outputs['samples']
47 |
48 | # Save samples
49 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
50 | for i, sample in enumerate(samples):
51 | sample = samples[i].unsqueeze(0)
52 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53 | cur_save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54 | save_videos_grid(sample, cur_save_path, fps=24)
55 | logger.info(f'Sample save to: {cur_save_path}')
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/method.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/poodle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/poodle.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_boy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_boy.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_man_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_01.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_man_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_02.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_man_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_03.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_poodle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_poodle.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_woman_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_01.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_woman_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_02.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/images/seg_woman_03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_03.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/material/application.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/application.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/material/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/logo.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/material/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/method.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/material/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/teaser.png
--------------------------------------------------------------------------------
/hunyuan_custom/assets/meta_files.list:
--------------------------------------------------------------------------------
1 | assets/meta_files/poodle.json
--------------------------------------------------------------------------------
/hunyuan_custom/assets/meta_files/poodle.json:
--------------------------------------------------------------------------------
1 | {
2 | "item_image_path": "assets/images/poodle.png",
3 | "item_prompt": "dog",
4 | "seed": 1124,
5 | "prompt": "A dog is chasing a cat in the park.",
6 | "negative_prompt": "",
7 | "chinese": "一只狗在公园里追一只猫。",
8 | "seg_item_image_path": "assets/images/seg_poodle.png"
9 | }
10 |
--------------------------------------------------------------------------------
/hunyuan_custom/assets/videos/seg_man_01.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_man_01.mp4
--------------------------------------------------------------------------------
/hunyuan_custom/assets/videos/seg_man_02.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_man_02.mp4
--------------------------------------------------------------------------------
/hunyuan_custom/assets/videos/seg_woman_01.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_woman_01.mp4
--------------------------------------------------------------------------------
/hunyuan_custom/assets/videos/seg_woman_03.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_woman_03.mp4
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_gradio/gradio_ref2v.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import json
5 | import datetime
6 | import requests
7 | import gradio as gr
8 | from tool_for_end2end import *
9 |
10 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
11 | DATADIR = './temp'
12 | _HEADER_ = '''
13 |
14 |
Tencet HunyuanvideoCustom Demo
15 |
16 |
17 | '''
18 | # flask url
19 | URL = "http://127.0.0.1:8080/predict2"
20 |
21 | def post_and_get(width, height, num_steps, num_frames, guidance, flow_shift, seed, prompt, id_image, neg_prompt, template_prompt, template_neg_prompt, name):
22 | now = datetime.datetime.now().isoformat()
23 | imgdir = os.path.join(DATADIR, 'reference')
24 | videodir = os.path.join(DATADIR, 'video')
25 | imgfile = os.path.join(imgdir, now + '.png')
26 | output_video_path = os.path.join(videodir, now + '.mp4')
27 |
28 | os.makedirs(imgdir, exist_ok=True)
29 | os.makedirs(videodir, exist_ok=True)
30 | cv2.imwrite(imgfile, id_image[:,:,::-1])
31 |
32 | proxies = {
33 | "http": None,
34 | "https": None,
35 | }
36 |
37 | files = {
38 | "trace_id": "abcd",
39 | "image_path": imgfile,
40 | "prompt": prompt,
41 | "negative_prompt": neg_prompt,
42 | "template_prompt": template_prompt,
43 | "template_neg_prompt": template_neg_prompt,
44 | "height": height,
45 | "width": width,
46 | "frames": num_frames,
47 | "cfg": guidance,
48 | "steps": num_steps,
49 | "seed": int(seed),
50 | "name": name,
51 | "shift": flow_shift,
52 | "save_fps": 25,
53 | }
54 | r = requests.get(URL, data = json.dumps(files), proxies=proxies)
55 | ret_dict = json.loads(r.text)
56 | video_buffer = ret_dict['content'][0]['buffer']
57 | save_video_base64_to_local(video_path=None, base64_buffer=video_buffer,
58 | output_video_path=output_video_path)
59 | print('='*50)
60 | return output_video_path
61 |
62 | def create_demo():
63 |
64 | with gr.Blocks() as demo:
65 | gr.Markdown(_HEADER_)
66 | with gr.Tab('单主体一致性'):
67 | with gr.Row():
68 | with gr.Column(scale=1):
69 | with gr.Group():
70 | prompt = gr.Textbox(label="Prompt", value="a man is riding a bicycle on the street.")
71 | neg_prompt = gr.Textbox(label="Negative Prompt", value="")
72 | id_image = gr.Image(label="Input reference image", height=480)
73 |
74 | with gr.Column(scale=2):
75 | with gr.Group():
76 | output_image = gr.Video(label="Generated Video")
77 |
78 | with gr.Row():
79 | with gr.Column(scale=2):
80 | with gr.Accordion("Options for generate video", open=False):
81 | with gr.Row():
82 | width = gr.Slider(256, 1536, 1280, step=16, label="Width")
83 | height = gr.Slider(256, 1536, 720, step=16, label="Height")
84 | with gr.Row():
85 | num_steps = gr.Slider(1, 100, 30, step=5, label="Number of steps")
86 | flow_shift = gr.Slider(1.0, 15.0, 13, step=1, label="Flow Shift")
87 | with gr.Row():
88 | num_frames = gr.Slider(1, 129, 129, step=4, label="Number of frames")
89 | guidance = gr.Slider(1.0, 10.0, 7.5, step=0.5, label="Guidance")
90 | seed = gr.Textbox(1024, label="Seed (-1 for random)")
91 | with gr.Row():
92 | template_prompt = gr.Textbox(label="Template Prompt", value="Realistic, High-quality. ")
93 | template_neg_prompt = gr.Textbox(label="Template Negative Prompt", value="Aerial view, aerial view, " \
94 | "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, " \
95 | "distortion, blurring, text, subtitles, static, picture, black border. ")
96 | name = gr.Textbox(label="Object Name", value="object ")
97 | with gr.Column(scale=1):
98 | generate_btn = gr.Button("Generate")
99 |
100 | generate_btn.click(fn=post_and_get,
101 | inputs=[width, height, num_steps, num_frames, guidance, flow_shift, seed, prompt, id_image, neg_prompt, template_prompt, template_neg_prompt, name],
102 | outputs=[output_image],
103 | )
104 |
105 | quick_prompts = [[x] for x in glob.glob('./assets/images/*.png')]
106 | example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Other object', samples_per_page=1000, components=[id_image])
107 | example_quick_prompts.click(lambda x: x[0], inputs=example_quick_prompts, outputs=id_image, show_progress=False, queue=False)
108 | with gr.Row(), gr.Column():
109 | gr.Markdown("## Examples")
110 | example_inps = [
111 | [
112 | 'A woman is drinking coffee at a café.',
113 | './assets/images/seg_woman_01.png',
114 | 1280, 720, 30, 129, 7.5, 13, 1024,
115 | "assets/videos/seg_woman_01.mp4"
116 | ],
117 | [
118 | 'In a cubicle of an office building, a woman focuses intently on the computer screen, typing rapidly on the keyboard, surrounded by piles of documents.',
119 | './assets/images/seg_woman_03.png',
120 | 1280, 720, 30, 129, 7.5, 13, 1025,
121 | "./assets/videos/seg_woman_03.mp4"
122 | ],
123 | [
124 | 'A man walks across an ancient stone bridge holding an umbrella, raindrops tapping against it.',
125 | './assets/images/seg_man_01.png',
126 | 1280, 720, 30, 129, 7.5, 13, 1025,
127 | "./assets/videos/seg_man_01.mp4"
128 | ],
129 | [
130 | 'During a train journey, a man admires the changing scenery through the window.',
131 | './assets/images/seg_man_02.png',
132 | 1280, 720, 30, 129, 7.5, 13, 1026,
133 | "./assets/videos/seg_man_02.mp4"
134 | ]
135 | ]
136 | gr.Examples(examples=example_inps, inputs=[prompt, id_image, width, height, num_steps, num_frames, guidance, flow_shift, seed, output_image],)
137 | return demo
138 |
139 | if __name__ == "__main__":
140 | allowed_paths = ['/']
141 | demo = create_demo()
142 | demo.launch(server_name='0.0.0.0', server_port=80, share=True, allowed_paths=allowed_paths)
143 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_gradio/tool_for_end2end.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import uuid
4 | import base64
5 | import imageio
6 | import torch
7 | import torchvision
8 | from PIL import Image
9 | import numpy as np
10 | from copy import deepcopy
11 | from einops import rearrange
12 |
13 | TEMP_DIR = "./temp"
14 | if not os.path.exists(TEMP_DIR):
15 | os.makedirs(TEMP_DIR, exist_ok=True)
16 |
17 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
18 | videos = rearrange(videos, "b c t h w -> t b c h w")
19 | outputs = []
20 | for x in videos:
21 | x = torchvision.utils.make_grid(x, nrow=n_rows)
22 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
23 | if rescale:
24 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
25 | x = torch.clamp(x,0,1)
26 | x = (x * 255).numpy().astype(np.uint8)
27 | outputs.append(x)
28 |
29 | os.makedirs(os.path.dirname(path), exist_ok=True)
30 | imageio.mimsave(path, outputs, fps=fps, quality=quality)
31 |
32 | def encode_image_to_base64(image_path):
33 | try:
34 | with open(image_path, 'rb') as image_file:
35 | image_data = image_file.read()
36 | encoded_data = base64.b64encode(image_data).decode('utf-8')
37 | print(f"Image file '{image_path}' has been successfully encoded to Base64.")
38 | return encoded_data
39 |
40 | except Exception as e:
41 | print(f"Error encoding image: {e}")
42 | return None
43 |
44 | def encode_video_to_base64(video_path):
45 | try:
46 | with open(video_path, 'rb') as video_file:
47 | video_data = video_file.read()
48 | encoded_data = base64.b64encode(video_data).decode('utf-8')
49 | print(f"Video file '{video_path}' has been successfully encoded to Base64.")
50 | return encoded_data
51 |
52 | except Exception as e:
53 | print(f"Error encoding video: {e}")
54 | return None
55 |
56 | def encode_wav_to_base64(wav_path):
57 | try:
58 | with open(wav_path, 'rb') as audio_file:
59 | audio_data = audio_file.read()
60 | encoded_data = base64.b64encode(audio_data).decode('utf-8')
61 | print(f"Audio file '{wav_path}' has been successfully encoded to Base64.")
62 | return encoded_data
63 |
64 | except Exception as e:
65 | print(f"Error encoding audio: {e}")
66 | return None
67 |
68 | def encode_pkl_to_base64(pkl_path):
69 | try:
70 | with open(pkl_path, 'rb') as pkl_file:
71 | pkl_data = pkl_file.read()
72 |
73 | encoded_data = base64.b64encode(pkl_data).decode('utf-8')
74 |
75 | print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.")
76 | return encoded_data
77 |
78 | except Exception as e:
79 | print(f"Error encoding pickle: {e}")
80 | return None
81 |
82 | def decode_base64_to_image(base64_buffer_str):
83 | try:
84 | image_data = base64.b64decode(base64_buffer_str)
85 | image = Image.open(io.BytesIO(image_data))
86 | image_array = np.array(image)
87 | print(f"Image Base64 string has beed succesfully decoded to image.")
88 | return image_array
89 | except Exception as e:
90 | print(f"Error encdecodingoding image: {e}")
91 | return None
92 |
93 | def decode_base64_to_video(base64_buffer_str):
94 | try:
95 | video_data = base64.b64decode(base64_buffer_str)
96 | video_bytes = io.BytesIO(video_data)
97 | video_bytes.seek(0)
98 | video_reader = imageio.get_reader(video_bytes, 'ffmpeg')
99 | video_frames = [frame for frame in video_reader]
100 | return video_frames
101 | except Exception as e:
102 | print(f"Error decoding video: {e}")
103 | return None
104 |
105 | def save_image_base64_to_local(image_path=None, base64_buffer=None):
106 | if image_path is not None and base64_buffer is None:
107 | image_buffer_base64 = encode_image_to_base64(image_path)
108 | elif image_path is None and base64_buffer is not None:
109 | image_buffer_base64 = deepcopy(base64_buffer)
110 | else:
111 | print("Please pass either 'image_path' or 'base64_buffer'")
112 | return None
113 |
114 | if image_buffer_base64 is not None:
115 | image_data = base64.b64decode(image_buffer_base64)
116 | uuid_string = str(uuid.uuid4())
117 | temp_image_path = f'{TEMP_DIR}/{uuid_string}.png'
118 | with open(temp_image_path, 'wb') as image_file:
119 | image_file.write(image_data)
120 | return temp_image_path
121 | else:
122 | return None
123 |
124 | def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None):
125 | if video_path is not None and base64_buffer is None:
126 | video_buffer_base64 = encode_video_to_base64(video_path)
127 | elif video_path is None and base64_buffer is not None:
128 | video_buffer_base64 = deepcopy(base64_buffer)
129 | else:
130 | print("Please pass either 'video_path' or 'base64_buffer'")
131 | return None
132 |
133 | if video_buffer_base64 is not None:
134 | video_data = base64.b64decode(video_buffer_base64)
135 | if output_video_path is None:
136 | uuid_string = str(uuid.uuid4())
137 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
138 | else:
139 | temp_video_path = output_video_path
140 | with open(temp_video_path, 'wb') as video_file:
141 | video_file.write(video_data)
142 | return temp_video_path
143 | else:
144 | return None
145 |
146 | def save_audio_base64_to_local(audio_path=None, base64_buffer=None):
147 | if audio_path is not None and base64_buffer is None:
148 | audio_buffer_base64 = encode_wav_to_base64(audio_path)
149 | elif audio_path is None and base64_buffer is not None:
150 | audio_buffer_base64 = deepcopy(base64_buffer)
151 | else:
152 | print("Please pass either 'audio_path' or 'base64_buffer'")
153 | return None
154 |
155 | if audio_buffer_base64 is not None:
156 | audio_data = base64.b64decode(audio_buffer_base64)
157 | uuid_string = str(uuid.uuid4())
158 | temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav'
159 | with open(temp_audio_path, 'wb') as audio_file:
160 | audio_file.write(audio_data)
161 | return temp_audio_path
162 | else:
163 | return None
164 |
165 | def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None):
166 | if pkl_path is not None and base64_buffer is None:
167 | pkl_buffer_base64 = encode_pkl_to_base64(pkl_path)
168 | elif pkl_path is None and base64_buffer is not None:
169 | pkl_buffer_base64 = deepcopy(base64_buffer)
170 | else:
171 | print("Please pass either 'pkl_path' or 'base64_buffer'")
172 | return None
173 |
174 | if pkl_buffer_base64 is not None:
175 | pkl_data = base64.b64decode(pkl_buffer_base64)
176 | uuid_string = str(uuid.uuid4())
177 | temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl'
178 | with open(temp_pkl_path, 'wb') as pkl_file:
179 | pkl_file.write(pkl_data)
180 | return temp_pkl_path
181 | else:
182 | return None
183 |
184 | def remove_temp_fles(input_dict):
185 | for key, val in input_dict.items():
186 | if "_path" in key and val is not None and os.path.exists(val):
187 | os.remove(val)
188 | print(f"Remove temporary {key} from {val}")
189 |
190 | def process_output_dict(output_dict):
191 | if output_dict["rank"] == 0:
192 | uuid_string = str(uuid.uuid4())
193 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
194 | save_videos_grid(output_dict["video"], temp_video_path, fps=25)
195 | save_path = temp_video_path
196 |
197 | video_base64_buffer = encode_video_to_base64(save_path)
198 |
199 | encoded_output_dict = {
200 | "errCode": output_dict["err_code"],
201 | "content": [
202 | {
203 | "buffer": video_base64_buffer
204 | },
205 | ],
206 | "info":output_dict["err_msg"],
207 | "trace_id": output_dict["trace_id"],
208 | }
209 |
210 | else:
211 | uuid_string = str(uuid.uuid4())
212 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
213 |
214 | try:
215 | video_base64_buffer = encode_video_to_base64(temp_video_path)
216 | except:
217 | video_base64_buffer = None
218 |
219 | encoded_output_dict = {
220 | "errCode": output_dict["err_code"],
221 | "content": [
222 | {
223 | "buffer": video_base64_buffer
224 | },
225 | ],
226 | "info":output_dict["err_msg"],
227 | "trace_id": output_dict["trace_id"],
228 | }
229 |
230 | return encoded_output_dict
231 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/hymm_sp/__init__.py
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from hymm_sp.constants import *
3 | import re
4 | import collections.abc
5 |
6 | def as_tuple(x):
7 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
8 | return tuple(x)
9 | if x is None or isinstance(x, (int, float, str)):
10 | return (x,)
11 | else:
12 | raise ValueError(f"Unknown type {type(x)}")
13 |
14 | def parse_args(namespace=None):
15 | parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
16 | parser = add_extra_args(parser)
17 | args = parser.parse_args(namespace=namespace)
18 | args = sanity_check_args(args)
19 | return args
20 |
21 | def add_extra_args(parser: argparse.ArgumentParser):
22 | parser = add_network_args(parser)
23 | parser = add_extra_models_args(parser)
24 | parser = add_denoise_schedule_args(parser)
25 | parser = add_evaluation_args(parser)
26 | return parser
27 |
28 | def add_network_args(parser: argparse.ArgumentParser):
29 | group = parser.add_argument_group(title="Network")
30 | group.add_argument("--model", type=str, default="HYVideo-T/2",
31 | help="Model architecture to use. It it also used to determine the experiment directory.")
32 | group.add_argument("--latent-channels", type=str, default=None,
33 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
34 | "it still needs to match the latent channels of the VAE model.")
35 | group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
36 | return parser
37 |
38 | def add_extra_models_args(parser: argparse.ArgumentParser):
39 | group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
40 |
41 | # VAE
42 | group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
43 | group.add_argument("--vae-precision", type=str, default="fp16",
44 | help="Precision mode for the VAE model.")
45 | group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
46 | group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
47 | help="Name of the text encoder model.")
48 | group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
49 | help="Precision mode for the text encoder model.")
50 | group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
51 | group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
52 | group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
53 | help="Name of the tokenizer model.")
54 | group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
55 | help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
56 | "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
57 | group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
58 | help="Video prompt template for the decoder-only text encoder model.")
59 | group.add_argument("--hidden-state-skip-layer", type=int, default=2,
60 | help="Skip layer for hidden states.")
61 | group.add_argument("--apply-final-norm", action="store_true",
62 | help="Apply final normalization to the used text encoder hidden states.")
63 |
64 | # - CLIP
65 | group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
66 | help="Name of the second text encoder model.")
67 | group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
68 | help="Precision mode for the second text encoder model.")
69 | group.add_argument("--text-states-dim-2", type=int, default=768,
70 | help="Dimension of the second text encoder hidden states.")
71 | group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
72 | help="Name of the second tokenizer model.")
73 | group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
74 | group.set_defaults(use_attention_mask=True)
75 | group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
76 | help="A projection layer for bridging the text encoder hidden states and the diffusion model "
77 | "conditions.")
78 | return parser
79 |
80 |
81 | def add_denoise_schedule_args(parser: argparse.ArgumentParser):
82 | group = parser.add_argument_group(title="Denoise schedule")
83 | group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
84 | group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
85 | group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
86 | group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
87 | "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
88 | group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
89 | return parser
90 |
91 | def add_evaluation_args(parser: argparse.ArgumentParser):
92 | group = parser.add_argument_group(title="Validation Loss Evaluation")
93 | parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
94 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
95 | parser.add_argument("--reproduce", action="store_true",
96 | help="Enable reproducibility by setting random seeds and deterministic algorithms.")
97 | parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
98 | parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
99 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
100 | parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
101 | group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
102 | group.add_argument("--video-size", type=int, nargs='+', default=512,
103 | help="Video size for training. If a single value is provided, it will be used for both width "
104 | "and height. If two values are provided, they will be used for width and height "
105 | "respectively.")
106 | group.add_argument("--sample-n-frames", type=int, default=1,
107 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
108 | group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
109 | group.add_argument("--val-disable-autocast", action="store_true",
110 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
111 | group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
112 | group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
113 | group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
114 | group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
115 | group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
116 | group.add_argument("--add-pos-prompt", type=str, default='', help="Addition prompt for sampling during evaluation.")
117 | group.add_argument("--add-neg-prompt", type=str, default='', help="Addition negative prompt for sampling during evaluation.")
118 | group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
119 | group.add_argument("--image-path", type=str, default="", help="")
120 | group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
121 | group.add_argument("--input", type=str, default=None, help="test data.")
122 | group.add_argument("--item-name", type=str, default=None, help="")
123 | group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
124 | group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
125 | group.add_argument("--use-deepcache", type=int, default=1)
126 | return parser
127 |
128 | def sanity_check_args(args):
129 | # VAE channels
130 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
131 | if not re.match(vae_pattern, args.vae):
132 | raise ValueError(
133 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
134 | )
135 | vae_channels = int(args.vae.split("-")[1][:-1])
136 | if args.latent_channels is None:
137 | args.latent_channels = vae_channels
138 | if vae_channels != args.latent_channels:
139 | raise ValueError(
140 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
141 | )
142 | return args
143 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | __all__ = [
5 | "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE",
6 | "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH",
7 | "TEXT_PROJECTION",
8 | ]
9 |
10 | # =================== Constant Values =====================
11 |
12 | PRECISION_TO_TYPE = {
13 | 'fp32': torch.float32,
14 | 'fp16': torch.float16,
15 | 'bf16': torch.bfloat16,
16 | }
17 |
18 | PROMPT_TEMPLATE_ENCODE_VIDEO = (
19 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
20 | "1. The main content and theme of the video."
21 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
22 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
23 | "4. background environment, light, style and atmosphere."
24 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
25 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
26 | )
27 |
28 | PROMPT_TEMPLATE = {
29 | "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95},
30 | }
31 |
32 | # ======================= Model ======================
33 | PRECISIONS = {"fp32", "fp16", "bf16"}
34 |
35 | # =================== Model Path =====================
36 | MODEL_BASE = os.getenv("MODEL_BASE")
37 |
38 | # 3D VAE
39 | VAE_PATH = {
40 | "884-16c-hy0801": f"{MODEL_BASE}/vae_3d/hyvae_v1_0801",
41 | }
42 |
43 | # Text Encoder
44 | TEXT_ENCODER_PATH = {
45 | "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
46 | "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1",
47 | }
48 |
49 | # Tokenizer
50 | TOKENIZER_PATH = {
51 | "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
52 | "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1",
53 | }
54 |
55 | TEXT_PROJECTION = {
56 | "linear", # Default, an nn.Linear() layer
57 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT
58 | }
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/data_kits/data_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import imageio
6 | import torchvision
7 | from einops import rearrange
8 |
9 |
10 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
11 | videos = rearrange(videos, "b c t h w -> t b c h w")
12 | outputs = []
13 | for x in videos:
14 | x = torchvision.utils.make_grid(x, nrow=n_rows)
15 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
16 | if rescale:
17 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
18 | x = torch.clamp(x,0,1)
19 | x = (x * 255).numpy().astype(np.uint8)
20 | outputs.append(x)
21 |
22 | os.makedirs(os.path.dirname(path), exist_ok=True)
23 | imageio.mimsave(path, outputs, fps=fps, quality=quality)
24 |
25 | def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
26 | crop_h, crop_w = crop_img.shape[:2]
27 | target_w, target_h = size
28 | scale_h, scale_w = target_h / crop_h, target_w / crop_w
29 | if scale_w > scale_h:
30 | resize_h = int(target_h*resize_ratio)
31 | resize_w = int(crop_w / crop_h * resize_h)
32 | else:
33 | resize_w = int(target_w*resize_ratio)
34 | resize_h = int(crop_h / crop_w * resize_w)
35 | crop_img = cv2.resize(crop_img, (resize_w, resize_h))
36 | pad_left = (target_w - resize_w) // 2
37 | pad_top = (target_h - resize_h) // 2
38 | pad_right = target_w - resize_w - pad_left
39 | pad_bottom = target_h - resize_h - pad_top
40 | crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
41 | return crop_img
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/data_kits/video_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import json
5 | import numpy as np
6 | from PIL import Image
7 | import torchvision.transforms as transforms
8 | from hymm_sp.data_kits.data_tools import *
9 |
10 |
11 | class DataPreprocess(object):
12 | def __init__(self):
13 | self.llava_size = (336, 336)
14 | self.llava_transform = transforms.Compose(
15 | [
16 | transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
17 | transforms.ToTensor(),
18 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
19 | ]
20 | )
21 |
22 | def get_batch(self, image_path, size):
23 | try:
24 | image = cv2.imread(image_path)
25 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
26 | except:
27 | image = Image.open(image_path).convert('RGB')
28 | llava_item_image = pad_image(image.copy(), self.llava_size)
29 | uncond_llava_item_image = np.ones_like(llava_item_image) * 255
30 | cat_item_image = pad_image(image.copy(), size)
31 |
32 | llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
33 | uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
34 | cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
35 | batch = {
36 | "pixel_value_llava": llava_item_tensor.unsqueeze(0),
37 | "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
38 | 'pixel_value_ref': cat_item_tensor.unsqueeze(0),
39 | }
40 | return batch
41 |
42 |
43 | class JsonDataset(object):
44 | def __init__(self, args):
45 | self.args = args
46 | self.data_list = args.input
47 | self.pad_color = (255, 255, 255)
48 | self.llava_size = (336, 336)
49 | self.ref_size = (args.video_size[1], args.video_size[0])
50 | if self.data_list.endswith('.list'):
51 | self.data_paths = [line.strip() for line in open(self.data_list, 'r')] if self.data_list is not None else []
52 | else:
53 | self.data_paths = [self.data_list]
54 | self.llava_transform = transforms.Compose(
55 | [
56 | transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
59 | ]
60 | )
61 |
62 | def __len__(self):
63 | return len(self.data_paths)
64 |
65 | def read_image(self, image_path):
66 | if isinstance(image_path, dict):
67 | image_path = image_path['seg_item_image_path']
68 |
69 | try:
70 | face_image_masked = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
71 | except:
72 | face_image_masked = Image.open(image_path).convert('RGB')
73 |
74 | cat_face_image = pad_image(face_image_masked.copy(), self.ref_size)
75 | llava_face_image = pad_image(face_image_masked.copy(), self.llava_size)
76 | return llava_face_image, cat_face_image
77 |
78 | def __getitem__(self, idx):
79 | data_path = self.data_paths[idx]
80 | data_name = os.path.basename(os.path.splitext(data_path)[0])
81 | if data_path.endswith('.json'):
82 | data = json.load(open(data_path, 'r'))
83 | llava_item_image, cat_item_image = self.read_image(data)
84 | item_prompt = data['item_prompt']
85 | seed = data['seed']
86 | prompt = data['prompt']
87 | if 'negative_prompt' in data:
88 | negative_prompt = data['negative_prompt']
89 | else:
90 | negative_prompt = ''
91 | else:
92 | llava_item_image, cat_item_image = self.read_image(data_path)
93 | item_prompt = 'object'
94 | seed = self.args.seed
95 | prompt = self.args.pos_prompt
96 | negative_prompt = self.args.neg_prompt
97 |
98 | llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
99 | cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
100 |
101 | uncond_llava_item_image = np.ones_like(llava_item_image) * 255
102 | uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
103 | # print(llava_item_tensor.shape, cat_item_tensor.shape)
104 | # raise ValueError
105 | batch = {
106 | "pixel_value_llava": llava_item_tensor,
107 | "uncond_pixel_value_llava": uncond_llava_item_tensor,
108 | "pixel_value_ref": cat_item_tensor,
109 | "prompt": prompt,
110 | "negative_prompt": negative_prompt,
111 | "seed": seed,
112 | "name": item_prompt,
113 | 'data_name': data_name
114 | }
115 | return batch
116 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import HunyuanVideoCustomPipeline
2 | from .schedulers import FlowMatchDiscreteScheduler
3 |
4 | def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None,
5 | device=None, progress_bar_config=None):
6 | """ Load the denoising scheduler for inference. """
7 | if scheduler is None:
8 | scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift_eval_video, reverse=args.flow_reverse, solver=args.flow_solver, )
9 |
10 | # Only enable progress bar for rank 0
11 | progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0}
12 |
13 | pipeline = HunyuanVideoCustomPipeline(vae=vae,
14 | text_encoder=text_encoder,
15 | text_encoder_2=text_encoder_2,
16 | transformer=model,
17 | scheduler=scheduler,
18 | # safety_checker=None,
19 | # feature_extractor=None,
20 | # requires_safety_checker=False,
21 | progress_bar_config=progress_bar_config,
22 | args=args,
23 | )
24 | if args.cpu_offload: # avoid oom
25 | pass
26 | else:
27 | pipeline = pipeline.to(device)
28 |
29 | return pipeline
30 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/diffusion/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_hunyuan_video_custom import HunyuanVideoCustomPipeline
2 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/diffusion/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | #
16 | # Modified from diffusers==0.29.2
17 | #
18 | # ==============================================================================
19 |
20 | from dataclasses import dataclass
21 | from typing import Optional, Tuple, Union
22 |
23 | import torch
24 |
25 | from diffusers.configuration_utils import ConfigMixin, register_to_config
26 | from diffusers.utils import BaseOutput, logging
27 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
28 |
29 |
30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31 |
32 |
33 | @dataclass
34 | class FlowMatchDiscreteSchedulerOutput(BaseOutput):
35 | """
36 | Output class for the scheduler's `step` function output.
37 |
38 | Args:
39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41 | denoising loop.
42 | """
43 |
44 | prev_sample: torch.FloatTensor
45 |
46 |
47 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
48 | """
49 | Euler scheduler.
50 |
51 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
52 | methods the library implements for all schedulers such as loading and saving.
53 |
54 | Args:
55 | num_train_timesteps (`int`, defaults to 1000):
56 | The number of diffusion steps to train the model.
57 | timestep_spacing (`str`, defaults to `"linspace"`):
58 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
59 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
60 | shift (`float`, defaults to 1.0):
61 | The shift value for the timestep schedule.
62 | reverse (`bool`, defaults to `True`):
63 | Whether to reverse the timestep schedule.
64 | """
65 |
66 | _compatibles = []
67 | order = 1
68 |
69 | @register_to_config
70 | def __init__(
71 | self,
72 | num_train_timesteps: int = 1000,
73 | shift: float = 1.0,
74 | reverse: bool = True,
75 | solver: str = "euler",
76 | n_tokens: Optional[int] = None,
77 | ):
78 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
79 |
80 | if not reverse:
81 | sigmas = sigmas.flip(0)
82 |
83 | self.sigmas = sigmas
84 | # the value fed to model
85 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
86 |
87 | self._step_index = None
88 | self._begin_index = None
89 |
90 | self.supported_solver = ["euler"]
91 | if solver not in self.supported_solver:
92 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
93 |
94 | @property
95 | def step_index(self):
96 | """
97 | The index counter for current timestep. It will increase 1 after each scheduler step.
98 | """
99 | return self._step_index
100 |
101 | @property
102 | def begin_index(self):
103 | """
104 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
105 | """
106 | return self._begin_index
107 |
108 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
109 | def set_begin_index(self, begin_index: int = 0):
110 | """
111 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
112 |
113 | Args:
114 | begin_index (`int`):
115 | The begin index for the scheduler.
116 | """
117 | self._begin_index = begin_index
118 |
119 | def _sigma_to_t(self, sigma):
120 | return sigma * self.config.num_train_timesteps
121 |
122 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
123 | n_tokens: int = None):
124 | """
125 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
126 |
127 | Args:
128 | num_inference_steps (`int`):
129 | The number of diffusion steps used when generating samples with a pre-trained model.
130 | device (`str` or `torch.device`, *optional*):
131 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132 | n_tokens (`int`, *optional*):
133 | Number of tokens in the input sequence.
134 | """
135 | self.num_inference_steps = num_inference_steps
136 |
137 | sigmas = torch.linspace(1, 0, num_inference_steps + 1)
138 | sigmas = self.sd3_time_shift(sigmas)
139 |
140 | if not self.config.reverse:
141 | sigmas = 1 - sigmas
142 |
143 | self.sigmas = sigmas
144 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
145 |
146 | # Reset step index
147 | self._step_index = None
148 |
149 | def index_for_timestep(self, timestep, schedule_timesteps=None):
150 | if schedule_timesteps is None:
151 | schedule_timesteps = self.timesteps
152 |
153 | indices = (schedule_timesteps == timestep).nonzero()
154 |
155 | # The sigma index that is taken for the **very** first `step`
156 | # is always the second index (or the last index if there is only 1)
157 | # This way we can ensure we don't accidentally skip a sigma in
158 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
159 | pos = 1 if len(indices) > 1 else 0
160 |
161 | return indices[pos].item()
162 |
163 | def _init_step_index(self, timestep):
164 | if self.begin_index is None:
165 | if isinstance(timestep, torch.Tensor):
166 | timestep = timestep.to(self.timesteps.device)
167 | self._step_index = self.index_for_timestep(timestep)
168 | else:
169 | self._step_index = self._begin_index
170 |
171 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
172 | return sample
173 |
174 | def sd3_time_shift(self, t: torch.Tensor):
175 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
176 |
177 | def step(
178 | self,
179 | model_output: torch.FloatTensor,
180 | timestep: Union[float, torch.FloatTensor],
181 | sample: torch.FloatTensor,
182 | return_dict: bool = True,
183 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
184 | """
185 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
186 | process from the learned model outputs (most often the predicted noise).
187 |
188 | Args:
189 | model_output (`torch.FloatTensor`):
190 | The direct output from learned diffusion model.
191 | timestep (`float`):
192 | The current discrete timestep in the diffusion chain.
193 | sample (`torch.FloatTensor`):
194 | A current instance of a sample created by the diffusion process.
195 | return_dict (`bool`):
196 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
197 | tuple.
198 |
199 | Returns:
200 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
201 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
202 | returned, otherwise a tuple is returned where the first element is the sample tensor.
203 | """
204 |
205 | if (
206 | isinstance(timestep, int)
207 | or isinstance(timestep, torch.IntTensor)
208 | or isinstance(timestep, torch.LongTensor)
209 | ):
210 | raise ValueError(
211 | (
212 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
213 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
214 | " one of the `scheduler.timesteps` as a timestep."
215 | ),
216 | )
217 |
218 | if self.step_index is None:
219 | self._init_step_index(timestep)
220 |
221 | # Upcast to avoid precision issues when computing prev_sample
222 | sample = sample.to(torch.float32)
223 |
224 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
225 |
226 | if self.config.solver == "euler":
227 | prev_sample = sample + model_output.float() * dt
228 | else:
229 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
230 |
231 | # upon completion increase step index by one
232 | self._step_index += 1
233 |
234 | if not return_dict:
235 | return (prev_sample,)
236 |
237 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
238 |
239 | def __len__(self):
240 | return self.config.num_train_timesteps
241 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, List
3 | from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd
4 |
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | def _ntuple(n):
10 | def parse(x):
11 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
12 | x = tuple(x)
13 | if len(x) == 1:
14 | x = tuple(repeat(x[0], n))
15 | return x
16 | return tuple(repeat(x, n))
17 | return parse
18 |
19 | to_1tuple = _ntuple(1)
20 | to_2tuple = _ntuple(2)
21 | to_3tuple = _ntuple(3)
22 | to_4tuple = _ntuple(4)
23 |
24 | def get_rope_freq_from_size(latents_size, ndim, target_ndim, args,
25 | rope_theta_rescale_factor: Union[float, List[float]]=1.0,
26 | rope_interpolation_factor: Union[float, List[float]]=1.0,
27 | concat_dict={}):
28 |
29 | if isinstance(args.patch_size, int):
30 | assert all(s % args.patch_size == 0 for s in latents_size), \
31 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
32 | f"but got {latents_size}."
33 | rope_sizes = [s // args.patch_size for s in latents_size]
34 | elif isinstance(args.patch_size, list):
35 | assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
36 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
37 | f"but got {latents_size}."
38 | rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]
39 |
40 | if len(rope_sizes) != target_ndim:
41 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
42 | head_dim = args.hidden_size // args.num_heads
43 | rope_dim_list = args.rope_dim_list
44 | if rope_dim_list is None:
45 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
46 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
47 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
48 | rope_sizes,
49 | theta=args.rope_theta,
50 | use_real=True,
51 | theta_rescale_factor=rope_theta_rescale_factor,
52 | interpolation_factor=rope_interpolation_factor,
53 | concat_dict=concat_dict)
54 | return freqs_cos, freqs_sin
55 |
56 | def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False,
57 | theta_rescale_factor: Union[float, List[float]]=1.0,
58 | interpolation_factor: Union[float, List[float]]=1.0,
59 | concat_dict={}
60 | ):
61 |
62 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
63 | if len(concat_dict)<1:
64 | pass
65 | else:
66 | if concat_dict['mode']=='timecat':
67 | bias = grid[:,:1].clone()
68 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
69 | grid = torch.cat([bias, grid], dim=1)
70 |
71 | elif concat_dict['mode']=='timecat-w':
72 | bias = grid[:,:1].clone()
73 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
74 | bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
75 | grid = torch.cat([bias, grid], dim=1)
76 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
77 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
78 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
79 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
80 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
81 |
82 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
83 | interpolation_factor = [interpolation_factor] * len(rope_dim_list)
84 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
85 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
86 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
87 |
88 | # use 1/ndim of dimensions to encode grid_axis
89 | embs = []
90 | for i in range(len(rope_dim_list)):
91 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
92 | theta_rescale_factor=theta_rescale_factor[i],
93 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
94 |
95 | embs.append(emb)
96 |
97 | if use_real:
98 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
99 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
100 | return cos, sin
101 | else:
102 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
103 | return emb
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from loguru import logger
4 | from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
5 | from hymm_sp.vae import load_vae
6 | from hymm_sp.modules import load_model
7 | from hymm_sp.text_encoder import TextEncoder
8 | import torch.distributed
9 | from hymm_sp.modules.parallel_states import (
10 | nccl_info,
11 | )
12 | from hymm_sp.modules.fp8_optimization import convert_fp8_linear
13 |
14 |
15 | class Inference(object):
16 | def __init__(self,
17 | args,
18 | vae,
19 | vae_kwargs,
20 | text_encoder,
21 | model,
22 | text_encoder_2=None,
23 | pipeline=None,
24 | cpu_offload=False,
25 | device=None,
26 | logger=None):
27 | self.vae = vae
28 | self.vae_kwargs = vae_kwargs
29 |
30 | self.text_encoder = text_encoder
31 | self.text_encoder_2 = text_encoder_2
32 |
33 | self.model = model
34 | self.pipeline = pipeline
35 | self.cpu_offload = cpu_offload
36 |
37 | self.args = args
38 | self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
39 | if nccl_info.sp_size > 1:
40 | self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
41 |
42 | self.logger = logger
43 |
44 | @classmethod
45 | def from_pretrained(cls,
46 | pretrained_model_path,
47 | args,
48 | device=None,
49 | **kwargs):
50 | """
51 | Initialize the Inference pipeline.
52 |
53 | Args:
54 | pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
55 | device (int): The device for inference. Default is 0.
56 | logger (logging.Logger): The logger for the inference pipeline. Default is None.
57 | """
58 | # ========================================================================
59 | logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
60 |
61 | # ======================== Get the args path =============================
62 |
63 | # Set device and disable gradient
64 | if device is None:
65 | device = "cuda" if torch.cuda.is_available() else "cpu"
66 | torch.set_grad_enabled(False)
67 | logger.info("Building model...")
68 | factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
69 | in_channels = args.latent_channels
70 | out_channels = args.latent_channels
71 | print("="*25, f"build model", "="*25)
72 | model = load_model(
73 | args,
74 | in_channels=in_channels,
75 | out_channels=out_channels,
76 | factor_kwargs=factor_kwargs
77 | )
78 | if args.use_fp8:
79 | convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision])
80 | if args.cpu_offload:
81 | print(f'='*20, f'load transformer to cpu')
82 | model = model.to('cpu')
83 | torch.cuda.empty_cache()
84 | else:
85 | model = model.to(device)
86 | model = Inference.load_state_dict(args, model, pretrained_model_path)
87 | model.eval()
88 |
89 | # ============================= Build extra models ========================
90 | # VAE
91 | print("="*25, f"load vae", "="*25)
92 | vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device)
93 | vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
94 |
95 | # Text encoder
96 | if args.prompt_template_video is not None:
97 | crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
98 | else:
99 | crop_start = 0
100 | max_length = args.text_len + crop_start
101 |
102 | # prompt_template_video
103 | prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
104 | print("="*25, f"load llava", "="*25)
105 | text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
106 | max_length = max_length,
107 | text_encoder_precision = args.text_encoder_precision,
108 | tokenizer_type = args.tokenizer,
109 | use_attention_mask = args.use_attention_mask,
110 | prompt_template_video = prompt_template_video,
111 | hidden_state_skip_layer = args.hidden_state_skip_layer,
112 | apply_final_norm = args.apply_final_norm,
113 | reproduce = args.reproduce,
114 | logger = logger,
115 | device = 'cpu' if args.cpu_offload else device ,
116 | )
117 | text_encoder_2 = None
118 | if args.text_encoder_2 is not None:
119 | text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
120 | max_length=args.text_len_2,
121 | text_encoder_precision=args.text_encoder_precision_2,
122 | tokenizer_type=args.tokenizer_2,
123 | use_attention_mask=args.use_attention_mask,
124 | reproduce=args.reproduce,
125 | logger=logger,
126 | device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu'
127 | )
128 |
129 | return cls(args=args,
130 | vae=vae,
131 | vae_kwargs=vae_kwargs,
132 | text_encoder=text_encoder,
133 | model=model,
134 | text_encoder_2=text_encoder_2,
135 | device=device,
136 | logger=logger)
137 |
138 | @staticmethod
139 | def load_state_dict(args, model, ckpt_path):
140 | load_key = args.load_key
141 | ckpt_path = Path(ckpt_path)
142 | if ckpt_path.is_dir():
143 | ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
144 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
145 | if load_key in state_dict:
146 | state_dict = state_dict[load_key]
147 | elif load_key == ".":
148 | pass
149 | else:
150 | raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
151 | model.load_state_dict(state_dict, strict=False)
152 | return model
153 |
154 | def get_exp_dir_and_ckpt_id(self):
155 | if self.ckpt is None:
156 | raise ValueError("The checkpoint path is not provided.")
157 |
158 | ckpt = Path(self.ckpt)
159 | if ckpt.parents[1].name == "checkpoints":
160 | # It should be a standard checkpoint path. We use the parent directory as the default save directory.
161 | exp_dir = ckpt.parents[2]
162 | else:
163 | raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
164 | f"It seems that the checkpoint path is not standard. Please explicitly provide the "
165 | f"save path by --save-path.")
166 | return exp_dir, ckpt.parent.name
167 |
168 | @staticmethod
169 | def parse_size(size):
170 | if isinstance(size, int):
171 | size = [size]
172 | if not isinstance(size, (list, tuple)):
173 | raise ValueError(f"Size must be an integer or (height, width), got {size}.")
174 | if len(size) == 1:
175 | size = [size[0], size[0]]
176 | if len(size) != 2:
177 | raise ValueError(f"Size must be an integer or (height, width), got {size}.")
178 | return size
179 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2 |
3 | def load_model(args, in_channels, out_channels, factor_kwargs):
4 | model = HYVideoDiffusionTransformer(
5 | args,
6 | in_channels=in_channels,
7 | out_channels=out_channels,
8 | **HUNYUAN_VIDEO_CONFIG[args.model],
9 | **factor_kwargs,
10 | )
11 | return model
12 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def get_activation_layer(act_type):
5 | """get activation layer
6 |
7 | Args:
8 | act_type (str): the activation type
9 |
10 | Returns:
11 | torch.nn.functional: the activation layer
12 | """
13 | if act_type == "gelu":
14 | return lambda: nn.GELU()
15 | elif act_type == "gelu_tanh":
16 | # Approximate `tanh` requires torch >= 1.13
17 | return lambda: nn.GELU(approximate="tanh")
18 | elif act_type == "relu":
19 | return nn.ReLU
20 | elif act_type == "silu":
21 | return nn.SiLU
22 | else:
23 | raise ValueError(f"Unknown activation type: {act_type}")
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from hymm_sp.helpers import to_2tuple
5 |
6 |
7 | class PatchEmbed(nn.Module):
8 | """ 2D Image to Patch Embedding
9 |
10 | Image to Patch Embedding using Conv2d
11 |
12 | A convolution based approach to patchifying a 2D image w/ embedding projection.
13 |
14 | Based on the impl in https://github.com/google-research/vision_transformer
15 |
16 | Hacked together by / Copyright 2020 Ross Wightman
17 |
18 | Remove the _assert function in forward function to be compatible with multi-resolution images.
19 | """
20 | def __init__(
21 | self,
22 | patch_size=16,
23 | in_chans=3,
24 | embed_dim=768,
25 | norm_layer=None,
26 | flatten=True,
27 | bias=True,
28 | dtype=None,
29 | device=None
30 | ):
31 | factory_kwargs = {'dtype': dtype, 'device': device}
32 | super().__init__()
33 | patch_size = to_2tuple(patch_size)
34 | self.patch_size = patch_size
35 | self.flatten = flatten
36 |
37 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias,
38 | **factory_kwargs)
39 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
40 | if bias:
41 | nn.init.zeros_(self.proj.bias)
42 |
43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
44 |
45 | def forward(self, x):
46 | x = self.proj(x)
47 | if self.flatten:
48 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
49 | x = self.norm(x)
50 | return x
51 |
52 |
53 | class TextProjection(nn.Module):
54 | """
55 | Projects text embeddings. Also handles dropout for classifier-free guidance.
56 |
57 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
58 | """
59 |
60 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
61 | factory_kwargs = {'dtype': dtype, 'device': device}
62 | super().__init__()
63 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
64 | self.act_1 = act_layer()
65 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
66 |
67 | def forward(self, caption):
68 | hidden_states = self.linear_1(caption)
69 | hidden_states = self.act_1(hidden_states)
70 | hidden_states = self.linear_2(hidden_states)
71 | return hidden_states
72 |
73 |
74 | def timestep_embedding(t, dim, max_period=10000):
75 | """
76 | Create sinusoidal timestep embeddings.
77 |
78 | Args:
79 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
80 | dim (int): the dimension of the output.
81 | max_period (int): controls the minimum frequency of the embeddings.
82 |
83 | Returns:
84 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
85 |
86 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
87 | """
88 | half = dim // 2
89 | freqs = torch.exp(
90 | -math.log(max_period)
91 | * torch.arange(start=0, end=half, dtype=torch.float32)
92 | / half
93 | ).to(device=t.device)
94 | args = t[:, None].float() * freqs[None]
95 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
96 | if dim % 2:
97 | embedding = torch.cat(
98 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
99 | )
100 | return embedding
101 |
102 |
103 | class TimestepEmbedder(nn.Module):
104 | """
105 | Embeds scalar timesteps into vector representations.
106 | """
107 | def __init__(self,
108 | hidden_size,
109 | act_layer,
110 | frequency_embedding_size=256,
111 | max_period=10000,
112 | out_size=None,
113 | dtype=None,
114 | device=None
115 | ):
116 | factory_kwargs = {'dtype': dtype, 'device': device}
117 | super().__init__()
118 | self.frequency_embedding_size = frequency_embedding_size
119 | self.max_period = max_period
120 | if out_size is None:
121 | out_size = hidden_size
122 |
123 | self.mlp = nn.Sequential(
124 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
125 | act_layer(),
126 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
127 | )
128 | nn.init.normal_(self.mlp[0].weight, std=0.02)
129 | nn.init.normal_(self.mlp[2].weight, std=0.02)
130 |
131 | def forward(self, t):
132 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
133 | t_emb = self.mlp(t_freq)
134 | return t_emb
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/fp8_optimization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8 | _bits = torch.tensor(bits)
9 | _mantissa_bit = torch.tensor(mantissa_bit)
10 | _sign_bits = torch.tensor(sign_bits)
11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12 | E = _bits - _sign_bits - M
13 | bias = 2 ** (E - 1) - 1
14 | mantissa = 1
15 | for i in range(mantissa_bit - 1):
16 | mantissa += 1 / (2 ** (i+1))
17 | maxval = mantissa * 2 ** (2**E - 1 - bias)
18 | return maxval
19 |
20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21 | """
22 | Default is E4M3.
23 | """
24 | bits = torch.tensor(bits)
25 | mantissa_bit = torch.tensor(mantissa_bit)
26 | sign_bits = torch.tensor(sign_bits)
27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28 | E = bits - sign_bits - M
29 | bias = 2 ** (E - 1) - 1
30 | mantissa = 1
31 | for i in range(mantissa_bit - 1):
32 | mantissa += 1 / (2 ** (i+1))
33 | maxval = mantissa * 2 ** (2**E - 1 - bias)
34 | minval = - maxval
35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36 | input_clamp = torch.min(torch.max(x, minval), maxval)
37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39 | # dequant
40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales
41 | return qdq_out, log_scales
42 |
43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44 | for i in range(len(x.shape) - 1):
45 | scale = scale.unsqueeze(-1)
46 | new_x = x / scale
47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48 | return quant_dequant_x, scale, log_scales
49 |
50 | def fp8_activation_dequant(qdq_out, scale, dtype):
51 | qdq_out = qdq_out.type(dtype)
52 | quant_dequant_x = qdq_out * scale.to(dtype)
53 | return quant_dequant_x
54 |
55 | def fp8_linear_forward(cls, original_dtype, input):
56 | weight_dtype = cls.weight.dtype
57 | #####
58 | if cls.weight.dtype != torch.float8_e4m3fn:
59 | maxval = get_fp_maxval()
60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62 | linear_weight = linear_weight.to(torch.float8_e4m3fn)
63 | weight_dtype = linear_weight.dtype
64 | else:
65 | scale = cls.fp8_scale.to(cls.weight.device)
66 | linear_weight = cls.weight
67 | #####
68 |
69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70 | if True or len(input.shape) == 3:
71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72 | if cls.bias != None:
73 | output = F.linear(input, cls_dequant, cls.bias)
74 | else:
75 | output = F.linear(input, cls_dequant)
76 | return output
77 | else:
78 | return cls.original_forward(input.to(original_dtype))
79 | else:
80 | return cls.original_forward(input)
81 |
82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83 | setattr(module, "fp8_matmul_enabled", True)
84 |
85 | # loading fp8 mapping file
86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87 | if os.path.exists(fp8_map_path):
88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)['module']
89 | else:
90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91 |
92 | fp8_layers = []
93 | for key, layer in module.named_modules():
94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95 | fp8_layers.append(key)
96 | original_forward = layer.forward
97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99 | setattr(layer, "original_forward", original_forward)
100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .modulate_layers import modulate
10 | from hymm_sp.helpers import to_2tuple
11 |
12 |
13 | class MLP(nn.Module):
14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
15 | """
16 | def __init__(self,
17 | in_channels,
18 | hidden_channels=None,
19 | out_features=None,
20 | act_layer=nn.GELU,
21 | norm_layer=None,
22 | bias=True,
23 | drop=0.,
24 | use_conv=False,
25 | device=None,
26 | dtype=None
27 | ):
28 | factory_kwargs = {'device': device, 'dtype': dtype}
29 | super().__init__()
30 | out_features = out_features or in_channels
31 | hidden_channels = hidden_channels or in_channels
32 | bias = to_2tuple(bias)
33 | drop_probs = to_2tuple(drop)
34 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
35 |
36 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
37 | self.act = act_layer()
38 | self.drop1 = nn.Dropout(drop_probs[0])
39 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
40 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
41 | self.drop2 = nn.Dropout(drop_probs[1])
42 |
43 | def forward(self, x):
44 | x = self.fc1(x)
45 | x = self.act(x)
46 | x = self.drop1(x)
47 | x = self.norm(x)
48 | x = self.fc2(x)
49 | x = self.drop2(x)
50 | return x
51 |
52 |
53 | class MLPEmbedder(nn.Module):
54 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
55 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
56 | factory_kwargs = {'device': device, 'dtype': dtype}
57 | super().__init__()
58 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
59 | self.silu = nn.SiLU()
60 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
61 |
62 | def forward(self, x: torch.Tensor) -> torch.Tensor:
63 | return self.out_layer(self.silu(self.in_layer(x)))
64 |
65 |
66 | class FinalLayer(nn.Module):
67 | """The final layer of DiT."""
68 |
69 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
70 | factory_kwargs = {'device': device, 'dtype': dtype}
71 | super().__init__()
72 |
73 | # Just use LayerNorm for the final layer
74 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
75 | if isinstance(patch_size, int):
76 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs)
77 | else:
78 | self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=True)
79 | nn.init.zeros_(self.linear.weight)
80 | nn.init.zeros_(self.linear.bias)
81 |
82 | # Here we don't distinguish between the modulate types. Just use the simple one.
83 | self.adaLN_modulation = nn.Sequential(
84 | act_layer(),
85 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
86 | )
87 | # Zero-initialize the modulation
88 | nn.init.zeros_(self.adaLN_modulation[1].weight)
89 | nn.init.zeros_(self.adaLN_modulation[1].bias)
90 |
91 | def forward(self, x, c):
92 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
93 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
94 | x = self.linear(x)
95 | return x
96 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ModulateDiT(nn.Module):
8 | """Modulation layer for DiT."""
9 | def __init__(
10 | self,
11 | hidden_size: int,
12 | factor: int,
13 | act_layer: Callable,
14 | dtype=None,
15 | device=None,
16 | ):
17 | factory_kwargs = {"dtype": dtype, "device": device}
18 | super().__init__()
19 | self.act = act_layer()
20 | self.linear = nn.Linear(
21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22 | )
23 | # Zero-initialize the modulation
24 | nn.init.zeros_(self.linear.weight)
25 | nn.init.zeros_(self.linear.bias)
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | return self.linear(self.act(x))
29 |
30 |
31 | def modulate(x, shift=None, scale=None):
32 | """modulate by shift and scale
33 |
34 | Args:
35 | x (torch.Tensor): input tensor.
36 | shift (torch.Tensor, optional): shift tensor. Defaults to None.
37 | scale (torch.Tensor, optional): scale tensor. Defaults to None.
38 |
39 | Returns:
40 | torch.Tensor: the output tensor after modulate.
41 | """
42 | if scale is None and shift is None:
43 | return x
44 | elif shift is None:
45 | return x * (1 + scale.unsqueeze(1))
46 | elif scale is None:
47 | return x + shift.unsqueeze(1)
48 | else:
49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50 |
51 |
52 | def apply_gate(x, gate=None, tanh=False):
53 | """AI is creating summary for apply_gate
54 |
55 | Args:
56 | x (torch.Tensor): input tensor.
57 | gate (torch.Tensor, optional): gate tensor. Defaults to None.
58 | tanh (bool, optional): whether to use tanh function. Defaults to False.
59 |
60 | Returns:
61 | torch.Tensor: the output tensor after apply gate.
62 | """
63 | if gate is None:
64 | return x
65 | if tanh:
66 | return x * gate.unsqueeze(1).tanh()
67 | else:
68 | return x * gate.unsqueeze(1)
69 |
70 |
71 | def ckpt_wrapper(module):
72 | def ckpt_forward(*inputs):
73 | outputs = module(*inputs)
74 | return outputs
75 |
76 | return ckpt_forward
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(
7 | self,
8 | dim: int,
9 | elementwise_affine=True,
10 | eps: float = 1e-6,
11 | device=None,
12 | dtype=None,
13 | ):
14 | """
15 | Initialize the RMSNorm normalization layer.
16 |
17 | Args:
18 | dim (int): The dimension of the input tensor.
19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20 |
21 | Attributes:
22 | eps (float): A small value added to the denominator for numerical stability.
23 | weight (nn.Parameter): Learnable scaling parameter.
24 |
25 | """
26 | factory_kwargs = {"device": device, "dtype": dtype}
27 | super().__init__()
28 | self.eps = eps
29 | if elementwise_affine:
30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31 |
32 | def _norm(self, x):
33 | """
34 | Apply the RMSNorm normalization to the input tensor.
35 |
36 | Args:
37 | x (torch.Tensor): The input tensor.
38 |
39 | Returns:
40 | torch.Tensor: The normalized tensor.
41 |
42 | """
43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44 |
45 | def forward(self, x):
46 | """
47 | Forward pass through the RMSNorm layer.
48 |
49 | Args:
50 | x (torch.Tensor): The input tensor.
51 |
52 | Returns:
53 | torch.Tensor: The output tensor after applying RMSNorm.
54 |
55 | """
56 | output = self._norm(x.float()).type_as(x)
57 | if hasattr(self, "weight"):
58 | output = output * self.weight
59 | return output
60 |
61 |
62 | def get_norm_layer(norm_layer):
63 | """
64 | Get the normalization layer.
65 |
66 | Args:
67 | norm_layer (str): The type of normalization layer.
68 |
69 | Returns:
70 | norm_layer (nn.Module): The normalization layer.
71 | """
72 | if norm_layer == "layer":
73 | return nn.LayerNorm
74 | elif norm_layer == "rms":
75 | return RMSNorm
76 | else:
77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple, List
3 |
4 |
5 | def _to_tuple(x, dim=2):
6 | if isinstance(x, int):
7 | return (x,) * dim
8 | elif len(x) == dim:
9 | return x
10 | else:
11 | raise ValueError(f"Expected length {dim} or int, but got {x}")
12 |
13 |
14 | def get_meshgrid_nd(start, *args, dim=2):
15 | """
16 | Get n-D meshgrid with start, stop and num.
17 |
18 | Args:
19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22 | n-tuples.
23 | *args: See above.
24 | dim (int): Dimension of the meshgrid. Defaults to 2.
25 |
26 | Returns:
27 | grid (np.ndarray): [dim, ...]
28 | """
29 | if len(args) == 0:
30 | # start is grid_size
31 | num = _to_tuple(start, dim=dim)
32 | start = (0,) * dim
33 | stop = num
34 | elif len(args) == 1:
35 | # start is start, args[0] is stop, step is 1
36 | start = _to_tuple(start, dim=dim)
37 | stop = _to_tuple(args[0], dim=dim)
38 | num = [stop[i] - start[i] for i in range(dim)]
39 | elif len(args) == 2:
40 | # start is start, args[0] is stop, args[1] is num
41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44 | else:
45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46 |
47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48 | axis_grid = []
49 | for i in range(dim):
50 | a, b, n = start[i], stop[i], num[i]
51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52 | axis_grid.append(g)
53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55 |
56 | return grid
57 |
58 |
59 | #################################################################################
60 | # Rotary Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63 |
64 | def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000., use_real=False,
65 | theta_rescale_factor: Union[float, List[float]]=1.0,
66 | interpolation_factor: Union[float, List[float]]=1.0):
67 | """
68 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
69 |
70 | Args:
71 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
72 | sum(rope_dim_list) should equal to head_dim of attention layer.
73 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
74 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
75 | *args: See above.
76 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
77 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
78 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
79 | part and an imaginary part separately.
80 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
81 |
82 | Returns:
83 | pos_embed (torch.Tensor): [HW, D/2]
84 | """
85 |
86 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
87 |
88 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
89 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
90 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
91 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
92 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
93 |
94 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
95 | interpolation_factor = [interpolation_factor] * len(rope_dim_list)
96 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
97 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
98 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
99 |
100 | # use 1/ndim of dimensions to encode grid_axis
101 | embs = []
102 | for i in range(len(rope_dim_list)):
103 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
104 | theta_rescale_factor=theta_rescale_factor[i],
105 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
106 | embs.append(emb)
107 |
108 | if use_real:
109 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
110 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
111 | return cos, sin
112 | else:
113 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
114 | return emb
115 |
116 |
117 | def get_1d_rotary_pos_embed(dim: int,
118 | pos: Union[torch.FloatTensor, int],
119 | theta: float = 10000.0,
120 | use_real: bool = False,
121 | theta_rescale_factor: float = 1.0,
122 | interpolation_factor: float = 1.0,
123 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
124 | """
125 | Precompute the frequency tensor for complex exponential (cis) with given dimensions.
126 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
127 |
128 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
129 | and the end index 'end'. The 'theta' parameter scales the frequencies.
130 | The returned tensor contains complex values in complex64 data type.
131 |
132 | Args:
133 | dim (int): Dimension of the frequency tensor.
134 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
135 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
136 | use_real (bool, optional): If True, return real part and imaginary part separately.
137 | Otherwise, return complex numbers.
138 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
139 |
140 | Returns:
141 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
142 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
143 | """
144 | if isinstance(pos, int):
145 | pos = torch.arange(pos).float()
146 |
147 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
148 | # has some connection to NTK literature
149 | if theta_rescale_factor != 1.0:
150 | theta *= theta_rescale_factor ** (dim / (dim - 2))
151 |
152 | freqs = 1.0 / (
153 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
154 | ) # [D/2]
155 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
156 | if use_real:
157 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
158 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
159 | return freqs_cos, freqs_sin
160 | else:
161 | freqs_cis = torch.polar(
162 | torch.ones_like(freqs), freqs
163 | ) # complex64 # [S, D/2]
164 | return freqs_cis
165 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/modules/token_refiner.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from einops import rearrange
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .activation_layers import get_activation_layer
8 | from .attn_layers import attention
9 | from .norm_layers import get_norm_layer
10 | from .embed_layers import TimestepEmbedder, TextProjection
11 | from .attn_layers import attention
12 | from .mlp_layers import MLP
13 | from .modulate_layers import apply_gate
14 |
15 |
16 | class IndividualTokenRefinerBlock(nn.Module):
17 | def __init__(
18 | self,
19 | hidden_size,
20 | num_heads,
21 | mlp_ratio: str = 4.0,
22 | mlp_drop_rate: float = 0.0,
23 | act_type: str = "silu",
24 | qk_norm: bool = False,
25 | qk_norm_type: str = "layer",
26 | qkv_bias: bool = True,
27 | dtype: Optional[torch.dtype] = None,
28 | device: Optional[torch.device] = None,
29 | ):
30 | factory_kwargs = {'device': device, 'dtype': dtype}
31 | super().__init__()
32 | self.num_heads = num_heads
33 | head_dim = hidden_size // num_heads
34 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
35 |
36 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37 | self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38 | qk_norm_layer = get_norm_layer(qk_norm_type)
39 | self.self_attn_q_norm = (
40 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
41 | if qk_norm
42 | else nn.Identity()
43 | )
44 | self.self_attn_k_norm = (
45 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
46 | if qk_norm
47 | else nn.Identity()
48 | )
49 | self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
50 |
51 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
52 | act_layer = get_activation_layer(act_type)
53 | self.mlp = MLP(
54 | in_channels=hidden_size,
55 | hidden_channels=mlp_hidden_dim,
56 | act_layer=act_layer,
57 | drop=mlp_drop_rate,
58 | **factory_kwargs,
59 | )
60 |
61 | self.adaLN_modulation = nn.Sequential(
62 | act_layer(),
63 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
64 | )
65 | # Zero-initialize the modulation
66 | nn.init.zeros_(self.adaLN_modulation[1].weight)
67 | nn.init.zeros_(self.adaLN_modulation[1].bias)
68 |
69 | def forward(
70 | self,
71 | x: torch.Tensor,
72 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations
73 | attn_mask: torch.Tensor = None,
74 | ):
75 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
76 |
77 | norm_x = self.norm1(x)
78 | qkv = self.self_attn_qkv(norm_x)
79 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
80 | # Apply QK-Norm if needed
81 | q = self.self_attn_q_norm(q).to(v)
82 | k = self.self_attn_k_norm(k).to(v)
83 |
84 | # Self-Attention
85 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
86 |
87 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
88 |
89 | # FFN Layer
90 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
91 |
92 | return x
93 |
94 |
95 | class IndividualTokenRefiner(nn.Module):
96 | def __init__(
97 | self,
98 | hidden_size,
99 | num_heads,
100 | depth,
101 | mlp_ratio: float = 4.0,
102 | mlp_drop_rate: float = 0.0,
103 | act_type: str = "silu",
104 | qk_norm: bool = False,
105 | qk_norm_type: str = "layer",
106 | qkv_bias: bool = True,
107 | dtype: Optional[torch.dtype] = None,
108 | device: Optional[torch.device] = None,
109 | ):
110 | factory_kwargs = {'device': device, 'dtype': dtype}
111 | super().__init__()
112 | self.blocks = nn.ModuleList([
113 | IndividualTokenRefinerBlock(
114 | hidden_size=hidden_size,
115 | num_heads=num_heads,
116 | mlp_ratio=mlp_ratio,
117 | mlp_drop_rate=mlp_drop_rate,
118 | act_type=act_type,
119 | qk_norm=qk_norm,
120 | qk_norm_type=qk_norm_type,
121 | qkv_bias=qkv_bias,
122 | **factory_kwargs,
123 | ) for _ in range(depth)
124 | ])
125 |
126 | def forward(
127 | self,
128 | x: torch.Tensor,
129 | c: torch.LongTensor,
130 | mask: Optional[torch.Tensor] = None,
131 | ):
132 | self_attn_mask = None
133 | if mask is not None:
134 | batch_size = mask.shape[0]
135 | seq_len = mask.shape[1]
136 | mask = mask.to(x.device)
137 | # batch_size x 1 x seq_len x seq_len
138 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
139 | # batch_size x 1 x seq_len x seq_len
140 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
141 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
142 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
143 | # avoids self-attention weight being NaN for padding tokens
144 | self_attn_mask[:, :, :, 0] = True
145 |
146 | for block in self.blocks:
147 | x = block(x, c, self_attn_mask)
148 | return x
149 |
150 |
151 | class SingleTokenRefiner(nn.Module):
152 | def __init__(
153 | self,
154 | in_channels,
155 | hidden_size,
156 | num_heads,
157 | depth,
158 | mlp_ratio: float = 4.0,
159 | mlp_drop_rate: float = 0.0,
160 | act_type: str = "silu",
161 | qk_norm: bool = False,
162 | qk_norm_type: str = "layer",
163 | qkv_bias: bool = True,
164 | dtype: Optional[torch.dtype] = None,
165 | device: Optional[torch.device] = None,
166 | ):
167 | factory_kwargs = {'device': device, 'dtype': dtype}
168 | super().__init__()
169 |
170 | self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
171 |
172 | act_layer = get_activation_layer(act_type)
173 | # Build timestep embedding layer
174 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
175 | # Build context embedding layer
176 | self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
177 |
178 | self.individual_token_refiner = IndividualTokenRefiner(
179 | hidden_size=hidden_size,
180 | num_heads=num_heads,
181 | depth=depth,
182 | mlp_ratio=mlp_ratio,
183 | mlp_drop_rate=mlp_drop_rate,
184 | act_type=act_type,
185 | qk_norm=qk_norm,
186 | qk_norm_type=qk_norm_type,
187 | qkv_bias=qkv_bias,
188 | **factory_kwargs
189 | )
190 |
191 | def forward(
192 | self,
193 | x: torch.Tensor,
194 | t: torch.LongTensor,
195 | mask: Optional[torch.LongTensor] = None,
196 | ):
197 | timestep_aware_representations = self.t_embedder(t)
198 |
199 | if mask is None:
200 | context_aware_representations = x.mean(dim=1)
201 | else:
202 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
203 | context_aware_representations = (
204 | (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
205 | )
206 | context_aware_representations = self.c_embedder(context_aware_representations)
207 | c = timestep_aware_representations + context_aware_representations
208 |
209 | x = self.input_embedder(x)
210 |
211 | x = self.individual_token_refiner(x, c, mask)
212 |
213 | return x
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/sample_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from loguru import logger
4 | import torch
5 | from einops import rearrange
6 | import torch.distributed
7 | from torch.utils.data.distributed import DistributedSampler
8 | from torch.utils.data import DataLoader
9 | from hymm_sp.config import parse_args
10 | from hymm_sp.sample_inference import HunyuanVideoSampler
11 | from hymm_sp.data_kits.video_dataset import JsonDataset
12 | from hymm_sp.data_kits.data_tools import save_videos_grid
13 | from hymm_sp.modules.parallel_states import (
14 | initialize_distributed,
15 | nccl_info,
16 | )
17 |
18 | def main():
19 | args = parse_args()
20 | models_root_path = Path(args.ckpt)
21 | print("*"*20)
22 | initialize_distributed(args.seed)
23 | if not models_root_path.exists():
24 | raise ValueError(f"`models_root` not exists: {models_root_path}")
25 | print("+"*20)
26 | # Create save folder to save the samples
27 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
28 | if not os.path.exists(args.save_path):
29 | os.makedirs(save_path, exist_ok=True)
30 |
31 | # Load models
32 | rank = 0
33 | vae_dtype = torch.float16
34 | device = torch.device("cuda")
35 | if nccl_info.sp_size > 1:
36 | device = torch.device(f"cuda:{torch.distributed.get_rank()}")
37 | rank = torch.distributed.get_rank()
38 |
39 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device)
40 | # Get the updated args
41 | args = hunyuan_video_sampler.args
42 |
43 | json_dataset = JsonDataset(args)
44 | sampler = DistributedSampler(json_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False)
45 | json_loader = DataLoader(json_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False)
46 | for batch_index, batch in enumerate(json_loader, start=1):
47 | pixel_value_llava = batch['pixel_value_llava'].to(device)
48 | pixel_value_ref = batch['pixel_value_ref'].to(device)
49 | uncond_pixel_value_llava = batch['uncond_pixel_value_llava']
50 | prompt = batch['prompt'][0]
51 | negative_prompt = batch['negative_prompt'][0]
52 | name = batch['name'][0]
53 | save_name = batch['data_name'][0]
54 | seed = batch['seed']
55 | pixel_value_ref = pixel_value_ref * 2 - 1.
56 | pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w")
57 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
58 | ref_latents = hunyuan_video_sampler.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
59 | uncond_ref_latents = hunyuan_video_sampler.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample()
60 | ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
61 | uncond_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
62 |
63 | prompt = args.add_pos_prompt + prompt
64 | negative_prompt = args.add_neg_prompt + negative_prompt
65 | outputs = hunyuan_video_sampler.predict(
66 | prompt=prompt,
67 | name=name,
68 | size=args.video_size,
69 | seed=seed,
70 | pixel_value_llava=pixel_value_llava,
71 | uncond_pixel_value_llava=uncond_pixel_value_llava,
72 | ref_latents=ref_latents,
73 | uncond_ref_latents=uncond_ref_latents,
74 | video_length=args.sample_n_frames,
75 | guidance_scale=args.cfg_scale,
76 | num_images_per_prompt=args.num_images,
77 | negative_prompt=negative_prompt,
78 | infer_steps=args.infer_steps,
79 | flow_shift=args.flow_shift_eval_video,
80 | use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
81 | linear_schedule_end=args.linear_schedule_end,
82 | use_deepcache=args.use_deepcache,
83 | )
84 |
85 | if rank == 0:
86 | samples = outputs['samples']
87 | for i, sample in enumerate(samples):
88 | sample = samples[i].unsqueeze(0)
89 | out_path = f"{save_path}/{save_name}.mp4"
90 | save_videos_grid(sample, out_path, fps=25)
91 | logger.info(f'Sample save to: {out_path}')
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/sample_gpu_poor.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from loguru import logger
4 | import torch
5 | from einops import rearrange
6 | import torch.distributed
7 | from torch.utils.data.distributed import DistributedSampler
8 | from torch.utils.data import DataLoader
9 | from hymm_sp.config import parse_args
10 | from hymm_sp.sample_inference import HunyuanVideoSampler
11 | from hymm_sp.data_kits.video_dataset import JsonDataset
12 | from hymm_sp.data_kits.data_tools import save_videos_grid
13 |
14 |
15 | def main():
16 | args = parse_args()
17 | models_root_path = Path(args.ckpt)
18 |
19 | if not models_root_path.exists():
20 | raise ValueError(f"`models_root` not exists: {models_root_path}")
21 |
22 | # Create save folder to save the samples
23 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
24 | if not os.path.exists(args.save_path):
25 | os.makedirs(save_path, exist_ok=True)
26 |
27 | # Load models
28 | rank = 0
29 | vae_dtype = torch.float16
30 | device = torch.device("cuda")
31 |
32 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device)
33 | # Get the updated args
34 | args = hunyuan_video_sampler.args
35 | if args.cpu_offload:
36 | from diffusers.hooks import apply_group_offloading
37 | onload_device = torch.device("cuda")
38 | apply_group_offloading(hunyuan_video_sampler.pipeline.transformer, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=1)
39 |
40 | json_dataset = JsonDataset(args)
41 | sampler = DistributedSampler(json_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False)
42 | json_loader = DataLoader(json_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False)
43 | for batch_index, batch in enumerate(json_loader, start=1):
44 | pixel_value_llava = batch['pixel_value_llava'].to(device)
45 | pixel_value_ref = batch['pixel_value_ref'].to(device)
46 | uncond_pixel_value_llava = batch['uncond_pixel_value_llava']
47 | prompt = batch['prompt'][0]
48 | negative_prompt = batch['negative_prompt'][0]
49 | name = batch['name'][0]
50 | save_name = batch['data_name'][0]
51 | seed = batch['seed']
52 | pixel_value_ref = pixel_value_ref * 2 - 1.
53 | pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w")
54 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
55 | if args.cpu_offload:
56 | hunyuan_video_sampler.vae.to('cuda')
57 | ref_latents = hunyuan_video_sampler.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
58 | uncond_ref_latents = hunyuan_video_sampler.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample()
59 | ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
60 | uncond_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
61 | if args.cpu_offload:
62 | hunyuan_video_sampler.vae.to('cpu')
63 | torch.cuda.empty_cache()
64 |
65 | prompt = args.add_pos_prompt + prompt
66 | negative_prompt = args.add_neg_prompt + negative_prompt
67 | outputs = hunyuan_video_sampler.predict(
68 | prompt=prompt,
69 | name=name,
70 | size=args.video_size,
71 | seed=seed,
72 | pixel_value_llava=pixel_value_llava,
73 | uncond_pixel_value_llava=uncond_pixel_value_llava,
74 | ref_latents=ref_latents,
75 | uncond_ref_latents=uncond_ref_latents,
76 | video_length=args.sample_n_frames,
77 | guidance_scale=args.cfg_scale,
78 | num_images_per_prompt=args.num_images,
79 | negative_prompt=negative_prompt,
80 | infer_steps=args.infer_steps,
81 | flow_shift=args.flow_shift_eval_video,
82 | use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
83 | linear_schedule_end=args.linear_schedule_end,
84 | use_deepcache=args.use_deepcache,
85 | cpu_offload=args.cpu_offload,
86 | )
87 |
88 | if rank == 0:
89 | samples = outputs['samples']
90 | for i, sample in enumerate(samples):
91 | sample = samples[i].unsqueeze(0)
92 | out_path = f"{save_path}/{save_name}_4090.mp4"
93 | save_videos_grid(sample, out_path, fps=25)
94 | logger.info(f'Sample save to: {out_path}')
95 |
96 |
97 | if __name__ == "__main__":
98 | main()
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/hunyuan_custom/hymm_sp/vae/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
4 | from ..constants import VAE_PATH, PRECISION_TO_TYPE
5 |
6 | def load_vae(vae_type,
7 | vae_precision=None,
8 | sample_size=None,
9 | vae_path=None,
10 | logger=None,
11 | device=None
12 | ):
13 | if vae_path is None:
14 | vae_path = VAE_PATH[vae_type]
15 | vae_compress_spec, _, _ = vae_type.split("-")
16 | length = len(vae_compress_spec)
17 | if length == 3:
18 | if logger is not None:
19 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
20 | config = AutoencoderKLCausal3D.load_config(vae_path)
21 | if sample_size:
22 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
23 | else:
24 | vae = AutoencoderKLCausal3D.from_config(config)
25 | ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
26 | if "state_dict" in ckpt:
27 | ckpt = ckpt["state_dict"]
28 | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
29 | vae.load_state_dict(vae_ckpt)
30 |
31 | spatial_compression_ratio = vae.config.spatial_compression_ratio
32 | time_compression_ratio = vae.config.time_compression_ratio
33 | else:
34 | raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")
35 |
36 | if vae_precision is not None:
37 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
38 |
39 | vae.requires_grad_(False)
40 |
41 | if logger is not None:
42 | logger.info(f"VAE to dtype: {vae.dtype}")
43 |
44 | if device is not None:
45 | vae = vae.to(device)
46 |
47 | # Set vae to eval mode, even though it's dropout rate is 0.
48 | vae.eval()
49 |
50 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio
51 |
--------------------------------------------------------------------------------
/hunyuan_custom/models/README.md:
--------------------------------------------------------------------------------
1 | # Download Pretrained Models
2 |
3 | All models are stored in `HunyuanCustom/models` by default, and the file structure is as follows
4 | ```shell
5 | HunyuanCustom
6 | ├──models
7 | │ ├──README.md
8 | │ ├──hunyuancustom_720P
9 | │ │ ├──mp_rank_00_model_states.pt
10 | │ │ │──mp_rank_00_model_states_fp8.pt
11 | │ │ ├──mp_rank_00_model_states_fp8_map.pt
12 | ├ ├──vae_3d
13 | │ ├──openai_clip-vit-large-patch14
14 | │ ├──llava-llama-3-8b-v1_1
15 | ├──...
16 | ```
17 |
18 | ## Download HunyuanCustom model
19 | To download the HunyuanCustom model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
20 |
21 | ```shell
22 | python -m pip install "huggingface_hub[cli]"
23 | ```
24 |
25 | Then download the model using the following commands:
26 |
27 | ```shell
28 | # Switch to the directory named 'HunyuanCustom'
29 | cd HunyuanCustom
30 | # Use the huggingface-cli tool to download HunyuanCustom model in HunyuanCustom/models dir.
31 | # The download time may vary from 10 minutes to 1 hour depending on network conditions.
32 | huggingface-cli download tencent/HunyuanCustom --local-dir ./
33 | ```
34 |
--------------------------------------------------------------------------------
/hunyuan_custom/run-single-video-8xA100.sh:
--------------------------------------------------------------------------------
1 |
2 | export MODEL_BASE="./models"
3 | export PYTHONPATH=./
4 |
5 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
6 | --input './assets/images/seg_woman_01.png' \
7 | --pos-prompt "Realistic, High-quality. A woman is drinking coffee at a café." \
8 | --neg-prompt "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
9 | --ckpt ${MODEL_BASE}"/hunyuancustom_720P/mp_rank_00_model_states.pt" \
10 | --video-size 768 1280 \
11 | --seed 1024 \
12 | --sample-n-frames 129 \
13 | --infer-steps 30 \
14 | --flow-shift-eval-video 13.0 \
15 | --save-path './results/sp_768p'
16 |
17 |
--------------------------------------------------------------------------------
/wan/run-single-inference.sh:
--------------------------------------------------------------------------------
1 |
2 | ckpt_dir="/mnt/localssd/wan/Wan2.1-T2V-14B"
3 |
4 |
5 | export CUDA_VISIBLE_DEVICES=0
6 |
7 | # --size 1280*768
8 | # --size 768*512
9 |
10 | python3 -u generate.py \
11 | --task t2v-14B \
12 | --size 768*512 \
13 | --ckpt_dir $ckpt_dir \
14 | --prompt "A giant panda is walking."
15 |
16 | # --prompt "warm colors dominate the room, with a focus on the tabby cat sitting contently in the center. the scene captures the fluffy orange tabby cat wearing a tiny virtual reality headset. the setting is a cozy living room, adorned with soft, warm lighting and a modern aesthetic. a plush sofa is visible in the background, along with a few lush potted plants, adding a touch of greenery. the cat's tail flicks curiously, as if engaging with an unseen virtual environment. its paws swipe at the air, indicating a playful and inquisitive nature, as it delves into the digital realm. the atmosphere is both whimsical and futuristic, highlighting the blend of analog and digital experiences."
17 |
18 | # --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
19 |
20 |
--------------------------------------------------------------------------------
/wan/wan/__init__.py:
--------------------------------------------------------------------------------
1 | from . import configs, distributed, modules
2 | from .image2video import WanI2V
3 | from .text2video import WanT2V
4 | from .first_last_frame2video import WanFLF2V
5 |
--------------------------------------------------------------------------------
/wan/wan/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import copy
3 | import os
4 |
5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6 |
7 | from .wan_i2v_14B import i2v_14B
8 | from .wan_t2v_1_3B import t2v_1_3B
9 | from .wan_t2v_14B import t2v_14B
10 |
11 | # the config of t2i_14B is the same as t2v_14B
12 | t2i_14B = copy.deepcopy(t2v_14B)
13 | t2i_14B.__name__ = 'Config: Wan T2I 14B'
14 |
15 | # the config of flf2v_14B is the same as i2v_14B
16 | flf2v_14B = copy.deepcopy(i2v_14B)
17 | flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
18 | flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
19 |
20 | WAN_CONFIGS = {
21 | 't2v-14B': t2v_14B,
22 | 't2v-1.3B': t2v_1_3B,
23 | 'i2v-14B': i2v_14B,
24 | 't2i-14B': t2i_14B,
25 | 'flf2v-14B': flf2v_14B
26 | }
27 |
28 | SIZE_CONFIGS = {
29 | '720*1280': (720, 1280),
30 | '1280*720': (1280, 720),
31 | '480*832': (480, 832),
32 | '832*480': (832, 480),
33 | '1024*1024': (1024, 1024),
34 | '768*512': (768, 512), # xuan: add
35 | '1280*768': (1280, 768), # xuan: add
36 | }
37 |
38 | MAX_AREA_CONFIGS = {
39 | '720*1280': 720 * 1280,
40 | '1280*720': 1280 * 720,
41 | '480*832': 480 * 832,
42 | '832*480': 832 * 480,
43 | }
44 |
45 | SUPPORTED_SIZES = {
46 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
47 | 't2v-1.3B': ('480*832', '832*480'),
48 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
49 | 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
50 | 't2i-14B': tuple(SIZE_CONFIGS.keys()),
51 | }
52 |
--------------------------------------------------------------------------------
/wan/wan/configs/shared_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | from easydict import EasyDict
4 |
5 | #------------------------ Wan shared config ------------------------#
6 | wan_shared_cfg = EasyDict()
7 |
8 | # t5
9 | wan_shared_cfg.t5_model = 'umt5_xxl'
10 | wan_shared_cfg.t5_dtype = torch.bfloat16
11 | wan_shared_cfg.text_len = 512
12 |
13 | # transformer
14 | wan_shared_cfg.param_dtype = torch.bfloat16
15 |
16 | # inference
17 | wan_shared_cfg.num_train_timesteps = 1000
18 | wan_shared_cfg.sample_fps = 16
19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20 |
--------------------------------------------------------------------------------
/wan/wan/configs/wan_i2v_14B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | from easydict import EasyDict
4 |
5 | from .shared_config import wan_shared_cfg
6 |
7 | #------------------------ Wan I2V 14B ------------------------#
8 |
9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10 | i2v_14B.update(wan_shared_cfg)
11 | i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
12 |
13 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
14 | i2v_14B.t5_tokenizer = 'google/umt5-xxl'
15 |
16 | # clip
17 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
18 | i2v_14B.clip_dtype = torch.float16
19 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
20 | i2v_14B.clip_tokenizer = 'xlm-roberta-large'
21 |
22 | # vae
23 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
24 | i2v_14B.vae_stride = (4, 8, 8)
25 |
26 | # transformer
27 | i2v_14B.patch_size = (1, 2, 2)
28 | i2v_14B.dim = 5120
29 | i2v_14B.ffn_dim = 13824
30 | i2v_14B.freq_dim = 256
31 | i2v_14B.num_heads = 40
32 | i2v_14B.num_layers = 40
33 | i2v_14B.window_size = (-1, -1)
34 | i2v_14B.qk_norm = True
35 | i2v_14B.cross_attn_norm = True
36 | i2v_14B.eps = 1e-6
37 |
--------------------------------------------------------------------------------
/wan/wan/configs/wan_t2v_14B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | from easydict import EasyDict
3 |
4 | from .shared_config import wan_shared_cfg
5 |
6 | #------------------------ Wan T2V 14B ------------------------#
7 |
8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9 | t2v_14B.update(wan_shared_cfg)
10 |
11 | # t5
12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14 |
15 | # vae
16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_14B.vae_stride = (4, 8, 8)
18 |
19 | # transformer
20 | t2v_14B.patch_size = (1, 2, 2)
21 | t2v_14B.dim = 5120
22 | t2v_14B.ffn_dim = 13824
23 | t2v_14B.freq_dim = 256
24 | t2v_14B.num_heads = 40
25 | t2v_14B.num_layers = 40
26 | t2v_14B.window_size = (-1, -1)
27 | t2v_14B.qk_norm = True
28 | t2v_14B.cross_attn_norm = True
29 | t2v_14B.eps = 1e-6
30 |
--------------------------------------------------------------------------------
/wan/wan/configs/wan_t2v_1_3B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | from easydict import EasyDict
3 |
4 | from .shared_config import wan_shared_cfg
5 |
6 | #------------------------ Wan T2V 1.3B ------------------------#
7 |
8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
9 | t2v_1_3B.update(wan_shared_cfg)
10 |
11 | # t5
12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14 |
15 | # vae
16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_1_3B.vae_stride = (4, 8, 8)
18 |
19 | # transformer
20 | t2v_1_3B.patch_size = (1, 2, 2)
21 | t2v_1_3B.dim = 1536
22 | t2v_1_3B.ffn_dim = 8960
23 | t2v_1_3B.freq_dim = 256
24 | t2v_1_3B.num_heads = 12
25 | t2v_1_3B.num_layers = 30
26 | t2v_1_3B.window_size = (-1, -1)
27 | t2v_1_3B.qk_norm = True
28 | t2v_1_3B.cross_attn_norm = True
29 | t2v_1_3B.eps = 1e-6
30 |
--------------------------------------------------------------------------------
/wan/wan/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/wan/wan/distributed/__init__.py
--------------------------------------------------------------------------------
/wan/wan/distributed/fsdp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import gc
3 | from functools import partial
4 |
5 | import torch
6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9 | from torch.distributed.utils import _free_storage
10 |
11 | def shard_model(
12 | model,
13 | device_id,
14 | param_dtype=torch.bfloat16,
15 | reduce_dtype=torch.float32,
16 | buffer_dtype=torch.float32,
17 | process_group=None,
18 | sharding_strategy=ShardingStrategy.FULL_SHARD,
19 | sync_module_states=True,
20 | ):
21 | model = FSDP(
22 | module=model,
23 | process_group=process_group,
24 | sharding_strategy=sharding_strategy,
25 | auto_wrap_policy=partial(
26 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
27 | mixed_precision=MixedPrecision(
28 | param_dtype=param_dtype,
29 | reduce_dtype=reduce_dtype,
30 | buffer_dtype=buffer_dtype),
31 | device_id=device_id,
32 | sync_module_states=sync_module_states)
33 | return model
34 |
35 | def free_model(model):
36 | for m in model.modules():
37 | if isinstance(m, FSDP):
38 | _free_storage(m._handle.flat_param.data)
39 | del model
40 | gc.collect()
41 | torch.cuda.empty_cache()
42 |
--------------------------------------------------------------------------------
/wan/wan/distributed/xdit_context_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | import torch.cuda.amp as amp
4 | from xfuser.core.distributed import (get_sequence_parallel_rank,
5 | get_sequence_parallel_world_size,
6 | get_sp_group)
7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8 |
9 | from ..modules.model import sinusoidal_embedding_1d
10 |
11 |
12 | def pad_freqs(original_tensor, target_len):
13 | seq_len, s1, s2 = original_tensor.shape
14 | pad_size = target_len - seq_len
15 | padding_tensor = torch.ones(
16 | pad_size,
17 | s1,
18 | s2,
19 | dtype=original_tensor.dtype,
20 | device=original_tensor.device)
21 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22 | return padded_tensor
23 |
24 |
25 | @amp.autocast(enabled=False)
26 | def rope_apply(x, grid_sizes, freqs):
27 | """
28 | x: [B, L, N, C].
29 | grid_sizes: [B, 3].
30 | freqs: [M, C // 2].
31 | """
32 | s, n, c = x.size(1), x.size(2), x.size(3) // 2
33 | # split freqs
34 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35 |
36 | # loop over samples
37 | output = []
38 | for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39 | seq_len = f * h * w
40 |
41 | # precompute multipliers
42 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43 | s, n, -1, 2))
44 | freqs_i = torch.cat([
45 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48 | ],
49 | dim=-1).reshape(seq_len, 1, -1)
50 |
51 | # apply rotary embedding
52 | sp_size = get_sequence_parallel_world_size()
53 | sp_rank = get_sequence_parallel_rank()
54 | freqs_i = pad_freqs(freqs_i, s * sp_size)
55 | s_per_rank = s
56 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57 | s_per_rank), :, :]
58 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59 | x_i = torch.cat([x_i, x[i, s:]])
60 |
61 | # append to collection
62 | output.append(x_i)
63 | return torch.stack(output).float()
64 |
65 |
66 | def usp_dit_forward(
67 | self,
68 | x,
69 | t,
70 | context,
71 | seq_len,
72 | clip_fea=None,
73 | y=None,
74 | ):
75 | """
76 | x: A list of videos each with shape [C, T, H, W].
77 | t: [B].
78 | context: A list of text embeddings each with shape [L, C].
79 | """
80 | if self.model_type == 'i2v':
81 | assert clip_fea is not None and y is not None
82 | # params
83 | device = self.patch_embedding.weight.device
84 | if self.freqs.device != device:
85 | self.freqs = self.freqs.to(device)
86 |
87 | if y is not None:
88 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89 |
90 | # embeddings
91 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92 | grid_sizes = torch.stack(
93 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94 | x = [u.flatten(2).transpose(1, 2) for u in x]
95 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96 | assert seq_lens.max() <= seq_len
97 | x = torch.cat([
98 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99 | for u in x
100 | ])
101 |
102 | # time embeddings
103 | with amp.autocast(dtype=torch.float32):
104 | e = self.time_embedding(
105 | sinusoidal_embedding_1d(self.freq_dim, t).float())
106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim))
107 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
108 |
109 | # context
110 | context_lens = None
111 | context = self.text_embedding(
112 | torch.stack([
113 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
114 | for u in context
115 | ]))
116 |
117 | if clip_fea is not None:
118 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
119 | context = torch.concat([context_clip, context], dim=1)
120 |
121 | # arguments
122 | kwargs = dict(
123 | e=e0,
124 | seq_lens=seq_lens,
125 | grid_sizes=grid_sizes,
126 | freqs=self.freqs,
127 | context=context,
128 | context_lens=context_lens)
129 |
130 | # Context Parallel
131 | x = torch.chunk(
132 | x, get_sequence_parallel_world_size(),
133 | dim=1)[get_sequence_parallel_rank()]
134 |
135 | for block in self.blocks:
136 | x = block(x, **kwargs)
137 |
138 | # head
139 | x = self.head(x, e)
140 |
141 | # Context Parallel
142 | x = get_sp_group().all_gather(x, dim=1)
143 |
144 | # unpatchify
145 | x = self.unpatchify(x, grid_sizes)
146 | return [u.float() for u in x]
147 |
148 |
149 | def usp_attn_forward(self,
150 | x,
151 | seq_lens,
152 | grid_sizes,
153 | freqs,
154 | dtype=torch.bfloat16):
155 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
156 | half_dtypes = (torch.float16, torch.bfloat16)
157 |
158 | def half(x):
159 | return x if x.dtype in half_dtypes else x.to(dtype)
160 |
161 | # query, key, value function
162 | def qkv_fn(x):
163 | q = self.norm_q(self.q(x)).view(b, s, n, d)
164 | k = self.norm_k(self.k(x)).view(b, s, n, d)
165 | v = self.v(x).view(b, s, n, d)
166 | return q, k, v
167 |
168 | q, k, v = qkv_fn(x)
169 | q = rope_apply(q, grid_sizes, freqs)
170 | k = rope_apply(k, grid_sizes, freqs)
171 |
172 | # TODO: We should use unpaded q,k,v for attention.
173 | # k_lens = seq_lens // get_sequence_parallel_world_size()
174 | # if k_lens is not None:
175 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
176 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
177 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
178 |
179 | x = xFuserLongContextAttention()(
180 | None,
181 | query=half(q),
182 | key=half(k),
183 | value=half(v),
184 | window_size=self.window_size)
185 |
186 | # TODO: padding after attention.
187 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
188 |
189 | # output
190 | x = x.flatten(2)
191 | x = self.o(x)
192 | return x
193 |
--------------------------------------------------------------------------------
/wan/wan/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import flash_attention
2 | from .model import WanModel
3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4 | from .tokenizers import HuggingfaceTokenizer
5 | from .vae import WanVAE
6 |
7 | __all__ = [
8 | 'WanVAE',
9 | 'WanModel',
10 | 'T5Model',
11 | 'T5Encoder',
12 | 'T5Decoder',
13 | 'T5EncoderModel',
14 | 'HuggingfaceTokenizer',
15 | 'flash_attention',
16 | ]
17 |
--------------------------------------------------------------------------------
/wan/wan/modules/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import html
3 | import string
4 |
5 | import ftfy
6 | import regex as re
7 | from transformers import AutoTokenizer
8 |
9 | __all__ = ['HuggingfaceTokenizer']
10 |
11 |
12 | def basic_clean(text):
13 | text = ftfy.fix_text(text)
14 | text = html.unescape(html.unescape(text))
15 | return text.strip()
16 |
17 |
18 | def whitespace_clean(text):
19 | text = re.sub(r'\s+', ' ', text)
20 | text = text.strip()
21 | return text
22 |
23 |
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 | text = text.replace('_', ' ')
26 | if keep_punctuation_exact_string:
27 | text = keep_punctuation_exact_string.join(
28 | part.translate(str.maketrans('', '', string.punctuation))
29 | for part in text.split(keep_punctuation_exact_string))
30 | else:
31 | text = text.translate(str.maketrans('', '', string.punctuation))
32 | text = text.lower()
33 | text = re.sub(r'\s+', ' ', text)
34 | return text.strip()
35 |
36 |
37 | class HuggingfaceTokenizer:
38 |
39 | def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41 | self.name = name
42 | self.seq_len = seq_len
43 | self.clean = clean
44 |
45 | # init tokenizer
46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 | self.vocab_size = self.tokenizer.vocab_size
48 |
49 | def __call__(self, sequence, **kwargs):
50 | return_mask = kwargs.pop('return_mask', False)
51 |
52 | # arguments
53 | _kwargs = {'return_tensors': 'pt'}
54 | if self.seq_len is not None:
55 | _kwargs.update({
56 | 'padding': 'max_length',
57 | 'truncation': True,
58 | 'max_length': self.seq_len
59 | })
60 | _kwargs.update(**kwargs)
61 |
62 | # tokenization
63 | if isinstance(sequence, str):
64 | sequence = [sequence]
65 | if self.clean:
66 | sequence = [self._clean(u) for u in sequence]
67 | ids = self.tokenizer(sequence, **_kwargs)
68 |
69 | # output
70 | if return_mask:
71 | return ids.input_ids, ids.attention_mask
72 | else:
73 | return ids.input_ids
74 |
75 | def _clean(self, text):
76 | if self.clean == 'whitespace':
77 | text = whitespace_clean(basic_clean(text))
78 | elif self.clean == 'lower':
79 | text = whitespace_clean(basic_clean(text)).lower()
80 | elif self.clean == 'canonicalize':
81 | text = canonicalize(basic_clean(text))
82 | return text
83 |
--------------------------------------------------------------------------------
/wan/wan/modules/xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | __all__ = ['XLMRoberta', 'xlm_roberta_large']
8 |
9 |
10 | class SelfAttention(nn.Module):
11 |
12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13 | assert dim % num_heads == 0
14 | super().__init__()
15 | self.dim = dim
16 | self.num_heads = num_heads
17 | self.head_dim = dim // num_heads
18 | self.eps = eps
19 |
20 | # layers
21 | self.q = nn.Linear(dim, dim)
22 | self.k = nn.Linear(dim, dim)
23 | self.v = nn.Linear(dim, dim)
24 | self.o = nn.Linear(dim, dim)
25 | self.dropout = nn.Dropout(dropout)
26 |
27 | def forward(self, x, mask):
28 | """
29 | x: [B, L, C].
30 | """
31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32 |
33 | # compute query, key, value
34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37 |
38 | # compute attention
39 | p = self.dropout.p if self.training else 0.0
40 | x = F.scaled_dot_product_attention(q, k, v, mask, p)
41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42 |
43 | # output
44 | x = self.o(x)
45 | x = self.dropout(x)
46 | return x
47 |
48 |
49 | class AttentionBlock(nn.Module):
50 |
51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52 | super().__init__()
53 | self.dim = dim
54 | self.num_heads = num_heads
55 | self.post_norm = post_norm
56 | self.eps = eps
57 |
58 | # layers
59 | self.attn = SelfAttention(dim, num_heads, dropout, eps)
60 | self.norm1 = nn.LayerNorm(dim, eps=eps)
61 | self.ffn = nn.Sequential(
62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63 | nn.Dropout(dropout))
64 | self.norm2 = nn.LayerNorm(dim, eps=eps)
65 |
66 | def forward(self, x, mask):
67 | if self.post_norm:
68 | x = self.norm1(x + self.attn(x, mask))
69 | x = self.norm2(x + self.ffn(x))
70 | else:
71 | x = x + self.attn(self.norm1(x), mask)
72 | x = x + self.ffn(self.norm2(x))
73 | return x
74 |
75 |
76 | class XLMRoberta(nn.Module):
77 | """
78 | XLMRobertaModel with no pooler and no LM head.
79 | """
80 |
81 | def __init__(self,
82 | vocab_size=250002,
83 | max_seq_len=514,
84 | type_size=1,
85 | pad_id=1,
86 | dim=1024,
87 | num_heads=16,
88 | num_layers=24,
89 | post_norm=True,
90 | dropout=0.1,
91 | eps=1e-5):
92 | super().__init__()
93 | self.vocab_size = vocab_size
94 | self.max_seq_len = max_seq_len
95 | self.type_size = type_size
96 | self.pad_id = pad_id
97 | self.dim = dim
98 | self.num_heads = num_heads
99 | self.num_layers = num_layers
100 | self.post_norm = post_norm
101 | self.eps = eps
102 |
103 | # embeddings
104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105 | self.type_embedding = nn.Embedding(type_size, dim)
106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107 | self.dropout = nn.Dropout(dropout)
108 |
109 | # blocks
110 | self.blocks = nn.ModuleList([
111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112 | for _ in range(num_layers)
113 | ])
114 |
115 | # norm layer
116 | self.norm = nn.LayerNorm(dim, eps=eps)
117 |
118 | def forward(self, ids):
119 | """
120 | ids: [B, L] of torch.LongTensor.
121 | """
122 | b, s = ids.shape
123 | mask = ids.ne(self.pad_id).long()
124 |
125 | # embeddings
126 | x = self.token_embedding(ids) + \
127 | self.type_embedding(torch.zeros_like(ids)) + \
128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129 | if self.post_norm:
130 | x = self.norm(x)
131 | x = self.dropout(x)
132 |
133 | # blocks
134 | mask = torch.where(
135 | mask.view(b, 1, 1, s).gt(0), 0.0,
136 | torch.finfo(x.dtype).min)
137 | for block in self.blocks:
138 | x = block(x, mask)
139 |
140 | # output
141 | if not self.post_norm:
142 | x = self.norm(x)
143 | return x
144 |
145 |
146 | def xlm_roberta_large(pretrained=False,
147 | return_tokenizer=False,
148 | device='cpu',
149 | **kwargs):
150 | """
151 | XLMRobertaLarge adapted from Huggingface.
152 | """
153 | # params
154 | cfg = dict(
155 | vocab_size=250002,
156 | max_seq_len=514,
157 | type_size=1,
158 | pad_id=1,
159 | dim=1024,
160 | num_heads=16,
161 | num_layers=24,
162 | post_norm=True,
163 | dropout=0.1,
164 | eps=1e-5)
165 | cfg.update(**kwargs)
166 |
167 | # init a model on device
168 | with torch.device(device):
169 | model = XLMRoberta(**cfg)
170 | return model
171 |
--------------------------------------------------------------------------------
/wan/wan/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
2 | retrieve_timesteps)
3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler
4 |
5 | __all__ = [
6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
8 | ]
9 |
--------------------------------------------------------------------------------
/wan/wan/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import argparse
3 | import binascii
4 | import os
5 | import os.path as osp
6 |
7 | import imageio
8 | import torch
9 | import torchvision
10 |
11 | __all__ = ['cache_video', 'cache_image', 'str2bool']
12 |
13 |
14 | def rand_name(length=8, suffix=''):
15 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
16 | if suffix:
17 | if not suffix.startswith('.'):
18 | suffix = '.' + suffix
19 | name += suffix
20 | return name
21 |
22 |
23 | def cache_video(tensor,
24 | save_file=None,
25 | fps=30,
26 | suffix='.mp4',
27 | nrow=8,
28 | normalize=True,
29 | value_range=(-1, 1),
30 | retry=5):
31 | # cache file
32 | cache_file = osp.join('/tmp', rand_name(
33 | suffix=suffix)) if save_file is None else save_file
34 |
35 | # save to cache
36 | error = None
37 | for _ in range(retry):
38 | try:
39 | # preprocess
40 | tensor = tensor.clamp(min(value_range), max(value_range))
41 | tensor = torch.stack([
42 | torchvision.utils.make_grid(
43 | u, nrow=nrow, normalize=normalize, value_range=value_range)
44 | for u in tensor.unbind(2)
45 | ],
46 | dim=1).permute(1, 2, 3, 0)
47 | tensor = (tensor * 255).type(torch.uint8).cpu()
48 |
49 | # write video
50 | writer = imageio.get_writer(
51 | cache_file, fps=fps, codec='libx264', quality=8)
52 | for frame in tensor.numpy():
53 | writer.append_data(frame)
54 | writer.close()
55 | return cache_file
56 | except Exception as e:
57 | error = e
58 | continue
59 | else:
60 | print(f'cache_video failed, error: {error}', flush=True)
61 | return None
62 |
63 |
64 | def cache_image(tensor,
65 | save_file,
66 | nrow=8,
67 | normalize=True,
68 | value_range=(-1, 1),
69 | retry=5):
70 | # cache file
71 | suffix = osp.splitext(save_file)[1]
72 | if suffix.lower() not in [
73 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
74 | ]:
75 | suffix = '.png'
76 |
77 | # save to cache
78 | error = None
79 | for _ in range(retry):
80 | try:
81 | tensor = tensor.clamp(min(value_range), max(value_range))
82 | torchvision.utils.save_image(
83 | tensor,
84 | save_file,
85 | nrow=nrow,
86 | normalize=normalize,
87 | value_range=value_range)
88 | return save_file
89 | except Exception as e:
90 | error = e
91 | continue
92 |
93 |
94 | def str2bool(v):
95 | """
96 | Convert a string to a boolean.
97 |
98 | Supported true values: 'yes', 'true', 't', 'y', '1'
99 | Supported false values: 'no', 'false', 'f', 'n', '0'
100 |
101 | Args:
102 | v (str): String to convert.
103 |
104 | Returns:
105 | bool: Converted boolean value.
106 |
107 | Raises:
108 | argparse.ArgumentTypeError: If the value cannot be converted to boolean.
109 | """
110 | if isinstance(v, bool):
111 | return v
112 | v_lower = v.lower()
113 | if v_lower in ('yes', 'true', 't', 'y', '1'):
114 | return True
115 | elif v_lower in ('no', 'false', 'f', 'n', '0'):
116 | return False
117 | else:
118 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
119 |
--------------------------------------------------------------------------------