├── .gitignore ├── README.md ├── assets ├── draft-attention.png └── video │ ├── demo-bluedress-dense.gif │ ├── demo-bluedress-sp0.9-ours.gif │ ├── demo-bluedress-sp0.9-svg.gif │ ├── demo-building-dense.gif │ ├── demo-building-sp0.9-ours.gif │ ├── demo-building-sp0.9-svg.gif │ ├── demo-hunyuan_custom-768p-dense.gif │ ├── demo-hunyuan_custom-768p-sp0.9.gif │ ├── demo-pisa-dense.gif │ ├── demo-pisa-sp0.9-ours.gif │ └── demo-pisa-sp0.9-svg.gif ├── draft_attention.py ├── draft_attention_classifier_free_guidance.py ├── hunyuan ├── hyvideo │ ├── __init__.py │ ├── config.py │ ├── constants.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ └── pipeline_hunyuan_video.py │ │ └── schedulers │ │ │ ├── __init__.py │ │ │ └── scheduling_flow_match_discrete.py │ ├── inference.py │ ├── modules │ │ ├── __init__.py │ │ ├── activation_layers.py │ │ ├── attenion.py │ │ ├── embed_layers.py │ │ ├── fp8_optimization.py │ │ ├── mlp_layers.py │ │ ├── models.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── posemb_layers.py │ │ └── token_refiner.py │ ├── prompt_rewrite.py │ ├── text_encoder │ │ └── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── data_utils.py │ │ ├── file_utils.py │ │ ├── helpers.py │ │ └── preprocess_text_encoder_tokenizer_utils.py │ └── vae │ │ ├── __init__.py │ │ ├── autoencoder_kl_causal_3d.py │ │ ├── unet_causal_3d_blocks.py │ │ └── vae.py ├── run-single-sample_video-fp8.sh ├── run-single-sample_video.sh └── sample_video.py ├── hunyuan_custom ├── assets │ ├── images │ │ ├── method.png │ │ ├── poodle.png │ │ ├── seg_boy.png │ │ ├── seg_man_01.png │ │ ├── seg_man_02.png │ │ ├── seg_man_03.png │ │ ├── seg_poodle.png │ │ ├── seg_woman_01.png │ │ ├── seg_woman_02.png │ │ └── seg_woman_03.png │ ├── material │ │ ├── application.png │ │ ├── logo.png │ │ ├── method.png │ │ └── teaser.png │ ├── meta_files.list │ ├── meta_files │ │ └── poodle.json │ └── videos │ │ ├── seg_man_01.mp4 │ │ ├── seg_man_02.mp4 │ │ ├── seg_woman_01.mp4 │ │ └── seg_woman_03.mp4 ├── hymm_gradio │ ├── flask_ref2v.py │ ├── gradio_ref2v.py │ └── tool_for_end2end.py ├── hymm_sp │ ├── __init__.py │ ├── config.py │ ├── constants.py │ ├── data_kits │ │ ├── data_tools.py │ │ └── video_dataset.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ └── pipeline_hunyuan_video_custom.py │ │ └── schedulers │ │ │ ├── __init__.py │ │ │ └── scheduling_flow_match_discrete.py │ ├── helpers.py │ ├── inference.py │ ├── modules │ │ ├── __init__.py │ │ ├── activation_layers.py │ │ ├── attn_layers.py │ │ ├── draft_attention_classifier_free_guidance.py │ │ ├── embed_layers.py │ │ ├── fp8_optimization.py │ │ ├── mlp_layers.py │ │ ├── models.py │ │ ├── modulate_layers.py │ │ ├── norm_layers.py │ │ ├── parallel_states.py │ │ ├── posemb_layers.py │ │ └── token_refiner.py │ ├── sample_batch.py │ ├── sample_gpu_poor.py │ ├── sample_inference.py │ ├── text_encoder │ │ └── __init__.py │ └── vae │ │ ├── __init__.py │ │ ├── autoencoder_kl_causal_3d.py │ │ ├── unet_causal_3d_blocks.py │ │ └── vae.py ├── models │ └── README.md └── run-single-video-8xA100.sh └── wan ├── generate.py ├── run-single-inference.sh └── wan ├── __init__.py ├── configs ├── __init__.py ├── shared_config.py ├── wan_i2v_14B.py ├── wan_t2v_14B.py └── wan_t2v_1_3B.py ├── distributed ├── __init__.py ├── fsdp.py └── xdit_context_parallel.py ├── first_last_frame2video.py ├── image2video.py ├── modules ├── __init__.py ├── attention.py ├── clip.py ├── model.py ├── t5.py ├── tokenizers.py ├── vae.py └── xlm_roberta.py ├── text2video.py └── utils ├── __init__.py ├── fm_solvers.py ├── fm_solvers_unipc.py ├── prompt_extend.py ├── qwen_vl_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### PythonVanilla template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | z-helloworld/ 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .nox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | *.py,cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | cover/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # pyenv 56 | # For a library or package, you might want to ignore these files since the code is 57 | # intended to run in multiple environments; otherwise, check them in: 58 | # .python-version 59 | 60 | # pipenv 61 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 62 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 63 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 64 | # install all needed dependencies. 65 | #Pipfile.lock 66 | 67 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 68 | __pypackages__/ 69 | 70 | 71 | ### CircuitPython template 72 | .Trashes 73 | .metadata_never_index 74 | .fseventsd/ 75 | boot_out.txt 76 | 77 | ### Python template 78 | # Byte-compiled / optimized / DLL files 79 | __pycache__/ 80 | *.py[cod] 81 | *$py.class 82 | 83 | # C extensions 84 | *.so 85 | 86 | # Distribution / packaging 87 | .Python 88 | build/ 89 | develop-eggs/ 90 | dist/ 91 | downloads/ 92 | eggs/ 93 | .eggs/ 94 | lib/ 95 | lib64/ 96 | parts/ 97 | sdist/ 98 | var/ 99 | wheels/ 100 | share/python-wheels/ 101 | *.egg-info/ 102 | .installed.cfg 103 | *.egg 104 | MANIFEST 105 | 106 | # PyInstaller 107 | # Usually these files are written by a python script from a template 108 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 109 | *.manifest 110 | *.spec 111 | 112 | # Installer logs 113 | pip-log.txt 114 | pip-delete-this-directory.txt 115 | 116 | # Unit test / coverage reports 117 | htmlcov/ 118 | .tox/ 119 | .nox/ 120 | .coverage 121 | .coverage.* 122 | .cache 123 | nosetests.xml 124 | coverage.xml 125 | *.cover 126 | *.py,cover 127 | .hypothesis/ 128 | .pytest_cache/ 129 | cover/ 130 | 131 | # Translations 132 | *.mo 133 | *.pot 134 | 135 | # Django stuff: 136 | *.log 137 | local_settings.py 138 | db.sqlite3 139 | db.sqlite3-journal 140 | 141 | # Flask stuff: 142 | instance/ 143 | .webassets-cache 144 | 145 | # Scrapy stuff: 146 | .scrapy 147 | 148 | # Sphinx documentation 149 | docs/_build/ 150 | 151 | # PyBuilder 152 | .pybuilder/ 153 | target/ 154 | 155 | # Jupyter Notebook 156 | .ipynb_checkpoints 157 | 158 | # IPython 159 | profile_default/ 160 | ipython_config.py 161 | 162 | # pyenv 163 | # For a library or package, you might want to ignore these files since the code is 164 | # intended to run in multiple environments; otherwise, check them in: 165 | # .python-version 166 | 167 | # pipenv 168 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 169 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 170 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 171 | # install all needed dependencies. 172 | #Pipfile.lock 173 | 174 | # poetry 175 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 176 | # This is especially recommended for binary packages to ensure reproducibility, and is more 177 | # commonly ignored for libraries. 178 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 179 | #poetry.lock 180 | 181 | # pdm 182 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 183 | #pdm.lock 184 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 185 | # in version control. 186 | # https://pdm.fming.dev/#use-with-ide 187 | .pdm.toml 188 | 189 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 190 | __pypackages__/ 191 | 192 | # Celery stuff 193 | celerybeat-schedule 194 | celerybeat.pid 195 | 196 | # SageMath parsed files 197 | *.sage.py 198 | 199 | # Environments 200 | .env 201 | .venv 202 | env/ 203 | venv/ 204 | ENV/ 205 | env.bak/ 206 | venv.bak/ 207 | 208 | # Spyder project settings 209 | .spyderproject 210 | .spyproject 211 | 212 | # Rope project settings 213 | .ropeproject 214 | 215 | # mkdocs documentation 216 | /site 217 | 218 | # mypy 219 | .mypy_cache/ 220 | .dmypy.json 221 | dmypy.json 222 | 223 | # Pyre type checker 224 | .pyre/ 225 | 226 | # pytype static type analyzer 227 | .pytype/ 228 | 229 | # Cython debug symbols 230 | cython_debug/ 231 | 232 | # PyCharm 233 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 234 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 235 | # and can be added to the global gitignore or merged into this file. For a more nuclear 236 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 237 | #.idea/ 238 | 239 | /z-helloworld/ 240 | /z-helloworld/ 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

5 | 6 | # Draft Attention 7 | 8 | This repository provides an overview of all resources for the paper 9 | ["DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance"](https://arxiv.org/abs/2505.14708). 10 | 11 | 12 | Draft Attention is a plug-and-play acceleration method for video diffusion transformers. 13 | 14 | Draft Attention reshapes long queries and keys into frame-wise feature maps and applying 2D average pooling to downsample them. 15 | 16 | Draft Attention provides the reference for the sparse attention in full length. 17 | 18 | Draft Attention introduces minimal overhead by compressing the number of tokens 128x or larger. 19 | 20 | 21 | ## 🔥 News 22 | - [2025/05] We support [HunyuanCustom](https://github.com/Tencent/HunyuanCustom) with classifier free guidance. 23 | 24 | 25 | 26 | ## 🎥 Demo 27 | 28 | ### Hunyuan 29 | 30 | 31 | 35 | 39 | 43 | 44 |
32 |
33 | Dense Attention 34 |
36 |
37 | Sparse Video Generation (SVG) 38 |
40 |
41 | Draft Attention (Ours) 42 |
45 |

46 | Prompt: 47 | "The banks of the Thames, as the camera moves vertically from low to high."
48 |

49 | 50 |
51 | 52 | 53 | 54 | 58 | 62 | 66 | 67 |
55 |
56 | Dense Attention 57 |
59 |
60 | Sparse Video Generation (SVG) 61 |
63 |
64 | Draft Attention (Ours) 65 |
68 |

69 | Prompt: 70 | "On the green grass, the white-walled Leaning Tower of Pisa stands tall. The camera moves vertically from top to bottom during filming."
71 |

72 | 73 |
74 | 75 | 76 | 77 | 81 | 85 | 89 | 90 |
78 |
79 | Dense Attention 80 |
82 |
83 | Sparse Video Generation (SVG) 84 |
86 |
87 | Draft Attention (Ours) 88 |
91 |

92 | Prompt: 93 | "A blue long dress fell from the balcony clothes rack and dropped into the water on the ground."
94 |

95 | 96 | Prompts are all from the Penguin Video Benchmark. 97 | 98 | Videos are generated with sparsity 90%, seed 42, using Hunyuan model in 768p on A100 GPU. 99 | 100 | ### HunyuanCustom 101 | 102 | 103 | 104 | 108 | 112 | 116 | 117 |
105 |
106 | Input Image 107 |
109 |
110 | Dense Attention 111 |
113 |
114 | Draft Attention (Ours) 115 |
118 |

119 | Prompt: 120 | "Realistic, High-quality. A woman is drinking coffee at a café."
121 |

122 | 123 | Videos are generated with seed 42 in 768p resolution on 8xA100 GPUs, with either dense attention or 90% sparse attention. 124 | 125 | 126 | 127 | ## 🚀 Quick Start 128 | 129 | ### Model Preparation 130 | Please follow the instruction of environment setup and download the checkpoint from [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [Wan2.1](https://github.com/Wan-Video/Wan2.1), and [HunyuanCustom](https://github.com/Tencent/HunyuanCustom). 131 | 132 | ### Sparse Attention 133 | We mainly adopt the [block sparse attention](https://github.com/mit-han-lab/Block-Sparse-Attention) for draft attention. 134 | 135 | ### Video Generation 136 | Simply run video generation with scripts in `hunyuan/`, `wan/` or `hunyuan_custom/`. 137 | 138 | Evaluation results in the paper are mainly achieved with [VBench](https://github.com/Vchitect/VBench) on [Penguin Video Benchmark](https://github.com/Tencent/HunyuanVideo/blob/main/assets/PenguinVideoBenchmark.csv) using HunyuanVideo and Wan2.1. 139 | 140 | ### Use for Your Own 141 | You can simply use the draft attention similar as the flash attention through the `Draft_Attention` defined in `draft_attention.py` or `draft_attention_classifier_free_guidance.py`. 142 | 143 | Here is the example for hunyuan model: 144 | ```python3 145 | from draft_attention import Draft_Attention 146 | 147 | draft_attention = Draft_Attention( 148 | pool_h=8, 149 | pool_w=16, 150 | latent_h=48, 151 | latent_w=80, 152 | visual_len=126_720, 153 | text_len=256, 154 | sparsity_ratio=0.9, 155 | ) 156 | 157 | x = draft_attention( 158 | q, 159 | k, 160 | v, 161 | attn_mask=attn_mask, 162 | causal=causal, 163 | drop_rate=drop_rate, 164 | cu_seqlens_q=cu_seqlens_q, 165 | cu_seqlens_kv=cu_seqlens_kv, 166 | max_seqlen_q=max_seqlen_q, 167 | max_seqlen_kv=max_seqlen_kv, 168 | batch_size=batch_size, 169 | ) 170 | ``` 171 | 172 | ## ✏️ TODO 173 | - [ ] Support any-resolution video generation with padding. 174 | - [ ] Support reordering of further block sparse grouping for faster hardware execution. 175 | 176 | ## 📑 Acknowledgement 177 | This work is mainly contributed by [Xuan](https://shawnricecake.github.io) and [Chenxia](https://cxhan.com/). 178 | 179 | 180 | ## 🔗 BibTeX 181 | If you find Draft Attention is interesting, please cite through BibTeX: 182 | ```bibtex 183 | @article{shen2025draft, 184 | title={DraftAttention: Fast Video Diffusion via Low-Resolution Attention Guidance}, 185 | author={Shen, Xuan and Han, Chenxia and Zhou, Yufa and Xie, Yanyue and Gong, Yifan and Wang, Quanyi and Wang, Yiwei and Wang, Yanzhi and Zhao, Pu and Gu, Jiuxiang}, 186 | journal={arXiv preprint arXiv:2505.14708}, 187 | year={2025} 188 | } 189 | ``` 190 | -------------------------------------------------------------------------------- /assets/draft-attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/draft-attention.png -------------------------------------------------------------------------------- /assets/video/demo-bluedress-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-dense.gif -------------------------------------------------------------------------------- /assets/video/demo-bluedress-sp0.9-ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-sp0.9-ours.gif -------------------------------------------------------------------------------- /assets/video/demo-bluedress-sp0.9-svg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-bluedress-sp0.9-svg.gif -------------------------------------------------------------------------------- /assets/video/demo-building-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-dense.gif -------------------------------------------------------------------------------- /assets/video/demo-building-sp0.9-ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-sp0.9-ours.gif -------------------------------------------------------------------------------- /assets/video/demo-building-sp0.9-svg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-building-sp0.9-svg.gif -------------------------------------------------------------------------------- /assets/video/demo-hunyuan_custom-768p-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-hunyuan_custom-768p-dense.gif -------------------------------------------------------------------------------- /assets/video/demo-hunyuan_custom-768p-sp0.9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-hunyuan_custom-768p-sp0.9.gif -------------------------------------------------------------------------------- /assets/video/demo-pisa-dense.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-dense.gif -------------------------------------------------------------------------------- /assets/video/demo-pisa-sp0.9-ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-sp0.9-ours.gif -------------------------------------------------------------------------------- /assets/video/demo-pisa-sp0.9-svg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/assets/video/demo-pisa-sp0.9-svg.gif -------------------------------------------------------------------------------- /hunyuan/hyvideo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan/hyvideo/__init__.py -------------------------------------------------------------------------------- /hunyuan/hyvideo/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | __all__ = [ 5 | "C_SCALE", 6 | "PROMPT_TEMPLATE", 7 | "MODEL_BASE", 8 | "PRECISIONS", 9 | "NORMALIZATION_TYPE", 10 | "ACTIVATION_TYPE", 11 | "VAE_PATH", 12 | "TEXT_ENCODER_PATH", 13 | "TOKENIZER_PATH", 14 | "TEXT_PROJECTION", 15 | "DATA_TYPE", 16 | "NEGATIVE_PROMPT", 17 | ] 18 | 19 | PRECISION_TO_TYPE = { 20 | 'fp32': torch.float32, 21 | 'fp16': torch.float16, 22 | 'bf16': torch.bfloat16, 23 | } 24 | 25 | # =================== Constant Values ===================== 26 | # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid 27 | # overflow error when tensorboard logging values. 28 | C_SCALE = 1_000_000_000_000_000 29 | 30 | # When using decoder-only models, we must provide a prompt template to instruct the text encoder 31 | # on how to generate the text. 32 | # -------------------------------------------------------------------- 33 | PROMPT_TEMPLATE_ENCODE = ( 34 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " 35 | "quantity, text, spatial relationships of the objects and background:<|eot_id|>" 36 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 37 | ) 38 | PROMPT_TEMPLATE_ENCODE_VIDEO = ( 39 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 40 | "1. The main content and theme of the video." 41 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 42 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 43 | "4. background environment, light, style and atmosphere." 44 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 45 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 46 | ) 47 | 48 | NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" 49 | 50 | PROMPT_TEMPLATE = { 51 | "dit-llm-encode": { 52 | "template": PROMPT_TEMPLATE_ENCODE, 53 | "crop_start": 36, 54 | }, 55 | "dit-llm-encode-video": { 56 | "template": PROMPT_TEMPLATE_ENCODE_VIDEO, 57 | "crop_start": 95, 58 | }, 59 | } 60 | 61 | # ======================= Model ====================== 62 | PRECISIONS = {"fp32", "fp16", "bf16"} 63 | NORMALIZATION_TYPE = {"layer", "rms"} 64 | ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"} 65 | 66 | # =================== Model Path ===================== 67 | MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts") 68 | 69 | # =================== Data ======================= 70 | DATA_TYPE = {"image", "video", "image_video"} 71 | 72 | # 3D VAE 73 | VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"} 74 | 75 | # Text Encoder 76 | TEXT_ENCODER_PATH = { 77 | "clipL": f"{MODEL_BASE}/text_encoder_2", 78 | "llm": f"{MODEL_BASE}/text_encoder", 79 | } 80 | 81 | # Tokenizer 82 | TOKENIZER_PATH = { 83 | "clipL": f"{MODEL_BASE}/text_encoder_2", 84 | "llm": f"{MODEL_BASE}/text_encoder", 85 | } 86 | 87 | TEXT_PROJECTION = { 88 | "linear", # Default, an nn.Linear() layer 89 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT 90 | } 91 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import HunyuanVideoPipeline 2 | from .schedulers import FlowMatchDiscreteScheduler 3 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/diffusion/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_hunyuan_video import HunyuanVideoPipeline 2 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/diffusion/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler 2 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Modified from diffusers==0.29.2 17 | # 18 | # ============================================================================== 19 | 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import numpy as np 24 | import torch 25 | 26 | from diffusers.configuration_utils import ConfigMixin, register_to_config 27 | from diffusers.utils import BaseOutput, logging 28 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 29 | 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | @dataclass 35 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): 36 | """ 37 | Output class for the scheduler's `step` function output. 38 | 39 | Args: 40 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 41 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 42 | denoising loop. 43 | """ 44 | 45 | prev_sample: torch.FloatTensor 46 | 47 | 48 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): 49 | """ 50 | Euler scheduler. 51 | 52 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 53 | methods the library implements for all schedulers such as loading and saving. 54 | 55 | Args: 56 | num_train_timesteps (`int`, defaults to 1000): 57 | The number of diffusion steps to train the model. 58 | timestep_spacing (`str`, defaults to `"linspace"`): 59 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 60 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 61 | shift (`float`, defaults to 1.0): 62 | The shift value for the timestep schedule. 63 | reverse (`bool`, defaults to `True`): 64 | Whether to reverse the timestep schedule. 65 | """ 66 | 67 | _compatibles = [] 68 | order = 1 69 | 70 | @register_to_config 71 | def __init__( 72 | self, 73 | num_train_timesteps: int = 1000, 74 | shift: float = 1.0, 75 | reverse: bool = True, 76 | solver: str = "euler", 77 | n_tokens: Optional[int] = None, 78 | ): 79 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) 80 | 81 | if not reverse: 82 | sigmas = sigmas.flip(0) 83 | 84 | self.sigmas = sigmas 85 | # the value fed to model 86 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) 87 | 88 | self._step_index = None 89 | self._begin_index = None 90 | 91 | self.supported_solver = ["euler"] 92 | if solver not in self.supported_solver: 93 | raise ValueError( 94 | f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" 95 | ) 96 | 97 | @property 98 | def step_index(self): 99 | """ 100 | The index counter for current timestep. It will increase 1 after each scheduler step. 101 | """ 102 | return self._step_index 103 | 104 | @property 105 | def begin_index(self): 106 | """ 107 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 108 | """ 109 | return self._begin_index 110 | 111 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 112 | def set_begin_index(self, begin_index: int = 0): 113 | """ 114 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 115 | 116 | Args: 117 | begin_index (`int`): 118 | The begin index for the scheduler. 119 | """ 120 | self._begin_index = begin_index 121 | 122 | def _sigma_to_t(self, sigma): 123 | return sigma * self.config.num_train_timesteps 124 | 125 | def set_timesteps( 126 | self, 127 | num_inference_steps: int, 128 | device: Union[str, torch.device] = None, 129 | n_tokens: int = None, 130 | ): 131 | """ 132 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 133 | 134 | Args: 135 | num_inference_steps (`int`): 136 | The number of diffusion steps used when generating samples with a pre-trained model. 137 | device (`str` or `torch.device`, *optional*): 138 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 139 | n_tokens (`int`, *optional*): 140 | Number of tokens in the input sequence. 141 | """ 142 | self.num_inference_steps = num_inference_steps 143 | 144 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) 145 | sigmas = self.sd3_time_shift(sigmas) 146 | 147 | if not self.config.reverse: 148 | sigmas = 1 - sigmas 149 | 150 | self.sigmas = sigmas 151 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( 152 | dtype=torch.float32, device=device 153 | ) 154 | 155 | # Reset step index 156 | self._step_index = None 157 | 158 | def index_for_timestep(self, timestep, schedule_timesteps=None): 159 | if schedule_timesteps is None: 160 | schedule_timesteps = self.timesteps 161 | 162 | indices = (schedule_timesteps == timestep).nonzero() 163 | 164 | # The sigma index that is taken for the **very** first `step` 165 | # is always the second index (or the last index if there is only 1) 166 | # This way we can ensure we don't accidentally skip a sigma in 167 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 168 | pos = 1 if len(indices) > 1 else 0 169 | 170 | return indices[pos].item() 171 | 172 | def _init_step_index(self, timestep): 173 | if self.begin_index is None: 174 | if isinstance(timestep, torch.Tensor): 175 | timestep = timestep.to(self.timesteps.device) 176 | self._step_index = self.index_for_timestep(timestep) 177 | else: 178 | self._step_index = self._begin_index 179 | 180 | def scale_model_input( 181 | self, sample: torch.Tensor, timestep: Optional[int] = None 182 | ) -> torch.Tensor: 183 | return sample 184 | 185 | def sd3_time_shift(self, t: torch.Tensor): 186 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) 187 | 188 | def step( 189 | self, 190 | model_output: torch.FloatTensor, 191 | timestep: Union[float, torch.FloatTensor], 192 | sample: torch.FloatTensor, 193 | return_dict: bool = True, 194 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: 195 | """ 196 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 197 | process from the learned model outputs (most often the predicted noise). 198 | 199 | Args: 200 | model_output (`torch.FloatTensor`): 201 | The direct output from learned diffusion model. 202 | timestep (`float`): 203 | The current discrete timestep in the diffusion chain. 204 | sample (`torch.FloatTensor`): 205 | A current instance of a sample created by the diffusion process. 206 | generator (`torch.Generator`, *optional*): 207 | A random number generator. 208 | n_tokens (`int`, *optional*): 209 | Number of tokens in the input sequence. 210 | return_dict (`bool`): 211 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 212 | tuple. 213 | 214 | Returns: 215 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 216 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 217 | returned, otherwise a tuple is returned where the first element is the sample tensor. 218 | """ 219 | 220 | if ( 221 | isinstance(timestep, int) 222 | or isinstance(timestep, torch.IntTensor) 223 | or isinstance(timestep, torch.LongTensor) 224 | ): 225 | raise ValueError( 226 | ( 227 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 228 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 229 | " one of the `scheduler.timesteps` as a timestep." 230 | ), 231 | ) 232 | 233 | if self.step_index is None: 234 | self._init_step_index(timestep) 235 | 236 | # Upcast to avoid precision issues when computing prev_sample 237 | sample = sample.to(torch.float32) 238 | 239 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 240 | 241 | if self.config.solver == "euler": 242 | prev_sample = sample + model_output.to(torch.float32) * dt 243 | else: 244 | raise ValueError( 245 | f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" 246 | ) 247 | 248 | # upon completion increase step index by one 249 | self._step_index += 1 250 | 251 | if not return_dict: 252 | return (prev_sample,) 253 | 254 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) 255 | 256 | def __len__(self): 257 | return self.config.num_train_timesteps 258 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG 2 | 3 | 4 | def load_model(args, in_channels, out_channels, factor_kwargs): 5 | """load hunyuan video model 6 | 7 | Args: 8 | args (dict): model args 9 | in_channels (int): input channels number 10 | out_channels (int): output channels number 11 | factor_kwargs (dict): factor kwargs 12 | 13 | Returns: 14 | model (nn.Module): The hunyuan video model 15 | """ 16 | if args.model in HUNYUAN_VIDEO_CONFIG.keys(): 17 | model = HYVideoDiffusionTransformer( 18 | args, 19 | in_channels=in_channels, 20 | out_channels=out_channels, 21 | **HUNYUAN_VIDEO_CONFIG[args.model], 22 | **factor_kwargs, 23 | ) 24 | return model 25 | else: 26 | raise NotImplementedError() 27 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_activation_layer(act_type): 5 | """get activation layer 6 | 7 | Args: 8 | act_type (str): the activation type 9 | 10 | Returns: 11 | torch.nn.functional: the activation layer 12 | """ 13 | if act_type == "gelu": 14 | return lambda: nn.GELU() 15 | elif act_type == "gelu_tanh": 16 | # Approximate `tanh` requires torch >= 1.13 17 | return lambda: nn.GELU(approximate="tanh") 18 | elif act_type == "relu": 19 | return nn.ReLU 20 | elif act_type == "silu": 21 | return nn.SiLU 22 | else: 23 | raise ValueError(f"Unknown activation type: {act_type}") 24 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange, repeat 5 | 6 | from ..utils.helpers import to_2tuple 7 | 8 | 9 | class PatchEmbed(nn.Module): 10 | """2D Image to Patch Embedding 11 | 12 | Image to Patch Embedding using Conv2d 13 | 14 | A convolution based approach to patchifying a 2D image w/ embedding projection. 15 | 16 | Based on the impl in https://github.com/google-research/vision_transformer 17 | 18 | Hacked together by / Copyright 2020 Ross Wightman 19 | 20 | Remove the _assert function in forward function to be compatible with multi-resolution images. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | patch_size=16, 26 | in_chans=3, 27 | embed_dim=768, 28 | norm_layer=None, 29 | flatten=True, 30 | bias=True, 31 | dtype=None, 32 | device=None, 33 | ): 34 | factory_kwargs = {"dtype": dtype, "device": device} 35 | super().__init__() 36 | patch_size = to_2tuple(patch_size) 37 | self.patch_size = patch_size 38 | self.flatten = flatten 39 | 40 | self.proj = nn.Conv3d( 41 | in_chans, 42 | embed_dim, 43 | kernel_size=patch_size, 44 | stride=patch_size, 45 | bias=bias, 46 | **factory_kwargs 47 | ) 48 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 49 | if bias: 50 | nn.init.zeros_(self.proj.bias) 51 | 52 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 53 | 54 | def forward(self, x): 55 | x = self.proj(x) 56 | if self.flatten: 57 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 58 | x = self.norm(x) 59 | return x 60 | 61 | 62 | class TextProjection(nn.Module): 63 | """ 64 | Projects text embeddings. Also handles dropout for classifier-free guidance. 65 | 66 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 67 | """ 68 | 69 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 70 | factory_kwargs = {"dtype": dtype, "device": device} 71 | super().__init__() 72 | self.linear_1 = nn.Linear( 73 | in_features=in_channels, 74 | out_features=hidden_size, 75 | bias=True, 76 | **factory_kwargs 77 | ) 78 | self.act_1 = act_layer() 79 | self.linear_2 = nn.Linear( 80 | in_features=hidden_size, 81 | out_features=hidden_size, 82 | bias=True, 83 | **factory_kwargs 84 | ) 85 | 86 | def forward(self, caption): 87 | hidden_states = self.linear_1(caption) 88 | hidden_states = self.act_1(hidden_states) 89 | hidden_states = self.linear_2(hidden_states) 90 | return hidden_states 91 | 92 | 93 | def timestep_embedding(t, dim, max_period=10000): 94 | """ 95 | Create sinusoidal timestep embeddings. 96 | 97 | Args: 98 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 99 | dim (int): the dimension of the output. 100 | max_period (int): controls the minimum frequency of the embeddings. 101 | 102 | Returns: 103 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 104 | 105 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 106 | """ 107 | half = dim // 2 108 | freqs = torch.exp( 109 | -math.log(max_period) 110 | * torch.arange(start=0, end=half, dtype=torch.float32) 111 | / half 112 | ).to(device=t.device) 113 | args = t[:, None].float() * freqs[None] 114 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 115 | if dim % 2: 116 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 117 | return embedding 118 | 119 | 120 | class TimestepEmbedder(nn.Module): 121 | """ 122 | Embeds scalar timesteps into vector representations. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | hidden_size, 128 | act_layer, 129 | frequency_embedding_size=256, 130 | max_period=10000, 131 | out_size=None, 132 | dtype=None, 133 | device=None, 134 | ): 135 | factory_kwargs = {"dtype": dtype, "device": device} 136 | super().__init__() 137 | self.frequency_embedding_size = frequency_embedding_size 138 | self.max_period = max_period 139 | if out_size is None: 140 | out_size = hidden_size 141 | 142 | self.mlp = nn.Sequential( 143 | nn.Linear( 144 | frequency_embedding_size, hidden_size, bias=True, **factory_kwargs 145 | ), 146 | act_layer(), 147 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 148 | ) 149 | nn.init.normal_(self.mlp[0].weight, std=0.02) 150 | nn.init.normal_(self.mlp[2].weight, std=0.02) 151 | 152 | def forward(self, t): 153 | t_freq = timestep_embedding( 154 | t, self.frequency_embedding_size, self.max_period 155 | ).type(self.mlp[0].weight.dtype) 156 | t_emb = self.mlp(t_freq) 157 | return t_emb 158 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/fp8_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): 8 | _bits = torch.tensor(bits) 9 | _mantissa_bit = torch.tensor(mantissa_bit) 10 | _sign_bits = torch.tensor(sign_bits) 11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) 12 | E = _bits - _sign_bits - M 13 | bias = 2 ** (E - 1) - 1 14 | mantissa = 1 15 | for i in range(mantissa_bit - 1): 16 | mantissa += 1 / (2 ** (i+1)) 17 | maxval = mantissa * 2 ** (2**E - 1 - bias) 18 | return maxval 19 | 20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): 21 | """ 22 | Default is E4M3. 23 | """ 24 | bits = torch.tensor(bits) 25 | mantissa_bit = torch.tensor(mantissa_bit) 26 | sign_bits = torch.tensor(sign_bits) 27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) 28 | E = bits - sign_bits - M 29 | bias = 2 ** (E - 1) - 1 30 | mantissa = 1 31 | for i in range(mantissa_bit - 1): 32 | mantissa += 1 / (2 ** (i+1)) 33 | maxval = mantissa * 2 ** (2**E - 1 - bias) 34 | minval = - maxval 35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) 36 | input_clamp = torch.min(torch.max(x, minval), maxval) 37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) 38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) 39 | # dequant 40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales 41 | return qdq_out, log_scales 42 | 43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): 44 | for i in range(len(x.shape) - 1): 45 | scale = scale.unsqueeze(-1) 46 | new_x = x / scale 47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) 48 | return quant_dequant_x, scale, log_scales 49 | 50 | def fp8_activation_dequant(qdq_out, scale, dtype): 51 | qdq_out = qdq_out.type(dtype) 52 | quant_dequant_x = qdq_out * scale.to(dtype) 53 | return quant_dequant_x 54 | 55 | def fp8_linear_forward(cls, original_dtype, input): 56 | weight_dtype = cls.weight.dtype 57 | ##### 58 | if cls.weight.dtype != torch.float8_e4m3fn: 59 | maxval = get_fp_maxval() 60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval 61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) 62 | linear_weight = linear_weight.to(torch.float8_e4m3fn) 63 | weight_dtype = linear_weight.dtype 64 | else: 65 | scale = cls.fp8_scale.to(cls.weight.device) 66 | linear_weight = cls.weight 67 | ##### 68 | 69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: 70 | if True or len(input.shape) == 3: 71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype) 72 | if cls.bias != None: 73 | output = F.linear(input, cls_dequant, cls.bias) 74 | else: 75 | output = F.linear(input, cls_dequant) 76 | return output 77 | else: 78 | return cls.original_forward(input.to(original_dtype)) 79 | else: 80 | return cls.original_forward(input) 81 | 82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}): 83 | setattr(module, "fp8_matmul_enabled", True) 84 | 85 | # loading fp8 mapping file 86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt') 87 | if os.path.exists(fp8_map_path): 88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage) 89 | else: 90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") 91 | 92 | fp8_layers = [] 93 | for key, layer in module.named_modules(): 94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): 95 | fp8_layers.append(key) 96 | original_forward = layer.forward 97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) 98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype)) 99 | setattr(layer, "original_forward", original_forward) 100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) 101 | 102 | 103 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .modulate_layers import modulate 10 | from ..utils.helpers import to_2tuple 11 | 12 | 13 | class MLP(nn.Module): 14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 15 | 16 | def __init__( 17 | self, 18 | in_channels, 19 | hidden_channels=None, 20 | out_features=None, 21 | act_layer=nn.GELU, 22 | norm_layer=None, 23 | bias=True, 24 | drop=0.0, 25 | use_conv=False, 26 | device=None, 27 | dtype=None, 28 | ): 29 | factory_kwargs = {"device": device, "dtype": dtype} 30 | super().__init__() 31 | out_features = out_features or in_channels 32 | hidden_channels = hidden_channels or in_channels 33 | bias = to_2tuple(bias) 34 | drop_probs = to_2tuple(drop) 35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 36 | 37 | self.fc1 = linear_layer( 38 | in_channels, hidden_channels, bias=bias[0], **factory_kwargs 39 | ) 40 | self.act = act_layer() 41 | self.drop1 = nn.Dropout(drop_probs[0]) 42 | self.norm = ( 43 | norm_layer(hidden_channels, **factory_kwargs) 44 | if norm_layer is not None 45 | else nn.Identity() 46 | ) 47 | self.fc2 = linear_layer( 48 | hidden_channels, out_features, bias=bias[1], **factory_kwargs 49 | ) 50 | self.drop2 = nn.Dropout(drop_probs[1]) 51 | 52 | def forward(self, x): 53 | x = self.fc1(x) 54 | x = self.act(x) 55 | x = self.drop1(x) 56 | x = self.norm(x) 57 | x = self.fc2(x) 58 | x = self.drop2(x) 59 | return x 60 | 61 | 62 | # 63 | class MLPEmbedder(nn.Module): 64 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" 65 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 66 | factory_kwargs = {"device": device, "dtype": dtype} 67 | super().__init__() 68 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 69 | self.silu = nn.SiLU() 70 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | return self.out_layer(self.silu(self.in_layer(x))) 74 | 75 | 76 | class FinalLayer(nn.Module): 77 | """The final layer of DiT.""" 78 | 79 | def __init__( 80 | self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None 81 | ): 82 | factory_kwargs = {"device": device, "dtype": dtype} 83 | super().__init__() 84 | 85 | # Just use LayerNorm for the final layer 86 | self.norm_final = nn.LayerNorm( 87 | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs 88 | ) 89 | if isinstance(patch_size, int): 90 | self.linear = nn.Linear( 91 | hidden_size, 92 | patch_size * patch_size * out_channels, 93 | bias=True, 94 | **factory_kwargs 95 | ) 96 | else: 97 | self.linear = nn.Linear( 98 | hidden_size, 99 | patch_size[0] * patch_size[1] * patch_size[2] * out_channels, 100 | bias=True, 101 | ) 102 | nn.init.zeros_(self.linear.weight) 103 | nn.init.zeros_(self.linear.bias) 104 | 105 | # Here we don't distinguish between the modulate types. Just use the simple one. 106 | self.adaLN_modulation = nn.Sequential( 107 | act_layer(), 108 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 109 | ) 110 | # Zero-initialize the modulation 111 | nn.init.zeros_(self.adaLN_modulation[1].weight) 112 | nn.init.zeros_(self.adaLN_modulation[1].bias) 113 | 114 | def forward(self, x, c): 115 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 116 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 117 | x = self.linear(x) 118 | return x 119 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ModulateDiT(nn.Module): 8 | """Modulation layer for DiT.""" 9 | def __init__( 10 | self, 11 | hidden_size: int, 12 | factor: int, 13 | act_layer: Callable, 14 | dtype=None, 15 | device=None, 16 | ): 17 | factory_kwargs = {"dtype": dtype, "device": device} 18 | super().__init__() 19 | self.act = act_layer() 20 | self.linear = nn.Linear( 21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs 22 | ) 23 | # Zero-initialize the modulation 24 | nn.init.zeros_(self.linear.weight) 25 | nn.init.zeros_(self.linear.bias) 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | return self.linear(self.act(x)) 29 | 30 | 31 | def modulate(x, shift=None, scale=None): 32 | """modulate by shift and scale 33 | 34 | Args: 35 | x (torch.Tensor): input tensor. 36 | shift (torch.Tensor, optional): shift tensor. Defaults to None. 37 | scale (torch.Tensor, optional): scale tensor. Defaults to None. 38 | 39 | Returns: 40 | torch.Tensor: the output tensor after modulate. 41 | """ 42 | if scale is None and shift is None: 43 | return x 44 | elif shift is None: 45 | return x * (1 + scale.unsqueeze(1)) 46 | elif scale is None: 47 | return x + shift.unsqueeze(1) 48 | else: 49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 50 | 51 | 52 | def apply_gate(x, gate=None, tanh=False): 53 | """AI is creating summary for apply_gate 54 | 55 | Args: 56 | x (torch.Tensor): input tensor. 57 | gate (torch.Tensor, optional): gate tensor. Defaults to None. 58 | tanh (bool, optional): whether to use tanh function. Defaults to False. 59 | 60 | Returns: 61 | torch.Tensor: the output tensor after apply gate. 62 | """ 63 | if gate is None: 64 | return x 65 | if tanh: 66 | return x * gate.unsqueeze(1).tanh() 67 | else: 68 | return x * gate.unsqueeze(1) 69 | 70 | 71 | def ckpt_wrapper(module): 72 | def ckpt_forward(*inputs): 73 | outputs = module(*inputs) 74 | return outputs 75 | 76 | return ckpt_forward 77 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__( 7 | self, 8 | dim: int, 9 | elementwise_affine=True, 10 | eps: float = 1e-6, 11 | device=None, 12 | dtype=None, 13 | ): 14 | """ 15 | Initialize the RMSNorm normalization layer. 16 | 17 | Args: 18 | dim (int): The dimension of the input tensor. 19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 20 | 21 | Attributes: 22 | eps (float): A small value added to the denominator for numerical stability. 23 | weight (nn.Parameter): Learnable scaling parameter. 24 | 25 | """ 26 | factory_kwargs = {"device": device, "dtype": dtype} 27 | super().__init__() 28 | self.eps = eps 29 | if elementwise_affine: 30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 31 | 32 | def _norm(self, x): 33 | """ 34 | Apply the RMSNorm normalization to the input tensor. 35 | 36 | Args: 37 | x (torch.Tensor): The input tensor. 38 | 39 | Returns: 40 | torch.Tensor: The normalized tensor. 41 | 42 | """ 43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 44 | 45 | def forward(self, x): 46 | """ 47 | Forward pass through the RMSNorm layer. 48 | 49 | Args: 50 | x (torch.Tensor): The input tensor. 51 | 52 | Returns: 53 | torch.Tensor: The output tensor after applying RMSNorm. 54 | 55 | """ 56 | output = self._norm(x.float()).type_as(x) 57 | if hasattr(self, "weight"): 58 | output = output * self.weight 59 | return output 60 | 61 | 62 | def get_norm_layer(norm_layer): 63 | """ 64 | Get the normalization layer. 65 | 66 | Args: 67 | norm_layer (str): The type of normalization layer. 68 | 69 | Returns: 70 | norm_layer (nn.Module): The normalization layer. 71 | """ 72 | if norm_layer == "layer": 73 | return nn.LayerNorm 74 | elif norm_layer == "rms": 75 | return RMSNorm 76 | else: 77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") 78 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/modules/token_refiner.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from einops import rearrange 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .activation_layers import get_activation_layer 8 | from .attenion import attention 9 | from .norm_layers import get_norm_layer 10 | from .embed_layers import TimestepEmbedder, TextProjection 11 | from .attenion import attention 12 | from .mlp_layers import MLP 13 | from .modulate_layers import modulate, apply_gate 14 | 15 | 16 | class IndividualTokenRefinerBlock(nn.Module): 17 | def __init__( 18 | self, 19 | hidden_size, 20 | heads_num, 21 | mlp_width_ratio: str = 4.0, 22 | mlp_drop_rate: float = 0.0, 23 | act_type: str = "silu", 24 | qk_norm: bool = False, 25 | qk_norm_type: str = "layer", 26 | qkv_bias: bool = True, 27 | dtype: Optional[torch.dtype] = None, 28 | device: Optional[torch.device] = None, 29 | ): 30 | factory_kwargs = {"device": device, "dtype": dtype} 31 | super().__init__() 32 | self.heads_num = heads_num 33 | head_dim = hidden_size // heads_num 34 | mlp_hidden_dim = int(hidden_size * mlp_width_ratio) 35 | 36 | self.norm1 = nn.LayerNorm( 37 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 38 | ) 39 | self.self_attn_qkv = nn.Linear( 40 | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs 41 | ) 42 | qk_norm_layer = get_norm_layer(qk_norm_type) 43 | self.self_attn_q_norm = ( 44 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 45 | if qk_norm 46 | else nn.Identity() 47 | ) 48 | self.self_attn_k_norm = ( 49 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 50 | if qk_norm 51 | else nn.Identity() 52 | ) 53 | self.self_attn_proj = nn.Linear( 54 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs 55 | ) 56 | 57 | self.norm2 = nn.LayerNorm( 58 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs 59 | ) 60 | act_layer = get_activation_layer(act_type) 61 | self.mlp = MLP( 62 | in_channels=hidden_size, 63 | hidden_channels=mlp_hidden_dim, 64 | act_layer=act_layer, 65 | drop=mlp_drop_rate, 66 | **factory_kwargs, 67 | ) 68 | 69 | self.adaLN_modulation = nn.Sequential( 70 | act_layer(), 71 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), 72 | ) 73 | # Zero-initialize the modulation 74 | nn.init.zeros_(self.adaLN_modulation[1].weight) 75 | nn.init.zeros_(self.adaLN_modulation[1].bias) 76 | 77 | def forward( 78 | self, 79 | x: torch.Tensor, 80 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 81 | attn_mask: torch.Tensor = None, 82 | ): 83 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 84 | 85 | norm_x = self.norm1(x) 86 | qkv = self.self_attn_qkv(norm_x) 87 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) 88 | # Apply QK-Norm if needed 89 | q = self.self_attn_q_norm(q).to(v) 90 | k = self.self_attn_k_norm(k).to(v) 91 | 92 | # Self-Attention 93 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) 94 | 95 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 96 | 97 | # FFN Layer 98 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) 99 | 100 | return x 101 | 102 | 103 | class IndividualTokenRefiner(nn.Module): 104 | def __init__( 105 | self, 106 | hidden_size, 107 | heads_num, 108 | depth, 109 | mlp_width_ratio: float = 4.0, 110 | mlp_drop_rate: float = 0.0, 111 | act_type: str = "silu", 112 | qk_norm: bool = False, 113 | qk_norm_type: str = "layer", 114 | qkv_bias: bool = True, 115 | dtype: Optional[torch.dtype] = None, 116 | device: Optional[torch.device] = None, 117 | ): 118 | factory_kwargs = {"device": device, "dtype": dtype} 119 | super().__init__() 120 | self.blocks = nn.ModuleList( 121 | [ 122 | IndividualTokenRefinerBlock( 123 | hidden_size=hidden_size, 124 | heads_num=heads_num, 125 | mlp_width_ratio=mlp_width_ratio, 126 | mlp_drop_rate=mlp_drop_rate, 127 | act_type=act_type, 128 | qk_norm=qk_norm, 129 | qk_norm_type=qk_norm_type, 130 | qkv_bias=qkv_bias, 131 | **factory_kwargs, 132 | ) 133 | for _ in range(depth) 134 | ] 135 | ) 136 | 137 | def forward( 138 | self, 139 | x: torch.Tensor, 140 | c: torch.LongTensor, 141 | mask: Optional[torch.Tensor] = None, 142 | ): 143 | self_attn_mask = None 144 | if mask is not None: 145 | batch_size = mask.shape[0] 146 | seq_len = mask.shape[1] 147 | mask = mask.to(x.device) 148 | # batch_size x 1 x seq_len x seq_len 149 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( 150 | 1, 1, seq_len, 1 151 | ) 152 | # batch_size x 1 x seq_len x seq_len 153 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 154 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num 155 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 156 | # avoids self-attention weight being NaN for padding tokens 157 | self_attn_mask[:, :, :, 0] = True 158 | 159 | for block in self.blocks: 160 | x = block(x, c, self_attn_mask) 161 | return x 162 | 163 | 164 | class SingleTokenRefiner(nn.Module): 165 | """ 166 | A single token refiner block for llm text embedding refine. 167 | """ 168 | def __init__( 169 | self, 170 | in_channels, 171 | hidden_size, 172 | heads_num, 173 | depth, 174 | mlp_width_ratio: float = 4.0, 175 | mlp_drop_rate: float = 0.0, 176 | act_type: str = "silu", 177 | qk_norm: bool = False, 178 | qk_norm_type: str = "layer", 179 | qkv_bias: bool = True, 180 | attn_mode: str = "torch", 181 | dtype: Optional[torch.dtype] = None, 182 | device: Optional[torch.device] = None, 183 | ): 184 | factory_kwargs = {"device": device, "dtype": dtype} 185 | super().__init__() 186 | self.attn_mode = attn_mode 187 | assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." 188 | 189 | self.input_embedder = nn.Linear( 190 | in_channels, hidden_size, bias=True, **factory_kwargs 191 | ) 192 | 193 | act_layer = get_activation_layer(act_type) 194 | # Build timestep embedding layer 195 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) 196 | # Build context embedding layer 197 | self.c_embedder = TextProjection( 198 | in_channels, hidden_size, act_layer, **factory_kwargs 199 | ) 200 | 201 | self.individual_token_refiner = IndividualTokenRefiner( 202 | hidden_size=hidden_size, 203 | heads_num=heads_num, 204 | depth=depth, 205 | mlp_width_ratio=mlp_width_ratio, 206 | mlp_drop_rate=mlp_drop_rate, 207 | act_type=act_type, 208 | qk_norm=qk_norm, 209 | qk_norm_type=qk_norm_type, 210 | qkv_bias=qkv_bias, 211 | **factory_kwargs, 212 | ) 213 | 214 | def forward( 215 | self, 216 | x: torch.Tensor, 217 | t: torch.LongTensor, 218 | mask: Optional[torch.LongTensor] = None, 219 | ): 220 | timestep_aware_representations = self.t_embedder(t) 221 | 222 | if mask is None: 223 | context_aware_representations = x.mean(dim=1) 224 | else: 225 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] 226 | context_aware_representations = (x * mask_float).sum( 227 | dim=1 228 | ) / mask_float.sum(dim=1) 229 | context_aware_representations = self.c_embedder(context_aware_representations) 230 | c = timestep_aware_representations + context_aware_representations 231 | 232 | x = self.input_embedder(x) 233 | 234 | x = self.individual_token_refiner(x, c, mask) 235 | 236 | return x 237 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/prompt_rewrite.py: -------------------------------------------------------------------------------- 1 | normal_mode_prompt = """Normal mode - Video Recaption Task: 2 | 3 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description. 4 | 5 | 0. Preserve ALL information, including style words and technical terms. 6 | 7 | 1. If the input is in Chinese, translate the entire description to English. 8 | 9 | 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences. 10 | 11 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations. 12 | 13 | 4. Output ALL must be in English. 14 | 15 | Given Input: 16 | input: "{input}" 17 | """ 18 | 19 | 20 | master_mode_prompt = """Master mode - Video Recaption Task: 21 | 22 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description. 23 | 24 | 0. Preserve ALL information, including style words and technical terms. 25 | 26 | 1. If the input is in Chinese, translate the entire description to English. 27 | 28 | 2. To generate high-quality visual scenes with aesthetic appeal, it is necessary to carefully depict each visual element to create a unique aesthetic. 29 | 30 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations. 31 | 32 | 4. Output ALL must be in English. 33 | 34 | Given Input: 35 | input: "{input}" 36 | """ 37 | 38 | def get_rewrite_prompt(ori_prompt, mode="Normal"): 39 | if mode == "Normal": 40 | prompt = normal_mode_prompt.format(input=ori_prompt) 41 | elif mode == "Master": 42 | prompt = master_mode_prompt.format(input=ori_prompt) 43 | else: 44 | raise Exception("Only supports Normal and Normal", mode) 45 | return prompt 46 | 47 | ori_prompt = "一只小狗在草地上奔跑。" 48 | normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal") 49 | master_prompt = get_rewrite_prompt(ori_prompt, mode="Master") 50 | 51 | # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt. 52 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan/hyvideo/utils/__init__.py -------------------------------------------------------------------------------- /hunyuan/hyvideo/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | def align_to(value, alignment): 6 | """align hight, width according to alignment 7 | 8 | Args: 9 | value (int): height or width 10 | alignment (int): target alignment factor 11 | 12 | Returns: 13 | int: the aligned value 14 | """ 15 | return int(math.ceil(value / alignment) * alignment) 16 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from einops import rearrange 4 | 5 | import torch 6 | import torchvision 7 | import numpy as np 8 | import imageio 9 | 10 | CODE_SUFFIXES = { 11 | ".py", # Python codes 12 | ".sh", # Shell scripts 13 | ".yaml", 14 | ".yml", # Configuration files 15 | } 16 | 17 | 18 | def safe_dir(path): 19 | """ 20 | Create a directory (or the parent directory of a file) if it does not exist. 21 | 22 | Args: 23 | path (str or Path): Path to the directory. 24 | 25 | Returns: 26 | path (Path): Path object of the directory. 27 | """ 28 | path = Path(path) 29 | path.mkdir(exist_ok=True, parents=True) 30 | return path 31 | 32 | 33 | def safe_file(path): 34 | """ 35 | Create the parent directory of a file if it does not exist. 36 | 37 | Args: 38 | path (str or Path): Path to the file. 39 | 40 | Returns: 41 | path (Path): Path object of the file. 42 | """ 43 | path = Path(path) 44 | path.parent.mkdir(exist_ok=True, parents=True) 45 | return path 46 | 47 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24): 48 | """save videos by video tensor 49 | copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61 50 | 51 | Args: 52 | videos (torch.Tensor): video tensor predicted by the model 53 | path (str): path to save video 54 | rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False. 55 | n_rows (int, optional): Defaults to 1. 56 | fps (int, optional): video save fps. Defaults to 8. 57 | """ 58 | videos = rearrange(videos, "b c t h w -> t b c h w") 59 | outputs = [] 60 | for x in videos: 61 | x = torchvision.utils.make_grid(x, nrow=n_rows) 62 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 63 | if rescale: 64 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 65 | x = torch.clamp(x, 0, 1) 66 | x = (x * 255).numpy().astype(np.uint8) 67 | outputs.append(x) 68 | 69 | os.makedirs(os.path.dirname(path), exist_ok=True) 70 | imageio.mimsave(path, outputs, fps=fps) 71 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | 3 | from itertools import repeat 4 | 5 | 6 | def _ntuple(n): 7 | def parse(x): 8 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 9 | x = tuple(x) 10 | if len(x) == 1: 11 | x = tuple(repeat(x[0], n)) 12 | return x 13 | return tuple(repeat(x, n)) 14 | return parse 15 | 16 | 17 | to_1tuple = _ntuple(1) 18 | to_2tuple = _ntuple(2) 19 | to_3tuple = _ntuple(3) 20 | to_4tuple = _ntuple(4) 21 | 22 | 23 | def as_tuple(x): 24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 25 | return tuple(x) 26 | if x is None or isinstance(x, (int, float, str)): 27 | return (x,) 28 | else: 29 | raise ValueError(f"Unknown type {type(x)}") 30 | 31 | 32 | def as_list_of_2tuple(x): 33 | x = as_tuple(x) 34 | if len(x) == 1: 35 | x = (x[0], x[0]) 36 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}." 37 | lst = [] 38 | for i in range(0, len(x), 2): 39 | lst.append((x[i], x[i + 1])) 40 | return lst 41 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from transformers import ( 4 | AutoProcessor, 5 | LlavaForConditionalGeneration, 6 | ) 7 | 8 | 9 | def preprocess_text_encoder_tokenizer(args): 10 | 11 | processor = AutoProcessor.from_pretrained(args.input_dir) 12 | model = LlavaForConditionalGeneration.from_pretrained( 13 | args.input_dir, 14 | torch_dtype=torch.float16, 15 | low_cpu_mem_usage=True, 16 | ).to(0) 17 | 18 | model.language_model.save_pretrained( 19 | f"{args.output_dir}" 20 | ) 21 | processor.tokenizer.save_pretrained( 22 | f"{args.output_dir}" 23 | ) 24 | 25 | if __name__ == "__main__": 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--input_dir", 30 | type=str, 31 | required=True, 32 | help="The path to the llava-llama-3-8b-v1_1-transformers.", 33 | ) 34 | parser.add_argument( 35 | "--output_dir", 36 | type=str, 37 | default="", 38 | help="The output path of the llava-llama-3-8b-text-encoder-tokenizer." 39 | "if '', the parent dir of output will be the same as input dir.", 40 | ) 41 | args = parser.parse_args() 42 | 43 | if len(args.output_dir) == 0: 44 | args.output_dir = "/".join(args.input_dir.split("/")[:-1]) 45 | 46 | preprocess_text_encoder_tokenizer(args) 47 | -------------------------------------------------------------------------------- /hunyuan/hyvideo/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D 6 | from ..constants import VAE_PATH, PRECISION_TO_TYPE 7 | 8 | def load_vae(vae_type: str="884-16c-hy", 9 | vae_precision: str=None, 10 | sample_size: tuple=None, 11 | vae_path: str=None, 12 | logger=None, 13 | device=None 14 | ): 15 | """the fucntion to load the 3D VAE model 16 | 17 | Args: 18 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy". 19 | vae_precision (str, optional): the precision to load vae. Defaults to None. 20 | sample_size (tuple, optional): the tiling size. Defaults to None. 21 | vae_path (str, optional): the path to vae. Defaults to None. 22 | logger (_type_, optional): logger. Defaults to None. 23 | device (_type_, optional): device to load vae. Defaults to None. 24 | """ 25 | if vae_path is None: 26 | vae_path = VAE_PATH[vae_type] 27 | 28 | if logger is not None: 29 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") 30 | config = AutoencoderKLCausal3D.load_config(vae_path) 31 | if sample_size: 32 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) 33 | else: 34 | vae = AutoencoderKLCausal3D.from_config(config) 35 | 36 | vae_ckpt = Path(vae_path) / "pytorch_model.pt" 37 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}" 38 | 39 | ckpt = torch.load(vae_ckpt, map_location=vae.device) 40 | if "state_dict" in ckpt: 41 | ckpt = ckpt["state_dict"] 42 | if any(k.startswith("vae.") for k in ckpt.keys()): 43 | ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} 44 | vae.load_state_dict(ckpt) 45 | 46 | spatial_compression_ratio = vae.config.spatial_compression_ratio 47 | time_compression_ratio = vae.config.time_compression_ratio 48 | 49 | if vae_precision is not None: 50 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) 51 | 52 | vae.requires_grad_(False) 53 | 54 | if logger is not None: 55 | logger.info(f"VAE to dtype: {vae.dtype}") 56 | 57 | if device is not None: 58 | vae = vae.to(device) 59 | 60 | vae.eval() 61 | 62 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio 63 | -------------------------------------------------------------------------------- /hunyuan/run-single-sample_video-fp8.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | export MODEL_BASE="/mnt/localssd/hunyuan-video/ckpts/" 4 | #dit_weight="/mnt/localssd/hunyuan-video/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt" 5 | seed=42 6 | 7 | dit_weight="/mnt/localssd/hunyuan-video/fp8/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt" 8 | 9 | python3 sample_video.py \ 10 | --model-base $MODEL_BASE \ 11 | --dit-weight $dit_weight \ 12 | --seed $seed \ 13 | --video-size 768 1280 \ 14 | --video-length 129 \ 15 | --infer-steps 50 \ 16 | --prompt "A cat walks on the grass, realistic style." \ 17 | --flow-reverse \ 18 | --use-cpu-offload \ 19 | --use-fp8 \ 20 | --save-path ./z-results 21 | -------------------------------------------------------------------------------- /hunyuan/run-single-sample_video.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | export MODEL_BASE="/mnt/localssd/hunyuan-video/ckpts/" 4 | dit_weight="/mnt/localssd/hunyuan-video/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt" 5 | seed=42 6 | 7 | python3 sample_video.py \ 8 | --model-base $MODEL_BASE \ 9 | --dit-weight $dit_weight \ 10 | --seed $seed \ 11 | --video-size 768 1280 \ 12 | --video-length 129 \ 13 | --infer-steps 50 \ 14 | --prompt "A cat walks on the grass, realistic style." \ 15 | --flow-reverse \ 16 | --use-cpu-offload \ 17 | --save-path ./z-results 18 | -------------------------------------------------------------------------------- /hunyuan/sample_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | from loguru import logger 5 | from datetime import datetime 6 | 7 | from hyvideo.utils.file_utils import save_videos_grid 8 | from hyvideo.config import parse_args 9 | from hyvideo.inference import HunyuanVideoSampler 10 | 11 | 12 | def main(): 13 | args = parse_args() 14 | print(args) 15 | models_root_path = Path(args.model_base) 16 | if not models_root_path.exists(): 17 | raise ValueError(f"`models_root` not exists: {models_root_path}") 18 | 19 | # Create save folder to save the samples 20 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' 21 | if not os.path.exists(save_path): 22 | os.makedirs(save_path, exist_ok=True) 23 | 24 | # Load models 25 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args) 26 | 27 | # Get the updated args 28 | args = hunyuan_video_sampler.args 29 | 30 | # Start sampling 31 | # TODO: batch inference check 32 | outputs = hunyuan_video_sampler.predict( 33 | prompt=args.prompt, 34 | height=args.video_size[0], 35 | width=args.video_size[1], 36 | video_length=args.video_length, 37 | seed=args.seed, 38 | negative_prompt=args.neg_prompt, 39 | infer_steps=args.infer_steps, 40 | guidance_scale=args.cfg_scale, 41 | num_videos_per_prompt=1,#args.num_videos, 42 | flow_shift=args.flow_shift, 43 | batch_size=args.batch_size, 44 | embedded_guidance_scale=args.embedded_cfg_scale 45 | ) 46 | samples = outputs['samples'] 47 | 48 | # Save samples 49 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: 50 | for i, sample in enumerate(samples): 51 | sample = samples[i].unsqueeze(0) 52 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") 53 | cur_save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" 54 | save_videos_grid(sample, cur_save_path, fps=24) 55 | logger.info(f'Sample save to: {cur_save_path}') 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/method.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/poodle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/poodle.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_boy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_boy.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_man_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_01.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_man_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_02.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_man_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_man_03.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_poodle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_poodle.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_woman_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_01.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_woman_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_02.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/images/seg_woman_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/images/seg_woman_03.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/material/application.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/application.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/material/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/logo.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/material/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/method.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/material/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/material/teaser.png -------------------------------------------------------------------------------- /hunyuan_custom/assets/meta_files.list: -------------------------------------------------------------------------------- 1 | assets/meta_files/poodle.json -------------------------------------------------------------------------------- /hunyuan_custom/assets/meta_files/poodle.json: -------------------------------------------------------------------------------- 1 | { 2 | "item_image_path": "assets/images/poodle.png", 3 | "item_prompt": "dog", 4 | "seed": 1124, 5 | "prompt": "A dog is chasing a cat in the park.", 6 | "negative_prompt": "", 7 | "chinese": "一只狗在公园里追一只猫。", 8 | "seg_item_image_path": "assets/images/seg_poodle.png" 9 | } 10 | -------------------------------------------------------------------------------- /hunyuan_custom/assets/videos/seg_man_01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_man_01.mp4 -------------------------------------------------------------------------------- /hunyuan_custom/assets/videos/seg_man_02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_man_02.mp4 -------------------------------------------------------------------------------- /hunyuan_custom/assets/videos/seg_woman_01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_woman_01.mp4 -------------------------------------------------------------------------------- /hunyuan_custom/assets/videos/seg_woman_03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/assets/videos/seg_woman_03.mp4 -------------------------------------------------------------------------------- /hunyuan_custom/hymm_gradio/gradio_ref2v.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | import datetime 6 | import requests 7 | import gradio as gr 8 | from tool_for_end2end import * 9 | 10 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" 11 | DATADIR = './temp' 12 | _HEADER_ = ''' 13 |
14 |

Tencet HunyuanvideoCustom Demo

15 |
16 | 17 | ''' 18 | # flask url 19 | URL = "http://127.0.0.1:8080/predict2" 20 | 21 | def post_and_get(width, height, num_steps, num_frames, guidance, flow_shift, seed, prompt, id_image, neg_prompt, template_prompt, template_neg_prompt, name): 22 | now = datetime.datetime.now().isoformat() 23 | imgdir = os.path.join(DATADIR, 'reference') 24 | videodir = os.path.join(DATADIR, 'video') 25 | imgfile = os.path.join(imgdir, now + '.png') 26 | output_video_path = os.path.join(videodir, now + '.mp4') 27 | 28 | os.makedirs(imgdir, exist_ok=True) 29 | os.makedirs(videodir, exist_ok=True) 30 | cv2.imwrite(imgfile, id_image[:,:,::-1]) 31 | 32 | proxies = { 33 | "http": None, 34 | "https": None, 35 | } 36 | 37 | files = { 38 | "trace_id": "abcd", 39 | "image_path": imgfile, 40 | "prompt": prompt, 41 | "negative_prompt": neg_prompt, 42 | "template_prompt": template_prompt, 43 | "template_neg_prompt": template_neg_prompt, 44 | "height": height, 45 | "width": width, 46 | "frames": num_frames, 47 | "cfg": guidance, 48 | "steps": num_steps, 49 | "seed": int(seed), 50 | "name": name, 51 | "shift": flow_shift, 52 | "save_fps": 25, 53 | } 54 | r = requests.get(URL, data = json.dumps(files), proxies=proxies) 55 | ret_dict = json.loads(r.text) 56 | video_buffer = ret_dict['content'][0]['buffer'] 57 | save_video_base64_to_local(video_path=None, base64_buffer=video_buffer, 58 | output_video_path=output_video_path) 59 | print('='*50) 60 | return output_video_path 61 | 62 | def create_demo(): 63 | 64 | with gr.Blocks() as demo: 65 | gr.Markdown(_HEADER_) 66 | with gr.Tab('单主体一致性'): 67 | with gr.Row(): 68 | with gr.Column(scale=1): 69 | with gr.Group(): 70 | prompt = gr.Textbox(label="Prompt", value="a man is riding a bicycle on the street.") 71 | neg_prompt = gr.Textbox(label="Negative Prompt", value="") 72 | id_image = gr.Image(label="Input reference image", height=480) 73 | 74 | with gr.Column(scale=2): 75 | with gr.Group(): 76 | output_image = gr.Video(label="Generated Video") 77 | 78 | with gr.Row(): 79 | with gr.Column(scale=2): 80 | with gr.Accordion("Options for generate video", open=False): 81 | with gr.Row(): 82 | width = gr.Slider(256, 1536, 1280, step=16, label="Width") 83 | height = gr.Slider(256, 1536, 720, step=16, label="Height") 84 | with gr.Row(): 85 | num_steps = gr.Slider(1, 100, 30, step=5, label="Number of steps") 86 | flow_shift = gr.Slider(1.0, 15.0, 13, step=1, label="Flow Shift") 87 | with gr.Row(): 88 | num_frames = gr.Slider(1, 129, 129, step=4, label="Number of frames") 89 | guidance = gr.Slider(1.0, 10.0, 7.5, step=0.5, label="Guidance") 90 | seed = gr.Textbox(1024, label="Seed (-1 for random)") 91 | with gr.Row(): 92 | template_prompt = gr.Textbox(label="Template Prompt", value="Realistic, High-quality. ") 93 | template_neg_prompt = gr.Textbox(label="Template Negative Prompt", value="Aerial view, aerial view, " \ 94 | "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, " \ 95 | "distortion, blurring, text, subtitles, static, picture, black border. ") 96 | name = gr.Textbox(label="Object Name", value="object ") 97 | with gr.Column(scale=1): 98 | generate_btn = gr.Button("Generate") 99 | 100 | generate_btn.click(fn=post_and_get, 101 | inputs=[width, height, num_steps, num_frames, guidance, flow_shift, seed, prompt, id_image, neg_prompt, template_prompt, template_neg_prompt, name], 102 | outputs=[output_image], 103 | ) 104 | 105 | quick_prompts = [[x] for x in glob.glob('./assets/images/*.png')] 106 | example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Other object', samples_per_page=1000, components=[id_image]) 107 | example_quick_prompts.click(lambda x: x[0], inputs=example_quick_prompts, outputs=id_image, show_progress=False, queue=False) 108 | with gr.Row(), gr.Column(): 109 | gr.Markdown("## Examples") 110 | example_inps = [ 111 | [ 112 | 'A woman is drinking coffee at a café.', 113 | './assets/images/seg_woman_01.png', 114 | 1280, 720, 30, 129, 7.5, 13, 1024, 115 | "assets/videos/seg_woman_01.mp4" 116 | ], 117 | [ 118 | 'In a cubicle of an office building, a woman focuses intently on the computer screen, typing rapidly on the keyboard, surrounded by piles of documents.', 119 | './assets/images/seg_woman_03.png', 120 | 1280, 720, 30, 129, 7.5, 13, 1025, 121 | "./assets/videos/seg_woman_03.mp4" 122 | ], 123 | [ 124 | 'A man walks across an ancient stone bridge holding an umbrella, raindrops tapping against it.', 125 | './assets/images/seg_man_01.png', 126 | 1280, 720, 30, 129, 7.5, 13, 1025, 127 | "./assets/videos/seg_man_01.mp4" 128 | ], 129 | [ 130 | 'During a train journey, a man admires the changing scenery through the window.', 131 | './assets/images/seg_man_02.png', 132 | 1280, 720, 30, 129, 7.5, 13, 1026, 133 | "./assets/videos/seg_man_02.mp4" 134 | ] 135 | ] 136 | gr.Examples(examples=example_inps, inputs=[prompt, id_image, width, height, num_steps, num_frames, guidance, flow_shift, seed, output_image],) 137 | return demo 138 | 139 | if __name__ == "__main__": 140 | allowed_paths = ['/'] 141 | demo = create_demo() 142 | demo.launch(server_name='0.0.0.0', server_port=80, share=True, allowed_paths=allowed_paths) 143 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_gradio/tool_for_end2end.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import uuid 4 | import base64 5 | import imageio 6 | import torch 7 | import torchvision 8 | from PIL import Image 9 | import numpy as np 10 | from copy import deepcopy 11 | from einops import rearrange 12 | 13 | TEMP_DIR = "./temp" 14 | if not os.path.exists(TEMP_DIR): 15 | os.makedirs(TEMP_DIR, exist_ok=True) 16 | 17 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): 18 | videos = rearrange(videos, "b c t h w -> t b c h w") 19 | outputs = [] 20 | for x in videos: 21 | x = torchvision.utils.make_grid(x, nrow=n_rows) 22 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 23 | if rescale: 24 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 25 | x = torch.clamp(x,0,1) 26 | x = (x * 255).numpy().astype(np.uint8) 27 | outputs.append(x) 28 | 29 | os.makedirs(os.path.dirname(path), exist_ok=True) 30 | imageio.mimsave(path, outputs, fps=fps, quality=quality) 31 | 32 | def encode_image_to_base64(image_path): 33 | try: 34 | with open(image_path, 'rb') as image_file: 35 | image_data = image_file.read() 36 | encoded_data = base64.b64encode(image_data).decode('utf-8') 37 | print(f"Image file '{image_path}' has been successfully encoded to Base64.") 38 | return encoded_data 39 | 40 | except Exception as e: 41 | print(f"Error encoding image: {e}") 42 | return None 43 | 44 | def encode_video_to_base64(video_path): 45 | try: 46 | with open(video_path, 'rb') as video_file: 47 | video_data = video_file.read() 48 | encoded_data = base64.b64encode(video_data).decode('utf-8') 49 | print(f"Video file '{video_path}' has been successfully encoded to Base64.") 50 | return encoded_data 51 | 52 | except Exception as e: 53 | print(f"Error encoding video: {e}") 54 | return None 55 | 56 | def encode_wav_to_base64(wav_path): 57 | try: 58 | with open(wav_path, 'rb') as audio_file: 59 | audio_data = audio_file.read() 60 | encoded_data = base64.b64encode(audio_data).decode('utf-8') 61 | print(f"Audio file '{wav_path}' has been successfully encoded to Base64.") 62 | return encoded_data 63 | 64 | except Exception as e: 65 | print(f"Error encoding audio: {e}") 66 | return None 67 | 68 | def encode_pkl_to_base64(pkl_path): 69 | try: 70 | with open(pkl_path, 'rb') as pkl_file: 71 | pkl_data = pkl_file.read() 72 | 73 | encoded_data = base64.b64encode(pkl_data).decode('utf-8') 74 | 75 | print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.") 76 | return encoded_data 77 | 78 | except Exception as e: 79 | print(f"Error encoding pickle: {e}") 80 | return None 81 | 82 | def decode_base64_to_image(base64_buffer_str): 83 | try: 84 | image_data = base64.b64decode(base64_buffer_str) 85 | image = Image.open(io.BytesIO(image_data)) 86 | image_array = np.array(image) 87 | print(f"Image Base64 string has beed succesfully decoded to image.") 88 | return image_array 89 | except Exception as e: 90 | print(f"Error encdecodingoding image: {e}") 91 | return None 92 | 93 | def decode_base64_to_video(base64_buffer_str): 94 | try: 95 | video_data = base64.b64decode(base64_buffer_str) 96 | video_bytes = io.BytesIO(video_data) 97 | video_bytes.seek(0) 98 | video_reader = imageio.get_reader(video_bytes, 'ffmpeg') 99 | video_frames = [frame for frame in video_reader] 100 | return video_frames 101 | except Exception as e: 102 | print(f"Error decoding video: {e}") 103 | return None 104 | 105 | def save_image_base64_to_local(image_path=None, base64_buffer=None): 106 | if image_path is not None and base64_buffer is None: 107 | image_buffer_base64 = encode_image_to_base64(image_path) 108 | elif image_path is None and base64_buffer is not None: 109 | image_buffer_base64 = deepcopy(base64_buffer) 110 | else: 111 | print("Please pass either 'image_path' or 'base64_buffer'") 112 | return None 113 | 114 | if image_buffer_base64 is not None: 115 | image_data = base64.b64decode(image_buffer_base64) 116 | uuid_string = str(uuid.uuid4()) 117 | temp_image_path = f'{TEMP_DIR}/{uuid_string}.png' 118 | with open(temp_image_path, 'wb') as image_file: 119 | image_file.write(image_data) 120 | return temp_image_path 121 | else: 122 | return None 123 | 124 | def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None): 125 | if video_path is not None and base64_buffer is None: 126 | video_buffer_base64 = encode_video_to_base64(video_path) 127 | elif video_path is None and base64_buffer is not None: 128 | video_buffer_base64 = deepcopy(base64_buffer) 129 | else: 130 | print("Please pass either 'video_path' or 'base64_buffer'") 131 | return None 132 | 133 | if video_buffer_base64 is not None: 134 | video_data = base64.b64decode(video_buffer_base64) 135 | if output_video_path is None: 136 | uuid_string = str(uuid.uuid4()) 137 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' 138 | else: 139 | temp_video_path = output_video_path 140 | with open(temp_video_path, 'wb') as video_file: 141 | video_file.write(video_data) 142 | return temp_video_path 143 | else: 144 | return None 145 | 146 | def save_audio_base64_to_local(audio_path=None, base64_buffer=None): 147 | if audio_path is not None and base64_buffer is None: 148 | audio_buffer_base64 = encode_wav_to_base64(audio_path) 149 | elif audio_path is None and base64_buffer is not None: 150 | audio_buffer_base64 = deepcopy(base64_buffer) 151 | else: 152 | print("Please pass either 'audio_path' or 'base64_buffer'") 153 | return None 154 | 155 | if audio_buffer_base64 is not None: 156 | audio_data = base64.b64decode(audio_buffer_base64) 157 | uuid_string = str(uuid.uuid4()) 158 | temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav' 159 | with open(temp_audio_path, 'wb') as audio_file: 160 | audio_file.write(audio_data) 161 | return temp_audio_path 162 | else: 163 | return None 164 | 165 | def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None): 166 | if pkl_path is not None and base64_buffer is None: 167 | pkl_buffer_base64 = encode_pkl_to_base64(pkl_path) 168 | elif pkl_path is None and base64_buffer is not None: 169 | pkl_buffer_base64 = deepcopy(base64_buffer) 170 | else: 171 | print("Please pass either 'pkl_path' or 'base64_buffer'") 172 | return None 173 | 174 | if pkl_buffer_base64 is not None: 175 | pkl_data = base64.b64decode(pkl_buffer_base64) 176 | uuid_string = str(uuid.uuid4()) 177 | temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl' 178 | with open(temp_pkl_path, 'wb') as pkl_file: 179 | pkl_file.write(pkl_data) 180 | return temp_pkl_path 181 | else: 182 | return None 183 | 184 | def remove_temp_fles(input_dict): 185 | for key, val in input_dict.items(): 186 | if "_path" in key and val is not None and os.path.exists(val): 187 | os.remove(val) 188 | print(f"Remove temporary {key} from {val}") 189 | 190 | def process_output_dict(output_dict): 191 | if output_dict["rank"] == 0: 192 | uuid_string = str(uuid.uuid4()) 193 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' 194 | save_videos_grid(output_dict["video"], temp_video_path, fps=25) 195 | save_path = temp_video_path 196 | 197 | video_base64_buffer = encode_video_to_base64(save_path) 198 | 199 | encoded_output_dict = { 200 | "errCode": output_dict["err_code"], 201 | "content": [ 202 | { 203 | "buffer": video_base64_buffer 204 | }, 205 | ], 206 | "info":output_dict["err_msg"], 207 | "trace_id": output_dict["trace_id"], 208 | } 209 | 210 | else: 211 | uuid_string = str(uuid.uuid4()) 212 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' 213 | 214 | try: 215 | video_base64_buffer = encode_video_to_base64(temp_video_path) 216 | except: 217 | video_base64_buffer = None 218 | 219 | encoded_output_dict = { 220 | "errCode": output_dict["err_code"], 221 | "content": [ 222 | { 223 | "buffer": video_base64_buffer 224 | }, 225 | ], 226 | "info":output_dict["err_msg"], 227 | "trace_id": output_dict["trace_id"], 228 | } 229 | 230 | return encoded_output_dict 231 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/hunyuan_custom/hymm_sp/__init__.py -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from hymm_sp.constants import * 3 | import re 4 | import collections.abc 5 | 6 | def as_tuple(x): 7 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 8 | return tuple(x) 9 | if x is None or isinstance(x, (int, float, str)): 10 | return (x,) 11 | else: 12 | raise ValueError(f"Unknown type {type(x)}") 13 | 14 | def parse_args(namespace=None): 15 | parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script") 16 | parser = add_extra_args(parser) 17 | args = parser.parse_args(namespace=namespace) 18 | args = sanity_check_args(args) 19 | return args 20 | 21 | def add_extra_args(parser: argparse.ArgumentParser): 22 | parser = add_network_args(parser) 23 | parser = add_extra_models_args(parser) 24 | parser = add_denoise_schedule_args(parser) 25 | parser = add_evaluation_args(parser) 26 | return parser 27 | 28 | def add_network_args(parser: argparse.ArgumentParser): 29 | group = parser.add_argument_group(title="Network") 30 | group.add_argument("--model", type=str, default="HYVideo-T/2", 31 | help="Model architecture to use. It it also used to determine the experiment directory.") 32 | group.add_argument("--latent-channels", type=str, default=None, 33 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " 34 | "it still needs to match the latent channels of the VAE model.") 35 | group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.") 36 | return parser 37 | 38 | def add_extra_models_args(parser: argparse.ArgumentParser): 39 | group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)") 40 | 41 | # VAE 42 | group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.") 43 | group.add_argument("--vae-precision", type=str, default="fp16", 44 | help="Precision mode for the VAE model.") 45 | group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.") 46 | group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH), 47 | help="Name of the text encoder model.") 48 | group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS, 49 | help="Precision mode for the text encoder model.") 50 | group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.") 51 | group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.") 52 | group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH), 53 | help="Name of the tokenizer model.") 54 | group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"], 55 | help="Inference mode for the text encoder model. It should match the text encoder type. T5 and " 56 | "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.") 57 | group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE, 58 | help="Video prompt template for the decoder-only text encoder model.") 59 | group.add_argument("--hidden-state-skip-layer", type=int, default=2, 60 | help="Skip layer for hidden states.") 61 | group.add_argument("--apply-final-norm", action="store_true", 62 | help="Apply final normalization to the used text encoder hidden states.") 63 | 64 | # - CLIP 65 | group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH), 66 | help="Name of the second text encoder model.") 67 | group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS, 68 | help="Precision mode for the second text encoder model.") 69 | group.add_argument("--text-states-dim-2", type=int, default=768, 70 | help="Dimension of the second text encoder hidden states.") 71 | group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH), 72 | help="Name of the second tokenizer model.") 73 | group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.") 74 | group.set_defaults(use_attention_mask=True) 75 | group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION, 76 | help="A projection layer for bridging the text encoder hidden states and the diffusion model " 77 | "conditions.") 78 | return parser 79 | 80 | 81 | def add_denoise_schedule_args(parser: argparse.ArgumentParser): 82 | group = parser.add_argument_group(title="Denoise schedule") 83 | group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.") 84 | group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.") 85 | group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.") 86 | group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching." 87 | "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)") 88 | group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.") 89 | return parser 90 | 91 | def add_evaluation_args(parser: argparse.ArgumentParser): 92 | group = parser.add_argument_group(title="Validation Loss Evaluation") 93 | parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS, 94 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.") 95 | parser.add_argument("--reproduce", action="store_true", 96 | help="Enable reproducibility by setting random seeds and deterministic algorithms.") 97 | parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.") 98 | parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"], 99 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.") 100 | parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.") 101 | group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.") 102 | group.add_argument("--video-size", type=int, nargs='+', default=512, 103 | help="Video size for training. If a single value is provided, it will be used for both width " 104 | "and height. If two values are provided, they will be used for width and height " 105 | "respectively.") 106 | group.add_argument("--sample-n-frames", type=int, default=1, 107 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1") 108 | group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.") 109 | group.add_argument("--val-disable-autocast", action="store_true", 110 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.") 111 | group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.") 112 | group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.") 113 | group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.") 114 | group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.") 115 | group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.") 116 | group.add_argument("--add-pos-prompt", type=str, default='', help="Addition prompt for sampling during evaluation.") 117 | group.add_argument("--add-neg-prompt", type=str, default='', help="Addition negative prompt for sampling during evaluation.") 118 | group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.") 119 | group.add_argument("--image-path", type=str, default="", help="") 120 | group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.") 121 | group.add_argument("--input", type=str, default=None, help="test data.") 122 | group.add_argument("--item-name", type=str, default=None, help="") 123 | group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.") 124 | group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.") 125 | group.add_argument("--use-deepcache", type=int, default=1) 126 | return parser 127 | 128 | def sanity_check_args(args): 129 | # VAE channels 130 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+" 131 | if not re.match(vae_pattern, args.vae): 132 | raise ValueError( 133 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'." 134 | ) 135 | vae_channels = int(args.vae.split("-")[1][:-1]) 136 | if args.latent_channels is None: 137 | args.latent_channels = vae_channels 138 | if vae_channels != args.latent_channels: 139 | raise ValueError( 140 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})." 141 | ) 142 | return args 143 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | __all__ = [ 5 | "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE", 6 | "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH", 7 | "TEXT_PROJECTION", 8 | ] 9 | 10 | # =================== Constant Values ===================== 11 | 12 | PRECISION_TO_TYPE = { 13 | 'fp32': torch.float32, 14 | 'fp16': torch.float16, 15 | 'bf16': torch.bfloat16, 16 | } 17 | 18 | PROMPT_TEMPLATE_ENCODE_VIDEO = ( 19 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 20 | "1. The main content and theme of the video." 21 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 22 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 23 | "4. background environment, light, style and atmosphere." 24 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 25 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 26 | ) 27 | 28 | PROMPT_TEMPLATE = { 29 | "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95}, 30 | } 31 | 32 | # ======================= Model ====================== 33 | PRECISIONS = {"fp32", "fp16", "bf16"} 34 | 35 | # =================== Model Path ===================== 36 | MODEL_BASE = os.getenv("MODEL_BASE") 37 | 38 | # 3D VAE 39 | VAE_PATH = { 40 | "884-16c-hy0801": f"{MODEL_BASE}/vae_3d/hyvae_v1_0801", 41 | } 42 | 43 | # Text Encoder 44 | TEXT_ENCODER_PATH = { 45 | "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14", 46 | "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1", 47 | } 48 | 49 | # Tokenizer 50 | TOKENIZER_PATH = { 51 | "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14", 52 | "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1", 53 | } 54 | 55 | TEXT_PROJECTION = { 56 | "linear", # Default, an nn.Linear() layer 57 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT 58 | } -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/data_kits/data_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import imageio 6 | import torchvision 7 | from einops import rearrange 8 | 9 | 10 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): 11 | videos = rearrange(videos, "b c t h w -> t b c h w") 12 | outputs = [] 13 | for x in videos: 14 | x = torchvision.utils.make_grid(x, nrow=n_rows) 15 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 16 | if rescale: 17 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 18 | x = torch.clamp(x,0,1) 19 | x = (x * 255).numpy().astype(np.uint8) 20 | outputs.append(x) 21 | 22 | os.makedirs(os.path.dirname(path), exist_ok=True) 23 | imageio.mimsave(path, outputs, fps=fps, quality=quality) 24 | 25 | def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): 26 | crop_h, crop_w = crop_img.shape[:2] 27 | target_w, target_h = size 28 | scale_h, scale_w = target_h / crop_h, target_w / crop_w 29 | if scale_w > scale_h: 30 | resize_h = int(target_h*resize_ratio) 31 | resize_w = int(crop_w / crop_h * resize_h) 32 | else: 33 | resize_w = int(target_w*resize_ratio) 34 | resize_h = int(crop_h / crop_w * resize_w) 35 | crop_img = cv2.resize(crop_img, (resize_w, resize_h)) 36 | pad_left = (target_w - resize_w) // 2 37 | pad_top = (target_h - resize_h) // 2 38 | pad_right = target_w - resize_w - pad_left 39 | pad_bottom = target_h - resize_h - pad_top 40 | crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) 41 | return crop_img -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/data_kits/video_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import json 5 | import numpy as np 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | from hymm_sp.data_kits.data_tools import * 9 | 10 | 11 | class DataPreprocess(object): 12 | def __init__(self): 13 | self.llava_size = (336, 336) 14 | self.llava_transform = transforms.Compose( 15 | [ 16 | transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), 19 | ] 20 | ) 21 | 22 | def get_batch(self, image_path, size): 23 | try: 24 | image = cv2.imread(image_path) 25 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 26 | except: 27 | image = Image.open(image_path).convert('RGB') 28 | llava_item_image = pad_image(image.copy(), self.llava_size) 29 | uncond_llava_item_image = np.ones_like(llava_item_image) * 255 30 | cat_item_image = pad_image(image.copy(), size) 31 | 32 | llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) 33 | uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) 34 | cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 35 | batch = { 36 | "pixel_value_llava": llava_item_tensor.unsqueeze(0), 37 | "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), 38 | 'pixel_value_ref': cat_item_tensor.unsqueeze(0), 39 | } 40 | return batch 41 | 42 | 43 | class JsonDataset(object): 44 | def __init__(self, args): 45 | self.args = args 46 | self.data_list = args.input 47 | self.pad_color = (255, 255, 255) 48 | self.llava_size = (336, 336) 49 | self.ref_size = (args.video_size[1], args.video_size[0]) 50 | if self.data_list.endswith('.list'): 51 | self.data_paths = [line.strip() for line in open(self.data_list, 'r')] if self.data_list is not None else [] 52 | else: 53 | self.data_paths = [self.data_list] 54 | self.llava_transform = transforms.Compose( 55 | [ 56 | transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), 59 | ] 60 | ) 61 | 62 | def __len__(self): 63 | return len(self.data_paths) 64 | 65 | def read_image(self, image_path): 66 | if isinstance(image_path, dict): 67 | image_path = image_path['seg_item_image_path'] 68 | 69 | try: 70 | face_image_masked = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 71 | except: 72 | face_image_masked = Image.open(image_path).convert('RGB') 73 | 74 | cat_face_image = pad_image(face_image_masked.copy(), self.ref_size) 75 | llava_face_image = pad_image(face_image_masked.copy(), self.llava_size) 76 | return llava_face_image, cat_face_image 77 | 78 | def __getitem__(self, idx): 79 | data_path = self.data_paths[idx] 80 | data_name = os.path.basename(os.path.splitext(data_path)[0]) 81 | if data_path.endswith('.json'): 82 | data = json.load(open(data_path, 'r')) 83 | llava_item_image, cat_item_image = self.read_image(data) 84 | item_prompt = data['item_prompt'] 85 | seed = data['seed'] 86 | prompt = data['prompt'] 87 | if 'negative_prompt' in data: 88 | negative_prompt = data['negative_prompt'] 89 | else: 90 | negative_prompt = '' 91 | else: 92 | llava_item_image, cat_item_image = self.read_image(data_path) 93 | item_prompt = 'object' 94 | seed = self.args.seed 95 | prompt = self.args.pos_prompt 96 | negative_prompt = self.args.neg_prompt 97 | 98 | llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) 99 | cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 100 | 101 | uncond_llava_item_image = np.ones_like(llava_item_image) * 255 102 | uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) 103 | # print(llava_item_tensor.shape, cat_item_tensor.shape) 104 | # raise ValueError 105 | batch = { 106 | "pixel_value_llava": llava_item_tensor, 107 | "uncond_pixel_value_llava": uncond_llava_item_tensor, 108 | "pixel_value_ref": cat_item_tensor, 109 | "prompt": prompt, 110 | "negative_prompt": negative_prompt, 111 | "seed": seed, 112 | "name": item_prompt, 113 | 'data_name': data_name 114 | } 115 | return batch 116 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import HunyuanVideoCustomPipeline 2 | from .schedulers import FlowMatchDiscreteScheduler 3 | 4 | def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None, 5 | device=None, progress_bar_config=None): 6 | """ Load the denoising scheduler for inference. """ 7 | if scheduler is None: 8 | scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift_eval_video, reverse=args.flow_reverse, solver=args.flow_solver, ) 9 | 10 | # Only enable progress bar for rank 0 11 | progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0} 12 | 13 | pipeline = HunyuanVideoCustomPipeline(vae=vae, 14 | text_encoder=text_encoder, 15 | text_encoder_2=text_encoder_2, 16 | transformer=model, 17 | scheduler=scheduler, 18 | # safety_checker=None, 19 | # feature_extractor=None, 20 | # requires_safety_checker=False, 21 | progress_bar_config=progress_bar_config, 22 | args=args, 23 | ) 24 | if args.cpu_offload: # avoid oom 25 | pass 26 | else: 27 | pipeline = pipeline.to(device) 28 | 29 | return pipeline 30 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/diffusion/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_hunyuan_video_custom import HunyuanVideoCustomPipeline 2 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/diffusion/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Modified from diffusers==0.29.2 17 | # 18 | # ============================================================================== 19 | 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | 25 | from diffusers.configuration_utils import ConfigMixin, register_to_config 26 | from diffusers.utils import BaseOutput, logging 27 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 28 | 29 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 31 | 32 | 33 | @dataclass 34 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): 35 | """ 36 | Output class for the scheduler's `step` function output. 37 | 38 | Args: 39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 41 | denoising loop. 42 | """ 43 | 44 | prev_sample: torch.FloatTensor 45 | 46 | 47 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): 48 | """ 49 | Euler scheduler. 50 | 51 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 52 | methods the library implements for all schedulers such as loading and saving. 53 | 54 | Args: 55 | num_train_timesteps (`int`, defaults to 1000): 56 | The number of diffusion steps to train the model. 57 | timestep_spacing (`str`, defaults to `"linspace"`): 58 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 59 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 60 | shift (`float`, defaults to 1.0): 61 | The shift value for the timestep schedule. 62 | reverse (`bool`, defaults to `True`): 63 | Whether to reverse the timestep schedule. 64 | """ 65 | 66 | _compatibles = [] 67 | order = 1 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | num_train_timesteps: int = 1000, 73 | shift: float = 1.0, 74 | reverse: bool = True, 75 | solver: str = "euler", 76 | n_tokens: Optional[int] = None, 77 | ): 78 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) 79 | 80 | if not reverse: 81 | sigmas = sigmas.flip(0) 82 | 83 | self.sigmas = sigmas 84 | # the value fed to model 85 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) 86 | 87 | self._step_index = None 88 | self._begin_index = None 89 | 90 | self.supported_solver = ["euler"] 91 | if solver not in self.supported_solver: 92 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") 93 | 94 | @property 95 | def step_index(self): 96 | """ 97 | The index counter for current timestep. It will increase 1 after each scheduler step. 98 | """ 99 | return self._step_index 100 | 101 | @property 102 | def begin_index(self): 103 | """ 104 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 105 | """ 106 | return self._begin_index 107 | 108 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 109 | def set_begin_index(self, begin_index: int = 0): 110 | """ 111 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 112 | 113 | Args: 114 | begin_index (`int`): 115 | The begin index for the scheduler. 116 | """ 117 | self._begin_index = begin_index 118 | 119 | def _sigma_to_t(self, sigma): 120 | return sigma * self.config.num_train_timesteps 121 | 122 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, 123 | n_tokens: int = None): 124 | """ 125 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 126 | 127 | Args: 128 | num_inference_steps (`int`): 129 | The number of diffusion steps used when generating samples with a pre-trained model. 130 | device (`str` or `torch.device`, *optional*): 131 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 132 | n_tokens (`int`, *optional*): 133 | Number of tokens in the input sequence. 134 | """ 135 | self.num_inference_steps = num_inference_steps 136 | 137 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) 138 | sigmas = self.sd3_time_shift(sigmas) 139 | 140 | if not self.config.reverse: 141 | sigmas = 1 - sigmas 142 | 143 | self.sigmas = sigmas 144 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) 145 | 146 | # Reset step index 147 | self._step_index = None 148 | 149 | def index_for_timestep(self, timestep, schedule_timesteps=None): 150 | if schedule_timesteps is None: 151 | schedule_timesteps = self.timesteps 152 | 153 | indices = (schedule_timesteps == timestep).nonzero() 154 | 155 | # The sigma index that is taken for the **very** first `step` 156 | # is always the second index (or the last index if there is only 1) 157 | # This way we can ensure we don't accidentally skip a sigma in 158 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 159 | pos = 1 if len(indices) > 1 else 0 160 | 161 | return indices[pos].item() 162 | 163 | def _init_step_index(self, timestep): 164 | if self.begin_index is None: 165 | if isinstance(timestep, torch.Tensor): 166 | timestep = timestep.to(self.timesteps.device) 167 | self._step_index = self.index_for_timestep(timestep) 168 | else: 169 | self._step_index = self._begin_index 170 | 171 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: 172 | return sample 173 | 174 | def sd3_time_shift(self, t: torch.Tensor): 175 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) 176 | 177 | def step( 178 | self, 179 | model_output: torch.FloatTensor, 180 | timestep: Union[float, torch.FloatTensor], 181 | sample: torch.FloatTensor, 182 | return_dict: bool = True, 183 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: 184 | """ 185 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 186 | process from the learned model outputs (most often the predicted noise). 187 | 188 | Args: 189 | model_output (`torch.FloatTensor`): 190 | The direct output from learned diffusion model. 191 | timestep (`float`): 192 | The current discrete timestep in the diffusion chain. 193 | sample (`torch.FloatTensor`): 194 | A current instance of a sample created by the diffusion process. 195 | return_dict (`bool`): 196 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 197 | tuple. 198 | 199 | Returns: 200 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 201 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 202 | returned, otherwise a tuple is returned where the first element is the sample tensor. 203 | """ 204 | 205 | if ( 206 | isinstance(timestep, int) 207 | or isinstance(timestep, torch.IntTensor) 208 | or isinstance(timestep, torch.LongTensor) 209 | ): 210 | raise ValueError( 211 | ( 212 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 213 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 214 | " one of the `scheduler.timesteps` as a timestep." 215 | ), 216 | ) 217 | 218 | if self.step_index is None: 219 | self._init_step_index(timestep) 220 | 221 | # Upcast to avoid precision issues when computing prev_sample 222 | sample = sample.to(torch.float32) 223 | 224 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 225 | 226 | if self.config.solver == "euler": 227 | prev_sample = sample + model_output.float() * dt 228 | else: 229 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") 230 | 231 | # upon completion increase step index by one 232 | self._step_index += 1 233 | 234 | if not return_dict: 235 | return (prev_sample,) 236 | 237 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) 238 | 239 | def __len__(self): 240 | return self.config.num_train_timesteps 241 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List 3 | from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd 4 | 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 12 | x = tuple(x) 13 | if len(x) == 1: 14 | x = tuple(repeat(x[0], n)) 15 | return x 16 | return tuple(repeat(x, n)) 17 | return parse 18 | 19 | to_1tuple = _ntuple(1) 20 | to_2tuple = _ntuple(2) 21 | to_3tuple = _ntuple(3) 22 | to_4tuple = _ntuple(4) 23 | 24 | def get_rope_freq_from_size(latents_size, ndim, target_ndim, args, 25 | rope_theta_rescale_factor: Union[float, List[float]]=1.0, 26 | rope_interpolation_factor: Union[float, List[float]]=1.0, 27 | concat_dict={}): 28 | 29 | if isinstance(args.patch_size, int): 30 | assert all(s % args.patch_size == 0 for s in latents_size), \ 31 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \ 32 | f"but got {latents_size}." 33 | rope_sizes = [s // args.patch_size for s in latents_size] 34 | elif isinstance(args.patch_size, list): 35 | assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ 36 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \ 37 | f"but got {latents_size}." 38 | rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)] 39 | 40 | if len(rope_sizes) != target_ndim: 41 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis 42 | head_dim = args.hidden_size // args.num_heads 43 | rope_dim_list = args.rope_dim_list 44 | if rope_dim_list is None: 45 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] 46 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" 47 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 48 | rope_sizes, 49 | theta=args.rope_theta, 50 | use_real=True, 51 | theta_rescale_factor=rope_theta_rescale_factor, 52 | interpolation_factor=rope_interpolation_factor, 53 | concat_dict=concat_dict) 54 | return freqs_cos, freqs_sin 55 | 56 | def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, 57 | theta_rescale_factor: Union[float, List[float]]=1.0, 58 | interpolation_factor: Union[float, List[float]]=1.0, 59 | concat_dict={} 60 | ): 61 | 62 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 63 | if len(concat_dict)<1: 64 | pass 65 | else: 66 | if concat_dict['mode']=='timecat': 67 | bias = grid[:,:1].clone() 68 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 69 | grid = torch.cat([bias, grid], dim=1) 70 | 71 | elif concat_dict['mode']=='timecat-w': 72 | bias = grid[:,:1].clone() 73 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 74 | bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178 75 | grid = torch.cat([bias, grid], dim=1) 76 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 77 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 78 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 79 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 80 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 81 | 82 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 83 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 84 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 85 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 86 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" 87 | 88 | # use 1/ndim of dimensions to encode grid_axis 89 | embs = [] 90 | for i in range(len(rope_dim_list)): 91 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, 92 | theta_rescale_factor=theta_rescale_factor[i], 93 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] 94 | 95 | embs.append(emb) 96 | 97 | if use_real: 98 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 99 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 100 | return cos, sin 101 | else: 102 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 103 | return emb -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from loguru import logger 4 | from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE 5 | from hymm_sp.vae import load_vae 6 | from hymm_sp.modules import load_model 7 | from hymm_sp.text_encoder import TextEncoder 8 | import torch.distributed 9 | from hymm_sp.modules.parallel_states import ( 10 | nccl_info, 11 | ) 12 | from hymm_sp.modules.fp8_optimization import convert_fp8_linear 13 | 14 | 15 | class Inference(object): 16 | def __init__(self, 17 | args, 18 | vae, 19 | vae_kwargs, 20 | text_encoder, 21 | model, 22 | text_encoder_2=None, 23 | pipeline=None, 24 | cpu_offload=False, 25 | device=None, 26 | logger=None): 27 | self.vae = vae 28 | self.vae_kwargs = vae_kwargs 29 | 30 | self.text_encoder = text_encoder 31 | self.text_encoder_2 = text_encoder_2 32 | 33 | self.model = model 34 | self.pipeline = pipeline 35 | self.cpu_offload = cpu_offload 36 | 37 | self.args = args 38 | self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" 39 | if nccl_info.sp_size > 1: 40 | self.device = torch.device(f"cuda:{torch.distributed.get_rank()}") 41 | 42 | self.logger = logger 43 | 44 | @classmethod 45 | def from_pretrained(cls, 46 | pretrained_model_path, 47 | args, 48 | device=None, 49 | **kwargs): 50 | """ 51 | Initialize the Inference pipeline. 52 | 53 | Args: 54 | pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. 55 | device (int): The device for inference. Default is 0. 56 | logger (logging.Logger): The logger for the inference pipeline. Default is None. 57 | """ 58 | # ======================================================================== 59 | logger.info(f"Got text-to-video model root path: {pretrained_model_path}") 60 | 61 | # ======================== Get the args path ============================= 62 | 63 | # Set device and disable gradient 64 | if device is None: 65 | device = "cuda" if torch.cuda.is_available() else "cpu" 66 | torch.set_grad_enabled(False) 67 | logger.info("Building model...") 68 | factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]} 69 | in_channels = args.latent_channels 70 | out_channels = args.latent_channels 71 | print("="*25, f"build model", "="*25) 72 | model = load_model( 73 | args, 74 | in_channels=in_channels, 75 | out_channels=out_channels, 76 | factor_kwargs=factor_kwargs 77 | ) 78 | if args.use_fp8: 79 | convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision]) 80 | if args.cpu_offload: 81 | print(f'='*20, f'load transformer to cpu') 82 | model = model.to('cpu') 83 | torch.cuda.empty_cache() 84 | else: 85 | model = model.to(device) 86 | model = Inference.load_state_dict(args, model, pretrained_model_path) 87 | model.eval() 88 | 89 | # ============================= Build extra models ======================== 90 | # VAE 91 | print("="*25, f"load vae", "="*25) 92 | vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device) 93 | vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio} 94 | 95 | # Text encoder 96 | if args.prompt_template_video is not None: 97 | crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) 98 | else: 99 | crop_start = 0 100 | max_length = args.text_len + crop_start 101 | 102 | # prompt_template_video 103 | prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None 104 | print("="*25, f"load llava", "="*25) 105 | text_encoder = TextEncoder(text_encoder_type = args.text_encoder, 106 | max_length = max_length, 107 | text_encoder_precision = args.text_encoder_precision, 108 | tokenizer_type = args.tokenizer, 109 | use_attention_mask = args.use_attention_mask, 110 | prompt_template_video = prompt_template_video, 111 | hidden_state_skip_layer = args.hidden_state_skip_layer, 112 | apply_final_norm = args.apply_final_norm, 113 | reproduce = args.reproduce, 114 | logger = logger, 115 | device = 'cpu' if args.cpu_offload else device , 116 | ) 117 | text_encoder_2 = None 118 | if args.text_encoder_2 is not None: 119 | text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2, 120 | max_length=args.text_len_2, 121 | text_encoder_precision=args.text_encoder_precision_2, 122 | tokenizer_type=args.tokenizer_2, 123 | use_attention_mask=args.use_attention_mask, 124 | reproduce=args.reproduce, 125 | logger=logger, 126 | device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu' 127 | ) 128 | 129 | return cls(args=args, 130 | vae=vae, 131 | vae_kwargs=vae_kwargs, 132 | text_encoder=text_encoder, 133 | model=model, 134 | text_encoder_2=text_encoder_2, 135 | device=device, 136 | logger=logger) 137 | 138 | @staticmethod 139 | def load_state_dict(args, model, ckpt_path): 140 | load_key = args.load_key 141 | ckpt_path = Path(ckpt_path) 142 | if ckpt_path.is_dir(): 143 | ckpt_path = next(ckpt_path.glob("*_model_states.pt")) 144 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage) 145 | if load_key in state_dict: 146 | state_dict = state_dict[load_key] 147 | elif load_key == ".": 148 | pass 149 | else: 150 | raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}") 151 | model.load_state_dict(state_dict, strict=False) 152 | return model 153 | 154 | def get_exp_dir_and_ckpt_id(self): 155 | if self.ckpt is None: 156 | raise ValueError("The checkpoint path is not provided.") 157 | 158 | ckpt = Path(self.ckpt) 159 | if ckpt.parents[1].name == "checkpoints": 160 | # It should be a standard checkpoint path. We use the parent directory as the default save directory. 161 | exp_dir = ckpt.parents[2] 162 | else: 163 | raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. " 164 | f"It seems that the checkpoint path is not standard. Please explicitly provide the " 165 | f"save path by --save-path.") 166 | return exp_dir, ckpt.parent.name 167 | 168 | @staticmethod 169 | def parse_size(size): 170 | if isinstance(size, int): 171 | size = [size] 172 | if not isinstance(size, (list, tuple)): 173 | raise ValueError(f"Size must be an integer or (height, width), got {size}.") 174 | if len(size) == 1: 175 | size = [size[0], size[0]] 176 | if len(size) != 2: 177 | raise ValueError(f"Size must be an integer or (height, width), got {size}.") 178 | return size 179 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG 2 | 3 | def load_model(args, in_channels, out_channels, factor_kwargs): 4 | model = HYVideoDiffusionTransformer( 5 | args, 6 | in_channels=in_channels, 7 | out_channels=out_channels, 8 | **HUNYUAN_VIDEO_CONFIG[args.model], 9 | **factor_kwargs, 10 | ) 11 | return model 12 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_activation_layer(act_type): 5 | """get activation layer 6 | 7 | Args: 8 | act_type (str): the activation type 9 | 10 | Returns: 11 | torch.nn.functional: the activation layer 12 | """ 13 | if act_type == "gelu": 14 | return lambda: nn.GELU() 15 | elif act_type == "gelu_tanh": 16 | # Approximate `tanh` requires torch >= 1.13 17 | return lambda: nn.GELU(approximate="tanh") 18 | elif act_type == "relu": 19 | return nn.ReLU 20 | elif act_type == "silu": 21 | return nn.SiLU 22 | else: 23 | raise ValueError(f"Unknown activation type: {act_type}") -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from hymm_sp.helpers import to_2tuple 5 | 6 | 7 | class PatchEmbed(nn.Module): 8 | """ 2D Image to Patch Embedding 9 | 10 | Image to Patch Embedding using Conv2d 11 | 12 | A convolution based approach to patchifying a 2D image w/ embedding projection. 13 | 14 | Based on the impl in https://github.com/google-research/vision_transformer 15 | 16 | Hacked together by / Copyright 2020 Ross Wightman 17 | 18 | Remove the _assert function in forward function to be compatible with multi-resolution images. 19 | """ 20 | def __init__( 21 | self, 22 | patch_size=16, 23 | in_chans=3, 24 | embed_dim=768, 25 | norm_layer=None, 26 | flatten=True, 27 | bias=True, 28 | dtype=None, 29 | device=None 30 | ): 31 | factory_kwargs = {'dtype': dtype, 'device': device} 32 | super().__init__() 33 | patch_size = to_2tuple(patch_size) 34 | self.patch_size = patch_size 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, 38 | **factory_kwargs) 39 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 40 | if bias: 41 | nn.init.zeros_(self.proj.bias) 42 | 43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 44 | 45 | def forward(self, x): 46 | x = self.proj(x) 47 | if self.flatten: 48 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 49 | x = self.norm(x) 50 | return x 51 | 52 | 53 | class TextProjection(nn.Module): 54 | """ 55 | Projects text embeddings. Also handles dropout for classifier-free guidance. 56 | 57 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 58 | """ 59 | 60 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 61 | factory_kwargs = {'dtype': dtype, 'device': device} 62 | super().__init__() 63 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 64 | self.act_1 = act_layer() 65 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 66 | 67 | def forward(self, caption): 68 | hidden_states = self.linear_1(caption) 69 | hidden_states = self.act_1(hidden_states) 70 | hidden_states = self.linear_2(hidden_states) 71 | return hidden_states 72 | 73 | 74 | def timestep_embedding(t, dim, max_period=10000): 75 | """ 76 | Create sinusoidal timestep embeddings. 77 | 78 | Args: 79 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 80 | dim (int): the dimension of the output. 81 | max_period (int): controls the minimum frequency of the embeddings. 82 | 83 | Returns: 84 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 85 | 86 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 87 | """ 88 | half = dim // 2 89 | freqs = torch.exp( 90 | -math.log(max_period) 91 | * torch.arange(start=0, end=half, dtype=torch.float32) 92 | / half 93 | ).to(device=t.device) 94 | args = t[:, None].float() * freqs[None] 95 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 96 | if dim % 2: 97 | embedding = torch.cat( 98 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 99 | ) 100 | return embedding 101 | 102 | 103 | class TimestepEmbedder(nn.Module): 104 | """ 105 | Embeds scalar timesteps into vector representations. 106 | """ 107 | def __init__(self, 108 | hidden_size, 109 | act_layer, 110 | frequency_embedding_size=256, 111 | max_period=10000, 112 | out_size=None, 113 | dtype=None, 114 | device=None 115 | ): 116 | factory_kwargs = {'dtype': dtype, 'device': device} 117 | super().__init__() 118 | self.frequency_embedding_size = frequency_embedding_size 119 | self.max_period = max_period 120 | if out_size is None: 121 | out_size = hidden_size 122 | 123 | self.mlp = nn.Sequential( 124 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 125 | act_layer(), 126 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 127 | ) 128 | nn.init.normal_(self.mlp[0].weight, std=0.02) 129 | nn.init.normal_(self.mlp[2].weight, std=0.02) 130 | 131 | def forward(self, t): 132 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 133 | t_emb = self.mlp(t_freq) 134 | return t_emb -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/fp8_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): 8 | _bits = torch.tensor(bits) 9 | _mantissa_bit = torch.tensor(mantissa_bit) 10 | _sign_bits = torch.tensor(sign_bits) 11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) 12 | E = _bits - _sign_bits - M 13 | bias = 2 ** (E - 1) - 1 14 | mantissa = 1 15 | for i in range(mantissa_bit - 1): 16 | mantissa += 1 / (2 ** (i+1)) 17 | maxval = mantissa * 2 ** (2**E - 1 - bias) 18 | return maxval 19 | 20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): 21 | """ 22 | Default is E4M3. 23 | """ 24 | bits = torch.tensor(bits) 25 | mantissa_bit = torch.tensor(mantissa_bit) 26 | sign_bits = torch.tensor(sign_bits) 27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) 28 | E = bits - sign_bits - M 29 | bias = 2 ** (E - 1) - 1 30 | mantissa = 1 31 | for i in range(mantissa_bit - 1): 32 | mantissa += 1 / (2 ** (i+1)) 33 | maxval = mantissa * 2 ** (2**E - 1 - bias) 34 | minval = - maxval 35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) 36 | input_clamp = torch.min(torch.max(x, minval), maxval) 37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) 38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) 39 | # dequant 40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales 41 | return qdq_out, log_scales 42 | 43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): 44 | for i in range(len(x.shape) - 1): 45 | scale = scale.unsqueeze(-1) 46 | new_x = x / scale 47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) 48 | return quant_dequant_x, scale, log_scales 49 | 50 | def fp8_activation_dequant(qdq_out, scale, dtype): 51 | qdq_out = qdq_out.type(dtype) 52 | quant_dequant_x = qdq_out * scale.to(dtype) 53 | return quant_dequant_x 54 | 55 | def fp8_linear_forward(cls, original_dtype, input): 56 | weight_dtype = cls.weight.dtype 57 | ##### 58 | if cls.weight.dtype != torch.float8_e4m3fn: 59 | maxval = get_fp_maxval() 60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval 61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) 62 | linear_weight = linear_weight.to(torch.float8_e4m3fn) 63 | weight_dtype = linear_weight.dtype 64 | else: 65 | scale = cls.fp8_scale.to(cls.weight.device) 66 | linear_weight = cls.weight 67 | ##### 68 | 69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: 70 | if True or len(input.shape) == 3: 71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype) 72 | if cls.bias != None: 73 | output = F.linear(input, cls_dequant, cls.bias) 74 | else: 75 | output = F.linear(input, cls_dequant) 76 | return output 77 | else: 78 | return cls.original_forward(input.to(original_dtype)) 79 | else: 80 | return cls.original_forward(input) 81 | 82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}): 83 | setattr(module, "fp8_matmul_enabled", True) 84 | 85 | # loading fp8 mapping file 86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt') 87 | if os.path.exists(fp8_map_path): 88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)['module'] 89 | else: 90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") 91 | 92 | fp8_layers = [] 93 | for key, layer in module.named_modules(): 94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): 95 | fp8_layers.append(key) 96 | original_forward = layer.forward 97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) 98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype)) 99 | setattr(layer, "original_forward", original_forward) 100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .modulate_layers import modulate 10 | from hymm_sp.helpers import to_2tuple 11 | 12 | 13 | class MLP(nn.Module): 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 15 | """ 16 | def __init__(self, 17 | in_channels, 18 | hidden_channels=None, 19 | out_features=None, 20 | act_layer=nn.GELU, 21 | norm_layer=None, 22 | bias=True, 23 | drop=0., 24 | use_conv=False, 25 | device=None, 26 | dtype=None 27 | ): 28 | factory_kwargs = {'device': device, 'dtype': dtype} 29 | super().__init__() 30 | out_features = out_features or in_channels 31 | hidden_channels = hidden_channels or in_channels 32 | bias = to_2tuple(bias) 33 | drop_probs = to_2tuple(drop) 34 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 35 | 36 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 37 | self.act = act_layer() 38 | self.drop1 = nn.Dropout(drop_probs[0]) 39 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 40 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 41 | self.drop2 = nn.Dropout(drop_probs[1]) 42 | 43 | def forward(self, x): 44 | x = self.fc1(x) 45 | x = self.act(x) 46 | x = self.drop1(x) 47 | x = self.norm(x) 48 | x = self.fc2(x) 49 | x = self.drop2(x) 50 | return x 51 | 52 | 53 | class MLPEmbedder(nn.Module): 54 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" 55 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 56 | factory_kwargs = {'device': device, 'dtype': dtype} 57 | super().__init__() 58 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 59 | self.silu = nn.SiLU() 60 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 61 | 62 | def forward(self, x: torch.Tensor) -> torch.Tensor: 63 | return self.out_layer(self.silu(self.in_layer(x))) 64 | 65 | 66 | class FinalLayer(nn.Module): 67 | """The final layer of DiT.""" 68 | 69 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 70 | factory_kwargs = {'device': device, 'dtype': dtype} 71 | super().__init__() 72 | 73 | # Just use LayerNorm for the final layer 74 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 75 | if isinstance(patch_size, int): 76 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs) 77 | else: 78 | self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=True) 79 | nn.init.zeros_(self.linear.weight) 80 | nn.init.zeros_(self.linear.bias) 81 | 82 | # Here we don't distinguish between the modulate types. Just use the simple one. 83 | self.adaLN_modulation = nn.Sequential( 84 | act_layer(), 85 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 86 | ) 87 | # Zero-initialize the modulation 88 | nn.init.zeros_(self.adaLN_modulation[1].weight) 89 | nn.init.zeros_(self.adaLN_modulation[1].bias) 90 | 91 | def forward(self, x, c): 92 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 93 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 94 | x = self.linear(x) 95 | return x 96 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ModulateDiT(nn.Module): 8 | """Modulation layer for DiT.""" 9 | def __init__( 10 | self, 11 | hidden_size: int, 12 | factor: int, 13 | act_layer: Callable, 14 | dtype=None, 15 | device=None, 16 | ): 17 | factory_kwargs = {"dtype": dtype, "device": device} 18 | super().__init__() 19 | self.act = act_layer() 20 | self.linear = nn.Linear( 21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs 22 | ) 23 | # Zero-initialize the modulation 24 | nn.init.zeros_(self.linear.weight) 25 | nn.init.zeros_(self.linear.bias) 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | return self.linear(self.act(x)) 29 | 30 | 31 | def modulate(x, shift=None, scale=None): 32 | """modulate by shift and scale 33 | 34 | Args: 35 | x (torch.Tensor): input tensor. 36 | shift (torch.Tensor, optional): shift tensor. Defaults to None. 37 | scale (torch.Tensor, optional): scale tensor. Defaults to None. 38 | 39 | Returns: 40 | torch.Tensor: the output tensor after modulate. 41 | """ 42 | if scale is None and shift is None: 43 | return x 44 | elif shift is None: 45 | return x * (1 + scale.unsqueeze(1)) 46 | elif scale is None: 47 | return x + shift.unsqueeze(1) 48 | else: 49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 50 | 51 | 52 | def apply_gate(x, gate=None, tanh=False): 53 | """AI is creating summary for apply_gate 54 | 55 | Args: 56 | x (torch.Tensor): input tensor. 57 | gate (torch.Tensor, optional): gate tensor. Defaults to None. 58 | tanh (bool, optional): whether to use tanh function. Defaults to False. 59 | 60 | Returns: 61 | torch.Tensor: the output tensor after apply gate. 62 | """ 63 | if gate is None: 64 | return x 65 | if tanh: 66 | return x * gate.unsqueeze(1).tanh() 67 | else: 68 | return x * gate.unsqueeze(1) 69 | 70 | 71 | def ckpt_wrapper(module): 72 | def ckpt_forward(*inputs): 73 | outputs = module(*inputs) 74 | return outputs 75 | 76 | return ckpt_forward -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__( 7 | self, 8 | dim: int, 9 | elementwise_affine=True, 10 | eps: float = 1e-6, 11 | device=None, 12 | dtype=None, 13 | ): 14 | """ 15 | Initialize the RMSNorm normalization layer. 16 | 17 | Args: 18 | dim (int): The dimension of the input tensor. 19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 20 | 21 | Attributes: 22 | eps (float): A small value added to the denominator for numerical stability. 23 | weight (nn.Parameter): Learnable scaling parameter. 24 | 25 | """ 26 | factory_kwargs = {"device": device, "dtype": dtype} 27 | super().__init__() 28 | self.eps = eps 29 | if elementwise_affine: 30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 31 | 32 | def _norm(self, x): 33 | """ 34 | Apply the RMSNorm normalization to the input tensor. 35 | 36 | Args: 37 | x (torch.Tensor): The input tensor. 38 | 39 | Returns: 40 | torch.Tensor: The normalized tensor. 41 | 42 | """ 43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 44 | 45 | def forward(self, x): 46 | """ 47 | Forward pass through the RMSNorm layer. 48 | 49 | Args: 50 | x (torch.Tensor): The input tensor. 51 | 52 | Returns: 53 | torch.Tensor: The output tensor after applying RMSNorm. 54 | 55 | """ 56 | output = self._norm(x.float()).type_as(x) 57 | if hasattr(self, "weight"): 58 | output = output * self.weight 59 | return output 60 | 61 | 62 | def get_norm_layer(norm_layer): 63 | """ 64 | Get the normalization layer. 65 | 66 | Args: 67 | norm_layer (str): The type of normalization layer. 68 | 69 | Returns: 70 | norm_layer (nn.Module): The normalization layer. 71 | """ 72 | if norm_layer == "layer": 73 | return nn.LayerNorm 74 | elif norm_layer == "rms": 75 | return RMSNorm 76 | else: 77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple, List 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000., use_real=False, 65 | theta_rescale_factor: Union[float, List[float]]=1.0, 66 | interpolation_factor: Union[float, List[float]]=1.0): 67 | """ 68 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 69 | 70 | Args: 71 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 72 | sum(rope_dim_list) should equal to head_dim of attention layer. 73 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 74 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 75 | *args: See above. 76 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 77 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 78 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 79 | part and an imaginary part separately. 80 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 81 | 82 | Returns: 83 | pos_embed (torch.Tensor): [HW, D/2] 84 | """ 85 | 86 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 87 | 88 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 89 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 90 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 91 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 92 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 93 | 94 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 95 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 96 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 97 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 98 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" 99 | 100 | # use 1/ndim of dimensions to encode grid_axis 101 | embs = [] 102 | for i in range(len(rope_dim_list)): 103 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, 104 | theta_rescale_factor=theta_rescale_factor[i], 105 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] 106 | embs.append(emb) 107 | 108 | if use_real: 109 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 110 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 111 | return cos, sin 112 | else: 113 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 114 | return emb 115 | 116 | 117 | def get_1d_rotary_pos_embed(dim: int, 118 | pos: Union[torch.FloatTensor, int], 119 | theta: float = 10000.0, 120 | use_real: bool = False, 121 | theta_rescale_factor: float = 1.0, 122 | interpolation_factor: float = 1.0, 123 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 124 | """ 125 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 126 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 127 | 128 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 129 | and the end index 'end'. The 'theta' parameter scales the frequencies. 130 | The returned tensor contains complex values in complex64 data type. 131 | 132 | Args: 133 | dim (int): Dimension of the frequency tensor. 134 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 135 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 136 | use_real (bool, optional): If True, return real part and imaginary part separately. 137 | Otherwise, return complex numbers. 138 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 139 | 140 | Returns: 141 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 142 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 143 | """ 144 | if isinstance(pos, int): 145 | pos = torch.arange(pos).float() 146 | 147 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 148 | # has some connection to NTK literature 149 | if theta_rescale_factor != 1.0: 150 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 151 | 152 | freqs = 1.0 / ( 153 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 154 | ) # [D/2] 155 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] 156 | if use_real: 157 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 158 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 159 | return freqs_cos, freqs_sin 160 | else: 161 | freqs_cis = torch.polar( 162 | torch.ones_like(freqs), freqs 163 | ) # complex64 # [S, D/2] 164 | return freqs_cis 165 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/modules/token_refiner.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from einops import rearrange 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .activation_layers import get_activation_layer 8 | from .attn_layers import attention 9 | from .norm_layers import get_norm_layer 10 | from .embed_layers import TimestepEmbedder, TextProjection 11 | from .attn_layers import attention 12 | from .mlp_layers import MLP 13 | from .modulate_layers import apply_gate 14 | 15 | 16 | class IndividualTokenRefinerBlock(nn.Module): 17 | def __init__( 18 | self, 19 | hidden_size, 20 | num_heads, 21 | mlp_ratio: str = 4.0, 22 | mlp_drop_rate: float = 0.0, 23 | act_type: str = "silu", 24 | qk_norm: bool = False, 25 | qk_norm_type: str = "layer", 26 | qkv_bias: bool = True, 27 | dtype: Optional[torch.dtype] = None, 28 | device: Optional[torch.device] = None, 29 | ): 30 | factory_kwargs = {'device': device, 'dtype': dtype} 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = hidden_size // num_heads 34 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 35 | 36 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) 37 | self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) 38 | qk_norm_layer = get_norm_layer(qk_norm_type) 39 | self.self_attn_q_norm = ( 40 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 41 | if qk_norm 42 | else nn.Identity() 43 | ) 44 | self.self_attn_k_norm = ( 45 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 46 | if qk_norm 47 | else nn.Identity() 48 | ) 49 | self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) 50 | 51 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) 52 | act_layer = get_activation_layer(act_type) 53 | self.mlp = MLP( 54 | in_channels=hidden_size, 55 | hidden_channels=mlp_hidden_dim, 56 | act_layer=act_layer, 57 | drop=mlp_drop_rate, 58 | **factory_kwargs, 59 | ) 60 | 61 | self.adaLN_modulation = nn.Sequential( 62 | act_layer(), 63 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 64 | ) 65 | # Zero-initialize the modulation 66 | nn.init.zeros_(self.adaLN_modulation[1].weight) 67 | nn.init.zeros_(self.adaLN_modulation[1].bias) 68 | 69 | def forward( 70 | self, 71 | x: torch.Tensor, 72 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 73 | attn_mask: torch.Tensor = None, 74 | ): 75 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 76 | 77 | norm_x = self.norm1(x) 78 | qkv = self.self_attn_qkv(norm_x) 79 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) 80 | # Apply QK-Norm if needed 81 | q = self.self_attn_q_norm(q).to(v) 82 | k = self.self_attn_k_norm(k).to(v) 83 | 84 | # Self-Attention 85 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) 86 | 87 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 88 | 89 | # FFN Layer 90 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) 91 | 92 | return x 93 | 94 | 95 | class IndividualTokenRefiner(nn.Module): 96 | def __init__( 97 | self, 98 | hidden_size, 99 | num_heads, 100 | depth, 101 | mlp_ratio: float = 4.0, 102 | mlp_drop_rate: float = 0.0, 103 | act_type: str = "silu", 104 | qk_norm: bool = False, 105 | qk_norm_type: str = "layer", 106 | qkv_bias: bool = True, 107 | dtype: Optional[torch.dtype] = None, 108 | device: Optional[torch.device] = None, 109 | ): 110 | factory_kwargs = {'device': device, 'dtype': dtype} 111 | super().__init__() 112 | self.blocks = nn.ModuleList([ 113 | IndividualTokenRefinerBlock( 114 | hidden_size=hidden_size, 115 | num_heads=num_heads, 116 | mlp_ratio=mlp_ratio, 117 | mlp_drop_rate=mlp_drop_rate, 118 | act_type=act_type, 119 | qk_norm=qk_norm, 120 | qk_norm_type=qk_norm_type, 121 | qkv_bias=qkv_bias, 122 | **factory_kwargs, 123 | ) for _ in range(depth) 124 | ]) 125 | 126 | def forward( 127 | self, 128 | x: torch.Tensor, 129 | c: torch.LongTensor, 130 | mask: Optional[torch.Tensor] = None, 131 | ): 132 | self_attn_mask = None 133 | if mask is not None: 134 | batch_size = mask.shape[0] 135 | seq_len = mask.shape[1] 136 | mask = mask.to(x.device) 137 | # batch_size x 1 x seq_len x seq_len 138 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) 139 | # batch_size x 1 x seq_len x seq_len 140 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 141 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads 142 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 143 | # avoids self-attention weight being NaN for padding tokens 144 | self_attn_mask[:, :, :, 0] = True 145 | 146 | for block in self.blocks: 147 | x = block(x, c, self_attn_mask) 148 | return x 149 | 150 | 151 | class SingleTokenRefiner(nn.Module): 152 | def __init__( 153 | self, 154 | in_channels, 155 | hidden_size, 156 | num_heads, 157 | depth, 158 | mlp_ratio: float = 4.0, 159 | mlp_drop_rate: float = 0.0, 160 | act_type: str = "silu", 161 | qk_norm: bool = False, 162 | qk_norm_type: str = "layer", 163 | qkv_bias: bool = True, 164 | dtype: Optional[torch.dtype] = None, 165 | device: Optional[torch.device] = None, 166 | ): 167 | factory_kwargs = {'device': device, 'dtype': dtype} 168 | super().__init__() 169 | 170 | self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs) 171 | 172 | act_layer = get_activation_layer(act_type) 173 | # Build timestep embedding layer 174 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) 175 | # Build context embedding layer 176 | self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs) 177 | 178 | self.individual_token_refiner = IndividualTokenRefiner( 179 | hidden_size=hidden_size, 180 | num_heads=num_heads, 181 | depth=depth, 182 | mlp_ratio=mlp_ratio, 183 | mlp_drop_rate=mlp_drop_rate, 184 | act_type=act_type, 185 | qk_norm=qk_norm, 186 | qk_norm_type=qk_norm_type, 187 | qkv_bias=qkv_bias, 188 | **factory_kwargs 189 | ) 190 | 191 | def forward( 192 | self, 193 | x: torch.Tensor, 194 | t: torch.LongTensor, 195 | mask: Optional[torch.LongTensor] = None, 196 | ): 197 | timestep_aware_representations = self.t_embedder(t) 198 | 199 | if mask is None: 200 | context_aware_representations = x.mean(dim=1) 201 | else: 202 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] 203 | context_aware_representations = ( 204 | (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) 205 | ) 206 | context_aware_representations = self.c_embedder(context_aware_representations) 207 | c = timestep_aware_representations + context_aware_representations 208 | 209 | x = self.input_embedder(x) 210 | 211 | x = self.individual_token_refiner(x, c, mask) 212 | 213 | return x -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/sample_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from loguru import logger 4 | import torch 5 | from einops import rearrange 6 | import torch.distributed 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch.utils.data import DataLoader 9 | from hymm_sp.config import parse_args 10 | from hymm_sp.sample_inference import HunyuanVideoSampler 11 | from hymm_sp.data_kits.video_dataset import JsonDataset 12 | from hymm_sp.data_kits.data_tools import save_videos_grid 13 | from hymm_sp.modules.parallel_states import ( 14 | initialize_distributed, 15 | nccl_info, 16 | ) 17 | 18 | def main(): 19 | args = parse_args() 20 | models_root_path = Path(args.ckpt) 21 | print("*"*20) 22 | initialize_distributed(args.seed) 23 | if not models_root_path.exists(): 24 | raise ValueError(f"`models_root` not exists: {models_root_path}") 25 | print("+"*20) 26 | # Create save folder to save the samples 27 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' 28 | if not os.path.exists(args.save_path): 29 | os.makedirs(save_path, exist_ok=True) 30 | 31 | # Load models 32 | rank = 0 33 | vae_dtype = torch.float16 34 | device = torch.device("cuda") 35 | if nccl_info.sp_size > 1: 36 | device = torch.device(f"cuda:{torch.distributed.get_rank()}") 37 | rank = torch.distributed.get_rank() 38 | 39 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device) 40 | # Get the updated args 41 | args = hunyuan_video_sampler.args 42 | 43 | json_dataset = JsonDataset(args) 44 | sampler = DistributedSampler(json_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False) 45 | json_loader = DataLoader(json_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False) 46 | for batch_index, batch in enumerate(json_loader, start=1): 47 | pixel_value_llava = batch['pixel_value_llava'].to(device) 48 | pixel_value_ref = batch['pixel_value_ref'].to(device) 49 | uncond_pixel_value_llava = batch['uncond_pixel_value_llava'] 50 | prompt = batch['prompt'][0] 51 | negative_prompt = batch['negative_prompt'][0] 52 | name = batch['name'][0] 53 | save_name = batch['data_name'][0] 54 | seed = batch['seed'] 55 | pixel_value_ref = pixel_value_ref * 2 - 1. 56 | pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w") 57 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): 58 | ref_latents = hunyuan_video_sampler.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample() 59 | uncond_ref_latents = hunyuan_video_sampler.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample() 60 | ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) 61 | uncond_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) 62 | 63 | prompt = args.add_pos_prompt + prompt 64 | negative_prompt = args.add_neg_prompt + negative_prompt 65 | outputs = hunyuan_video_sampler.predict( 66 | prompt=prompt, 67 | name=name, 68 | size=args.video_size, 69 | seed=seed, 70 | pixel_value_llava=pixel_value_llava, 71 | uncond_pixel_value_llava=uncond_pixel_value_llava, 72 | ref_latents=ref_latents, 73 | uncond_ref_latents=uncond_ref_latents, 74 | video_length=args.sample_n_frames, 75 | guidance_scale=args.cfg_scale, 76 | num_images_per_prompt=args.num_images, 77 | negative_prompt=negative_prompt, 78 | infer_steps=args.infer_steps, 79 | flow_shift=args.flow_shift_eval_video, 80 | use_linear_quadratic_schedule=args.use_linear_quadratic_schedule, 81 | linear_schedule_end=args.linear_schedule_end, 82 | use_deepcache=args.use_deepcache, 83 | ) 84 | 85 | if rank == 0: 86 | samples = outputs['samples'] 87 | for i, sample in enumerate(samples): 88 | sample = samples[i].unsqueeze(0) 89 | out_path = f"{save_path}/{save_name}.mp4" 90 | save_videos_grid(sample, out_path, fps=25) 91 | logger.info(f'Sample save to: {out_path}') 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/sample_gpu_poor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from loguru import logger 4 | import torch 5 | from einops import rearrange 6 | import torch.distributed 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torch.utils.data import DataLoader 9 | from hymm_sp.config import parse_args 10 | from hymm_sp.sample_inference import HunyuanVideoSampler 11 | from hymm_sp.data_kits.video_dataset import JsonDataset 12 | from hymm_sp.data_kits.data_tools import save_videos_grid 13 | 14 | 15 | def main(): 16 | args = parse_args() 17 | models_root_path = Path(args.ckpt) 18 | 19 | if not models_root_path.exists(): 20 | raise ValueError(f"`models_root` not exists: {models_root_path}") 21 | 22 | # Create save folder to save the samples 23 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' 24 | if not os.path.exists(args.save_path): 25 | os.makedirs(save_path, exist_ok=True) 26 | 27 | # Load models 28 | rank = 0 29 | vae_dtype = torch.float16 30 | device = torch.device("cuda") 31 | 32 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device) 33 | # Get the updated args 34 | args = hunyuan_video_sampler.args 35 | if args.cpu_offload: 36 | from diffusers.hooks import apply_group_offloading 37 | onload_device = torch.device("cuda") 38 | apply_group_offloading(hunyuan_video_sampler.pipeline.transformer, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=1) 39 | 40 | json_dataset = JsonDataset(args) 41 | sampler = DistributedSampler(json_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False) 42 | json_loader = DataLoader(json_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False) 43 | for batch_index, batch in enumerate(json_loader, start=1): 44 | pixel_value_llava = batch['pixel_value_llava'].to(device) 45 | pixel_value_ref = batch['pixel_value_ref'].to(device) 46 | uncond_pixel_value_llava = batch['uncond_pixel_value_llava'] 47 | prompt = batch['prompt'][0] 48 | negative_prompt = batch['negative_prompt'][0] 49 | name = batch['name'][0] 50 | save_name = batch['data_name'][0] 51 | seed = batch['seed'] 52 | pixel_value_ref = pixel_value_ref * 2 - 1. 53 | pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w") 54 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): 55 | if args.cpu_offload: 56 | hunyuan_video_sampler.vae.to('cuda') 57 | ref_latents = hunyuan_video_sampler.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample() 58 | uncond_ref_latents = hunyuan_video_sampler.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample() 59 | ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) 60 | uncond_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) 61 | if args.cpu_offload: 62 | hunyuan_video_sampler.vae.to('cpu') 63 | torch.cuda.empty_cache() 64 | 65 | prompt = args.add_pos_prompt + prompt 66 | negative_prompt = args.add_neg_prompt + negative_prompt 67 | outputs = hunyuan_video_sampler.predict( 68 | prompt=prompt, 69 | name=name, 70 | size=args.video_size, 71 | seed=seed, 72 | pixel_value_llava=pixel_value_llava, 73 | uncond_pixel_value_llava=uncond_pixel_value_llava, 74 | ref_latents=ref_latents, 75 | uncond_ref_latents=uncond_ref_latents, 76 | video_length=args.sample_n_frames, 77 | guidance_scale=args.cfg_scale, 78 | num_images_per_prompt=args.num_images, 79 | negative_prompt=negative_prompt, 80 | infer_steps=args.infer_steps, 81 | flow_shift=args.flow_shift_eval_video, 82 | use_linear_quadratic_schedule=args.use_linear_quadratic_schedule, 83 | linear_schedule_end=args.linear_schedule_end, 84 | use_deepcache=args.use_deepcache, 85 | cpu_offload=args.cpu_offload, 86 | ) 87 | 88 | if rank == 0: 89 | samples = outputs['samples'] 90 | for i, sample in enumerate(samples): 91 | sample = samples[i].unsqueeze(0) 92 | out_path = f"{save_path}/{save_name}_4090.mp4" 93 | save_videos_grid(sample, out_path, fps=25) 94 | logger.info(f'Sample save to: {out_path}') 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /hunyuan_custom/hymm_sp/vae/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D 4 | from ..constants import VAE_PATH, PRECISION_TO_TYPE 5 | 6 | def load_vae(vae_type, 7 | vae_precision=None, 8 | sample_size=None, 9 | vae_path=None, 10 | logger=None, 11 | device=None 12 | ): 13 | if vae_path is None: 14 | vae_path = VAE_PATH[vae_type] 15 | vae_compress_spec, _, _ = vae_type.split("-") 16 | length = len(vae_compress_spec) 17 | if length == 3: 18 | if logger is not None: 19 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") 20 | config = AutoencoderKLCausal3D.load_config(vae_path) 21 | if sample_size: 22 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) 23 | else: 24 | vae = AutoencoderKLCausal3D.from_config(config) 25 | ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device) 26 | if "state_dict" in ckpt: 27 | ckpt = ckpt["state_dict"] 28 | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} 29 | vae.load_state_dict(vae_ckpt) 30 | 31 | spatial_compression_ratio = vae.config.spatial_compression_ratio 32 | time_compression_ratio = vae.config.time_compression_ratio 33 | else: 34 | raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.") 35 | 36 | if vae_precision is not None: 37 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) 38 | 39 | vae.requires_grad_(False) 40 | 41 | if logger is not None: 42 | logger.info(f"VAE to dtype: {vae.dtype}") 43 | 44 | if device is not None: 45 | vae = vae.to(device) 46 | 47 | # Set vae to eval mode, even though it's dropout rate is 0. 48 | vae.eval() 49 | 50 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio 51 | -------------------------------------------------------------------------------- /hunyuan_custom/models/README.md: -------------------------------------------------------------------------------- 1 | # Download Pretrained Models 2 | 3 | All models are stored in `HunyuanCustom/models` by default, and the file structure is as follows 4 | ```shell 5 | HunyuanCustom 6 | ├──models 7 | │ ├──README.md 8 | │ ├──hunyuancustom_720P 9 | │ │ ├──mp_rank_00_model_states.pt 10 | │ │ │──mp_rank_00_model_states_fp8.pt 11 | │ │ ├──mp_rank_00_model_states_fp8_map.pt 12 | ├ ├──vae_3d 13 | │ ├──openai_clip-vit-large-patch14 14 | │ ├──llava-llama-3-8b-v1_1 15 | ├──... 16 | ``` 17 | 18 | ## Download HunyuanCustom model 19 | To download the HunyuanCustom model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).) 20 | 21 | ```shell 22 | python -m pip install "huggingface_hub[cli]" 23 | ``` 24 | 25 | Then download the model using the following commands: 26 | 27 | ```shell 28 | # Switch to the directory named 'HunyuanCustom' 29 | cd HunyuanCustom 30 | # Use the huggingface-cli tool to download HunyuanCustom model in HunyuanCustom/models dir. 31 | # The download time may vary from 10 minutes to 1 hour depending on network conditions. 32 | huggingface-cli download tencent/HunyuanCustom --local-dir ./ 33 | ``` 34 | -------------------------------------------------------------------------------- /hunyuan_custom/run-single-video-8xA100.sh: -------------------------------------------------------------------------------- 1 | 2 | export MODEL_BASE="./models" 3 | export PYTHONPATH=./ 4 | 5 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \ 6 | --input './assets/images/seg_woman_01.png' \ 7 | --pos-prompt "Realistic, High-quality. A woman is drinking coffee at a café." \ 8 | --neg-prompt "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \ 9 | --ckpt ${MODEL_BASE}"/hunyuancustom_720P/mp_rank_00_model_states.pt" \ 10 | --video-size 768 1280 \ 11 | --seed 1024 \ 12 | --sample-n-frames 129 \ 13 | --infer-steps 30 \ 14 | --flow-shift-eval-video 13.0 \ 15 | --save-path './results/sp_768p' 16 | 17 | -------------------------------------------------------------------------------- /wan/run-single-inference.sh: -------------------------------------------------------------------------------- 1 | 2 | ckpt_dir="/mnt/localssd/wan/Wan2.1-T2V-14B" 3 | 4 | 5 | export CUDA_VISIBLE_DEVICES=0 6 | 7 | # --size 1280*768 8 | # --size 768*512 9 | 10 | python3 -u generate.py \ 11 | --task t2v-14B \ 12 | --size 768*512 \ 13 | --ckpt_dir $ckpt_dir \ 14 | --prompt "A giant panda is walking." 15 | 16 | # --prompt "warm colors dominate the room, with a focus on the tabby cat sitting contently in the center. the scene captures the fluffy orange tabby cat wearing a tiny virtual reality headset. the setting is a cozy living room, adorned with soft, warm lighting and a modern aesthetic. a plush sofa is visible in the background, along with a few lush potted plants, adding a touch of greenery. the cat's tail flicks curiously, as if engaging with an unseen virtual environment. its paws swipe at the air, indicating a playful and inquisitive nature, as it delves into the digital realm. the atmosphere is both whimsical and futuristic, highlighting the blend of analog and digital experiences." 17 | 18 | # --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." 19 | 20 | -------------------------------------------------------------------------------- /wan/wan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs, distributed, modules 2 | from .image2video import WanI2V 3 | from .text2video import WanT2V 4 | from .first_last_frame2video import WanFLF2V 5 | -------------------------------------------------------------------------------- /wan/wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_i2v_14B import i2v_14B 8 | from .wan_t2v_1_3B import t2v_1_3B 9 | from .wan_t2v_14B import t2v_14B 10 | 11 | # the config of t2i_14B is the same as t2v_14B 12 | t2i_14B = copy.deepcopy(t2v_14B) 13 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 14 | 15 | # the config of flf2v_14B is the same as i2v_14B 16 | flf2v_14B = copy.deepcopy(i2v_14B) 17 | flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' 18 | flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt 19 | 20 | WAN_CONFIGS = { 21 | 't2v-14B': t2v_14B, 22 | 't2v-1.3B': t2v_1_3B, 23 | 'i2v-14B': i2v_14B, 24 | 't2i-14B': t2i_14B, 25 | 'flf2v-14B': flf2v_14B 26 | } 27 | 28 | SIZE_CONFIGS = { 29 | '720*1280': (720, 1280), 30 | '1280*720': (1280, 720), 31 | '480*832': (480, 832), 32 | '832*480': (832, 480), 33 | '1024*1024': (1024, 1024), 34 | '768*512': (768, 512), # xuan: add 35 | '1280*768': (1280, 768), # xuan: add 36 | } 37 | 38 | MAX_AREA_CONFIGS = { 39 | '720*1280': 720 * 1280, 40 | '1280*720': 1280 * 720, 41 | '480*832': 480 * 832, 42 | '832*480': 832 * 480, 43 | } 44 | 45 | SUPPORTED_SIZES = { 46 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 47 | 't2v-1.3B': ('480*832', '832*480'), 48 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 49 | 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 50 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 51 | } 52 | -------------------------------------------------------------------------------- /wan/wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 20 | -------------------------------------------------------------------------------- /wan/wan/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt 12 | 13 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 14 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 15 | 16 | # clip 17 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 18 | i2v_14B.clip_dtype = torch.float16 19 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 20 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 21 | 22 | # vae 23 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 24 | i2v_14B.vae_stride = (4, 8, 8) 25 | 26 | # transformer 27 | i2v_14B.patch_size = (1, 2, 2) 28 | i2v_14B.dim = 5120 29 | i2v_14B.ffn_dim = 13824 30 | i2v_14B.freq_dim = 256 31 | i2v_14B.num_heads = 40 32 | i2v_14B.num_layers = 40 33 | i2v_14B.window_size = (-1, -1) 34 | i2v_14B.qk_norm = True 35 | i2v_14B.cross_attn_norm = True 36 | i2v_14B.eps = 1e-6 37 | -------------------------------------------------------------------------------- /wan/wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/wan/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shawnricecake/draft-attention/f3c81d58304e144305cf06b0fa801e82088f89a0/wan/wan/distributed/__init__.py -------------------------------------------------------------------------------- /wan/wan/distributed/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | from functools import partial 4 | 5 | import torch 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 9 | from torch.distributed.utils import _free_storage 10 | 11 | def shard_model( 12 | model, 13 | device_id, 14 | param_dtype=torch.bfloat16, 15 | reduce_dtype=torch.float32, 16 | buffer_dtype=torch.float32, 17 | process_group=None, 18 | sharding_strategy=ShardingStrategy.FULL_SHARD, 19 | sync_module_states=True, 20 | ): 21 | model = FSDP( 22 | module=model, 23 | process_group=process_group, 24 | sharding_strategy=sharding_strategy, 25 | auto_wrap_policy=partial( 26 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), 27 | mixed_precision=MixedPrecision( 28 | param_dtype=param_dtype, 29 | reduce_dtype=reduce_dtype, 30 | buffer_dtype=buffer_dtype), 31 | device_id=device_id, 32 | sync_module_states=sync_module_states) 33 | return model 34 | 35 | def free_model(model): 36 | for m in model.modules(): 37 | if isinstance(m, FSDP): 38 | _free_storage(m._handle.flat_param.data) 39 | del model 40 | gc.collect() 41 | torch.cuda.empty_cache() 42 | -------------------------------------------------------------------------------- /wan/wan/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.cuda.amp as amp 4 | from xfuser.core.distributed import (get_sequence_parallel_rank, 5 | get_sequence_parallel_world_size, 6 | get_sp_group) 7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 8 | 9 | from ..modules.model import sinusoidal_embedding_1d 10 | 11 | 12 | def pad_freqs(original_tensor, target_len): 13 | seq_len, s1, s2 = original_tensor.shape 14 | pad_size = target_len - seq_len 15 | padding_tensor = torch.ones( 16 | pad_size, 17 | s1, 18 | s2, 19 | dtype=original_tensor.dtype, 20 | device=original_tensor.device) 21 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 22 | return padded_tensor 23 | 24 | 25 | @amp.autocast(enabled=False) 26 | def rope_apply(x, grid_sizes, freqs): 27 | """ 28 | x: [B, L, N, C]. 29 | grid_sizes: [B, 3]. 30 | freqs: [M, C // 2]. 31 | """ 32 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 33 | # split freqs 34 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 35 | 36 | # loop over samples 37 | output = [] 38 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 39 | seq_len = f * h * w 40 | 41 | # precompute multipliers 42 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( 43 | s, n, -1, 2)) 44 | freqs_i = torch.cat([ 45 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 46 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 47 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 48 | ], 49 | dim=-1).reshape(seq_len, 1, -1) 50 | 51 | # apply rotary embedding 52 | sp_size = get_sequence_parallel_world_size() 53 | sp_rank = get_sequence_parallel_rank() 54 | freqs_i = pad_freqs(freqs_i, s * sp_size) 55 | s_per_rank = s 56 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * 57 | s_per_rank), :, :] 58 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) 59 | x_i = torch.cat([x_i, x[i, s:]]) 60 | 61 | # append to collection 62 | output.append(x_i) 63 | return torch.stack(output).float() 64 | 65 | 66 | def usp_dit_forward( 67 | self, 68 | x, 69 | t, 70 | context, 71 | seq_len, 72 | clip_fea=None, 73 | y=None, 74 | ): 75 | """ 76 | x: A list of videos each with shape [C, T, H, W]. 77 | t: [B]. 78 | context: A list of text embeddings each with shape [L, C]. 79 | """ 80 | if self.model_type == 'i2v': 81 | assert clip_fea is not None and y is not None 82 | # params 83 | device = self.patch_embedding.weight.device 84 | if self.freqs.device != device: 85 | self.freqs = self.freqs.to(device) 86 | 87 | if y is not None: 88 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 89 | 90 | # embeddings 91 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 92 | grid_sizes = torch.stack( 93 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 94 | x = [u.flatten(2).transpose(1, 2) for u in x] 95 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 96 | assert seq_lens.max() <= seq_len 97 | x = torch.cat([ 98 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 99 | for u in x 100 | ]) 101 | 102 | # time embeddings 103 | with amp.autocast(dtype=torch.float32): 104 | e = self.time_embedding( 105 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 107 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 108 | 109 | # context 110 | context_lens = None 111 | context = self.text_embedding( 112 | torch.stack([ 113 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 114 | for u in context 115 | ])) 116 | 117 | if clip_fea is not None: 118 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 119 | context = torch.concat([context_clip, context], dim=1) 120 | 121 | # arguments 122 | kwargs = dict( 123 | e=e0, 124 | seq_lens=seq_lens, 125 | grid_sizes=grid_sizes, 126 | freqs=self.freqs, 127 | context=context, 128 | context_lens=context_lens) 129 | 130 | # Context Parallel 131 | x = torch.chunk( 132 | x, get_sequence_parallel_world_size(), 133 | dim=1)[get_sequence_parallel_rank()] 134 | 135 | for block in self.blocks: 136 | x = block(x, **kwargs) 137 | 138 | # head 139 | x = self.head(x, e) 140 | 141 | # Context Parallel 142 | x = get_sp_group().all_gather(x, dim=1) 143 | 144 | # unpatchify 145 | x = self.unpatchify(x, grid_sizes) 146 | return [u.float() for u in x] 147 | 148 | 149 | def usp_attn_forward(self, 150 | x, 151 | seq_lens, 152 | grid_sizes, 153 | freqs, 154 | dtype=torch.bfloat16): 155 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 156 | half_dtypes = (torch.float16, torch.bfloat16) 157 | 158 | def half(x): 159 | return x if x.dtype in half_dtypes else x.to(dtype) 160 | 161 | # query, key, value function 162 | def qkv_fn(x): 163 | q = self.norm_q(self.q(x)).view(b, s, n, d) 164 | k = self.norm_k(self.k(x)).view(b, s, n, d) 165 | v = self.v(x).view(b, s, n, d) 166 | return q, k, v 167 | 168 | q, k, v = qkv_fn(x) 169 | q = rope_apply(q, grid_sizes, freqs) 170 | k = rope_apply(k, grid_sizes, freqs) 171 | 172 | # TODO: We should use unpaded q,k,v for attention. 173 | # k_lens = seq_lens // get_sequence_parallel_world_size() 174 | # if k_lens is not None: 175 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) 176 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) 177 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) 178 | 179 | x = xFuserLongContextAttention()( 180 | None, 181 | query=half(q), 182 | key=half(k), 183 | value=half(v), 184 | window_size=self.window_size) 185 | 186 | # TODO: padding after attention. 187 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) 188 | 189 | # output 190 | x = x.flatten(2) 191 | x = self.o(x) 192 | return x 193 | -------------------------------------------------------------------------------- /wan/wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import flash_attention 2 | from .model import WanModel 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 4 | from .tokenizers import HuggingfaceTokenizer 5 | from .vae import WanVAE 6 | 7 | __all__ = [ 8 | 'WanVAE', 9 | 'WanModel', 10 | 'T5Model', 11 | 'T5Encoder', 12 | 'T5Decoder', 13 | 'T5EncoderModel', 14 | 'HuggingfaceTokenizer', 15 | 'flash_attention', 16 | ] 17 | -------------------------------------------------------------------------------- /wan/wan/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /wan/wan/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['XLMRoberta', 'xlm_roberta_large'] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | 12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 13 | assert dim % num_heads == 0 14 | super().__init__() 15 | self.dim = dim 16 | self.num_heads = num_heads 17 | self.head_dim = dim // num_heads 18 | self.eps = eps 19 | 20 | # layers 21 | self.q = nn.Linear(dim, dim) 22 | self.k = nn.Linear(dim, dim) 23 | self.v = nn.Linear(dim, dim) 24 | self.o = nn.Linear(dim, dim) 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, x, mask): 28 | """ 29 | x: [B, L, C]. 30 | """ 31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 32 | 33 | # compute query, key, value 34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 37 | 38 | # compute attention 39 | p = self.dropout.p if self.training else 0.0 40 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 42 | 43 | # output 44 | x = self.o(x) 45 | x = self.dropout(x) 46 | return x 47 | 48 | 49 | class AttentionBlock(nn.Module): 50 | 51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 52 | super().__init__() 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | self.post_norm = post_norm 56 | self.eps = eps 57 | 58 | # layers 59 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 60 | self.norm1 = nn.LayerNorm(dim, eps=eps) 61 | self.ffn = nn.Sequential( 62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), 63 | nn.Dropout(dropout)) 64 | self.norm2 = nn.LayerNorm(dim, eps=eps) 65 | 66 | def forward(self, x, mask): 67 | if self.post_norm: 68 | x = self.norm1(x + self.attn(x, mask)) 69 | x = self.norm2(x + self.ffn(x)) 70 | else: 71 | x = x + self.attn(self.norm1(x), mask) 72 | x = x + self.ffn(self.norm2(x)) 73 | return x 74 | 75 | 76 | class XLMRoberta(nn.Module): 77 | """ 78 | XLMRobertaModel with no pooler and no LM head. 79 | """ 80 | 81 | def __init__(self, 82 | vocab_size=250002, 83 | max_seq_len=514, 84 | type_size=1, 85 | pad_id=1, 86 | dim=1024, 87 | num_heads=16, 88 | num_layers=24, 89 | post_norm=True, 90 | dropout=0.1, 91 | eps=1e-5): 92 | super().__init__() 93 | self.vocab_size = vocab_size 94 | self.max_seq_len = max_seq_len 95 | self.type_size = type_size 96 | self.pad_id = pad_id 97 | self.dim = dim 98 | self.num_heads = num_heads 99 | self.num_layers = num_layers 100 | self.post_norm = post_norm 101 | self.eps = eps 102 | 103 | # embeddings 104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 105 | self.type_embedding = nn.Embedding(type_size, dim) 106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | # blocks 110 | self.blocks = nn.ModuleList([ 111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps) 112 | for _ in range(num_layers) 113 | ]) 114 | 115 | # norm layer 116 | self.norm = nn.LayerNorm(dim, eps=eps) 117 | 118 | def forward(self, ids): 119 | """ 120 | ids: [B, L] of torch.LongTensor. 121 | """ 122 | b, s = ids.shape 123 | mask = ids.ne(self.pad_id).long() 124 | 125 | # embeddings 126 | x = self.token_embedding(ids) + \ 127 | self.type_embedding(torch.zeros_like(ids)) + \ 128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 129 | if self.post_norm: 130 | x = self.norm(x) 131 | x = self.dropout(x) 132 | 133 | # blocks 134 | mask = torch.where( 135 | mask.view(b, 1, 1, s).gt(0), 0.0, 136 | torch.finfo(x.dtype).min) 137 | for block in self.blocks: 138 | x = block(x, mask) 139 | 140 | # output 141 | if not self.post_norm: 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | def xlm_roberta_large(pretrained=False, 147 | return_tokenizer=False, 148 | device='cpu', 149 | **kwargs): 150 | """ 151 | XLMRobertaLarge adapted from Huggingface. 152 | """ 153 | # params 154 | cfg = dict( 155 | vocab_size=250002, 156 | max_seq_len=514, 157 | type_size=1, 158 | pad_id=1, 159 | dim=1024, 160 | num_heads=16, 161 | num_layers=24, 162 | post_norm=True, 163 | dropout=0.1, 164 | eps=1e-5) 165 | cfg.update(**kwargs) 166 | 167 | # init a model on device 168 | with torch.device(device): 169 | model = XLMRoberta(**cfg) 170 | return model 171 | -------------------------------------------------------------------------------- /wan/wan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, 2 | retrieve_timesteps) 3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 4 | 5 | __all__ = [ 6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' 8 | ] 9 | -------------------------------------------------------------------------------- /wan/wan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import binascii 4 | import os 5 | import os.path as osp 6 | 7 | import imageio 8 | import torch 9 | import torchvision 10 | 11 | __all__ = ['cache_video', 'cache_image', 'str2bool'] 12 | 13 | 14 | def rand_name(length=8, suffix=''): 15 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 16 | if suffix: 17 | if not suffix.startswith('.'): 18 | suffix = '.' + suffix 19 | name += suffix 20 | return name 21 | 22 | 23 | def cache_video(tensor, 24 | save_file=None, 25 | fps=30, 26 | suffix='.mp4', 27 | nrow=8, 28 | normalize=True, 29 | value_range=(-1, 1), 30 | retry=5): 31 | # cache file 32 | cache_file = osp.join('/tmp', rand_name( 33 | suffix=suffix)) if save_file is None else save_file 34 | 35 | # save to cache 36 | error = None 37 | for _ in range(retry): 38 | try: 39 | # preprocess 40 | tensor = tensor.clamp(min(value_range), max(value_range)) 41 | tensor = torch.stack([ 42 | torchvision.utils.make_grid( 43 | u, nrow=nrow, normalize=normalize, value_range=value_range) 44 | for u in tensor.unbind(2) 45 | ], 46 | dim=1).permute(1, 2, 3, 0) 47 | tensor = (tensor * 255).type(torch.uint8).cpu() 48 | 49 | # write video 50 | writer = imageio.get_writer( 51 | cache_file, fps=fps, codec='libx264', quality=8) 52 | for frame in tensor.numpy(): 53 | writer.append_data(frame) 54 | writer.close() 55 | return cache_file 56 | except Exception as e: 57 | error = e 58 | continue 59 | else: 60 | print(f'cache_video failed, error: {error}', flush=True) 61 | return None 62 | 63 | 64 | def cache_image(tensor, 65 | save_file, 66 | nrow=8, 67 | normalize=True, 68 | value_range=(-1, 1), 69 | retry=5): 70 | # cache file 71 | suffix = osp.splitext(save_file)[1] 72 | if suffix.lower() not in [ 73 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' 74 | ]: 75 | suffix = '.png' 76 | 77 | # save to cache 78 | error = None 79 | for _ in range(retry): 80 | try: 81 | tensor = tensor.clamp(min(value_range), max(value_range)) 82 | torchvision.utils.save_image( 83 | tensor, 84 | save_file, 85 | nrow=nrow, 86 | normalize=normalize, 87 | value_range=value_range) 88 | return save_file 89 | except Exception as e: 90 | error = e 91 | continue 92 | 93 | 94 | def str2bool(v): 95 | """ 96 | Convert a string to a boolean. 97 | 98 | Supported true values: 'yes', 'true', 't', 'y', '1' 99 | Supported false values: 'no', 'false', 'f', 'n', '0' 100 | 101 | Args: 102 | v (str): String to convert. 103 | 104 | Returns: 105 | bool: Converted boolean value. 106 | 107 | Raises: 108 | argparse.ArgumentTypeError: If the value cannot be converted to boolean. 109 | """ 110 | if isinstance(v, bool): 111 | return v 112 | v_lower = v.lower() 113 | if v_lower in ('yes', 'true', 't', 'y', '1'): 114 | return True 115 | elif v_lower in ('no', 'false', 'f', 'n', '0'): 116 | return False 117 | else: 118 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') 119 | --------------------------------------------------------------------------------