├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── attn_processor.py ├── config.py ├── configs ├── accelerate │ ├── uncompiled_8.yaml │ └── uncompiled_8_ds.yaml ├── config_defaults.yaml └── deepspeed │ └── zero_stage2_config.json ├── ema.py ├── hv_control_lora_inference.py ├── hv_train_control_lora.py ├── mds_dataloaders.py ├── notebooks └── 000-inference.ipynb ├── optim.py ├── pyproject.toml ├── setup_env.sh └── tools ├── hv_cache_dataset.py ├── hv_precompute_latents.py └── hv_precompute_latents_dist.py /.gitignore: -------------------------------------------------------------------------------- 1 | */output/* 2 | */wandb/* 3 | */outputs/* 4 | */input/* 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | - id: check-docstring-first 10 | - id: check-toml 11 | - id: check-case-conflict 12 | 13 | - repo: https://github.com/asottile/pyupgrade 14 | rev: v3.3.1 15 | hooks: 16 | - id: pyupgrade 17 | args: [--py38-plus] 18 | name: Upgrade code 19 | 20 | - repo: https://github.com/PyCQA/isort 21 | rev: 5.12.0 22 | hooks: 23 | - id: isort 24 | name: imports 25 | require_serial: false 26 | args: [--settings-path=pyproject.toml] 27 | 28 | - repo: https://github.com/psf/black 29 | rev: 23.1.0 30 | hooks: 31 | - id: black 32 | name: Format code 33 | args: [--config=pyproject.toml] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **HunyuanVideo Keyframe Control Lora** 2 | 3 | --- 4 | 5 | This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our experiments on HunyuanVideo Keyframe Control Lora. 6 | 7 | ## Contents 8 | 9 | - [🔥🔥🔥 News!!](#-news) 10 | - [Abstract](#-abstract) 11 | - [Demo](#-demo) 12 | - [📜 Recommeded Settings](#-recommeded-settings) 13 | - [🛠️ Dependencies and Installation](#-dependencies-and-installation) 14 | - [🚀 Inference](#-inference) 15 | - [🚀 Training](#-training) 16 | 17 | --- 18 | 19 | ## 🔥🔥🔥 News!! 20 | 21 | - 27 February 2025: We release the training code of HunyuanVideo Keyframe Control Lora and [Blog](https://insiders.dashtoon.com/introducing-hunyuan-keyframe-lora-open-source-keyframe-based-video-generation/). 22 | - 24 February 2025: We release the inference code and model weights of HunyuanVideo Keyframe Control Lora . [Download](https://huggingface.co/dashtoon/hunyuan-video-keyframe-control-lora/tree/main). 23 | 24 | ## Abstract 25 | 26 | HunyuanVideo Keyframe Control Lora is an adapter for HunyuanVideo T2V model for keyframe-based video generation. ​Our architecture builds upon existing models, introducing key enhancements to optimize keyframe-based video generation:​ 27 | 28 | - We modify the input patch embedding projection layer to effectively incorporate keyframe information. By adjusting the convolutional input parameters, we enable the model to process image inputs within the Diffusion Transformer (DiT) framework.​ 29 | - We apply Low-Rank Adaptation (LoRA) across all linear layers and the convolutional input layer. This approach facilitates efficient fine-tuning by introducing low-rank matrices that approximate the weight updates, thereby preserving the base model's foundational capabilities while reducing the number of trainable parameters. 30 | - The model is conditioned on user-defined keyframes, allowing precise control over the generated video's start and end frames. This conditioning ensures that the generated content aligns seamlessly with the specified keyframes, enhancing the coherence and narrative flow of the video.​ 31 | 32 | ## 🎥 Demo 33 | 34 | **Click on the first column images to view the generated videos** 35 | 36 | | Generated Video | Image 1 | Image 2 | 37 | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | 38 | | Click to view video | Image 1 | Image 2 | 39 | | Click to view video | Image 1 | Image 2 | 40 | | Click to view video | Image 1 | Image 2 | 41 | | Click to view video | Image 1 | Image 2 | 42 | 43 | ## 📜 Recommeded Settings 44 | 45 | 1. The model works best on human subjects. Single subject images work slightly better. 46 | 2. It is recommended to use the following image generation resolutions `720x1280`, `544x960`, `1280x720`, `960x544`. 47 | 3. It is recommended to set frames from 33 upto 97. Can go upto 121 frames as well (but not tested much). 48 | 4. Prompting helps a lot but works even without. The prompt can be as simple as just the name of the object you want to generate or can be detailed. 49 | 5. `num_inference_steps` is recommended to be 50, but for fast results you can use 30 as well. Anything less than 30 is not recommended. 50 | 51 | ## 🛠️ Dependencies and Installation 52 | 53 | Begin by cloning the repository: 54 | 55 | ```shell 56 | git clone https://github.com/dashtoon/hunyuan-video-keyframe-control-lora.git 57 | cd hunyuan-video-keyframe-control-lora 58 | ``` 59 | 60 | ### Installation Guide for Linux 61 | 62 | We recommend CUDA versions 12.4 63 | 64 | Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html). 65 | 66 | ```shell 67 | bash setup_env.sh 68 | ``` 69 | 70 | ## 🚀 Inference 71 | 72 | The model weights can be downloaded from [Huggingface](https://huggingface.co/dashtoon/hunyuan-video-keyframe-control-lora) 73 | 74 | You can run inference using the provided script. The script uses `flash_attn` but can also be modified to use `sage_attn`. Running the below command will output a video that is saved in `output.mp4` 75 | 76 | - An NVIDIA GPU with CUDA support is required. 77 | - The model is tested on a single 80G GPU. 78 | - **Minimum**: The minimum GPU memory required is ~60GB for 720px1280px129f and ~45G for 544px960px129f. 79 | - **Recommended**: We recommend using a GPU with 80GB of memory for better generation quality. 80 | - Tested operating system: Linux 81 | 82 | ```shell 83 | export BASE_MODEL="hunyuanvideo-community/HunyuanVideo" 84 | export LORA_PATH="" 85 | export IMAGE_1="" 86 | export IMAGE_2="" 87 | export PROMPT="" 88 | export HEIGHT=960 89 | export WIDTH=544 90 | export n_FRAMES=33 91 | 92 | python hv_control_lora_inference.py \ 93 | --model $BASE_MODEL \ 94 | --lora $LORA_PATH \ 95 | --frame1 $IMAGE_1 --frame2 $IMAGE_2 --prompt "$PROMPT" --frames $n_FRAMES \ 96 | --height $HEIGHT --width $WIDTH \ 97 | --steps 50 \ 98 | --guidance 6.0 \ 99 | --seed 123143153 \ 100 | --output output.mp4 101 | ``` 102 | 103 | ## 🚀 Training 104 | 105 | ### Dataset Preparation 106 | 107 | It is recommended to have atleast 1 GPU with 80GB of VRAM. We use mosaic-ml streaming for caching our data. We expect our original data in the following format. Running the tree command, you should see: 108 | 109 | ``` 110 | dataset 111 | ├── metadata.csv 112 | ├── videos 113 | ├── 00000.mp4 114 | ├── 00001.mp4 115 | ├── ... 116 | ``` 117 | 118 | The csv can contain any number of columns, but due to limited support at the moment, we only make use of prompt and video columns. The CSV should look like this: 119 | 120 | ``` 121 | caption,video_file,other_column1,other_column2 122 | A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.,videos/00000.mp4,...,... 123 | ``` 124 | 125 | For the above format you would run the following command for starting to cache the dataset: 126 | 127 | ```shell 128 | python tools/hv_cache_dataset.py \ 129 | --csv "dataset/metadata.csv" \ 130 | --base_dir "dataset" \ 131 | --video_column video_file \ 132 | --caption_column "caption" \ 133 | --output_dir "dataset/mds_cache" \ 134 | --bucket_reso \ 135 | "1280x720x33" "1280x720x65" "1280x720x97" "960x544x33" "960x544x65" "960x544x97" \ 136 | "720x1280x33" "720x1280x65" "720x1280x97" "544x960x33" "544x960x65" "544x960x97" \ 137 | --min_bucket_count 100 \ 138 | --head_frame 0 139 | ``` 140 | 141 | - `bucket_reso` : this specifies the bucket resolutions to train on in the format of WxHxF. 142 | - `head_frame`: the intial frame from where to start extracting from a video 143 | 144 | **NOTE:** It is recommened to first convert your video into separate scenes and ensure there is continuity between scenes. [This](https://github.com/aigc-apps/EasyAnimate/tree/main/easyanimate/video_caption) is a good starting point for video dataset preparation. 145 | 146 | The next commanded will start caching the LLM embeds and the VAE states. 147 | 148 | ```shell 149 | NUM_GPUS=8 150 | MIXED_PRECISION="bf16" 151 | accelerate launch --num_processes=$NUM_GPUS --mixed_precision=$MIXED_PRECISION --main_process_port=12345 \ 152 | tools/hv_precompute_latents_dist.py \ 153 | --pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo" \ 154 | --mds_data_path "dataset/mds_cache" \ 155 | --output_dir "dataset/mds_cache_latents" \ 156 | --recursive 157 | ``` 158 | 159 | Now you need to add the path to all the mds latent folders in `./configs/config_defaults.yaml` config file under `data.local` as a list. The latent_cache should be stored unfer `--output_dir` folder as `1280x720x33_00` folders. Where `1280` is the width of the video, `720` is the height of the video and `33` is the framerate of the video and `00` is the gpu id. 160 | Now we are ready to start training! 161 | 162 | ### Starting a traning run 163 | 164 | ```shell 165 | NUM_GPUS=8 166 | MIXED_PRECISION="bf16" 167 | EXPERIMENT_NAME="my_first_run" 168 | OUTPUT_DIR="outputs/" 169 | CONFIG_PATH="./configs/config_defaults.yaml" 170 | NUM_EPOCHS=1 171 | 172 | accelerate launch --num_processes=$NUM_GPUS --mixed_precision=$MIXED_PRECISION --main_process_port=12345 \ 173 | hv_train_control_lora.py \ 174 | --config_path $CONFIG_PATH \ 175 | --experiment.run_id=$EXPERIMENT_NAME \ 176 | --experiment.output_dirpath=$OUTPUT_DIR \ 177 | --network.train_norm_layers=False \ 178 | --network.lora_dropout=0.05 \ 179 | --hparams.ema.use_ema=False \ 180 | --hparams.num_train_epochs=1 181 | ``` 182 | 183 | ## Acknowledgements 184 | 185 | - We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration. 186 | 187 | - We build on top of a body of great open-source libraries: transformers, accelerate, peft, diffusers, bitsandbytes, torchao, deepspeed, mosaicml-streaming -- to name a few. 188 | 189 | ## Star History 190 | 191 | [![Star History Chart](https://api.star-history.com/svg?repos=dashtoon/hunyuan-video-keyframe-control-lora&type=Timeline)](https://www.star-history.com/#dashtoon/hunyuan-video-keyframe-control-lora&Timeline) 192 | -------------------------------------------------------------------------------- /attn_processor.py: -------------------------------------------------------------------------------- 1 | import flash_attn 2 | import torch 3 | from accelerate.logging import get_logger 4 | from diffusers.models.embeddings import apply_rotary_emb 5 | from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | def get_cu_seqlens(attention_mask): 11 | """Calculate cu_seqlens_q, cu_seqlens_kv 12 | 13 | Args: 14 | attention_mask (torch.Tensor): boolean attention mask of shape: [B, 1, 1, N] 15 | 16 | Returns: 17 | torch.Tensor: the calculated cu_seqlens for flash attention 18 | """ 19 | batch_size = attention_mask.shape[0] 20 | text_len = attention_mask.sum(dim=-1, dtype=torch.int) 21 | max_len = attention_mask.shape[-1] 22 | 23 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 24 | 25 | for i in range(batch_size): 26 | s = text_len[i] 27 | s1 = i * max_len + s 28 | s2 = (i + 1) * max_len 29 | cu_seqlens[2 * i + 1] = s1 30 | cu_seqlens[2 * i + 2] = s2 31 | 32 | return cu_seqlens 33 | 34 | 35 | class HunyuanVideoFlashAttnProcessor: 36 | def __init__(self): 37 | pass 38 | 39 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): 40 | if attn.add_q_proj is None and encoder_hidden_states is not None: 41 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) 42 | 43 | query = attn.to_q(hidden_states) 44 | key = attn.to_k(hidden_states) 45 | value = attn.to_v(hidden_states) 46 | 47 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 48 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 49 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 50 | 51 | if attn.norm_q is not None: 52 | query = attn.norm_q(query) 53 | if attn.norm_k is not None: 54 | key = attn.norm_k(key) 55 | 56 | if image_rotary_emb is not None: 57 | if attn.add_q_proj is None and encoder_hidden_states is not None: 58 | query = torch.cat( 59 | [ 60 | apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 61 | query[:, :, -encoder_hidden_states.shape[1] :], 62 | ], 63 | dim=2, 64 | ) 65 | key = torch.cat( 66 | [ 67 | apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 68 | key[:, :, -encoder_hidden_states.shape[1] :], 69 | ], 70 | dim=2, 71 | ) 72 | else: 73 | query = apply_rotary_emb(query, image_rotary_emb) 74 | key = apply_rotary_emb(key, image_rotary_emb) 75 | 76 | batch_size = hidden_states.shape[0] 77 | img_seq_len = hidden_states.shape[1] 78 | txt_seq_len = 0 79 | 80 | if attn.add_q_proj is not None and encoder_hidden_states is not None: 81 | encoder_query = attn.add_q_proj(encoder_hidden_states) 82 | encoder_key = attn.add_k_proj(encoder_hidden_states) 83 | encoder_value = attn.add_v_proj(encoder_hidden_states) 84 | 85 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 86 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 87 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 88 | 89 | if attn.norm_added_q is not None: 90 | encoder_query = attn.norm_added_q(encoder_query) 91 | if attn.norm_added_k is not None: 92 | encoder_key = attn.norm_added_k(encoder_key) 93 | 94 | query = torch.cat([query, encoder_query], dim=2) 95 | key = torch.cat([key, encoder_key], dim=2) 96 | value = torch.cat([value, encoder_value], dim=2) 97 | 98 | txt_seq_len = encoder_hidden_states.shape[1] 99 | 100 | max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len 101 | cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask) 102 | 103 | query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3]) 104 | key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3]) 105 | value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3]) 106 | hidden_states = flash_attn_varlen_func( 107 | query, 108 | key, 109 | value, 110 | cu_seqlens_q, 111 | cu_seqlens_kv, 112 | max_seqlen_q, 113 | max_seqlen_kv, 114 | softmax_scale=None, 115 | dropout_p=0.0, 116 | causal=False, 117 | ) 118 | 119 | hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1) 120 | hidden_states = hidden_states.to(query.dtype) 121 | 122 | if encoder_hidden_states is not None: 123 | hidden_states, encoder_hidden_states = ( 124 | hidden_states[:, : -encoder_hidden_states.shape[1]], 125 | hidden_states[:, -encoder_hidden_states.shape[1] :], 126 | ) 127 | 128 | if getattr(attn, "to_out", None) is not None: 129 | hidden_states = attn.to_out[0](hidden_states) 130 | hidden_states = attn.to_out[1](hidden_states) 131 | 132 | if getattr(attn, "to_add_out", None) is not None: 133 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 134 | 135 | return hidden_states, encoder_hidden_states 136 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | import uuid 4 | from dataclasses import asdict, dataclass 5 | 6 | from peft import LoraConfig 7 | from pyrallis import field 8 | 9 | 10 | @dataclass 11 | class BaseConfig: 12 | def get(self, attribute_name, default=None): 13 | return getattr(self, attribute_name, default) 14 | 15 | def pop(self, attribute_name, default=None): 16 | if hasattr(self, attribute_name): 17 | value = getattr(self, attribute_name) 18 | delattr(self, attribute_name) 19 | return value 20 | else: 21 | return default 22 | 23 | def __str__(self): 24 | return json.dumps(asdict(self), indent=4) 25 | 26 | 27 | @dataclass 28 | class DataLoaderKwargs(BaseConfig): 29 | """Configuration for data loading parameters""" 30 | 31 | drop_last: bool = field(default=True) # Whether to drop the last incomplete batch 32 | num_workers: int = field(default=8) # Number of worker processes for data loading 33 | persistent_workers: bool = field(default=True) # Keep worker processes alive between epochs 34 | pin_memory: bool = field(default=True) # Pin memory for faster data transfer to GPU 35 | prefetch_factor: int = field(default=2) # Number of batches to prefetch per worker 36 | 37 | 38 | @dataclass 39 | class StreamingKwargs(BaseConfig): 40 | """Configuration for data streaming parameters""" 41 | 42 | cache_limit: str = field(default="5tb") # Maximum cache size limit 43 | download_timeout: int = field(default=12000) # Timeout in seconds for downloads 44 | num_canonical_nodes: typing.Optional[int] = field(default=None) # Number of canonical nodes to use 45 | shuffle: bool = field(default=True) # Whether to shuffle the data 46 | batching_method: str = field(default="per_stream") # Method used for batching data 47 | 48 | 49 | @dataclass 50 | class DataConfig(BaseConfig): 51 | """Configuration for data sources and processing""" 52 | 53 | remote: typing.Optional[typing.List[typing.Optional[str]]] = field(default=None) # Remote data source paths 54 | local: typing.Optional[typing.List[typing.Optional[str]]] = field(default=None) # Local data source paths 55 | batch_size: int = field(default=1) # Training batch size 56 | video_key: str = field(default="video") # Key for video data in dataset 57 | caption_key: str = field(default="caption") # Key for caption data in dataset 58 | latents_key: str = field(default="latents") # Key for latents in dataset 59 | prompt_embeds_key: str = field(default="prompt_embeds") # Key for prompt embeddings 60 | latents_cond_key: str = field(default="latents_cond") # Key for conditional latents 61 | prompt_attention_mask_key: str = field(default="prompt_attention_mask") # Key for prompt attention mask 62 | pooled_prompt_embeds_key: str = field(default="pooled_prompt_embeds") # Key for pooled prompt embeddings 63 | repeat: typing.Optional[typing.List] = field(default=None, is_mutable=True) # Number of times to repeat dataset 64 | choose: typing.Optional[typing.List] = field(default=None, is_mutable=True) # Indices to choose from dataset 65 | streaming_kwargs: StreamingKwargs = field( 66 | default_factory=StreamingKwargs, is_mutable=True 67 | ) # Streaming configuration 68 | dataloader_kwargs: DataLoaderKwargs = field( 69 | default_factory=DataLoaderKwargs, is_mutable=True 70 | ) # DataLoader configuration 71 | 72 | 73 | @dataclass 74 | class PretrainedModelConfig(BaseConfig): 75 | """Configuration for pretrained model loading""" 76 | 77 | pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo" # Path or name of pretrained model 78 | revision: typing.Optional[str] = field(default=None) # Specific model revision to use 79 | variant: typing.Optional[str] = field(default=None) # Specific model variant to use 80 | 81 | 82 | @dataclass 83 | class NetworkConfig(BaseConfig): 84 | """Configuration for network architecture""" 85 | 86 | lora_rank: int = field(default=16) # Rank for LoRA adaptation 87 | lora_alpha: int = field(default=16) # Alpha scaling for LoRA 88 | target_modules: typing.Optional[typing.List[str]] = field(default=None, is_mutable=True) # Target modules for LoRA 89 | lora_dropout: float = field(default=0.0) # Dropout probability for LoRA layers 90 | train_norm_layers: bool = field(default=False) # Whether to train normalization layers 91 | init_lora_weights: typing.Union[bool, str] = field( 92 | default=True 93 | ) # typing.Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"] 94 | lora_layers: typing.Optional[str] = field(default=None) 95 | 96 | 97 | @dataclass 98 | class ExperimentConfig(BaseConfig): 99 | """Configuration for experiment tracking""" 100 | 101 | output_dirpath: str = field(default="./outputs") # Directory path for outputs 102 | random_seed: int = field(default=42) # Random seed for reproducibility 103 | run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:8], is_mutable=True) # Unique run identifier 104 | name: str = field(default="whatever") # Name of the experiment 105 | ic_debug: bool = field(default=False) # Whether to use ic for debugging 106 | 107 | 108 | @dataclass 109 | class FlowMatchConfig(BaseConfig): 110 | """Configuration for flow matching parameters""" 111 | 112 | discrete_flow_shift: float = field(default=7.0) # Shift for discrete flow 113 | timestep_sampling: str = field(default="sigma") # "uniform", "sigmoid", "logit_normal" 114 | weighting_scheme: str = field(default="none") # "sigma_sqrt", "cosmap", "none" 115 | sigmoid_scale: float = field(default=1.0) # Scale for sigmoid function 116 | logit_mean: float = field(default=0.0) # Mean for logit distribution 117 | logit_std: float = field(default=1.0) # Standard deviation for logit distribution 118 | 119 | 120 | @dataclass 121 | class EmaConfig(BaseConfig): 122 | """Ema configuration""" 123 | 124 | use_ema: bool = field(default=False) 125 | ema_decay: float = field(default=0.99) 126 | ema_foreach_disable: bool = field(default=False) 127 | ema_device: str = field(default="accelerator") # | typing.Literal["accelerator", "cpu"] 128 | ema_cpu_only: bool = field(default=False) 129 | ema_update_interval: typing.Optional[int] = field(default=None) 130 | 131 | 132 | @dataclass 133 | class TrainingHyperParametersConfig(BaseConfig): 134 | """Configuration for training hyperparameters""" 135 | 136 | mixed_precision: str = field(default="bf16") # Mixed precision training type 137 | gradient_checkpointing: bool = field(default=True) # Whether to use gradient checkpointing 138 | gradient_accumulation_steps: int = field(default=1) # Number of gradient accumulation steps 139 | learning_rate: float = field(default=1e-04) # Learning rate for training 140 | optimizer_type: str = field(default="torch.optim.AdamW") # Type of optimizer to use 141 | optimizer_args: typing.List[str] = field(default=[], is_mutable=True) # Additional optimizer arguments 142 | max_grad_norm: int = field(default=1.0) # Maximum gradient norm for clipping 143 | grad_clip_method: str = field(default="norm") 144 | lr_scheduler: str = field(default="constant") # Learning rate scheduler type 145 | lr_warmup_steps: int = field(default=0) # Number of warmup steps 146 | lr_scheduler_num_cycles: int = field(default=1) # Number of scheduler cycles 147 | lr_scheduler_power: float = field(default=0.9) # Power for scheduler 148 | guidance_scale: int = field(default=1.0) # Scale for guidance 149 | flow_match: FlowMatchConfig = field(default_factory=FlowMatchConfig, is_mutable=True) # Flow matching configuration 150 | num_train_epochs: typing.Optional[int] = field(default=None) # Number of training epochs 151 | max_train_steps: typing.Optional[int] = field(default=None) # Maximum number of training steps 152 | caption_dropout_p: float = field(default=0.0) # Dropout probability for captions 153 | ema: EmaConfig = field(default_factory=EmaConfig, is_mutable=True) # EMA configuration 154 | gradient_precision: str = field( 155 | default="accelerator" 156 | ) # gradient precision from LLAMA paper | typing.Literal["accelerator", "fp32"] 157 | 158 | 159 | @dataclass 160 | class CheckpointConfig(BaseConfig): 161 | """Configuration for model checkpointing""" 162 | 163 | save_every_n_steps: typing.Optional[int] = field(default=None) # Save checkpoint every N steps 164 | save_last_n_steps: typing.Optional[int] = field(default=None) # Keep last N checkpoints 165 | resume_from_checkpoint: typing.Optional[str] = field(default=None) # Path to checkpoint to resume from 166 | 167 | 168 | @dataclass 169 | class TorchDDPKwargs(BaseConfig): 170 | """Configuration for torch distributed parameters""" 171 | 172 | backend: str = field(default="nccl") 173 | find_unused_parameters: bool = field(default=False) 174 | gradient_as_bucket_view: bool = field(default=False) 175 | static_graph: bool = field(default=False) 176 | 177 | 178 | @dataclass 179 | class Config(BaseConfig): 180 | """Main configuration class combining all sub-configurations""" 181 | 182 | experiment: ExperimentConfig = field(default_factory=ExperimentConfig, is_mutable=True) # Experiment configuration 183 | data: DataConfig = field(default_factory=DataConfig, is_mutable=True) # Data configuration 184 | model: PretrainedModelConfig = field(default_factory=PretrainedModelConfig, is_mutable=True) # Model configuration 185 | network: NetworkConfig = field(default_factory=NetworkConfig, is_mutable=True) # Network configuration 186 | hparams: TrainingHyperParametersConfig = field( 187 | default_factory=TrainingHyperParametersConfig, is_mutable=True 188 | ) # Training hyperparameters 189 | checkpointing: CheckpointConfig = field( 190 | default_factory=CheckpointConfig, is_mutable=True 191 | ) # Checkpointing configuration 192 | ddp_kwargs: TorchDDPKwargs = field(default_factory=TorchDDPKwargs, is_mutable=True) 193 | 194 | 195 | if __name__ == "__main__": 196 | import pyrallis 197 | 198 | cfg = pyrallis.parse(config_class=Config) 199 | print(f"Training {cfg}") 200 | -------------------------------------------------------------------------------- /configs/accelerate/uncompiled_8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: "no" 5 | enable_cpu_affinity: false 6 | gpu_ids: all 7 | machine_rank: 0 8 | main_training_function: main 9 | mixed_precision: bf16 10 | num_machines: 1 11 | num_processes: 8 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /configs/accelerate/uncompiled_8_ds.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: ./configs/deepspeed/zero_stage2_config.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: "no" 8 | enable_cpu_affinity: false 9 | machine_rank: 0 10 | main_training_function: main 11 | num_machines: 1 12 | num_processes: 8 13 | rdzv_backend: static 14 | same_network: true 15 | tpu_env: [] 16 | tpu_use_cluster: false 17 | tpu_use_sudo: false 18 | use_cpu: false 19 | -------------------------------------------------------------------------------- /configs/config_defaults.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | output_dirpath: "./i2v_outputs" 3 | random_seed: 4201 4 | name: "i2v" 5 | run_id: "my_first_experiment" 6 | 7 | model: 8 | pretrained_model_name_or_path: "hunyuanvideo-community/HunyuanVideo" 9 | revision: null 10 | variant: null 11 | 12 | data: 13 | local: 14 | batch_size: 1 15 | video_key: "video" 16 | caption_key: "caption" 17 | latents_key: "latents" 18 | prompt_embeds_key: "prompt_embeds" 19 | latents_cond_key: "latents_cond" 20 | prompt_attention_mask_key: "prompt_attention_mask" 21 | pooled_prompt_embeds_key: "pooled_prompt_embeds" 22 | 23 | dataloader_kwargs: 24 | drop_last: false 25 | num_workers: 8 26 | persistent_workers: true 27 | pin_memory: true 28 | prefetch_factor: 2 29 | 30 | streaming_kwargs: 31 | shuffle: true 32 | batching_method: "per_stream" 33 | num_canonical_nodes: 8 34 | 35 | network: 36 | lora_rank: 128 37 | lora_alpha: 128 38 | lora_dropout: 0.00 39 | init_lora_weights: true 40 | train_norm_layers: false 41 | lora_layers: "all-linear" 42 | 43 | hparams: 44 | caption_dropout_p: 0.00 45 | mixed_precision: "bf16" 46 | gradient_checkpointing: true 47 | gradient_accumulation_steps: 1 48 | learning_rate: 2e-4 49 | optimizer_type: "optimi-stableadamw" 50 | optimizer_args: ["weight_decay=1e-2", "eps=1e-8", "betas=(0.9, 0.95)"] 51 | max_grad_norm: 0.0 52 | lr_scheduler: "constant_with_warmup" 53 | lr_warmup_steps: 100 54 | guidance_scale: 1.0 55 | num_train_epochs: 1 56 | max_train_steps: null 57 | gradient_precision: "accelerator" 58 | 59 | flow_match: 60 | weighting_scheme: "none" 61 | timestep_sampling: "logit_normal" 62 | sigmoid_scale: 1.0 63 | logit_mean: 0.0 64 | logit_std: 1.0 65 | discrete_flow_shift: 7.0 66 | 67 | ema: 68 | use_ema: true 69 | ema_device: "accelerator" 70 | ema_decay: 0.99 71 | 72 | checkpointing: 73 | save_every_n_steps: 500 74 | save_last_n_steps: 1500 75 | resume_from_checkpoint: null 76 | -------------------------------------------------------------------------------- /configs/deepspeed/zero_stage2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "contiguous_gradients": true, 5 | "overlap_comm": true, 6 | "reduce_scatter": true, 7 | "reduce_bucket_size": 5e8, 8 | "allgather_bucket_size": 5e8 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "gradient_accumulation_steps": 1.0, 14 | "train_micro_batch_size_per_gpu": 1.0, 15 | "gradient_clipping": 1.0, 16 | "wall_clock_breakdown": false, 17 | "zero_allow_untested_optimizer": true 18 | } 19 | -------------------------------------------------------------------------------- /ema.py: -------------------------------------------------------------------------------- 1 | # copy-paste from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/ema.py 2 | import contextlib 3 | import copy 4 | import logging 5 | import os 6 | from typing import Any, Dict, Iterable, Optional, Union 7 | 8 | import torch 9 | import transformers 10 | from accelerate.logging import get_logger 11 | from diffusers.utils import is_transformers_available 12 | from diffusers.utils.deprecation_utils import deprecate 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def should_update_ema(args, step): 18 | if args.ema_update_interval is None: 19 | # If the EMA update interval is not set, always update the EMA. 20 | return True 21 | else: 22 | should_update = step % args.ema_update_interval == 0 23 | if should_update: 24 | logger.debug("Updating EMA weights...") 25 | return should_update 26 | 27 | 28 | class EMAModel: 29 | """ 30 | Exponential Moving Average of models weights 31 | """ 32 | 33 | def __init__( 34 | self, 35 | args, 36 | accelerator, 37 | parameters: Iterable[torch.nn.Parameter], 38 | decay: float = 0.9999, 39 | min_decay: float = 0.0, 40 | update_after_step: int = 0, 41 | use_ema_warmup: bool = False, 42 | inv_gamma: Union[float, int] = 1.0, 43 | power: Union[float, int] = 2 / 3, 44 | foreach: bool = True, 45 | model_cls: Optional[Any] = None, 46 | model_config: Dict[str, Any] = None, 47 | **kwargs, 48 | ): 49 | """ 50 | Args: 51 | parameters (Iterable[torch.nn.Parameter]): The parameters to track. 52 | decay (float): The decay factor for the exponential moving average. 53 | min_decay (float): The minimum decay factor for the exponential moving average. 54 | update_after_step (int): The number of steps to wait before starting to update the EMA weights. 55 | use_ema_warmup (bool): Whether to use EMA warmup. 56 | inv_gamma (float): 57 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. 58 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. 59 | foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. 60 | device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA 61 | weights will be stored on CPU. 62 | 63 | @crowsonkb's notes on EMA Warmup: 64 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 65 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 66 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 67 | at 215.4k steps). 68 | """ 69 | 70 | if isinstance(parameters, torch.nn.Module): 71 | deprecation_message = ( 72 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " 73 | "Please pass the parameters of the module instead." 74 | ) 75 | deprecate( 76 | "passing a `torch.nn.Module` to `ExponentialMovingAverage`", 77 | "1.0.0", 78 | deprecation_message, 79 | standard_warn=False, 80 | ) 81 | parameters = parameters.parameters() 82 | 83 | # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility 84 | use_ema_warmup = True 85 | 86 | if kwargs.get("max_value", None) is not None: 87 | deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." 88 | deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) 89 | decay = kwargs["max_value"] 90 | 91 | if kwargs.get("min_value", None) is not None: 92 | deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." 93 | deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) 94 | min_decay = kwargs["min_value"] 95 | 96 | parameters = list(parameters) 97 | self.shadow_params = [p.clone().detach() for p in parameters] 98 | 99 | if kwargs.get("device", None) is not None: 100 | deprecation_message = "The `device` argument is deprecated. Please use `to` instead." 101 | deprecate("device", "1.0.0", deprecation_message, standard_warn=False) 102 | self.to(device=kwargs["device"]) 103 | 104 | self.temp_stored_params = None 105 | 106 | self.decay = decay 107 | self.min_decay = min_decay 108 | self.update_after_step = update_after_step 109 | self.use_ema_warmup = use_ema_warmup 110 | self.inv_gamma = inv_gamma 111 | self.power = power 112 | self.optimization_step = 0 113 | self.cur_decay_value = None # set in `step()` 114 | self.foreach = foreach 115 | 116 | self.model_cls = model_cls 117 | self.model_config = model_config 118 | self.args = args 119 | self.accelerator = accelerator 120 | self.training = True # To emulate nn.Module's training mode 121 | 122 | def save_state_dict(self, path: str) -> None: 123 | """ 124 | Save the EMA model's state directly to a file. 125 | 126 | Args: 127 | path (str): The file path where the EMA state will be saved. 128 | """ 129 | # if the folder containing the path does not exist, create it 130 | os.makedirs(os.path.dirname(path), exist_ok=True) 131 | # grab state dict 132 | state_dict = self.state_dict() 133 | # save it using torch.save 134 | torch.save(state_dict, path) 135 | logger.info(f"EMA model state saved to {path}") 136 | 137 | def load_state_dict(self, path: str) -> None: 138 | """ 139 | Load the EMA model's state from a file and apply it to this instance. 140 | 141 | Args: 142 | path (str): The file path from where the EMA state will be loaded. 143 | """ 144 | state_dict = torch.load(path, map_location="cpu", weights_only=True) 145 | 146 | # Load metadata 147 | self.decay = state_dict.get("decay", self.decay) 148 | self.min_decay = state_dict.get("min_decay", self.min_decay) 149 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) 150 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) 151 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) 152 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) 153 | self.power = state_dict.get("power", self.power) 154 | 155 | # Load shadow parameters 156 | shadow_params = [] 157 | idx = 0 158 | while f"shadow_params.{idx}" in state_dict: 159 | shadow_params.append(state_dict[f"shadow_params.{idx}"]) 160 | idx += 1 161 | 162 | if len(shadow_params) != len(self.shadow_params): 163 | raise ValueError( 164 | f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, " 165 | f"but found {len(shadow_params)} in the state dict." 166 | ) 167 | 168 | for current_param, loaded_param in zip(self.shadow_params, shadow_params): 169 | current_param.data.copy_(loaded_param.data) 170 | 171 | logger.info(f"EMA model state loaded from {path}") 172 | 173 | @classmethod 174 | def from_pretrained(cls, path, model_cls) -> "EMAModel": 175 | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) 176 | model = model_cls.from_pretrained(path) 177 | 178 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) 179 | 180 | ema_model.load_state_dict(ema_kwargs) 181 | return ema_model 182 | 183 | def save_pretrained(self, path, max_shard_size: str = "10GB"): 184 | if self.model_cls is None: 185 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") 186 | 187 | if self.model_config is None: 188 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") 189 | 190 | model = self.model_cls.from_config(self.model_config) 191 | state_dict = self.state_dict(exclude_params=True) 192 | state_dict.pop("shadow_params", None) 193 | 194 | model.register_to_config(**state_dict) 195 | self.copy_to(model.parameters()) 196 | model.save_pretrained(path, max_shard_size=max_shard_size) 197 | 198 | def get_decay(self, optimization_step: int = None) -> float: 199 | """ 200 | Compute the decay factor for the exponential moving average. 201 | """ 202 | if optimization_step is None: 203 | optimization_step = self.optimization_step 204 | 205 | step = max(0, optimization_step - self.update_after_step - 1) 206 | 207 | if step <= 0: 208 | return 0.0 209 | 210 | if self.use_ema_warmup: 211 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power 212 | else: 213 | cur_decay_value = (1 + step) / (10 + step) 214 | 215 | cur_decay_value = min(cur_decay_value, self.decay) 216 | # make sure decay is not smaller than min_decay 217 | cur_decay_value = max(cur_decay_value, self.min_decay) 218 | return cur_decay_value 219 | 220 | @torch.no_grad() 221 | def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None): 222 | if not should_update_ema(self.args, global_step): 223 | return 224 | 225 | if self.args.ema_device == "cpu" and not self.args.ema_cpu_only: 226 | # Move EMA to accelerator for faster update. 227 | self.to(device=self.accelerator.device, non_blocking=True) 228 | if isinstance(parameters, torch.nn.Module): 229 | deprecation_message = ( 230 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " 231 | "Please pass the parameters of the module instead." 232 | ) 233 | deprecate( 234 | "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", 235 | "1.0.0", 236 | deprecation_message, 237 | standard_warn=False, 238 | ) 239 | parameters = parameters.parameters() 240 | 241 | parameters = list(parameters) 242 | 243 | if global_step is not None: 244 | # When we're updating the EMA periodically, we can't trust the counter. 245 | self.optimization_step = global_step 246 | else: 247 | self.optimization_step += 1 248 | 249 | # Compute the decay factor for the exponential moving average. 250 | decay = self.get_decay(self.optimization_step) 251 | self.cur_decay_value = decay 252 | one_minus_decay = 1 - decay 253 | 254 | context_manager = contextlib.nullcontext 255 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): 256 | import deepspeed 257 | 258 | if self.foreach: 259 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): 260 | context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) 261 | 262 | with context_manager(): 263 | params_grad = [param for param in parameters if param.requires_grad] 264 | s_params_grad = [ 265 | s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad 266 | ] 267 | 268 | if len(params_grad) < len(parameters): 269 | torch._foreach_copy_( 270 | [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], 271 | [param for param in parameters if not param.requires_grad], 272 | non_blocking=True, 273 | ) 274 | 275 | torch._foreach_sub_( 276 | s_params_grad, 277 | torch._foreach_sub(s_params_grad, params_grad), 278 | alpha=one_minus_decay, 279 | ) 280 | 281 | else: 282 | for s_param, param in zip(self.shadow_params, parameters): 283 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): 284 | context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) 285 | 286 | with context_manager(): 287 | if param.requires_grad: 288 | s_param.sub_(one_minus_decay * (s_param - param.to(s_param.device))) 289 | else: 290 | s_param.copy_(param) 291 | if self.args.ema_device == "cpu" and not self.args.ema_cpu_only: 292 | # Move back to CPU for safe-keeping. 293 | self.to(device=self.args.ema_device, non_blocking=True) 294 | 295 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 296 | """ 297 | Copy current averaged parameters into given collection of parameters. 298 | 299 | Args: 300 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 301 | updated with the stored moving averages. If `None`, the parameters with which this 302 | `ExponentialMovingAverage` was initialized will be used. 303 | """ 304 | parameters = list(parameters) 305 | if self.foreach: 306 | torch._foreach_copy_( 307 | [param.data for param in parameters], 308 | [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], 309 | ) 310 | else: 311 | for s_param, param in zip(self.shadow_params, parameters): 312 | param.data.copy_(s_param.to(param.device).data) 313 | 314 | def pin_memory(self) -> None: 315 | r""" 316 | Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for 317 | offloading EMA params to the host. 318 | """ 319 | if torch.backends.mps.is_available(): 320 | logger.warning("Apple silicon does not support pinned memory. Skipping.") 321 | return 322 | 323 | if self.args.ema_cpu_only: 324 | return 325 | 326 | # This probably won't work, but we'll do it anyway. 327 | self.shadow_params = [p.pin_memory() for p in self.shadow_params] 328 | 329 | def to(self, *args, **kwargs): 330 | for param in self.shadow_params: 331 | param.data = param.data.to(*args, **kwargs) 332 | return self 333 | 334 | def cuda(self, device=None): 335 | return self.to(device="cuda" if device is None else f"cuda:{device}") 336 | 337 | def cpu(self): 338 | return self.to(device="cpu") 339 | 340 | def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False): 341 | r""" 342 | Returns a dictionary containing a whole state of the EMA model. 343 | """ 344 | state_dict = { 345 | "decay": self.decay, 346 | "min_decay": self.min_decay, 347 | "optimization_step": self.optimization_step, 348 | "update_after_step": self.update_after_step, 349 | "use_ema_warmup": self.use_ema_warmup, 350 | "inv_gamma": self.inv_gamma, 351 | "power": self.power, 352 | } 353 | if exclude_params: 354 | return state_dict 355 | for idx, param in enumerate(self.shadow_params): 356 | state_dict[f"{prefix}shadow_params.{idx}"] = param if keep_vars else param.detach() 357 | return state_dict 358 | 359 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: 360 | r""" 361 | Save the current parameters for restoring later. 362 | """ 363 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] 364 | 365 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: 366 | r""" 367 | Restore the parameters stored with the `store` method. 368 | """ 369 | if self.temp_stored_params is None: 370 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") 371 | if self.foreach: 372 | torch._foreach_copy_( 373 | [param.data for param in parameters], 374 | [c_param.data for c_param in self.temp_stored_params], 375 | ) 376 | else: 377 | for c_param, param in zip(self.temp_stored_params, parameters): 378 | param.data.copy_(c_param.data) 379 | 380 | # Better memory-wise. 381 | self.temp_stored_params = None 382 | 383 | def parameter_count(self) -> int: 384 | return sum(p.numel() for p in self.shadow_params) 385 | 386 | # Implementing nn.Module methods to emulate its behavior 387 | 388 | def named_children(self): 389 | # No child modules 390 | return iter([]) 391 | 392 | def children(self): 393 | return iter([]) 394 | 395 | def modules(self): 396 | yield self 397 | 398 | def named_modules(self, memo=None, prefix=""): 399 | yield prefix, self 400 | 401 | def parameters(self, recurse=True): 402 | return iter(self.shadow_params) 403 | 404 | def named_parameters(self, prefix="", recurse=True): 405 | for i, param in enumerate(self.shadow_params): 406 | name = f"{prefix}shadow_params.{i}" 407 | yield name, param 408 | 409 | def buffers(self, recurse=True): 410 | return iter([]) 411 | 412 | def named_buffers(self, prefix="", recurse=True): 413 | return iter([]) 414 | 415 | def train(self, mode=True): 416 | self.training = mode 417 | return self 418 | 419 | def eval(self): 420 | return self.train(False) 421 | 422 | def zero_grad(self): 423 | # No gradients to zero in EMA model 424 | pass 425 | -------------------------------------------------------------------------------- /hv_control_lora_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import time 5 | import uuid 6 | from datetime import datetime 7 | from typing import Any, Callable, Dict, List, Optional, Union 8 | 9 | import av 10 | import cv2 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import torchvision.transforms.v2 as transforms 15 | from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline 16 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 17 | from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel 18 | from diffusers.models.attention import Attention 19 | from diffusers.models.embeddings import apply_rotary_emb 20 | from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed 21 | from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps 22 | from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput 23 | from PIL import Image 24 | 25 | # Try to import flash attention 26 | try: 27 | import flash_attn 28 | from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func 29 | except ImportError: 30 | flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None 31 | 32 | try: 33 | from sageattention import sageattn, sageattn_varlen 34 | except ImportError: 35 | sageattn, sageattn_varlen = None, None 36 | 37 | 38 | def get_cu_seqlens(attention_mask): 39 | """Calculate cu_seqlens_q, cu_seqlens_kv using attention_mask""" 40 | batch_size = attention_mask.shape[0] 41 | text_len = attention_mask.sum(dim=-1, dtype=torch.int) 42 | max_len = attention_mask.shape[-1] 43 | 44 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 45 | 46 | for i in range(batch_size): 47 | s = text_len[i] 48 | s1 = i * max_len + s 49 | s2 = (i + 1) * max_len 50 | cu_seqlens[2 * i + 1] = s1 51 | cu_seqlens[2 * i + 2] = s2 52 | 53 | return cu_seqlens 54 | 55 | 56 | class HunyuanVideoFlashAttnProcessor: 57 | def __init__(self, use_flash_attn=True, use_sageattn=False): 58 | self.use_flash_attn = use_flash_attn 59 | self.use_sageattn = use_sageattn 60 | if self.use_flash_attn: 61 | assert flash_attn is not None, "Flash attention not available" 62 | if self.use_sageattn: 63 | assert sageattn is not None, "Sage attention not available" 64 | 65 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): 66 | if attn.add_q_proj is None and encoder_hidden_states is not None: 67 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) 68 | 69 | query = attn.to_q(hidden_states) 70 | key = attn.to_k(hidden_states) 71 | value = attn.to_v(hidden_states) 72 | 73 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 74 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 75 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 76 | 77 | if attn.norm_q is not None: 78 | query = attn.norm_q(query) 79 | if attn.norm_k is not None: 80 | key = attn.norm_k(key) 81 | 82 | if image_rotary_emb is not None: 83 | if attn.add_q_proj is None and encoder_hidden_states is not None: 84 | query = torch.cat( 85 | [ 86 | apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 87 | query[:, :, -encoder_hidden_states.shape[1] :], 88 | ], 89 | dim=2, 90 | ) 91 | key = torch.cat( 92 | [ 93 | apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), 94 | key[:, :, -encoder_hidden_states.shape[1] :], 95 | ], 96 | dim=2, 97 | ) 98 | else: 99 | query = apply_rotary_emb(query, image_rotary_emb) 100 | key = apply_rotary_emb(key, image_rotary_emb) 101 | 102 | batch_size = hidden_states.shape[0] 103 | img_seq_len = hidden_states.shape[1] 104 | txt_seq_len = 0 105 | 106 | if attn.add_q_proj is not None and encoder_hidden_states is not None: 107 | encoder_query = attn.add_q_proj(encoder_hidden_states) 108 | encoder_key = attn.add_k_proj(encoder_hidden_states) 109 | encoder_value = attn.add_v_proj(encoder_hidden_states) 110 | 111 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 112 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 113 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 114 | 115 | if attn.norm_added_q is not None: 116 | encoder_query = attn.norm_added_q(encoder_query) 117 | if attn.norm_added_k is not None: 118 | encoder_key = attn.norm_added_k(encoder_key) 119 | 120 | query = torch.cat([query, encoder_query], dim=2) 121 | key = torch.cat([key, encoder_key], dim=2) 122 | value = torch.cat([value, encoder_value], dim=2) 123 | 124 | txt_seq_len = encoder_hidden_states.shape[1] 125 | 126 | max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len 127 | cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask) 128 | 129 | query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3]) 130 | key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3]) 131 | value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3]) 132 | 133 | if self.use_flash_attn: 134 | hidden_states = flash_attn_varlen_func( 135 | query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv 136 | ) 137 | elif self.use_sageattn: 138 | hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) 139 | else: 140 | raise NotImplementedError("Please set use_flash_attn=True or use_sageattn=True") 141 | 142 | hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1) 143 | hidden_states = hidden_states.to(query.dtype) 144 | 145 | if encoder_hidden_states is not None: 146 | hidden_states, encoder_hidden_states = ( 147 | hidden_states[:, : -encoder_hidden_states.shape[1]], 148 | hidden_states[:, -encoder_hidden_states.shape[1] :], 149 | ) 150 | 151 | if getattr(attn, "to_out", None) is not None: 152 | hidden_states = attn.to_out[0](hidden_states) 153 | hidden_states = attn.to_out[1](hidden_states) 154 | 155 | if getattr(attn, "to_add_out", None) is not None: 156 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 157 | 158 | return hidden_states, encoder_hidden_states 159 | 160 | 161 | def save_video(video, output_path="output.mp4"): 162 | """Save frames as a video file""" 163 | width, height = video[0].size 164 | 165 | container = av.open(output_path, mode="w") 166 | 167 | # Create video stream 168 | codec = "libx264" 169 | pixel_format = "yuv420p" 170 | stream = container.add_stream(codec, rate=24) 171 | stream.width = width 172 | stream.height = height 173 | stream.pix_fmt = pixel_format 174 | stream.bit_rate = 4000000 # 4Mbit/s 175 | 176 | for frame_array in video: 177 | frame = av.VideoFrame.from_image(frame_array) 178 | packets = stream.encode(frame) 179 | for packet in packets: 180 | container.mux(packet) 181 | 182 | # Flush remaining packets 183 | for packet in stream.encode(): 184 | container.mux(packet) 185 | 186 | container.close() 187 | 188 | 189 | def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: 190 | """Resize the image to the bucket resolution.""" 191 | is_pil_image = isinstance(image, Image.Image) 192 | if is_pil_image: 193 | image_width, image_height = image.size 194 | else: 195 | image_height, image_width = image.shape[:2] 196 | 197 | if bucket_reso == (image_width, image_height): 198 | return np.array(image) if is_pil_image else image 199 | 200 | bucket_width, bucket_height = bucket_reso 201 | 202 | scale_width = bucket_width / image_width 203 | scale_height = bucket_height / image_height 204 | scale = max(scale_width, scale_height) 205 | image_width = int(image_width * scale + 0.5) 206 | image_height = int(image_height * scale + 0.5) 207 | 208 | if scale > 1: 209 | image = Image.fromarray(image) if not is_pil_image else image 210 | image = image.resize((image_width, image_height), Image.LANCZOS) 211 | image = np.array(image) 212 | else: 213 | image = np.array(image) if is_pil_image else image 214 | image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) 215 | 216 | # Crop the image to the bucket resolution 217 | crop_left = (image_width - bucket_width) // 2 218 | crop_top = (image_height - bucket_height) // 2 219 | image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] 220 | 221 | return image 222 | 223 | 224 | @torch.inference_mode() 225 | def call_pipe( 226 | pipe, 227 | prompt: Union[str, List[str]] = None, 228 | prompt_2: Union[str, List[str]] = None, 229 | height: int = 720, 230 | width: int = 1280, 231 | num_frames: int = 129, 232 | num_inference_steps: int = 50, 233 | sigmas: List[float] = None, 234 | guidance_scale: float = 6.0, 235 | num_videos_per_prompt: Optional[int] = 1, 236 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 237 | latents: Optional[torch.Tensor] = None, 238 | prompt_embeds: Optional[torch.Tensor] = None, 239 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 240 | prompt_attention_mask: Optional[torch.Tensor] = None, 241 | output_type: Optional[str] = "pil", 242 | return_dict: bool = True, 243 | attention_kwargs: Optional[Dict[str, Any]] = None, 244 | callback_on_step_end: Optional[ 245 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 246 | ] = None, 247 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 248 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, 249 | max_sequence_length: int = 256, 250 | image_latents: Optional[torch.Tensor] = None, 251 | ): 252 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 253 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 254 | 255 | # 1. Check inputs 256 | pipe.check_inputs( 257 | prompt, 258 | prompt_2, 259 | height, 260 | width, 261 | prompt_embeds, 262 | callback_on_step_end_tensor_inputs, 263 | prompt_template, 264 | ) 265 | 266 | pipe._guidance_scale = guidance_scale 267 | pipe._attention_kwargs = attention_kwargs 268 | pipe._current_timestep = None 269 | pipe._interrupt = False 270 | 271 | device = pipe._execution_device 272 | 273 | # 2. Define call parameters 274 | if prompt is not None and isinstance(prompt, str): 275 | batch_size = 1 276 | elif prompt is not None and isinstance(prompt, list): 277 | batch_size = len(prompt) 278 | else: 279 | batch_size = prompt_embeds.shape[0] 280 | 281 | # 3. Encode input prompt 282 | prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt( 283 | prompt=prompt, 284 | prompt_2=prompt_2, 285 | prompt_template=prompt_template, 286 | num_videos_per_prompt=num_videos_per_prompt, 287 | prompt_embeds=prompt_embeds, 288 | pooled_prompt_embeds=pooled_prompt_embeds, 289 | prompt_attention_mask=prompt_attention_mask, 290 | device=device, 291 | max_sequence_length=max_sequence_length, 292 | ) 293 | 294 | transformer_dtype = pipe.transformer.dtype 295 | prompt_embeds = prompt_embeds.to(transformer_dtype) 296 | prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) 297 | if pooled_prompt_embeds is not None: 298 | pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) 299 | 300 | # 4. Prepare timesteps 301 | sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas 302 | timesteps, num_inference_steps = retrieve_timesteps( 303 | pipe.scheduler, 304 | num_inference_steps, 305 | device, 306 | sigmas=sigmas, 307 | ) 308 | 309 | # 5. Prepare latent variables 310 | num_channels_latents = pipe.transformer.config.in_channels 311 | num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1 312 | latents = pipe.prepare_latents( 313 | batch_size * num_videos_per_prompt, 314 | num_channels_latents, 315 | height, 316 | width, 317 | num_latent_frames, 318 | torch.float32, 319 | device, 320 | generator, 321 | latents, 322 | ) 323 | 324 | # 6. Prepare guidance condition 325 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 326 | 327 | # 7. Denoising loop 328 | num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order 329 | pipe._num_timesteps = len(timesteps) 330 | 331 | with pipe.progress_bar(total=num_inference_steps) as progress_bar: 332 | for i, t in enumerate(timesteps): 333 | if pipe.interrupt: 334 | continue 335 | 336 | pipe._current_timestep = t 337 | latent_model_input = latents.to(transformer_dtype) 338 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 339 | 340 | noise_pred = pipe.transformer( 341 | hidden_states=torch.cat([latent_model_input, image_latents], dim=1), 342 | timestep=timestep, 343 | encoder_hidden_states=prompt_embeds, 344 | encoder_attention_mask=prompt_attention_mask, 345 | pooled_projections=pooled_prompt_embeds, 346 | guidance=guidance, 347 | attention_kwargs=attention_kwargs, 348 | return_dict=False, 349 | )[0] 350 | 351 | # Compute the previous noisy sample x_t -> x_t-1 352 | latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 353 | 354 | if callback_on_step_end is not None: 355 | callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} 356 | callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs) 357 | latents = callback_outputs.pop("latents", latents) 358 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 359 | 360 | # Update progress bar 361 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): 362 | progress_bar.update() 363 | 364 | pipe._current_timestep = None 365 | 366 | if output_type != "latent": 367 | latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor 368 | video = pipe.vae.decode(latents, return_dict=False)[0] 369 | video = pipe.video_processor.postprocess_video(video, output_type=output_type) 370 | else: 371 | video = latents 372 | 373 | # Offload all models 374 | pipe.maybe_free_model_hooks() 375 | 376 | return (video,) if not return_dict else HunyuanVideoPipelineOutput(frames=video) 377 | 378 | 379 | def setup_pipeline(model_path, lora_path=None): 380 | """Set up the HunyuanVideo pipeline with optional LoRA weights""" 381 | pipe = HunyuanVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16) 382 | pipe.to("cuda") 383 | pipe.vae.enable_slicing() 384 | pipe.vae.enable_tiling() 385 | 386 | # Apply flash attention to all transformer blocks 387 | for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks: 388 | block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=True, use_sageattn=False) 389 | 390 | # Modify input channels for conditioning 391 | with torch.no_grad(): 392 | initial_input_channels = pipe.transformer.config.in_channels 393 | new_img_in = HunyuanVideoPatchEmbed( 394 | patch_size=( 395 | pipe.transformer.config.patch_size_t, 396 | pipe.transformer.config.patch_size, 397 | pipe.transformer.config.patch_size, 398 | ), 399 | in_chans=pipe.transformer.config.in_channels * 2, 400 | embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim, 401 | ) 402 | new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype) 403 | new_img_in.proj.weight.zero_() 404 | new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight) 405 | 406 | if pipe.transformer.x_embedder.proj.bias is not None: 407 | new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias) 408 | 409 | pipe.transformer.x_embedder = new_img_in 410 | pipe.transformer.x_embedder.requires_grad_(False) 411 | 412 | # Load LoRA weights if provided 413 | if lora_path and os.path.exists(lora_path): 414 | lora_state_dict = pipe.lora_state_dict(lora_path) 415 | 416 | # Load transformer LoRA weights 417 | transformer_lora_state_dict = { 418 | k.replace("transformer.", ""): v 419 | for k, v in lora_state_dict.items() 420 | if k.startswith("transformer.") and "lora" in k 421 | } 422 | pipe.load_lora_into_transformer( 423 | transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe 424 | ) 425 | pipe.set_adapters(["i2v"], adapter_weights=[1.0]) 426 | pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"]) 427 | pipe.unload_lora_weights() 428 | 429 | # Load norm layers if present 430 | NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] 431 | transformer_norm_layers_state_dict = { 432 | k.replace("transformer.", ""): v 433 | for k, v in lora_state_dict.items() 434 | if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES) 435 | } 436 | 437 | if transformer_norm_layers_state_dict: 438 | print("[INFO] Loading normalization layers from state dict...") 439 | transformer_state_dict = pipe.transformer.state_dict() 440 | transformer_keys = set(transformer_state_dict.keys()) 441 | state_dict_keys = set(transformer_norm_layers_state_dict.keys()) 442 | extra_keys = list(state_dict_keys - transformer_keys) 443 | 444 | if extra_keys: 445 | print(f"[WARNING] Ignoring unsupported keys: {extra_keys}") 446 | for key in extra_keys: 447 | transformer_norm_layers_state_dict.pop(key) 448 | 449 | pipe.transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False) 450 | else: 451 | print("[INFO] No normalization layers found in state dict") 452 | 453 | return pipe 454 | 455 | 456 | def prepare_conditioning(pipe, frame1_path, frame2_path, n_frames, height, width): 457 | """Prepare conditioning frames for the model""" 458 | video_transforms = transforms.Compose( 459 | [ 460 | transforms.Lambda(lambda x: x / 255.0), 461 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 462 | ] 463 | ) 464 | 465 | # Load and resize conditioning frames 466 | cond_frame1 = Image.open(frame1_path).convert("RGB") 467 | cond_frame2 = Image.open(frame2_path).convert("RGB") 468 | 469 | cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height)) 470 | cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height)) 471 | 472 | # Create conditioning video tensor 473 | cond_video = np.zeros(shape=(n_frames, height, width, 3)) 474 | cond_video[0], cond_video[-1] = np.array(cond_frame1), np.array(cond_frame2) 475 | 476 | cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2) 477 | cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0) 478 | 479 | # Encode to latent space 480 | with torch.inference_mode(): 481 | image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype) 482 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] 483 | cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample() 484 | cond_latents = cond_latents * pipe.vae.config.scaling_factor 485 | cond_latents = cond_latents.to(dtype=pipe.dtype) 486 | 487 | return cond_latents 488 | 489 | 490 | def main(): 491 | parser = argparse.ArgumentParser(description="Run HunyuanVideo inference with control frames") 492 | parser.add_argument( 493 | "--model", type=str, default="hunyuanvideo-community/HunyuanVideo", help="Path to HunyuanVideo model" 494 | ) 495 | parser.add_argument("--lora", type=str, required=True, help="Path to LoRA weights for image-to-video control") 496 | parser.add_argument("--frame1", type=str, required=True, help="Path to first control frame") 497 | parser.add_argument("--frame2", type=str, required=True, help="Path to second control frame") 498 | parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation") 499 | parser.add_argument("--height", type=int, default=720, help="Output video height") 500 | parser.add_argument("--width", type=int, default=1280, help="Output video width") 501 | parser.add_argument("--frames", type=int, default=77, help="Number of frames to generate") 502 | parser.add_argument("--steps", type=int, default=50, help="Number of inference steps") 503 | parser.add_argument("--guidance", type=float, default=6.0, help="Guidance scale") 504 | parser.add_argument("--seed", type=int, default=None, help="Random seed for generation") 505 | parser.add_argument("--output", type=str, default=None, help="Output video path (default: auto-generated)") 506 | 507 | args = parser.parse_args() 508 | 509 | # Validate inputs 510 | for path in [args.frame1, args.frame2]: 511 | if not os.path.exists(path): 512 | raise FileNotFoundError(f"Input frame not found: {path}") 513 | 514 | if args.lora and not os.path.exists(args.lora): 515 | raise FileNotFoundError(f"LoRA weights not found: {args.lora}") 516 | 517 | # Set random seed 518 | seed = args.seed if args.seed is not None else int(time.time()) % 10000 519 | generator = torch.Generator(device="cuda").manual_seed(seed) 520 | 521 | # Setup pipeline 522 | print(f"Loading model from {args.model}") 523 | pipe = setup_pipeline(args.model, args.lora) 524 | 525 | # Prepare conditioning 526 | print("Preparing conditioning frames...") 527 | cond_latents = prepare_conditioning(pipe, args.frame1, args.frame2, args.frames, args.height, args.width) 528 | 529 | # Generate video 530 | print(f"Generating video with prompt: '{args.prompt}'") 531 | video = call_pipe( 532 | pipe, 533 | prompt=args.prompt, 534 | num_frames=args.frames, 535 | num_inference_steps=args.steps, 536 | image_latents=cond_latents, 537 | width=args.width, 538 | height=args.height, 539 | guidance_scale=args.guidance, 540 | generator=generator, 541 | ).frames[0] 542 | 543 | # Save output 544 | if args.output: 545 | output_path = args.output 546 | else: 547 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d_%H%M%S") 548 | output_path = f"hv-CL-{args.height}x{args.width}x{args.frames}-{time_flag}.mp4" 549 | 550 | print(f"Saving video to {output_path}") 551 | save_video(video, output_path) 552 | print("Done!") 553 | 554 | 555 | if __name__ == "__main__": 556 | main() 557 | -------------------------------------------------------------------------------- /hv_train_control_lora.py: -------------------------------------------------------------------------------- 1 | from icecream import ic, install 2 | 3 | install() 4 | ic.configureOutput(includeContext=True) 5 | 6 | import ast 7 | import gc 8 | import importlib 9 | import json 10 | import logging 11 | import math 12 | import os 13 | import random 14 | import shutil 15 | import sys 16 | import time 17 | import typing 18 | import warnings 19 | from copy import deepcopy 20 | from dataclasses import asdict 21 | from datetime import datetime, timedelta 22 | from pathlib import Path 23 | from pprint import pformat 24 | 25 | import diffusers 26 | import numpy as np 27 | import pyrallis 28 | import torch 29 | import torch.optim.adamw 30 | import transformers 31 | from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType, InitProcessGroupKwargs 32 | from accelerate.logging import get_logger 33 | from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed 34 | from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel 35 | from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed 36 | from diffusers.optimization import get_scheduler 37 | from diffusers.training_utils import ( 38 | cast_training_params, 39 | compute_density_for_timestep_sampling, 40 | compute_loss_weighting_for_sd3, 41 | ) 42 | from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft 43 | from diffusers.utils.torch_utils import is_compiled_module 44 | from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict 45 | from tqdm.auto import tqdm 46 | 47 | from attn_processor import HunyuanVideoFlashAttnProcessor # isort: skip 48 | from config import Config # isort:skip 49 | from mds_dataloaders import build_mds_dataloader # isort: skip 50 | from optim import get_optimizer, max_gradient # isort: skip 51 | from ema import EMAModel # isort: skip 52 | 53 | NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] 54 | 55 | logging.basicConfig( 56 | level=logging.INFO, 57 | format="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s", 58 | force=True, 59 | handlers=[logging.StreamHandler()], 60 | ) 61 | warnings.filterwarnings("ignore") # ignore warning 62 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 63 | 64 | if torch.cuda.is_available(): 65 | torch.backends.cuda.matmul.allow_tf32 = True 66 | torch.backends.cudnn.allow_tf32 = True 67 | 68 | logger = get_logger(__name__) 69 | 70 | 71 | def bytes_to_gigabytes(x: int) -> float: 72 | if x is not None: 73 | return x / 1024**3 74 | 75 | 76 | def free_memory() -> None: 77 | if torch.cuda.is_available(): 78 | gc.collect() 79 | torch.cuda.empty_cache() 80 | torch.cuda.ipc_collect() 81 | 82 | 83 | def get_memory_statistics(precision: int = 3) -> typing.Dict[str, typing.Any]: 84 | memory_allocated = None 85 | memory_reserved = None 86 | max_memory_allocated = None 87 | max_memory_reserved = None 88 | 89 | if torch.cuda.is_available(): 90 | device = torch.cuda.current_device() 91 | memory_allocated = torch.cuda.memory_allocated(device) 92 | memory_reserved = torch.cuda.memory_reserved(device) 93 | max_memory_allocated = torch.cuda.max_memory_allocated(device) 94 | max_memory_reserved = torch.cuda.max_memory_reserved(device) 95 | 96 | elif torch.backends.mps.is_available(): 97 | memory_allocated = torch.mps.current_allocated_memory() 98 | 99 | else: 100 | logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") 101 | 102 | return { 103 | "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), 104 | "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), 105 | "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), 106 | "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), 107 | } 108 | 109 | 110 | def get_nb_trainable_parameters(mod: torch.nn.Module): 111 | r""" 112 | Returns the number of trainable parameters and number of all parameters in the model. 113 | """ 114 | # note: same as PeftModel.get_nb_trainable_parameters 115 | trainable_params = 0 116 | all_param = 0 117 | for _, param in mod.named_parameters(): 118 | num_params = param.numel() 119 | # if using DS Zero 3 and the weights are initialized empty 120 | if num_params == 0 and hasattr(param, "ds_numel"): 121 | num_params = param.ds_numel 122 | 123 | # Due to the design of 4bit linear layers from bitsandbytes 124 | # one needs to multiply the number of parameters by 2 to get 125 | # the correct number of parameters 126 | if param.__class__.__name__ == "Params4bit": 127 | num_params = num_params * 2 128 | 129 | all_param += num_params 130 | if param.requires_grad: 131 | trainable_params += num_params 132 | 133 | return trainable_params, all_param 134 | 135 | 136 | def get_noisy_model_input_and_timesteps( 137 | cfg: Config, latents, noise, noise_scheduler, device, weight_dtype, scheduler_sigmas, generator=None 138 | ): 139 | batch_size = latents.size(0) 140 | if cfg.hparams.flow_match.timestep_sampling == "uniform": 141 | sigmas = torch.rand((batch_size,), device=device, generator=generator) 142 | elif cfg.hparams.flow_match.timestep_sampling == "sigmoid": 143 | # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF 144 | # also used by: https://github.com/XLabs-AI/x-flux/tree/main 145 | # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f 146 | sigmas = torch.sigmoid( 147 | cfg.hparams.flow_match.sigmoid_scale * torch.randn((batch_size,), device=device, generator=generator) 148 | ) 149 | elif cfg.hparams.flow_match.timestep_sampling == "logit_normal": 150 | sigmas = torch.normal( 151 | cfg.hparams.flow_match.logit_mean, 152 | cfg.hparams.flow_match.logit_std, 153 | size=(batch_size,), 154 | device=device, 155 | generator=generator, 156 | ) 157 | sigmas = torch.sigmoid(cfg.hparams.flow_match.sigmoid_scale * sigmas) 158 | 159 | if cfg.hparams.flow_match.discrete_flow_shift is not None and cfg.hparams.flow_match.discrete_flow_shift > 0: 160 | sigmas = (sigmas * cfg.hparams.flow_match.discrete_flow_shift) / ( 161 | 1 + (cfg.hparams.flow_match.discrete_flow_shift - 1) * sigmas 162 | ) 163 | 164 | timesteps, sigmas = (sigmas * 1000.0).long(), sigmas.view(-1, 1, 1, 1, 1) 165 | noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise 166 | return noisy_model_input, timesteps 167 | 168 | 169 | @pyrallis.wrap() 170 | def main(cfg: Config): 171 | if cfg.experiment.ic_debug: 172 | ic.enable() 173 | else: 174 | ic.disable() 175 | 176 | # ====================================================== 177 | # 1. configs & runtime variables 178 | # ====================================================== 179 | output_dirpath = Path(cfg.experiment.output_dirpath) / cfg.experiment.run_id 180 | logging_dirpath = output_dirpath / "logs" 181 | 182 | accelerator_project_config = ProjectConfiguration(project_dir=output_dirpath, logging_dir=logging_dirpath) 183 | ddp_kwargs = DistributedDataParallelKwargs( 184 | find_unused_parameters=cfg.ddp_kwargs.find_unused_parameters, 185 | gradient_as_bucket_view=cfg.ddp_kwargs.gradient_as_bucket_view, 186 | static_graph=cfg.ddp_kwargs.static_graph, 187 | ) 188 | init_kwargs = InitProcessGroupKwargs(backend=cfg.ddp_kwargs.backend, timeout=timedelta(seconds=5400)) 189 | 190 | accelerator = Accelerator( 191 | gradient_accumulation_steps=cfg.hparams.gradient_accumulation_steps, 192 | mixed_precision=cfg.hparams.mixed_precision, 193 | log_with=None, 194 | project_config=accelerator_project_config, 195 | kwargs_handlers=[ddp_kwargs, init_kwargs], 196 | ) 197 | 198 | print(accelerator.state) 199 | 200 | accelerator.print("\nENVIRONMENT\n") 201 | accelerator.print(f" Python .......................... {sys.version}") 202 | accelerator.print(f" torch.__version__ ............... {torch.__version__}") 203 | accelerator.print(f" torch.version.cuda .............. {torch.version.cuda}") 204 | accelerator.print(f" torch.backends.cudnn.version() .. {torch.backends.cudnn.version()}\n") 205 | accelerator.print("\n") 206 | accelerator.print(f">> Run ID : {cfg.experiment.run_id!r}") 207 | 208 | if accelerator.is_main_process: 209 | transformers.utils.logging.set_verbosity_warning() 210 | diffusers.utils.logging.set_verbosity_info() 211 | else: 212 | transformers.utils.logging.set_verbosity_error() 213 | diffusers.utils.logging.set_verbosity_error() 214 | 215 | if cfg.experiment.random_seed is not None: 216 | set_seed(cfg.experiment.random_seed) 217 | 218 | if accelerator.num_processes > 1: 219 | logger.info("DDP VARS: ") 220 | logger.info(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'N/A')}") 221 | logger.info(f" LOCAL_WORLD_SIZE: {os.getenv('LOCAL_WORLD_SIZE', 'N/A')}") 222 | logger.info(f" RANK: {os.getenv('RANK', 'N/A')}") 223 | logger.info(f" MASTER_ADDR: {os.getenv('MASTER_ADDR', 'N/A')}") 224 | logger.info(f" MASTER_PORT: {os.getenv('MASTER_PORT', 'N/A')}") 225 | 226 | if accelerator.is_main_process: 227 | output_dirpath.mkdir(parents=True, exist_ok=True) 228 | if not accelerator.is_main_process: 229 | ic.disable() 230 | 231 | if accelerator.is_main_process: 232 | logger.info(f"Saving config to {output_dirpath / 'config.yaml'}") 233 | yaml_cfg = pyrallis.dump(cfg) 234 | with open(output_dirpath / "config.yaml", "w") as f: 235 | f.write(yaml_cfg) 236 | 237 | logger.info(f"config = \n{pyrallis.dump(cfg)}") 238 | 239 | # ====================================================== 240 | # 2. build model 241 | # ====================================================== 242 | 243 | noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=cfg.hparams.flow_match.discrete_flow_shift) 244 | 245 | load_dtype = torch.bfloat16 246 | logger.info(f"Load transformer model from {cfg.model.pretrained_model_name_or_path!r}") 247 | t0 = time.time() 248 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 249 | cfg.model.pretrained_model_name_or_path, 250 | subfolder="transformer", 251 | torch_dtype=load_dtype, 252 | revision=cfg.model.revision, 253 | variant=cfg.model.variant, 254 | ) 255 | logger.info(f"Loaded transformer model from {cfg.model.pretrained_model_name_or_path!r} in {time.time() - t0:.2f}s") 256 | 257 | for blk in transformer.transformer_blocks + transformer.single_transformer_blocks: 258 | blk.attn.processor = HunyuanVideoFlashAttnProcessor() 259 | logger.info("transformer patch with flash_attn done ok!") 260 | 261 | with torch.no_grad(): 262 | logger.info("expand transformer x_embedder input channels") 263 | t0 = time.time() 264 | initial_input_channels = transformer.config.in_channels 265 | new_img_in = HunyuanVideoPatchEmbed( 266 | patch_size=(transformer.config.patch_size_t, transformer.config.patch_size, transformer.config.patch_size), 267 | in_chans=transformer.config.in_channels * 2, 268 | embed_dim=transformer.config.num_attention_heads * transformer.config.attention_head_dim, 269 | ) 270 | new_img_in.proj.weight.zero_() 271 | new_img_in.proj.weight[:, :initial_input_channels].copy_(transformer.x_embedder.proj.weight) 272 | if transformer.x_embedder.proj.bias is not None: 273 | new_img_in.proj.bias.copy_(transformer.x_embedder.proj.bias) 274 | transformer.x_embedder = new_img_in 275 | assert torch.all(transformer.x_embedder.proj.weight[:, initial_input_channels:] == 0) 276 | transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) 277 | logger.info(f"expanded transformer x_embedder input channels in {time.time() - t0:.2f}s") 278 | accelerator.wait_for_everyone() 279 | 280 | weight_dtype = torch.float32 281 | if accelerator.state.deepspeed_plugin: 282 | # DeepSpeed is handling precision, use what's in the DeepSpeed config 283 | if ( 284 | "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config 285 | and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] 286 | ): 287 | weight_dtype = torch.float16 288 | if ( 289 | "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config 290 | and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] 291 | ): 292 | weight_dtype = torch.bfloat16 293 | else: 294 | if accelerator.mixed_precision == "fp16": 295 | weight_dtype = torch.float16 296 | elif accelerator.mixed_precision == "bf16": 297 | weight_dtype = torch.bfloat16 298 | logger.info(f"configured weight dtype: {weight_dtype!r}") 299 | 300 | if cfg.hparams.gradient_checkpointing: 301 | transformer.enable_gradient_checkpointing() 302 | 303 | transformer.requires_grad_(False) 304 | 305 | transformer.to(accelerator.device, dtype=weight_dtype) 306 | logger.info(f"dit dtype: {next(transformer.parameters()).dtype!r}") 307 | 308 | if cfg.network.lora_layers is not None: 309 | if cfg.network.lora_layers != "all-linear": 310 | target_modules = [layer.strip() for layer in cfg.network.lora_layers.split(",")] 311 | if "x_embedder" not in target_modules: 312 | target_modules.append("x_embedder.proj") 313 | elif cfg.network.lora_layers == "all-linear": 314 | target_modules = set() 315 | for name, module in transformer.named_modules(): 316 | if isinstance(module, torch.nn.Linear): 317 | target_modules.add(name) 318 | target_modules = list(target_modules) 319 | if "x_embedder" not in target_modules: 320 | target_modules.append("x_embedder.proj") 321 | target_modules = [t for t in target_modules if "norm" not in t] 322 | else: 323 | assert cfg.network.target_modules is not None, "either `lora_layers` or `target_modules` must be specified" 324 | target_modules = cfg.network.target_modules 325 | 326 | logger.info(f"using LoRA traning mode: ") 327 | logger.info(f"rank .......................................... {cfg.network.lora_rank!r}") 328 | logger.info(f"alpha ......................................... {cfg.network.lora_alpha!r}") 329 | logger.info(f"target_modules ................................ {json.dumps(target_modules, indent=4)}") 330 | 331 | transformer_lora_config = LoraConfig( 332 | r=cfg.network.lora_rank, 333 | lora_alpha=cfg.network.lora_alpha, 334 | lora_dropout=cfg.network.lora_dropout, 335 | target_modules=target_modules, 336 | init_lora_weights=cfg.network.init_lora_weights, 337 | ) 338 | transformer.add_adapter(transformer_lora_config) 339 | accelerator.wait_for_everyone() 340 | 341 | trainable_params, all_param = get_nb_trainable_parameters(transformer) 342 | logger.info( 343 | f"trainable params: {trainable_params:,d} || " 344 | f"all params: {all_param:,d} || " 345 | f"trainable%: {100 * trainable_params / all_param:.4f}" 346 | ) 347 | 348 | if cfg.network.train_norm_layers: 349 | train_norm_layers = [] 350 | logger.info(f"train norm layers, setting requires_grad to True for layers matching {NORM_LAYER_PREFIXES!r}") 351 | for name, param in transformer.named_parameters(): 352 | if any(k in name for k in NORM_LAYER_PREFIXES): 353 | param.requires_grad_(True) 354 | train_norm_layers.append(name) 355 | logger.info(f"train norm layers ............................. {json.dumps(train_norm_layers, indent=4)}") 356 | 357 | if cfg.hparams.mixed_precision == "fp16": 358 | logger.warning("full fp16 training is unstable, casting params to fp32") 359 | cast_training_params([transformer]) 360 | 361 | if cfg.hparams.gradient_checkpointing: 362 | transformer.enable_gradient_checkpointing() 363 | 364 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) 365 | 366 | ema_model = None 367 | if accelerator.is_main_process and cfg.hparams.ema.use_ema: 368 | logger.info("Using EMA. Creating EMAModel.") 369 | ema_model_cls, ema_model_config = transformer.__class__, transformer.config 370 | ema_model = EMAModel( 371 | cfg.hparams.ema, 372 | accelerator, 373 | parameters=transformer_lora_parameters, 374 | model_cls=ema_model_cls, 375 | model_config=ema_model_config, 376 | decay=cfg.hparams.ema.ema_decay, 377 | foreach=not cfg.hparams.ema.ema_foreach_disable, 378 | ) 379 | logger.info(f"EMA model creation completed with {ema_model.parameter_count():,} parameters") 380 | 381 | accelerator.wait_for_everyone() 382 | 383 | use_deepspeed_optimizer = ( 384 | accelerator.state.deepspeed_plugin is not None 385 | and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config 386 | ) 387 | use_deepspeed_scheduler = ( 388 | accelerator.state.deepspeed_plugin is not None 389 | and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config 390 | ) 391 | 392 | optimizer = get_optimizer( 393 | transformer_lora_parameters, 394 | optimizer_name=cfg.hparams.optimizer_type, 395 | learning_rate=cfg.hparams.learning_rate, 396 | optimizer_args_str=cfg.hparams.optimizer_args, 397 | use_deepspeed=use_deepspeed_optimizer, 398 | ) 399 | 400 | # ====================================================== 401 | # 3. build dataset and dataloaders 402 | # ====================================================== 403 | 404 | train_dataloader = build_mds_dataloader( 405 | remote=cfg.data.remote, 406 | local=cfg.data.local, 407 | batch_size=cfg.data.batch_size, 408 | video_key=cfg.data.video_key, 409 | caption_key=cfg.data.caption_key, 410 | latents_key=cfg.data.latents_key, 411 | latents_cond_key=cfg.data.latents_cond_key, 412 | prompt_embeds_key=cfg.data.prompt_embeds_key, 413 | prompt_attention_mask_key=cfg.data.prompt_attention_mask_key, 414 | pooled_prompt_embeds_key=cfg.data.pooled_prompt_embeds_key, 415 | streaming_kwargs=asdict(cfg.data.streaming_kwargs), 416 | dataloader_kwargs=asdict(cfg.data.dataloader_kwargs), 417 | latent_dtype=weight_dtype, 418 | ) 419 | 420 | # ======================================================= 421 | # 4. distributed training preparation with accelerator 422 | # ======================================================= 423 | 424 | def unwrap_model(model): 425 | model = accelerator.unwrap_model(model) 426 | model = model._orig_mod if is_compiled_module(model) else model 427 | return model 428 | 429 | def save_model_hook(models, weights, output_dir): 430 | if not accelerator.is_main_process: 431 | return 432 | 433 | if cfg.hparams.ema.use_ema and ema_model is not None: 434 | primary_model = unwrap_model(transformer) 435 | ema_model_path = os.path.join(output_dir, "ema_model.pt") 436 | logger.info(f"Saving EMA model state to {ema_model_path!r}") 437 | try: 438 | ema_model.save_state_dict(ema_model_path) 439 | except Exception as e: 440 | logger.error(f"Error saving EMA model: {e!r}") 441 | 442 | # we'll temporarily overwrite the LoRA parameters with the EMA parameters to save it. 443 | logger.info("Saving EMA model to disk.") 444 | trainable_parameters = [p for p in primary_model.parameters() if p.requires_grad] 445 | ema_model.store(trainable_parameters) 446 | ema_model.copy_to(trainable_parameters) 447 | transformer_lora_layers = get_peft_model_state_dict(primary_model) 448 | HunyuanVideoPipeline.save_lora_weights( 449 | os.path.join(output_dir, "ema"), 450 | transformer_lora_layers=transformer_lora_layers, 451 | weight_name=f"{cfg.experiment.name}.sft", 452 | ) 453 | ema_model.restore(trainable_parameters) 454 | 455 | transformer_lora_layers_to_save = None 456 | for model in models: 457 | if isinstance(unwrap_model(model), type(unwrap_model(transformer))): 458 | model = unwrap_model(model) 459 | transformer_lora_layers_to_save = get_peft_model_state_dict(model) 460 | 461 | if cfg.network.train_norm_layers: 462 | transformer_norm_layers_to_save = { 463 | f"transformer.{name}": param 464 | for name, param in model.named_parameters() 465 | if any(k in name for k in NORM_LAYER_PREFIXES) 466 | } 467 | transformer_lora_layers_to_save = { 468 | **transformer_lora_layers_to_save, 469 | **transformer_norm_layers_to_save, 470 | } 471 | else: 472 | raise ValueError(f"unexpected save model: {model.__class__}") 473 | 474 | if weights: 475 | weights.pop() 476 | 477 | HunyuanVideoPipeline.save_lora_weights( 478 | output_dir, 479 | transformer_lora_layers=transformer_lora_layers_to_save, 480 | weight_name=f"{cfg.experiment.name}.sft", 481 | ) 482 | 483 | if hasattr(train_dataloader, "state_dict"): 484 | torch.save(train_dataloader.state_dict(), os.path.join(output_dir, "train_dataloader_state.pt")) 485 | 486 | def load_model_hook(models, input_dir): 487 | if hasattr(train_dataloader, "load_state_dict"): 488 | logger.info(f"Loading train dataloader state from Path: {input_dir!r}") 489 | train_dataloader.load_state_dict(torch.load(os.path.join(input_dir, "train_dataloader_state.pt"))) 490 | 491 | if cfg.hparams.ema.use_ema and ema_model is not None: 492 | logger.info(f"Loading EMA model from Path: {input_dir!r}") 493 | try: 494 | ema_model.load_state_dict(os.path.join(input_dir, "ema_model.pt")) 495 | except Exception as e: 496 | logger.error(f"Could not load EMA model: {e!r}") 497 | 498 | transformer_ = None 499 | if not accelerator.distributed_type == DistributedType.DEEPSPEED: 500 | while len(models) > 0: 501 | model = models.pop() 502 | 503 | if isinstance(model, type(unwrap_model(transformer))): 504 | transformer_ = model 505 | else: 506 | raise ValueError(f"unexpected save model: {model.__class__}") 507 | else: 508 | transformer_ = HunyuanVideoTransformer3DModel.from_pretrained( 509 | cfg.model.pretrained_model_name_or_path, subfolder="transformer" 510 | ) 511 | transformer_.to(accelerator.device, weight_dtype) 512 | 513 | with torch.no_grad(): 514 | initial_input_channels = transformer.config.in_channels 515 | new_img_in = HunyuanVideoPatchEmbed( 516 | patch_size=( 517 | transformer.config.patch_size_t, 518 | transformer.config.patch_size, 519 | transformer.config.patch_size, 520 | ), 521 | in_chans=transformer.config.in_channels * 2, 522 | embed_dim=transformer.config.num_attention_heads 523 | * transformer.config.num_attention_heads.attention_head_dim, 524 | ) 525 | new_img_in.proj.weight.zero_() 526 | new_img_in.proj.weight[:, :initial_input_channels].copy_(transformer.x_embedder.proj.weight) 527 | if transformer.x_embedder.proj.bias is not None: 528 | new_img_in.proj.bias.copy_(transformer.x_embedder.proj.bias) 529 | transformer.x_embedder = new_img_in 530 | transformer.register_to_config( 531 | in_channels=initial_input_channels * 2, out_channels=initial_input_channels 532 | ) 533 | 534 | transformer_.add_adapter(transformer_lora_config) 535 | 536 | lora_weight_name = os.path.join(input_dir, f"{cfg.experiment.name}.sft") 537 | logger.info(f"Loading LoRA weights from Path: {lora_weight_name!r}") 538 | lora_state_dict = HunyuanVideoPipeline.lora_state_dict(lora_weight_name) 539 | transformer_lora_state_dict = { 540 | f'{k.replace("transformer.", "")}': v 541 | for k, v in lora_state_dict.items() 542 | if k.startswith("transformer.") and "lora" in k 543 | } 544 | # transformer_lora_state_dict = convert_unet_state_dict_to_peft(transformer_lora_state_dict) 545 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_lora_state_dict, adapter_name="default") 546 | 547 | if incompatible_keys is not None: 548 | # check only for unexpected keys 549 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 550 | if unexpected_keys: 551 | logger.warning( 552 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 553 | f" {unexpected_keys}. " 554 | ) 555 | 556 | if cfg.network.train_norm_layers: 557 | transformer_norm_layers_state_dict = { 558 | f'{k.replace("transformer.", "")}': v 559 | for k, v in lora_state_dict.items() 560 | if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES) 561 | } 562 | for key in list(transformer_norm_layers_state_dict.keys()): 563 | if key.split(".")[0] == "transformer": 564 | transformer_norm_layers_state_dict[ 565 | key[len(f"transformer.") :] 566 | ] = transformer_norm_layers_state_dict.pop(key) 567 | 568 | transformer_state_dict = transformer.state_dict() 569 | transformer_keys = set(transformer_state_dict.keys()) 570 | state_dict_keys = set(transformer_norm_layers_state_dict.keys()) 571 | extra_keys = list(state_dict_keys - transformer_keys) 572 | 573 | if extra_keys: 574 | logger.warning( 575 | f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}." 576 | ) 577 | 578 | for key in extra_keys: 579 | transformer_norm_layers_state_dict.pop(key) 580 | 581 | # We can't load with strict=True because the current state_dict does not contain all the transformer keys 582 | incompatible_keys = transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False) 583 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 584 | 585 | # We shouldn't expect to see the supported norm keys here being present in the unexpected keys. 586 | if unexpected_keys: 587 | if any(norm_key in k for k in unexpected_keys for norm_key in NORM_LAYER_PREFIXES): 588 | raise ValueError( 589 | f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer." 590 | ) 591 | 592 | if cfg.hparams.mixed_precision == "fp16": 593 | cast_training_params([transformer_]) 594 | 595 | logger.info(f"Completed loading checkpoint from Path: {input_dir!r}") 596 | 597 | accelerator.register_save_state_pre_hook(save_model_hook) 598 | accelerator.register_load_state_pre_hook(load_model_hook) 599 | 600 | # The epoch_size attribute of StreamingDataset is the number of samples per epoch of training. 601 | # The __len__() method returns the epoch_size divided by the number of devices – it is the number of samples seen per device, per epoch. 602 | # The size() method returns the number of unique samples in the underlying dataset. 603 | # Due to upsampling/downsampling, size() may not be the same as epoch_size. 604 | if cfg.hparams.max_train_steps is None: 605 | len_train_dataloader_after_sharding = len(train_dataloader) 606 | num_update_steps_per_epoch = math.ceil( 607 | len_train_dataloader_after_sharding / cfg.hparams.gradient_accumulation_steps 608 | ) 609 | num_training_steps_for_scheduler = ( 610 | cfg.hparams.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes 611 | ) 612 | else: 613 | num_training_steps_for_scheduler = cfg.hparams.max_train_steps * accelerator.num_processes 614 | 615 | if use_deepspeed_scheduler: 616 | from accelerate.utils import DummyScheduler 617 | 618 | lr_scheduler = DummyScheduler( 619 | name=cfg.hparams.lr_scheduler, 620 | optimizer=optimizer, 621 | total_num_steps=num_training_steps_for_scheduler, 622 | num_warmup_steps=cfg.hparams.lr_warmup_steps * accelerator.num_processes, 623 | ) 624 | else: 625 | lr_scheduler = get_scheduler( 626 | name=cfg.hparams.lr_scheduler, 627 | optimizer=optimizer, 628 | num_training_steps=num_training_steps_for_scheduler, 629 | num_warmup_steps=cfg.hparams.lr_warmup_steps * accelerator.num_processes, 630 | num_cycles=cfg.hparams.lr_scheduler_num_cycles, 631 | power=cfg.hparams.lr_scheduler_power, 632 | ) 633 | 634 | # not need to wrap dataloader because mosaicml-streaming handles it internally 635 | # the config should be passed via deepspeed json 636 | if accelerator.state.deepspeed_plugin is not None: 637 | d = transformer.config.num_attention_heads * transformer.config.attention_head_dim 638 | accelerator.state.deepspeed_plugin.deepspeed_config["zero_optimization"]["reduce_bucket_size"] = d 639 | accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = cfg.data.batch_size 640 | accelerator.state.deepspeed_plugin.deepspeed_config[ 641 | "gradient_accumulation_steps" 642 | ] = cfg.hparams.gradient_accumulation_steps 643 | 644 | # Prepare everything with our `accelerator`. 645 | # passing dataloader is required to resolve deepspeed 'auto' params, but we do 646 | transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler) 647 | 648 | if cfg.hparams.ema.use_ema and ema_model is not None: 649 | if cfg.hparams.ema.ema_device == "accelerator": 650 | logger.info("Moving EMA model weights to accelerator...") 651 | 652 | ema_model.to((accelerator.device if cfg.hparams.ema.ema_device == "accelerator" else "cpu"), dtype=weight_dtype) 653 | 654 | if cfg.hparams.ema.ema_device == "cpu" and not cfg.hparams.ema.ema_cpu_only: 655 | logger.info("Pinning EMA model weights to CPU...") 656 | try: 657 | ema_model.pin_memory() 658 | except Exception as e: 659 | logger.error(f"Failed to pin EMA model to CPU: {e}") 660 | 661 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.hparams.gradient_accumulation_steps) 662 | if cfg.hparams.max_train_steps is None: 663 | cfg.hparams.max_train_steps = cfg.hparams.num_train_epochs * num_update_steps_per_epoch 664 | if num_training_steps_for_scheduler != cfg.hparams.max_train_steps * accelerator.num_processes: 665 | logger.warning( 666 | f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " 667 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " 668 | f"This inconsistency may result in the learning rate scheduler not functioning properly." 669 | ) 670 | 671 | # Afterwards we recalculate our number of training epochs 672 | cfg.hparams.num_train_epochs = math.ceil(cfg.hparams.max_train_steps / num_update_steps_per_epoch) 673 | 674 | total_batch_size = cfg.data.batch_size * accelerator.num_processes * cfg.hparams.gradient_accumulation_steps 675 | num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters) 676 | 677 | # fmt: off 678 | logger.info("***** Running training *****") 679 | logger.info(f" Num trainable parameters ........................................... {num_trainable_parameters}") 680 | logger.info(f" Num examples ....................................................... {train_dataloader.dataset.size}") 681 | logger.info(f" Num batches each epoch ............................................. {len(train_dataloader)}") 682 | logger.info(f" Num epochs ......................................................... {cfg.hparams.num_train_epochs}") 683 | logger.info(f" Instantaneous batch size per device ................................ {cfg.data.batch_size}") 684 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) ... {total_batch_size}") 685 | logger.info(f" Gradient accumulation steps ........................................ {cfg.hparams.gradient_accumulation_steps}") 686 | logger.info(f" Total optimization steps ........................................... {cfg.hparams.max_train_steps}") 687 | # fmt: on 688 | 689 | global_step, first_epoch = 0, 0 690 | 691 | if not cfg.checkpointing.resume_from_checkpoint: 692 | initial_global_step = 0 693 | else: 694 | if cfg.checkpointing.resume_from_checkpoint != "latest": 695 | path = cfg.checkpointing.resume_from_checkpoint 696 | else: 697 | # Get the mos recent checkpoint 698 | dirs = os.listdir(cfg.experiment.output_dir) 699 | dirs = [d for d in dirs if d.startswith("checkpoint")] 700 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 701 | path = dirs[-1] if len(dirs) > 0 else None 702 | path = os.path.join(cfg.experiment.output_dir, path) 703 | 704 | if path is None: 705 | accelerator.print( 706 | f"Checkpoint {cfg.checkpointing.resume_from_checkpoint!r} does not exist. Starting a new training run." 707 | ) 708 | cfg.checkpointing.resume_from_checkpoint = None 709 | initial_global_step = 0 710 | else: 711 | accelerator.print(f"Resuming from checkpoint {path!r}") 712 | accelerator.load_state(path) 713 | 714 | global_step = int(path.split("checkpoint-step")[1]) 715 | initial_global_step = global_step 716 | first_epoch = global_step // num_update_steps_per_epoch 717 | 718 | logger.info(f"Override: global_step={initial_global_step} | first_epoch={first_epoch}") 719 | 720 | memory_statistics = get_memory_statistics() 721 | logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") 722 | 723 | # ======================================================= 724 | # 5. training loop 725 | # ======================================================= 726 | accelerator.wait_for_everyone() 727 | progress_bar = tqdm( 728 | range(0, cfg.hparams.max_train_steps), 729 | initial=initial_global_step, 730 | desc="Steps", 731 | dynamic_ncols=True, 732 | disable=not accelerator.is_main_process, 733 | ) 734 | 735 | model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config 736 | generator = torch.Generator(device=accelerator.device) 737 | scheduler_sigmas = noise_scheduler.sigmas.clone().to(device=accelerator.device, dtype=weight_dtype) 738 | if cfg.experiment.random_seed is not None: 739 | generator = generator.manual_seed(cfg.experiment.random_seed) 740 | 741 | for epoch in range(first_epoch, cfg.hparams.num_train_epochs): 742 | logger.info(f"epoch {epoch+1}/{ cfg.hparams.num_train_epochs}") 743 | transformer.train() 744 | 745 | for step, batch in enumerate(train_dataloader): 746 | models_to_accumulate = [transformer] 747 | 748 | with accelerator.accumulate(models_to_accumulate): 749 | latents, cond_latents = batch[cfg.data.latents_key], batch[cfg.data.latents_cond_key] 750 | 751 | latents = latents.to(accelerator.device, dtype=weight_dtype).contiguous() 752 | cond_latents = cond_latents.to(accelerator.device, dtype=weight_dtype).contiguous() 753 | 754 | batch_size = latents.size(0) 755 | 756 | prompt_embeds = batch[cfg.data.prompt_embeds_key].to(accelerator.device, dtype=weight_dtype) 757 | pooled_prompt_embeds = batch[cfg.data.pooled_prompt_embeds_key].to( 758 | accelerator.device, dtype=weight_dtype 759 | ) 760 | prompt_attention_mask = batch[cfg.data.prompt_attention_mask_key].to( 761 | accelerator.device, dtype=torch.bool 762 | ) 763 | 764 | if random.random() < cfg.hparams.caption_dropout_p: 765 | prompt_embeds.fill_(0) 766 | pooled_prompt_embeds.fill_(0) 767 | prompt_attention_mask.fill_(False) 768 | 769 | noise = torch.randn(latents.shape, device=accelerator.device, dtype=weight_dtype, generator=generator) 770 | 771 | noisy_model_input, timesteps = get_noisy_model_input_and_timesteps( 772 | cfg=cfg, 773 | latents=latents, 774 | noise=noise, 775 | noise_scheduler=noise_scheduler, 776 | device=accelerator.device, 777 | weight_dtype=weight_dtype, 778 | generator=generator, 779 | scheduler_sigmas=scheduler_sigmas, 780 | ) 781 | noisy_model_input = noisy_model_input.to(weight_dtype) 782 | 783 | weighting = compute_loss_weighting_for_sd3( 784 | cfg.hparams.flow_match.weighting_scheme, sigmas=scheduler_sigmas 785 | ) 786 | while len(weighting.shape) < latents.ndim: 787 | weighting = weighting.unsqueeze(-1) 788 | guidance_vec = ( 789 | torch.full((batch_size,), float(cfg.hparams.guidance_scale), device=accelerator.device) * 1000.0 790 | ) 791 | 792 | ic(noisy_model_input.shape, cond_latents.shape) 793 | ic( 794 | step, 795 | guidance_vec, 796 | weighting, 797 | timesteps, 798 | prompt_embeds.shape, 799 | prompt_attention_mask.shape, 800 | pooled_prompt_embeds.shape, 801 | ) 802 | 803 | denoised_latents = transformer( 804 | hidden_states=torch.cat([noisy_model_input, cond_latents], dim=1), 805 | timestep=timesteps, 806 | encoder_hidden_states=prompt_embeds, 807 | pooled_projections=pooled_prompt_embeds, 808 | encoder_attention_mask=prompt_attention_mask, 809 | guidance=guidance_vec, 810 | return_dict=False, 811 | )[0] 812 | 813 | target = noise - latents 814 | loss = torch.nn.functional.mse_loss(denoised_latents.to(weight_dtype), target, reduction="none") 815 | 816 | if weighting is not None: 817 | loss = loss * weighting 818 | 819 | loss = loss.mean() 820 | 821 | assert torch.isnan(loss) == False, "NaN loss detected" 822 | 823 | accelerator.backward(loss) 824 | 825 | if cfg.hparams.gradient_precision == "fp32": 826 | for param in transformer_lora_parameters: 827 | if param.grad is not None: 828 | param.grad.data = param.grad.data.to(torch.float32) 829 | 830 | grad_norm = max_gradient(transformer_lora_parameters) 831 | if accelerator.sync_gradients: 832 | if accelerator.distributed_type == DistributedType.DEEPSPEED: 833 | grad_norm = transformer.get_global_grad_norm() 834 | 835 | elif cfg.hparams.max_grad_norm > 0: 836 | if cfg.hparams.grad_clip_method == "norm": 837 | grad_norm = accelerator.clip_grad_norm_( 838 | transformer_lora_parameters, cfg.hparams.max_grad_norm 839 | ) 840 | elif cfg.hparams.grad_clip_method == "value": 841 | grad_norm = accelerator.clip_grad_value_( 842 | transformer_lora_parameters, cfg.hparams.max_grad_norm 843 | ) 844 | 845 | if torch.is_tensor(grad_norm): 846 | grad_norm = grad_norm.item() 847 | 848 | optimizer.step() 849 | optimizer.zero_grad(set_to_none=True) 850 | lr_scheduler.step() 851 | 852 | if accelerator.sync_gradients: 853 | progress_bar.update(1) 854 | global_step += 1 855 | 856 | if cfg.hparams.ema.use_ema and ema_model is not None: 857 | ema_model.step(parameters=transformer_lora_parameters, global_step=global_step) 858 | 859 | if accelerator.is_main_process: 860 | if global_step % cfg.checkpointing.save_every_n_steps == 0: 861 | save_path = os.path.join(output_dirpath, f"checkpoint-step{global_step:08d}") 862 | accelerator.save_state(save_path) 863 | logger.info(f"Saved state to {save_path!r}") 864 | 865 | if cfg.checkpointing.save_last_n_steps is not None: 866 | remove_step_no = global_step - cfg.checkpointing.save_last_n_steps - 1 867 | remove_step_no = remove_step_no - (remove_step_no % cfg.checkpointing.save_every_n_steps) 868 | if remove_step_no < 0: 869 | remove_step_no = None 870 | if remove_step_no is not None: 871 | remove_ckpt_name = os.path.join(output_dirpath, f"checkpoint-step{remove_step_no:08d}") 872 | if os.path.exists(remove_ckpt_name): 873 | logger.info(f"removing old checkpoint: {remove_ckpt_name!r}") 874 | shutil.rmtree(remove_ckpt_name) 875 | 876 | logs = {} 877 | logs["loss"] = accelerator.reduce(loss.detach().clone(), reduction="mean").item() 878 | logs["grad_norm"] = grad_norm 879 | logs["lr"] = lr_scheduler.get_last_lr()[0] 880 | if ema_model is not None: 881 | logs["ema_decay"] = ema_model.get_decay() 882 | progress_bar.set_postfix(**logs) 883 | accelerator.log(logs, step=global_step) 884 | 885 | if global_step >= cfg.hparams.max_train_steps: 886 | logger.info(f"max training steps={cfg.hparams.max_train_steps!r} reached.") 887 | break 888 | 889 | memory_statistics = get_memory_statistics() 890 | logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") 891 | 892 | accelerator.wait_for_everyone() 893 | if accelerator.is_main_process: 894 | if cfg.hparams.ema.use_ema and ema_model is not None: 895 | ema_model.copy_to(transformer_lora_parameters) 896 | 897 | transformer = unwrap_model(transformer) 898 | transformer_lora_layers = get_peft_model_state_dict(transformer) 899 | HunyuanVideoPipeline.save_lora_weights( 900 | output_dirpath, 901 | transformer_lora_layers=transformer_lora_layers, 902 | safe_serialization=True, 903 | weight_name=f"{cfg.experiment.name}.sft", 904 | ) 905 | accelerator.wait_for_everyone() 906 | 907 | memory_statistics = get_memory_statistics() 908 | logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") 909 | accelerator.end_training() 910 | 911 | 912 | if __name__ == "__main__": 913 | main() 914 | -------------------------------------------------------------------------------- /mds_dataloaders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from streaming import Stream, StreamingDataLoader, StreamingDataset 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | def make_streams(remote, local=None, proportion=None, repeat=None, choose=None): 14 | """Helper function to create a list of Stream objects from a set of remotes and stream weights. 15 | 16 | Args: 17 | remote (Union[str, Sequence[str]]): The remote path or paths to stream from. 18 | local (Union[str, Sequence[str]], optional): The local path or paths to cache the data. If not provided, the 19 | default local path is used. Default: ``None``. 20 | proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``. 21 | repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``. 22 | choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``. 23 | 24 | Returns: 25 | List[Stream]: A list of Stream objects. 26 | """ 27 | remote, local = _make_remote_and_local_sequences(remote, local) 28 | proportion, repeat, choose = _make_weighting_sequences(remote, proportion, repeat, choose) 29 | 30 | streams = [] 31 | for i, (r, l) in enumerate(zip(remote, local)): 32 | streams.append(Stream(remote=r, local=l, proportion=proportion[i], repeat=repeat[i], choose=choose[i])) 33 | return streams 34 | 35 | 36 | def _make_remote_and_local_sequences(remote, local=None): 37 | if isinstance(remote, str): 38 | remote = [remote] 39 | if isinstance(local, str): 40 | local = [local] 41 | if not local: 42 | local = [_make_default_local_path(r) for r in remote] 43 | 44 | if isinstance(remote, Sequence) and isinstance(local, Sequence): 45 | if len(remote) != len(local): 46 | ValueError( 47 | f"remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}" 48 | ) 49 | else: 50 | ValueError(f"remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.") 51 | return remote, local 52 | 53 | 54 | def _make_default_local_path(remote_path): 55 | return str(Path(*["/tmp"] + list(Path(remote_path).parts[1:]))) 56 | 57 | 58 | def _make_weighting_sequences(remote, proportion=None, repeat=None, choose=None): 59 | weights = {"proportion": proportion, "repeat": repeat, "choose": choose} 60 | for name, weight in weights.items(): 61 | if weight is not None and len(remote) != len(weight): 62 | ValueError(f"{name} must be the same length as remote, got lengths {len(remote)} and {len(weight)}") 63 | proportion = weights["proportion"] if weights["proportion"] is not None else [None] * len(remote) 64 | repeat = weights["repeat"] if weights["repeat"] is not None else [None] * len(remote) 65 | choose = weights["choose"] if weights["choose"] is not None else [None] * len(remote) 66 | return proportion, repeat, choose 67 | 68 | 69 | class StreamingVideoCaptionLatentsDataset(StreamingDataset): 70 | def __init__( 71 | self, 72 | streams: Sequence[Stream], 73 | video_key: str = "video", 74 | caption_key: str = "caption", 75 | latents_key: str = "latents", 76 | latents_cond_key: str = "latents_cond", 77 | prompt_embeds_key: str = "prompt_embeds", 78 | prompt_attention_mask_key: str = "prompt_attention_mask", 79 | pooled_prompt_embeds_key: str = "pooled_prompt_embeds", 80 | latent_dtype: torch.dtype = torch.bfloat16, 81 | batch_size: int = None, 82 | **streaming_kwargs, 83 | ): 84 | streaming_kwargs.setdefault("shuffle_block_size", 1 << 18) 85 | streaming_kwargs.setdefault("shuffle_algo", "py1s") 86 | super().__init__(streams=streams, batch_size=batch_size, **streaming_kwargs) 87 | 88 | self.video_key = video_key 89 | self.caption_key = caption_key 90 | self.latents_key = latents_key 91 | self.prompt_embeds_key = prompt_embeds_key 92 | self.latents_cond_key = latents_cond_key 93 | self.prompt_attention_mask_key = prompt_attention_mask_key 94 | self.pooled_prompt_embeds_key = pooled_prompt_embeds_key 95 | self.latent_dtype = latent_dtype 96 | 97 | def __getitem__(self, index): 98 | sample = super().__getitem__(index) 99 | 100 | out = {} 101 | 102 | latents = torch.from_numpy(sample[self.latents_key].copy()).to(dtype=self.latent_dtype) 103 | latents_cond = torch.from_numpy(sample[self.latents_cond_key].copy()).to(dtype=self.latent_dtype) 104 | 105 | prompt_embeds = torch.from_numpy(sample[self.prompt_embeds_key].copy()).to(dtype=self.latent_dtype) 106 | pooled_prompt_embeds = torch.from_numpy(sample[self.pooled_prompt_embeds_key].copy()).to( 107 | dtype=self.latent_dtype 108 | ) 109 | prompt_attention_mask = torch.from_numpy(sample[self.prompt_attention_mask_key].copy()).to(dtype=torch.bool) 110 | 111 | out[self.latents_key] = latents 112 | out[self.latents_cond_key] = latents_cond 113 | out[self.prompt_embeds_key] = prompt_embeds 114 | out[self.pooled_prompt_embeds_key] = pooled_prompt_embeds 115 | out[self.prompt_attention_mask_key] = prompt_attention_mask 116 | 117 | return out 118 | 119 | 120 | def build_mds_dataloader( 121 | remote: Union[str, List], 122 | local: Union[str, List], 123 | batch_size: int, 124 | video_key: str = "video", 125 | caption_key: str = "caption", 126 | latents_key: str = "latents", 127 | latents_cond_key: str = "latents_cond", 128 | prompt_embeds_key: str = "prompt_embeds", 129 | prompt_attention_mask_key: str = "prompt_attention_mask", 130 | pooled_prompt_embeds_key: str = "pooled_prompt_embeds", 131 | latent_dtype: torch.dtype = torch.bfloat16, 132 | proportion: Optional[list] = None, 133 | repeat: Optional[list] = None, 134 | choose: Optional[list] = None, 135 | streaming_kwargs: Optional[Dict] = None, 136 | dataloader_kwargs: Optional[Dict] = None, 137 | ): 138 | if streaming_kwargs is None: 139 | streaming_kwargs = {} 140 | if dataloader_kwargs is None: 141 | dataloader_kwargs = {} 142 | 143 | # streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose) 144 | if isinstance(local, str): 145 | local = [local] 146 | streams = [Stream(local=l) for l in local] 147 | 148 | dataset = StreamingVideoCaptionLatentsDataset( 149 | streams=streams, 150 | video_key=video_key, 151 | caption_key=caption_key, 152 | latents_key=latents_key, 153 | latents_cond_key=latents_cond_key, 154 | prompt_embeds_key=prompt_embeds_key, 155 | prompt_attention_mask_key=prompt_attention_mask_key, 156 | pooled_prompt_embeds_key=pooled_prompt_embeds_key, 157 | latent_dtype=latent_dtype, 158 | batch_size=batch_size, 159 | **streaming_kwargs, 160 | ) 161 | 162 | dataloader = StreamingDataLoader( 163 | dataset=dataset, 164 | batch_size=batch_size, 165 | sampler=None, 166 | **dataloader_kwargs, 167 | ) 168 | 169 | return dataloader 170 | -------------------------------------------------------------------------------- /notebooks/000-inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%env CUDA_VISIBLE_DEVICES=0\n", 10 | "%env HF_HUB_ENABLE_HF_TRANSFER=1" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import sys\n", 21 | "import time\n", 22 | "import uuid\n", 23 | "from datetime import datetime\n", 24 | "from typing import *\n", 25 | "\n", 26 | "import av\n", 27 | "import cv2\n", 28 | "import diffusers\n", 29 | "import numpy as np\n", 30 | "import safetensors.torch\n", 31 | "import torch\n", 32 | "import torch.nn.functional as F\n", 33 | "import torchvision.transforms.v2 as transforms\n", 34 | "from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline\n", 35 | "from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\n", 36 | "from diffusers.loaders import HunyuanVideoLoraLoaderMixin\n", 37 | "from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel\n", 38 | "from diffusers.models.attention import Attention\n", 39 | "from diffusers.models.embeddings import apply_rotary_emb\n", 40 | "from diffusers.models.transformers.transformer_hunyuan_video import *\n", 41 | "from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel\n", 42 | "from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps\n", 43 | "from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput\n", 44 | "from diffusers.pipelines.pipeline_utils import DiffusionPipeline\n", 45 | "from diffusers.schedulers import FlowMatchEulerDiscreteScheduler\n", 46 | "from diffusers.utils import export_to_video, is_torch_xla_available, logging, replace_example_docstring\n", 47 | "from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft\n", 48 | "from diffusers.utils.torch_utils import randn_tensor\n", 49 | "from diffusers.video_processor import VideoProcessor\n", 50 | "from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\n", 51 | "from PIL import Image" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "print(\"\\nENVIRONMENT\\n\")\n", 61 | "print(f\" Python {sys.version}\\n\")\n", 62 | "print(f\" torch.__version__ = {torch.__version__}\")\n", 63 | "print(f\" torch.version.cuda = {torch.version.cuda}\")\n", 64 | "print(f\" torch.backends.cudnn.version() = {torch.backends.cudnn.version()}\")\n", 65 | "print(f\" diffusers.__version__ = {diffusers.__version__}\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "try:\n", 75 | " import flash_attn\n", 76 | " from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func\n", 77 | "except ImportError:\n", 78 | " flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None\n", 79 | "\n", 80 | "try:\n", 81 | " from sageattention import sageattn, sageattn_varlen\n", 82 | "except ImportError:\n", 83 | " sageattn, sageattn_varlen = None, None\n", 84 | "\n", 85 | "\n", 86 | "def get_cu_seqlens(attention_mask):\n", 87 | " \"\"\"Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len\n", 88 | "\n", 89 | " Args:\n", 90 | " text_mask (torch.Tensor): the mask of text\n", 91 | " img_len (int): the length of image\n", 92 | "\n", 93 | " Returns:\n", 94 | " torch.Tensor: the calculated cu_seqlens for flash attention\n", 95 | " \"\"\"\n", 96 | " batch_size = attention_mask.shape[0]\n", 97 | " text_len = attention_mask.sum(dim=-1, dtype=torch.int) # .flatten()\n", 98 | " max_len = attention_mask.shape[-1]\n", 99 | "\n", 100 | " cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=\"cuda\")\n", 101 | "\n", 102 | " for i in range(batch_size):\n", 103 | " s = text_len[i]\n", 104 | " s1 = i * max_len + s\n", 105 | " s2 = (i + 1) * max_len\n", 106 | " cu_seqlens[2 * i + 1] = s1\n", 107 | " cu_seqlens[2 * i + 2] = s2\n", 108 | "\n", 109 | " return cu_seqlens\n", 110 | "\n", 111 | "\n", 112 | "class HunyuanVideoFlashAttnProcessor:\n", 113 | " def __init__(self, use_flash_attn=True, use_sageattn=False):\n", 114 | " self.use_flash_attn = use_flash_attn\n", 115 | " self.use_sageattn = use_sageattn\n", 116 | " if self.use_flash_attn:\n", 117 | " assert flash_attn is not None\n", 118 | " if self.use_sageattn:\n", 119 | " assert sageattn is not None\n", 120 | "\n", 121 | " def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):\n", 122 | " if attn.add_q_proj is None and encoder_hidden_states is not None:\n", 123 | " hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)\n", 124 | "\n", 125 | " query = attn.to_q(hidden_states)\n", 126 | " key = attn.to_k(hidden_states)\n", 127 | " value = attn.to_v(hidden_states)\n", 128 | "\n", 129 | " query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 130 | " key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 131 | " value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 132 | "\n", 133 | " if attn.norm_q is not None:\n", 134 | " query = attn.norm_q(query)\n", 135 | " if attn.norm_k is not None:\n", 136 | " key = attn.norm_k(key)\n", 137 | "\n", 138 | " if image_rotary_emb is not None:\n", 139 | " if attn.add_q_proj is None and encoder_hidden_states is not None:\n", 140 | " query = torch.cat([apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), query[:, :, -encoder_hidden_states.shape[1] :]], dim=2)\n", 141 | " key = torch.cat([apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), key[:, :, -encoder_hidden_states.shape[1] :]], dim=2)\n", 142 | " else:\n", 143 | " query = apply_rotary_emb(query, image_rotary_emb)\n", 144 | " key = apply_rotary_emb(key, image_rotary_emb)\n", 145 | "\n", 146 | " batch_size = hidden_states.shape[0]\n", 147 | " img_seq_len = hidden_states.shape[1]\n", 148 | " txt_seq_len = 0\n", 149 | "\n", 150 | " if attn.add_q_proj is not None and encoder_hidden_states is not None:\n", 151 | " encoder_query = attn.add_q_proj(encoder_hidden_states)\n", 152 | " encoder_key = attn.add_k_proj(encoder_hidden_states)\n", 153 | " encoder_value = attn.add_v_proj(encoder_hidden_states)\n", 154 | "\n", 155 | " encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 156 | " encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 157 | " encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n", 158 | "\n", 159 | " if attn.norm_added_q is not None:\n", 160 | " encoder_query = attn.norm_added_q(encoder_query)\n", 161 | " if attn.norm_added_k is not None:\n", 162 | " encoder_key = attn.norm_added_k(encoder_key)\n", 163 | "\n", 164 | " query = torch.cat([query, encoder_query], dim=2)\n", 165 | " key = torch.cat([key, encoder_key], dim=2)\n", 166 | " value = torch.cat([value, encoder_value], dim=2)\n", 167 | "\n", 168 | " txt_seq_len = encoder_hidden_states.shape[1]\n", 169 | "\n", 170 | " max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len\n", 171 | " cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)\n", 172 | "\n", 173 | " query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3])\n", 174 | " key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3])\n", 175 | " value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3])\n", 176 | "\n", 177 | " if self.use_flash_attn:\n", 178 | " hidden_states = flash_attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)\n", 179 | " elif self.use_sageattn:\n", 180 | " hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)\n", 181 | " else:\n", 182 | " raise NotImplementedError(\"Please set use_flash_attn=True or use_sageattn=True\")\n", 183 | "\n", 184 | " hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1)\n", 185 | " hidden_states = hidden_states.to(query.dtype)\n", 186 | "\n", 187 | " if encoder_hidden_states is not None:\n", 188 | " hidden_states, encoder_hidden_states = (hidden_states[:, : -encoder_hidden_states.shape[1]], hidden_states[:, -encoder_hidden_states.shape[1] :])\n", 189 | "\n", 190 | " if getattr(attn, \"to_out\", None) is not None:\n", 191 | " hidden_states = attn.to_out[0](hidden_states)\n", 192 | " hidden_states = attn.to_out[1](hidden_states)\n", 193 | "\n", 194 | " if getattr(attn, \"to_add_out\", None) is not None:\n", 195 | " encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n", 196 | "\n", 197 | " return hidden_states, encoder_hidden_states" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "def save_video(video, output_path=\"output.mp4\"):\n", 207 | " width, height = video[0].size\n", 208 | "\n", 209 | " container = av.open(output_path, mode=\"w\")\n", 210 | "\n", 211 | " # create video stream\n", 212 | " codec = \"libx264\"\n", 213 | " pixel_format = \"yuv420p\"\n", 214 | " stream = container.add_stream(codec, rate=24)\n", 215 | " stream.width = width\n", 216 | " stream.height = height\n", 217 | " stream.pix_fmt = pixel_format\n", 218 | " stream.bit_rate = 4000000 # 4Mbit/s\n", 219 | "\n", 220 | " for frame_array in video:\n", 221 | " frame = av.VideoFrame.from_image(frame_array)\n", 222 | " packets = stream.encode(frame)\n", 223 | " for packet in packets:\n", 224 | " container.mux(packet)\n", 225 | "\n", 226 | " for packet in stream.encode():\n", 227 | " container.mux(packet)\n", 228 | "\n", 229 | " container.close()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "pipe = HunyuanVideoPipeline.from_pretrained(\"hunyuanvideo-community/HunyuanVideo\", torch_dtype=torch.bfloat16)\n", 239 | "pipe.to(\"cuda\")\n", 240 | "pipe.vae.enable_slicing()\n", 241 | "pipe.vae.enable_tiling()\n", 242 | "\n", 243 | "for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks:\n", 244 | " block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=True, use_sageattn=False)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "with torch.no_grad():\n", 254 | " initial_input_channels = pipe.transformer.config.in_channels\n", 255 | " new_img_in = HunyuanVideoPatchEmbed(\n", 256 | " patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),\n", 257 | " in_chans=pipe.transformer.config.in_channels * 2,\n", 258 | " embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,\n", 259 | " )\n", 260 | " new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)\n", 261 | " new_img_in.proj.weight.zero_()\n", 262 | " new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)\n", 263 | "\n", 264 | " if pipe.transformer.x_embedder.proj.bias is not None:\n", 265 | " new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)\n", 266 | "\n", 267 | " pipe.transformer.x_embedder = new_img_in\n", 268 | " pipe.transformer.x_embedder.requires_grad_(False)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "LORA_PATH = \"/mnt/data/ayushman/projects/output/i2v_outputs/exp006-resume-from-9500/i2v.sft\"\n", 278 | "\n", 279 | "# ------ load lora\n", 280 | "lora_state_dict = pipe.lora_state_dict(LORA_PATH)\n", 281 | "transformer_lora_state_dict = {f'{k.replace(\"transformer.\", \"\")}': v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\") and \"lora\" in k}\n", 282 | "pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name=\"i2v\", _pipeline=pipe)\n", 283 | "pipe.set_adapters([\"i2v\"], adapter_weights=[1.0])\n", 284 | "pipe.fuse_lora(components=[\"transformer\"], lora_scale=1.0, adapter_names=[\"i2v\"])\n", 285 | "pipe.unload_lora_weights()\n", 286 | "\n", 287 | "# -------- load norm layers\n", 288 | "NORM_LAYER_PREFIXES = [\"norm_q\", \"norm_k\", \"norm_added_q\", \"norm_added_k\"]\n", 289 | "transformer_norm_layers_state_dict = {\n", 290 | " f'{k.replace(\"transformer.\", \"\")}': v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)\n", 291 | "}\n", 292 | "if len(transformer_norm_layers_state_dict) == 0:\n", 293 | " print(\"[INFO] No normalization layers found in state dict. Skipping loading normalization layers....\")\n", 294 | "else:\n", 295 | " print(\"[INFO] Extracting normalization layers from state dict and loading them into the transformer....\")\n", 296 | " for key in list(transformer_norm_layers_state_dict.keys()):\n", 297 | " if key.split(\".\")[0] == \"transformer\":\n", 298 | " transformer_norm_layers_state_dict[key[len(f\"transformer.\") :]] = transformer_norm_layers_state_dict.pop(key)\n", 299 | " transformer_state_dict = pipe.transformer.state_dict()\n", 300 | " transformer_keys = set(transformer_state_dict.keys())\n", 301 | " state_dict_keys = set(transformer_norm_layers_state_dict.keys())\n", 302 | " extra_keys = list(state_dict_keys - transformer_keys)\n", 303 | " if extra_keys:\n", 304 | " print(f\"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\\n{extra_keys}.\")\n", 305 | " for key in extra_keys:\n", 306 | " transformer_norm_layers_state_dict.pop(key)\n", 307 | " incompatible_keys = pipe.transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False)\n", 308 | " unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n", 309 | " if any(norm_key in k for k in unexpected_keys for norm_key in NORM_LAYER_PREFIXES):\n", 310 | " print(f\"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer.\")" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "video_transforms = transforms.Compose(\n", 320 | " [\n", 321 | " transforms.Lambda(lambda x: x / 255.0),\n", 322 | " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n", 323 | " ]\n", 324 | ")\n", 325 | "\n", 326 | "\n", 327 | "def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:\n", 328 | " \"\"\"\n", 329 | " Resize the image to the bucket resolution.\n", 330 | " \"\"\"\n", 331 | " is_pil_image = isinstance(image, Image.Image)\n", 332 | " if is_pil_image:\n", 333 | " image_width, image_height = image.size\n", 334 | " else:\n", 335 | " image_height, image_width = image.shape[:2]\n", 336 | "\n", 337 | " if bucket_reso == (image_width, image_height):\n", 338 | " return np.array(image) if is_pil_image else image\n", 339 | "\n", 340 | " bucket_width, bucket_height = bucket_reso\n", 341 | "\n", 342 | " scale_width = bucket_width / image_width\n", 343 | " scale_height = bucket_height / image_height\n", 344 | " scale = max(scale_width, scale_height)\n", 345 | " image_width = int(image_width * scale + 0.5)\n", 346 | " image_height = int(image_height * scale + 0.5)\n", 347 | "\n", 348 | " if scale > 1:\n", 349 | " image = Image.fromarray(image) if not is_pil_image else image\n", 350 | " image = image.resize((image_width, image_height), Image.LANCZOS)\n", 351 | " image = np.array(image)\n", 352 | " else:\n", 353 | " image = np.array(image) if is_pil_image else image\n", 354 | " image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)\n", 355 | "\n", 356 | " # crop the image to the bucket resolution\n", 357 | " crop_left = (image_width - bucket_width) // 2\n", 358 | " crop_top = (image_height - bucket_height) // 2\n", 359 | " image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]\n", 360 | "\n", 361 | " return image" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "# n_frames, height, width = 77, 960, 544\n", 371 | "# prompt = 'a woman moving her head'\n", 372 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/dvqdv.png\"\n", 373 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/F11024.png\"\n", 374 | "\n", 375 | "n_frames, height, width = 77, 720, 1280\n", 376 | "prompt = \"A man and a woman sit at a cozy café table in the warm midday light. She wears a vibrant green dress, and he wears a simple black T-shirt. Their table is adorned with fresh salads, small plates of appetizers, and a sparkling beverage in elegant glassware. They chat casually, occasionally glancing at their phones or laughing together. Large windows in the background reveal a bustling street, where people pass by under soft lanterns. The camera begins with a medium-wide shot, capturing the relaxed ambiance and lively chatter of nearby patrons. It then transitions to a closer view, highlighting the woman's bright smile and the man's attentive expression. Their voices blend with the gentle hum of the café, as they enjoy a pleasant lunch and each other's company.\"\n", 377 | "p1 = \"/mnt/data/ayushman/projects/input/img_samples/009-F1.png\"\n", 378 | "p2 = \"/mnt/data/ayushman/projects/input/img_samples/009-F2.png\"\n", 379 | "\n", 380 | "\n", 381 | "# n_frames, height, width = 73, 960, 544\n", 382 | "# prompt = 'a man'\n", 383 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/3.3.PNG\"\n", 384 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/4.1 (1).PNG\"\n", 385 | "\n", 386 | "\n", 387 | "# n_frames, height, width = 73, 544, 960\n", 388 | "# prompt = 'realistic style, a man walking out of a cave'\n", 389 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/3015505-hd_1920_1080_24fps-Scene-001-01.jpg\"\n", 390 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/3015505-hd_1920_1080_24fps-Scene-001-02.jpg\"\n", 391 | "\n", 392 | "# n_frames, height, width = 73, 544, 960\n", 393 | "# prompt = 'a whimsical milk cartoon dancing'\n", 394 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-03.jpg\"\n", 395 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-01.jpg\"\n", 396 | "\n", 397 | "# n_frames, height, width = 73, 544, 960\n", 398 | "# prompt = 'a whimsical milk cartoon dancing'\n", 399 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-03.jpg\"\n", 400 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-01.jpg\"\n", 401 | "\n", 402 | "# n_frames, height, width = 73, 960, 544\n", 403 | "# prompt = \"a woman\"\n", 404 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/F14866.png\"\n", 405 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/F14920.png\"\n", 406 | "\n", 407 | "assert os.path.exists(p1)\n", 408 | "assert os.path.exists(p2)\n", 409 | "\n", 410 | "cond_frame1 = Image.open(p1).convert(\"RGB\")\n", 411 | "cond_frame2 = Image.open(p2).convert(\"RGB\")\n", 412 | "# show_images([cond_frame1, cond_frame2], titles=[\"cond_frame1\", \"cond_frame2\"], imsize=4)\n", 413 | "# plt.show()" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height))\n", 423 | "cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height))\n", 424 | "\n", 425 | "cond_video = np.zeros(shape=(n_frames, height, width, 3))\n", 426 | "cond_video[0], cond_video[-1] = np.array(cond_frame1), np.array(cond_frame2)\n", 427 | "\n", 428 | "cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)\n", 429 | "cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)\n", 430 | "\n", 431 | "with torch.inference_mode():\n", 432 | " image_or_video = cond_video.to(device=\"cuda\", dtype=pipe.dtype)\n", 433 | " image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]\n", 434 | " cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()\n", 435 | " cond_latents = cond_latents * pipe.vae.config.scaling_factor\n", 436 | " cond_latents = cond_latents.to(dtype=pipe.dtype)\n", 437 | " assert not torch.any(torch.isnan(cond_latents))" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "@torch.inference_mode()\n", 447 | "def call_pipe(\n", 448 | " pipe,\n", 449 | " prompt: Union[str, List[str]] = None,\n", 450 | " prompt_2: Union[str, List[str]] = None,\n", 451 | " height: int = 720,\n", 452 | " width: int = 1280,\n", 453 | " num_frames: int = 129,\n", 454 | " num_inference_steps: int = 50,\n", 455 | " sigmas: List[float] = None,\n", 456 | " guidance_scale: float = 6.0,\n", 457 | " num_videos_per_prompt: Optional[int] = 1,\n", 458 | " generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n", 459 | " latents: Optional[torch.Tensor] = None,\n", 460 | " prompt_embeds: Optional[torch.Tensor] = None,\n", 461 | " pooled_prompt_embeds: Optional[torch.Tensor] = None,\n", 462 | " prompt_attention_mask: Optional[torch.Tensor] = None,\n", 463 | " output_type: Optional[str] = \"pil\",\n", 464 | " return_dict: bool = True,\n", 465 | " attention_kwargs: Optional[Dict[str, Any]] = None,\n", 466 | " callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,\n", 467 | " callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n", 468 | " prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,\n", 469 | " max_sequence_length: int = 256,\n", 470 | " image_latents: Optional[torch.Tensor] = None,\n", 471 | "):\n", 472 | "\n", 473 | " if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n", 474 | " callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n", 475 | "\n", 476 | " # 1. Check inputs. Raise error if not correct\n", 477 | " pipe.check_inputs(\n", 478 | " prompt,\n", 479 | " prompt_2,\n", 480 | " height,\n", 481 | " width,\n", 482 | " prompt_embeds,\n", 483 | " callback_on_step_end_tensor_inputs,\n", 484 | " prompt_template,\n", 485 | " )\n", 486 | "\n", 487 | " pipe._guidance_scale = guidance_scale\n", 488 | " pipe._attention_kwargs = attention_kwargs\n", 489 | " pipe._current_timestep = None\n", 490 | " pipe._interrupt = False\n", 491 | "\n", 492 | " device = pipe._execution_device\n", 493 | "\n", 494 | " # 2. Define call parameters\n", 495 | " if prompt is not None and isinstance(prompt, str):\n", 496 | " batch_size = 1\n", 497 | " elif prompt is not None and isinstance(prompt, list):\n", 498 | " batch_size = len(prompt)\n", 499 | " else:\n", 500 | " batch_size = prompt_embeds.shape[0]\n", 501 | "\n", 502 | " # 3. Encode input prompt\n", 503 | " prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(\n", 504 | " prompt=prompt,\n", 505 | " prompt_2=prompt_2,\n", 506 | " prompt_template=prompt_template,\n", 507 | " num_videos_per_prompt=num_videos_per_prompt,\n", 508 | " prompt_embeds=prompt_embeds,\n", 509 | " pooled_prompt_embeds=pooled_prompt_embeds,\n", 510 | " prompt_attention_mask=prompt_attention_mask,\n", 511 | " device=device,\n", 512 | " max_sequence_length=max_sequence_length,\n", 513 | " )\n", 514 | "\n", 515 | " transformer_dtype = pipe.transformer.dtype\n", 516 | " prompt_embeds = prompt_embeds.to(transformer_dtype)\n", 517 | " prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)\n", 518 | " if pooled_prompt_embeds is not None:\n", 519 | " pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)\n", 520 | "\n", 521 | " # 4. Prepare timesteps\n", 522 | " sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas\n", 523 | " timesteps, num_inference_steps = retrieve_timesteps(\n", 524 | " pipe.scheduler,\n", 525 | " num_inference_steps,\n", 526 | " device,\n", 527 | " sigmas=sigmas,\n", 528 | " )\n", 529 | "\n", 530 | " # 5. Prepare latent variables\n", 531 | " num_channels_latents = pipe.transformer.config.in_channels\n", 532 | " num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1\n", 533 | " latents = pipe.prepare_latents(\n", 534 | " batch_size * num_videos_per_prompt,\n", 535 | " num_channels_latents,\n", 536 | " height,\n", 537 | " width,\n", 538 | " num_latent_frames,\n", 539 | " torch.float32,\n", 540 | " device,\n", 541 | " generator,\n", 542 | " latents,\n", 543 | " )\n", 544 | "\n", 545 | " # 6. Prepare guidance condition\n", 546 | " guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0\n", 547 | "\n", 548 | " # 7. Denoising loop\n", 549 | " num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order\n", 550 | " pipe._num_timesteps = len(timesteps)\n", 551 | "\n", 552 | " with pipe.progress_bar(total=num_inference_steps) as progress_bar:\n", 553 | " for i, t in enumerate(timesteps):\n", 554 | " if pipe.interrupt:\n", 555 | " continue\n", 556 | "\n", 557 | " pipe._current_timestep = t\n", 558 | " latent_model_input = latents.to(transformer_dtype)\n", 559 | " timestep = t.expand(latents.shape[0]).to(latents.dtype)\n", 560 | "\n", 561 | " noise_pred = pipe.transformer(\n", 562 | " hidden_states=torch.cat([latent_model_input, image_latents], dim=1),\n", 563 | " timestep=timestep,\n", 564 | " encoder_hidden_states=prompt_embeds,\n", 565 | " encoder_attention_mask=prompt_attention_mask,\n", 566 | " pooled_projections=pooled_prompt_embeds,\n", 567 | " guidance=guidance,\n", 568 | " attention_kwargs=attention_kwargs,\n", 569 | " return_dict=False,\n", 570 | " )[0]\n", 571 | "\n", 572 | " # compute the previous noisy sample x_t -> x_t-1\n", 573 | " latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n", 574 | "\n", 575 | " if callback_on_step_end is not None:\n", 576 | " callback_kwargs = {}\n", 577 | " for k in callback_on_step_end_tensor_inputs:\n", 578 | " callback_kwargs[k] = locals()[k]\n", 579 | " callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)\n", 580 | "\n", 581 | " latents = callback_outputs.pop(\"latents\", latents)\n", 582 | " prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n", 583 | "\n", 584 | " # call the callback, if provided\n", 585 | " if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):\n", 586 | " progress_bar.update()\n", 587 | " pipe._current_timestep = None\n", 588 | "\n", 589 | " if not output_type == \"latent\":\n", 590 | " latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor\n", 591 | " video = pipe.vae.decode(latents, return_dict=False)[0]\n", 592 | " video = pipe.video_processor.postprocess_video(video, output_type=output_type)\n", 593 | " else:\n", 594 | " video = latents\n", 595 | "\n", 596 | " # Offload all models\n", 597 | " pipe.maybe_free_model_hooks()\n", 598 | "\n", 599 | " if not return_dict:\n", 600 | " return (video,)\n", 601 | "\n", 602 | " return HunyuanVideoPipelineOutput(frames=video)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "video = call_pipe(\n", 612 | " pipe,\n", 613 | " prompt=prompt,\n", 614 | " num_frames=n_frames,\n", 615 | " num_inference_steps=50,\n", 616 | " image_latents=cond_latents,\n", 617 | " width=width,\n", 618 | " height=height,\n", 619 | " guidance_scale=6.0,\n", 620 | " generator=torch.Generator(device=\"cuda\").manual_seed(0),\n", 621 | ").frames[0]" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": null, 627 | "metadata": {}, 628 | "outputs": [], 629 | "source": [ 630 | "time_flag = datetime.fromtimestamp(time.time()).strftime(\"%Y%m%d_%H%M%S\")\n", 631 | "fp = f\"/mnt/data/ayushman/projects/output/samples/hv-CL-{height}x{width}x{n_frames}-{time_flag}.mp4\"\n", 632 | "print(fp)\n", 633 | "save_video(video, fp)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [] 642 | } 643 | ], 644 | "metadata": { 645 | "kernelspec": { 646 | "display_name": "hunyuan_control_env", 647 | "language": "python", 648 | "name": "python3" 649 | }, 650 | "language_info": { 651 | "codemirror_mode": { 652 | "name": "ipython", 653 | "version": 3 654 | }, 655 | "file_extension": ".py", 656 | "mimetype": "text/x-python", 657 | "name": "python", 658 | "nbconvert_exporter": "python", 659 | "pygments_lexer": "ipython3", 660 | "version": "3.10.16" 661 | } 662 | }, 663 | "nbformat": 4, 664 | "nbformat_minor": 2 665 | } 666 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Dict 3 | 4 | import torch 5 | from accelerate.logging import get_logger 6 | 7 | try: 8 | from optimi import AdamW as OptimiAdamW 9 | from optimi import StableAdamW as OptimiStableAdamW 10 | except ImportError: 11 | OptimiAdamW, OptimiStableAdamW = None, None 12 | 13 | try: 14 | from bitsandbytes.optim import AdamW8bit, Lion8bit 15 | except ImportError: 16 | AdamW8bit, Lion8bit = None, None 17 | 18 | try: 19 | from came_pytorch import CAME 20 | except ImportError: 21 | CAME = None 22 | 23 | import ast 24 | 25 | logger = get_logger(__name__) 26 | 27 | 28 | OPTIMIZER_FUNC_TO_NAME = { 29 | "adam": torch.optim.Adam, 30 | "adamw": torch.optim.AdamW, 31 | "optimi-adamw": OptimiAdamW, 32 | "optimi-stableadamw": OptimiStableAdamW, 33 | "bnb-adamw8bit": AdamW8bit, 34 | "bnb-lion8bit": Lion8bit, 35 | "came": CAME, 36 | } 37 | 38 | 39 | def get_optimizer( 40 | params_to_optimize, 41 | optimizer_name: str = "adam", 42 | learning_rate: float = 1e-3, 43 | optimizer_args_str: str | None = None, 44 | use_deepspeed: bool = False, 45 | # use_cpu_offload_optimizer: bool = False, 46 | # offload_gradients: bool = False, 47 | ) -> torch.optim.Optimizer: 48 | optimizer_kwargs = {} 49 | 50 | if optimizer_args_str is not None and len(optimizer_args_str) > 0: 51 | for arg in optimizer_args_str: 52 | key, value = arg.split("=") 53 | value = ast.literal_eval(value) 54 | optimizer_kwargs[key] = value 55 | 56 | optimizer_name = optimizer_name.lower() 57 | 58 | if use_deepspeed: 59 | from accelerate.utils import DummyOptim 60 | 61 | return DummyOptim(params_to_optimize, lr=learning_rate, **optimizer_kwargs) 62 | 63 | assert optimizer_name in OPTIMIZER_FUNC_TO_NAME, f"Unknown optimizer: {optimizer_name!r}" 64 | 65 | optimizer_class = OPTIMIZER_FUNC_TO_NAME[optimizer_name] 66 | assert optimizer_class is not None 67 | 68 | optimizer = optimizer_class(params_to_optimize, lr=learning_rate, **optimizer_kwargs) 69 | 70 | logger.info(f"Use {optimizer.__class__.__name__!r} | {optimizer_kwargs!r}") 71 | return optimizer 72 | 73 | 74 | def gradient_norm(parameters): 75 | norm = 0 76 | for param in parameters: 77 | if param.grad is None: 78 | continue 79 | local_norm = param.grad.detach().data.norm(2) 80 | norm += local_norm.item() ** 2 81 | norm = norm**0.5 82 | return norm 83 | 84 | 85 | def max_gradient(parameters): 86 | max_grad_value = float("-inf") 87 | for param in parameters: 88 | if param.grad is None: 89 | continue 90 | local_max_grad = param.grad.detach().data.abs().max() 91 | max_grad_value = max(max_grad_value, local_max_grad.item()) 92 | return max_grad_value 93 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | line_length = 120 4 | force_grid_wrap = 0 5 | use_parentheses = true 6 | ensure_newline_before_comments = true 7 | 8 | [tool.black] 9 | line-length = 120 10 | exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist|_notebooks/.*)" 11 | -------------------------------------------------------------------------------- /setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/user/bin/env bash 2 | conda create -n hunyuan_control_env python=3.10 -y && conda activate hunyuan_control_env 3 | conda install -c nvidia/label/cuda-12.4.0 cuda-toolkit cuda -y 4 | 5 | export CUDA_HOME=$(dirname $(dirname $(which nvcc))) 6 | export PATH=$CUDA_HOME/bin:$PATH 7 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 8 | 9 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 10 | python -m pip install ninja 11 | python -m pip install --verbose --upgrade git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3 12 | 13 | wget "https://raw.githubusercontent.com/kohya-ss/musubi-tuner/refs/heads/main/requirements.txt" -O requirements.txt 14 | python -m pip install -r requirements.txt && rm requirements.txt 15 | 16 | python -m pip install accelerate==1.2.1 transformers==4.46.3 bitsandbytes==0.45.2 decord==0.6.0 deepspeed==0.16.3 opencv-python==4.10.0.84 pandas==2.2.3 peft==0.14.0 mosaicml-streaming==0.11.0 pyrallis==0.3.1 torch-optimi==0.2.1 17 | python -m pip install huggingface-hub hf_transfer 18 | python -m pip install --upgrade git+https://github.com/huggingface/diffusers@81440fd47493b9f9e817411ca0499d0bf06fde95 19 | python -m pip install icecream pre-commit 20 | -------------------------------------------------------------------------------- /tools/hv_cache_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing as mp 3 | import os 4 | import pprint 5 | import traceback 6 | from argparse import ArgumentParser 7 | from ast import literal_eval 8 | from pathlib import Path 9 | from typing import Tuple, Union 10 | 11 | import cv2 12 | import decord 13 | import numpy as np 14 | import pandas as pd 15 | from PIL import Image 16 | from streaming import MDSWriter 17 | from tqdm.auto import tqdm 18 | 19 | decord.bridge.set_bridge("native") 20 | 21 | 22 | def parse_args(): 23 | parser = ArgumentParser() 24 | parser.add_argument("--csv", required=True, type=str) 25 | parser.add_argument("--base_dir", default=None, type=str) 26 | parser.add_argument("--video_column", default="video", type=str) 27 | parser.add_argument("--caption_column", default="caption", type=str) 28 | parser.add_argument("--output_dir", required=True, type=str) 29 | parser.add_argument("--bucket_reso", default=None, type=str, nargs="+") 30 | parser.add_argument("--debug", action="store_true", help="debug mode") 31 | parser.add_argument("--min_bucket_count", type=int, default=1) 32 | parser.add_argument("--head_frame", type=int, default=0) 33 | 34 | args = parser.parse_args() 35 | print("args:", pprint.pformat(args, sort_dicts=True, compact=True)) 36 | return args 37 | 38 | 39 | def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: 40 | """ 41 | Resize the image to the bucket resolution. 42 | """ 43 | is_pil_image = isinstance(image, Image.Image) 44 | if is_pil_image: 45 | image_width, image_height = image.size 46 | else: 47 | image_height, image_width = image.shape[:2] 48 | 49 | if bucket_reso == (image_width, image_height): 50 | return np.array(image) if is_pil_image else image 51 | 52 | bucket_width, bucket_height = bucket_reso 53 | if bucket_width == image_width or bucket_height == image_height: 54 | image = np.array(image) if is_pil_image else image 55 | else: 56 | # resize the image to the bucket resolution to match the short side 57 | scale_width = bucket_width / image_width 58 | scale_height = bucket_height / image_height 59 | scale = max(scale_width, scale_height) 60 | image_width = int(image_width * scale + 0.5) 61 | image_height = int(image_height * scale + 0.5) 62 | 63 | if scale > 1: 64 | image = Image.fromarray(image) if not is_pil_image else image 65 | image = image.resize((image_width, image_height), Image.LANCZOS) 66 | image = np.array(image) 67 | else: 68 | image = np.array(image) if is_pil_image else image 69 | image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) 70 | 71 | # crop the image to the bucket resolution 72 | crop_left = (image_width - bucket_width) // 2 73 | crop_top = (image_height - bucket_height) // 2 74 | image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] 75 | return image 76 | 77 | 78 | def divisible_by(num: int, divisor: int) -> int: 79 | return num - num % divisor 80 | 81 | 82 | class BucketSelector: 83 | RESOLUTION_STEPS_HUNYUAN = 16 84 | 85 | def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): 86 | self.resolution = resolution 87 | self.bucket_area = resolution[0] * resolution[1] 88 | self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN 89 | 90 | if not enable_bucket: 91 | # only define one bucket 92 | self.bucket_resolutions = [resolution] 93 | self.no_upscale = False 94 | else: 95 | # prepare bucket resolution 96 | self.no_upscale = no_upscale 97 | sqrt_size = int(math.sqrt(self.bucket_area)) 98 | min_size = divisible_by(sqrt_size // 2, self.reso_steps) 99 | self.bucket_resolutions = [] 100 | for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): 101 | h = divisible_by(self.bucket_area // w, self.reso_steps) 102 | self.bucket_resolutions.append((w, h)) 103 | self.bucket_resolutions.append((h, w)) 104 | 105 | self.bucket_resolutions = list(set(self.bucket_resolutions)) 106 | self.bucket_resolutions.sort() 107 | 108 | # calculate aspect ratio to find the nearest resolution 109 | self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) 110 | 111 | def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: 112 | """ 113 | return the bucket resolution for the given image size, (width, height) 114 | """ 115 | area = image_size[0] * image_size[1] 116 | if self.no_upscale and area <= self.bucket_area: 117 | w, h = image_size 118 | w = divisible_by(w, self.reso_steps) 119 | h = divisible_by(h, self.reso_steps) 120 | return w, h 121 | 122 | aspect_ratio = image_size[0] / image_size[1] 123 | ar_errors = self.aspect_ratios - aspect_ratio 124 | bucket_id = np.abs(ar_errors).argmin() 125 | return self.bucket_resolutions[bucket_id] 126 | 127 | 128 | def load_video(video_path, bucket_selector=None, start_frame=None, end_frame=None): 129 | vr = decord.VideoReader(uri=video_path) 130 | video_num_frames = len(vr) 131 | _start_frame, _end_frame = 0, video_num_frames 132 | 133 | if start_frame is not None: 134 | _start_frame = start_frame 135 | if end_frame is not None: 136 | _end_frame = min(_end_frame, end_frame) 137 | 138 | frames = vr.get_batch(range(_start_frame, _end_frame)).asnumpy() 139 | f, h, w, _ = frames.shape 140 | bucket_reso = bucket_selector.get_bucket_resolution(image_size=(w, h)) 141 | frames = [resize_image_to_bucket(frame, bucket_reso=bucket_reso) for frame in frames] 142 | return frames 143 | 144 | 145 | class BucketBatchManager: 146 | def __init__(self, bucketed_item_info, min_bucket_count=0): 147 | self.buckets = bucketed_item_info 148 | self.bucket_resos = list(self.buckets.keys()) 149 | self.bucket_resos.sort() 150 | 151 | for bucket_reso in self.bucket_resos: 152 | bucket = self.buckets[bucket_reso] 153 | if len(bucket) < min_bucket_count: 154 | print( 155 | f"bucket {bucket_reso!r} (n={len(bucket)!r}) has less than {min_bucket_count!r} items, remove it..." 156 | ) 157 | del self.buckets[bucket_reso] 158 | 159 | self.bucket_resos = list(self.buckets.keys()) 160 | self.bucket_resos.sort() 161 | 162 | def show_bucket_info(self): 163 | for bucket_reso in self.bucket_resos: 164 | bucket = self.buckets[bucket_reso] 165 | print(f"bucket: {bucket_reso}, count: {len(bucket)}") 166 | 167 | 168 | def convert_and_make_shards(args, base_df, bucket_batch_manager, bucket_reso, b_idx): 169 | columns = { 170 | "idx": "int32", 171 | "item_key": "str", 172 | "item": "str", 173 | "frame_count": "int32", 174 | "bucket_width": "int32", 175 | "bucket_height": "int32", 176 | "original_width": "int32", 177 | "original_height": "int32", 178 | "caption_str": "str", 179 | "video_arr": "ndarray", 180 | } 181 | print(f"Starting converter processs for bucket {bucket_reso!r}...") 182 | output_path = os.path.join(args.output_dir, "x".join(list(map(str, bucket_reso)))) 183 | Path(output_path).mkdir(parents=True, exist_ok=True) 184 | 185 | bucket = bucket_batch_manager.buckets[bucket_reso] 186 | 187 | writer = MDSWriter(out=output_path, columns=columns, size_limit=256 * (2**20), max_workers=os.cpu_count()) 188 | 189 | for item_info in tqdm(bucket, dynamic_ncols=True, position=b_idx, leave=False, desc=f"bucket {bucket_reso}"): 190 | item_key = item_info["item_key"] 191 | frame_count = item_info["frame_count"] 192 | frame_crop_pos = item_info["frame_crop_pos"] 193 | idx = item_info["idx"] 194 | bucket_reso_wh = (item_info["bucket_width"], item_info["bucket_height"]) 195 | 196 | row = base_df.iloc[idx] 197 | video_path, caption = row[args.video_column], row[args.caption_column] 198 | 199 | try: 200 | vr = decord.VideoReader(uri=video_path) 201 | video = vr.get_batch(range(frame_crop_pos, frame_crop_pos + frame_count)).asnumpy() 202 | original_width, original_height = video.shape[2], video.shape[1] 203 | video = [resize_image_to_bucket(frame, bucket_reso=bucket_reso_wh) for frame in video] 204 | video = np.stack(video, axis=0) 205 | except Exception as e: 206 | print(f"Failed to load video {video_path!r} : {e!r}") 207 | print(traceback.format_exc()) 208 | continue 209 | 210 | sample = {} 211 | sample["idx"] = idx 212 | sample["item_key"] = str(item_key) 213 | sample["item"] = str(Path(video_path).name) 214 | sample["frame_count"] = frame_count 215 | sample["bucket_width"] = bucket_reso_wh[0] 216 | sample["bucket_height"] = bucket_reso_wh[1] 217 | sample["original_width"] = original_height 218 | sample["original_height"] = original_width 219 | sample["caption_str"] = caption 220 | sample["video_arr"] = video 221 | 222 | writer.write(sample) 223 | writer.finish() 224 | print(f"Converter process finished for bucket {bucket_reso!r} !!!") 225 | 226 | 227 | def main(args): 228 | if str(args.csv).endswith(".csv"): 229 | df = pd.read_csv(args.csv) 230 | elif str(args.csv).endswith(".json"): 231 | df = pd.read_json(args.csv) 232 | elif str(args.csv).endswith(".parquet"): 233 | df = pd.read_parquet(args.csv) 234 | elif str(args.csv).endswith(".jsonl"): 235 | df = pd.read_json(args.csv, lines=True, orient="records") 236 | else: 237 | raise ValueError(f"Invalid csv path: {args.csv!r}") 238 | if args.base_dir is not None: 239 | df[args.video_column] = df[args.video_column].apply(lambda x: os.path.join(args.base_dir, x)) 240 | 241 | if args.debug: 242 | df = df.sample(n=10).reset_index(drop=True, inplace=False) 243 | 244 | print("Total number of samples: ", len(df)) 245 | 246 | bucket_selectors = [] 247 | for res in args.bucket_reso: 248 | w, h, f = res.split("x") 249 | bs = BucketSelector(resolution=(int(w), int(h), int(f)), enable_bucket=True, no_upscale=False) 250 | bucket_selectors.append(bs) 251 | 252 | batches = {} 253 | for idx, row in tqdm(df.iterrows(), total=len(df), dynamic_ncols=True, desc="Generating buckets"): 254 | video_path = row[args.video_column] 255 | 256 | vr = decord.VideoReader(uri=video_path) 257 | frame_count = len(vr) 258 | video = vr.get_batch(range(0, 1)).asnumpy() 259 | frame_height, frame_width = video.shape[1], video.shape[2] 260 | frame_size = (frame_width, frame_height) 261 | 262 | for bs in bucket_selectors: 263 | target_frame = bs.resolution[-1] 264 | bucket_reso = bs.get_bucket_resolution(frame_size) 265 | 266 | if frame_count >= (target_frame + args.head_frame): 267 | crop_pos_and_frames = [args.head_frame, target_frame] 268 | body, ext = os.path.splitext(Path(video_path).name) 269 | item_key = f"{body}_{crop_pos_and_frames[0]:05d}-{target_frame:05d}{ext}" 270 | batch_key = (*bucket_reso, target_frame) 271 | item_info = { 272 | "item_key": item_key, 273 | "batch_key": batch_key, 274 | "frame_count": target_frame, 275 | "frame_crop_pos": crop_pos_and_frames[0], 276 | "idx": idx, 277 | "bucket_height": bucket_reso[1], 278 | "bucket_width": bucket_reso[0], 279 | } 280 | batch = batches.get(batch_key, []) 281 | batch.append(item_info) 282 | batches[batch_key] = batch 283 | 284 | bucket_manager = BucketBatchManager(batches, min_bucket_count=args.min_bucket_count) 285 | bucket_manager.show_bucket_info() 286 | 287 | bucket_resos = bucket_manager.bucket_resos 288 | 289 | exporters = [] 290 | for bucket_idx, bucket_reso in enumerate(bucket_resos): 291 | op = mp.Process(target=convert_and_make_shards, args=(args, df, bucket_manager, bucket_reso, bucket_idx)) 292 | op.start() 293 | exporters.append(op) 294 | 295 | for op in exporters: 296 | op.join() 297 | op.close() 298 | 299 | 300 | if __name__ == "__main__": 301 | main(parse_args()) 302 | -------------------------------------------------------------------------------- /tools/hv_precompute_latents.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from io import BytesIO 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms.v2 as transforms 10 | from accelerate.logging import get_logger 11 | from diffusers import ( 12 | AutoencoderKLHunyuanVideo, 13 | FlowMatchEulerDiscreteScheduler, 14 | HunyuanVideoPipeline, 15 | HunyuanVideoTransformer3DModel, 16 | ) 17 | from PIL import Image 18 | from streaming import MDSWriter, Stream, StreamingDataset 19 | from torch.utils.data import DataLoader 20 | from tqdm.auto import tqdm 21 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer 22 | 23 | torch.backends.cuda.matmul.allow_tf32 = True 24 | 25 | 26 | def load_condition_models( 27 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 28 | text_encoder_dtype: torch.dtype = torch.float16, 29 | text_encoder_2_dtype: torch.dtype = torch.float16, 30 | revision: Optional[str] = None, 31 | cache_dir: Optional[str] = None, 32 | device: Optional[torch.device] = None, 33 | **kwargs, 34 | ) -> Dict[str, nn.Module]: 35 | tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) 36 | text_encoder = LlamaModel.from_pretrained( 37 | model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir 38 | ) 39 | tokenizer_2 = CLIPTokenizer.from_pretrained( 40 | model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir 41 | ) 42 | text_encoder_2 = CLIPTextModel.from_pretrained( 43 | model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir 44 | ) 45 | if device is not None: 46 | text_encoder.to(device) 47 | text_encoder_2.to(device) 48 | return { 49 | "tokenizer": tokenizer, 50 | "text_encoder": text_encoder, 51 | "tokenizer_2": tokenizer_2, 52 | "text_encoder_2": text_encoder_2, 53 | } 54 | 55 | 56 | def load_latent_models( 57 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 58 | vae_dtype: torch.dtype = torch.float16, 59 | revision: Optional[str] = None, 60 | cache_dir: Optional[str] = None, 61 | device: Optional[torch.device] = None, 62 | **kwargs, 63 | ) -> Dict[str, nn.Module]: 64 | vae = AutoencoderKLHunyuanVideo.from_pretrained( 65 | model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir 66 | ) 67 | if device is not None: 68 | vae.to(device) 69 | vae.enable_slicing() 70 | vae.enable_tiling() 71 | return { 72 | "vae": vae, 73 | } 74 | 75 | 76 | def load_diffusion_models( 77 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 78 | transformer_dtype: torch.dtype = torch.bfloat16, 79 | shift: float = 1.0, 80 | revision: Optional[str] = None, 81 | cache_dir: Optional[str] = None, 82 | **kwargs, 83 | ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: 84 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 85 | model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir 86 | ) 87 | scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) 88 | return {"transformer": transformer, "scheduler": scheduler} 89 | 90 | 91 | def prepare_conditions( 92 | tokenizer: LlamaTokenizer, 93 | text_encoder: LlamaModel, 94 | tokenizer_2: CLIPTokenizer, 95 | text_encoder_2: CLIPTextModel, 96 | prompt: Union[str, List[str]], 97 | device: Optional[torch.device] = None, 98 | dtype: Optional[torch.dtype] = None, 99 | max_sequence_length: int = 256, 100 | prompt_template: Dict[str, Any] = { 101 | "template": ( 102 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 103 | "1. The main content and theme of the video." 104 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 105 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 106 | "4. background environment, light, style and atmosphere." 107 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 108 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 109 | ), 110 | "crop_start": 95, 111 | }, 112 | **kwargs, 113 | ) -> torch.Tensor: 114 | device = device or text_encoder.device 115 | dtype = dtype or text_encoder.dtype 116 | 117 | if isinstance(prompt, str): 118 | prompt = [prompt] 119 | 120 | conditions = {} 121 | conditions.update( 122 | _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length) 123 | ) 124 | conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype)) 125 | 126 | return conditions 127 | 128 | 129 | def prepare_latents( 130 | vae: AutoencoderKLHunyuanVideo, 131 | image_or_video: torch.Tensor, 132 | device: Optional[torch.device] = None, 133 | dtype: Optional[torch.dtype] = None, 134 | generator: Optional[torch.Generator] = None, 135 | **kwargs, 136 | ) -> torch.Tensor: 137 | device = device or vae.device 138 | dtype = dtype or vae.dtype 139 | 140 | if image_or_video.ndim == 4: 141 | image_or_video = image_or_video.unsqueeze(2) 142 | assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" 143 | 144 | image_or_video = image_or_video.to(device=device, dtype=vae.dtype) 145 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] 146 | 147 | latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) 148 | latents = latents * vae.config.scaling_factor 149 | latents = latents.to(dtype=dtype) 150 | return {"latents": latents} 151 | 152 | 153 | def _get_llama_prompt_embeds( 154 | tokenizer: LlamaTokenizer, 155 | text_encoder: LlamaModel, 156 | prompt: List[str], 157 | prompt_template: Dict[str, Any], 158 | device: torch.device, 159 | dtype: torch.dtype, 160 | max_sequence_length: int = 256, 161 | num_hidden_layers_to_skip: int = 2, 162 | ) -> Tuple[torch.Tensor, torch.Tensor]: 163 | batch_size = len(prompt) 164 | prompt = [prompt_template["template"].format(p) for p in prompt] 165 | 166 | crop_start = prompt_template.get("crop_start", None) 167 | if crop_start is None: 168 | prompt_template_input = tokenizer( 169 | prompt_template["template"], 170 | padding="max_length", 171 | return_tensors="pt", 172 | return_length=False, 173 | return_overflowing_tokens=False, 174 | return_attention_mask=False, 175 | ) 176 | crop_start = prompt_template_input["input_ids"].shape[-1] 177 | # Remove <|eot_id|> token and placeholder {} 178 | crop_start -= 2 179 | 180 | max_sequence_length += crop_start 181 | text_inputs = tokenizer( 182 | prompt, 183 | max_length=max_sequence_length, 184 | padding="max_length", 185 | truncation=True, 186 | return_tensors="pt", 187 | return_length=False, 188 | return_overflowing_tokens=False, 189 | return_attention_mask=True, 190 | ) 191 | text_input_ids = text_inputs.input_ids.to(device=device) 192 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) 193 | 194 | prompt_embeds = text_encoder( 195 | input_ids=text_input_ids, 196 | attention_mask=prompt_attention_mask, 197 | output_hidden_states=True, 198 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 199 | prompt_embeds = prompt_embeds.to(dtype=dtype) 200 | 201 | if crop_start is not None and crop_start > 0: 202 | prompt_embeds = prompt_embeds[:, crop_start:] 203 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 204 | 205 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) 206 | 207 | return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} 208 | 209 | 210 | def _get_clip_prompt_embeds( 211 | tokenizer_2: CLIPTokenizer, 212 | text_encoder_2: CLIPTextModel, 213 | prompt: Union[str, List[str]], 214 | device: torch.device, 215 | dtype: torch.dtype, 216 | max_sequence_length: int = 77, 217 | ) -> torch.Tensor: 218 | text_inputs = tokenizer_2( 219 | prompt, 220 | padding="max_length", 221 | max_length=max_sequence_length, 222 | truncation=True, 223 | return_tensors="pt", 224 | ) 225 | 226 | prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output 227 | prompt_embeds = prompt_embeds.to(dtype=dtype) 228 | 229 | return {"pooled_prompt_embeds": prompt_embeds} 230 | 231 | 232 | def main(args): 233 | if "cuda" in args.device.lower(): 234 | assert torch.cuda.is_available() 235 | device = torch.device(args.device) 236 | dtype = getattr(torch, args.dtype) 237 | 238 | video_transforms = transforms.Compose( 239 | [ 240 | transforms.Lambda(lambda x: x / 255.0), 241 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 242 | ] 243 | ) 244 | 245 | stream = [Stream(local=args.mds_data_path)] 246 | dataset = StreamingDataset(streams=stream, batch_size=1) 247 | 248 | save_path = Path(args.output_dir) / Path(args.mds_data_path).name 249 | save_path.mkdir(parents=True, exist_ok=True) 250 | save_path = save_path.as_posix() 251 | 252 | print(f"Saving to => {save_path!r}") 253 | 254 | columns = { 255 | "idx": "int32", 256 | "item_key": "str", 257 | "item": "str", 258 | "frame_count": "int32", 259 | "bucket_width": "int32", 260 | "bucket_height": "int32", 261 | "original_width": "int32", 262 | "original_height": "int32", 263 | "caption_str": "str", 264 | "video_arr": "ndarray", 265 | "prompt_embeds": "ndarray", 266 | "prompt_attention_mask": "ndarray", 267 | "pooled_prompt_embeds": "ndarray", 268 | "latents": "ndarray", 269 | "latents_cond": "ndarray", 270 | } 271 | 272 | writer = MDSWriter( 273 | out=save_path, 274 | columns=columns, 275 | compression=args.mds_shard_compression, 276 | size_limit=256 * (2**20), 277 | max_workers=64, 278 | ) 279 | 280 | models = load_condition_models( 281 | args.pretrained_model_name_or_path, text_encoder_dtype=dtype, text_encoder_2_dtype=torch.bfloat16, device=device 282 | ) 283 | models.update(load_latent_models(args.pretrained_model_name_or_path, vae_dtype=dtype, device=device)) 284 | 285 | for batch in tqdm(dataset, dynamic_ncols=True, desc="Precomputing latents"): 286 | video, caption = batch["video_arr"], batch["caption_str"] 287 | video = torch.from_numpy(video.copy()).permute(0, 3, 1, 2) # n_frames, c, h, w 288 | 289 | control_condition = torch.zeros_like(video) # create an empty video 290 | control_condition[0] = video[0] # keep the first frame 291 | control_condition[-1] = video[-1] # keep the last frame 292 | 293 | video = torch.stack([video_transforms(x) for x in video], dim=0).unsqueeze(0) 294 | control_condition = torch.stack([video_transforms(x) for x in control_condition], dim=0).unsqueeze(0) 295 | 296 | with torch.inference_mode(): # @TODO: add batch support ? 297 | latents = prepare_latents(models["vae"], video, device=device, dtype=dtype)["latents"] 298 | latents_cond = prepare_latents(models["vae"], control_condition, device=device, dtype=dtype)["latents"] 299 | conditions = prepare_conditions( 300 | tokenizer=models["tokenizer"], 301 | text_encoder=models["text_encoder"], 302 | tokenizer_2=models["tokenizer_2"], 303 | text_encoder_2=models["text_encoder_2"], 304 | prompt=caption, 305 | device=device, 306 | dtype=dtype, 307 | ) 308 | prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = ( 309 | conditions["prompt_embeds"], 310 | conditions["prompt_attention_mask"], 311 | conditions["pooled_prompt_embeds"], 312 | ) 313 | 314 | out_batch = {**batch} 315 | out_batch["latents"] = latents[0].float().cpu().numpy() 316 | out_batch["prompt_embeds"] = prompt_embeds[0].float().cpu().numpy() 317 | out_batch["prompt_attention_mask"] = prompt_attention_mask[0].float().cpu().numpy() 318 | out_batch["pooled_prompt_embeds"] = pooled_prompt_embeds[0].float().cpu().numpy() 319 | out_batch["latents_cond"] = latents_cond[0].float().cpu().numpy() 320 | 321 | writer.write(out_batch) 322 | 323 | writer.finish() 324 | 325 | 326 | if __name__ == "__main__": 327 | parser = ArgumentParser() 328 | parser.add_argument("--mds_data_path", required=True, type=str) 329 | parser.add_argument( 330 | "--pretrained_model_name_or_path", required=True, type=str, default="hunyuanvideo-community/HunyuanVideo" 331 | ) 332 | parser.add_argument("--output_dir", required=True, type=str) 333 | parser.add_argument("--device", type=str, default="cuda") 334 | parser.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16") 335 | parser.add_argument("--mds_shard_compression", type=str, default=None) 336 | 337 | args = parser.parse_args() 338 | main(args) 339 | -------------------------------------------------------------------------------- /tools/hv_precompute_latents_dist.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | from datetime import datetime, timedelta 5 | from io import BytesIO 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Optional, Tuple, Union 8 | 9 | import accelerate 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import ( 17 | AutoencoderKLHunyuanVideo, 18 | FlowMatchEulerDiscreteScheduler, 19 | HunyuanVideoPipeline, 20 | HunyuanVideoTransformer3DModel, 21 | ) 22 | from PIL import Image 23 | from streaming import MDSWriter, Stream, StreamingDataLoader, StreamingDataset 24 | from torch.utils.data import DataLoader 25 | from tqdm.auto import tqdm 26 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer 27 | 28 | torch.backends.cuda.matmul.allow_tf32 = True 29 | logger = get_logger("cache_latents") 30 | logging.basicConfig( 31 | level=logging.INFO, 32 | format="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s", 33 | force=True, 34 | handlers=[logging.StreamHandler()], 35 | ) 36 | 37 | _COMMON_BEGINNING_PHRASES = ( 38 | "This video", 39 | "The video", 40 | "This clip", 41 | "The clip", 42 | "The animation", 43 | "This image", 44 | "The image", 45 | "This picture", 46 | "The picture", 47 | ) 48 | _COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents") 49 | 50 | COMMON_LLM_START_PHRASES = ( 51 | "In the video,", 52 | "In this video,", 53 | "In this video clip,", 54 | "In the clip,", 55 | "Caption:", 56 | *( 57 | f"{beginning} {continuation}" 58 | for beginning in _COMMON_BEGINNING_PHRASES 59 | for continuation in _COMMON_CONTINUATION_WORDS 60 | ), 61 | ) 62 | 63 | 64 | def load_condition_models( 65 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 66 | text_encoder_dtype: torch.dtype = torch.float16, 67 | text_encoder_2_dtype: torch.dtype = torch.float16, 68 | revision: Optional[str] = None, 69 | cache_dir: Optional[str] = None, 70 | device: Optional[torch.device] = None, 71 | **kwargs, 72 | ) -> Dict[str, nn.Module]: 73 | tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) 74 | text_encoder = LlamaModel.from_pretrained( 75 | model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir 76 | ) 77 | tokenizer_2 = CLIPTokenizer.from_pretrained( 78 | model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir 79 | ) 80 | text_encoder_2 = CLIPTextModel.from_pretrained( 81 | model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir 82 | ) 83 | if device is not None: 84 | text_encoder.to(device) 85 | text_encoder_2.to(device) 86 | return { 87 | "tokenizer": tokenizer, 88 | "text_encoder": text_encoder, 89 | "tokenizer_2": tokenizer_2, 90 | "text_encoder_2": text_encoder_2, 91 | } 92 | 93 | 94 | def load_latent_models( 95 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 96 | vae_dtype: torch.dtype = torch.float16, 97 | revision: Optional[str] = None, 98 | cache_dir: Optional[str] = None, 99 | device: Optional[torch.device] = None, 100 | **kwargs, 101 | ) -> Dict[str, nn.Module]: 102 | vae = AutoencoderKLHunyuanVideo.from_pretrained( 103 | model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir 104 | ) 105 | if device is not None: 106 | vae.to(device) 107 | vae.enable_slicing() 108 | vae.enable_tiling() 109 | return { 110 | "vae": vae, 111 | } 112 | 113 | 114 | def load_diffusion_models( 115 | model_id: str = "hunyuanvideo-community/HunyuanVideo", 116 | transformer_dtype: torch.dtype = torch.bfloat16, 117 | shift: float = 1.0, 118 | revision: Optional[str] = None, 119 | cache_dir: Optional[str] = None, 120 | **kwargs, 121 | ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: 122 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 123 | model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir 124 | ) 125 | scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) 126 | return {"transformer": transformer, "scheduler": scheduler} 127 | 128 | 129 | def prepare_conditions( 130 | tokenizer: LlamaTokenizer, 131 | text_encoder: LlamaModel, 132 | tokenizer_2: CLIPTokenizer, 133 | text_encoder_2: CLIPTextModel, 134 | prompt: Union[str, List[str]], 135 | device: Optional[torch.device] = None, 136 | dtype: Optional[torch.dtype] = None, 137 | max_sequence_length: int = 256, 138 | prompt_template: Dict[str, Any] = { 139 | "template": ( 140 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 141 | "1. The main content and theme of the video." 142 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 143 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 144 | "4. background environment, light, style and atmosphere." 145 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 146 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 147 | ), 148 | "crop_start": 95, 149 | }, 150 | **kwargs, 151 | ) -> torch.Tensor: 152 | if isinstance(prompt, str): 153 | prompt = [prompt] 154 | 155 | conditions = {} 156 | conditions.update( 157 | _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length) 158 | ) 159 | conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype)) 160 | 161 | return conditions 162 | 163 | 164 | def prepare_latents( 165 | vae: AutoencoderKLHunyuanVideo, 166 | image_or_video: torch.Tensor, 167 | device: Optional[torch.device] = None, 168 | dtype: Optional[torch.dtype] = None, 169 | generator: Optional[torch.Generator] = None, 170 | **kwargs, 171 | ) -> torch.Tensor: 172 | if image_or_video.ndim == 4: 173 | image_or_video = image_or_video.unsqueeze(2) 174 | assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" 175 | 176 | image_or_video = image_or_video.to(device=device, dtype=dtype) 177 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] 178 | 179 | latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) 180 | latents = latents * vae.config.scaling_factor 181 | latents = latents.to(dtype=dtype) 182 | return {"latents": latents} 183 | 184 | 185 | def _get_llama_prompt_embeds( 186 | tokenizer: LlamaTokenizer, 187 | text_encoder: LlamaModel, 188 | prompt: List[str], 189 | prompt_template: Dict[str, Any], 190 | device: torch.device, 191 | dtype: torch.dtype, 192 | max_sequence_length: int = 256, 193 | num_hidden_layers_to_skip: int = 2, 194 | ) -> Tuple[torch.Tensor, torch.Tensor]: 195 | batch_size = len(prompt) 196 | prompt = [prompt_template["template"].format(p) for p in prompt] 197 | 198 | crop_start = prompt_template.get("crop_start", None) 199 | if crop_start is None: 200 | prompt_template_input = tokenizer( 201 | prompt_template["template"], 202 | padding="max_length", 203 | return_tensors="pt", 204 | return_length=False, 205 | return_overflowing_tokens=False, 206 | return_attention_mask=False, 207 | ) 208 | crop_start = prompt_template_input["input_ids"].shape[-1] 209 | # Remove <|eot_id|> token and placeholder {} 210 | crop_start -= 2 211 | 212 | max_sequence_length += crop_start 213 | text_inputs = tokenizer( 214 | prompt, 215 | max_length=max_sequence_length, 216 | padding="max_length", 217 | truncation=True, 218 | return_tensors="pt", 219 | return_length=False, 220 | return_overflowing_tokens=False, 221 | return_attention_mask=True, 222 | ) 223 | text_input_ids = text_inputs.input_ids.to(device=device) 224 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) 225 | 226 | prompt_embeds = text_encoder( 227 | input_ids=text_input_ids, 228 | attention_mask=prompt_attention_mask, 229 | output_hidden_states=True, 230 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] 231 | prompt_embeds = prompt_embeds.to(dtype=dtype) 232 | 233 | if crop_start is not None and crop_start > 0: 234 | prompt_embeds = prompt_embeds[:, crop_start:] 235 | prompt_attention_mask = prompt_attention_mask[:, crop_start:] 236 | 237 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) 238 | 239 | return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} 240 | 241 | 242 | def _get_clip_prompt_embeds( 243 | tokenizer_2: CLIPTokenizer, 244 | text_encoder_2: CLIPTextModel, 245 | prompt: Union[str, List[str]], 246 | device: torch.device, 247 | dtype: torch.dtype, 248 | max_sequence_length: int = 77, 249 | ) -> torch.Tensor: 250 | text_inputs = tokenizer_2( 251 | prompt, 252 | padding="max_length", 253 | max_length=max_sequence_length, 254 | truncation=True, 255 | return_tensors="pt", 256 | ) 257 | 258 | prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output 259 | prompt_embeds = prompt_embeds.to(dtype=dtype) 260 | 261 | return {"pooled_prompt_embeds": prompt_embeds} 262 | 263 | 264 | def main(args): 265 | accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(hours=36))]) 266 | 267 | video_transforms = transforms.Compose( 268 | [ 269 | transforms.Lambda(lambda x: x / 255.0), 270 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), 271 | ] 272 | ) 273 | 274 | mds_streams = [] 275 | if args.recursive: 276 | for fp in args.mds_data_path: 277 | for dir in os.listdir(fp): 278 | dir_path = os.path.join(fp, dir) 279 | mds_streams.append((Stream(local=dir_path), dir_path)) 280 | else: 281 | for fp in args.mds_data_path: 282 | mds_streams.append((Stream(local=fp), fp)) 283 | 284 | accelerator.print(f"## mds_streams: {len(mds_streams)}") 285 | 286 | dtype = torch.float32 287 | if accelerator.mixed_precision == "fp16": 288 | dtype = torch.float16 289 | elif accelerator.mixed_precision == "bf16": 290 | dtype = torch.bfloat16 291 | 292 | logger.info(f"set dtype to {dtype!r}") 293 | 294 | models = load_condition_models( 295 | args.pretrained_model_name_or_path, 296 | text_encoder_dtype=dtype, 297 | text_encoder_2_dtype=torch.bfloat16, 298 | device=accelerator.device, 299 | ) 300 | models.update(load_latent_models(args.pretrained_model_name_or_path, vae_dtype=dtype, device=accelerator.device)) 301 | 302 | models["vae"].eval().requires_grad_(False) 303 | models["text_encoder"].eval().requires_grad_(False) 304 | models["text_encoder_2"].eval().requires_grad_(False) 305 | 306 | accelerator.wait_for_everyone() 307 | 308 | for varname in ["RANK", "LOCAL_WORLD_SIZE", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]: 309 | assert os.environ.get(varname) is not None, f"{varname} is not set" 310 | logger.info(f"{varname!r}: {os.environ.get(varname)}") 311 | 312 | for stream, data_path in mds_streams: 313 | logger.info(f"## Processing {data_path!r}") 314 | 315 | dataset = StreamingDataset( 316 | streams=[stream], 317 | batch_size=1, 318 | num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8), 319 | ) 320 | 321 | save_path = Path(args.output_dir) / f"{Path(data_path).name}_{accelerator.local_process_index:02d}" 322 | save_path.mkdir(parents=True, exist_ok=True) 323 | save_path = save_path.as_posix() 324 | 325 | logger.info(f"Saving to => {save_path!r}") 326 | 327 | columns = { 328 | # "idx": "int32", 329 | # "item_key": "str", 330 | # "item": "str", 331 | # "frame_count": "int32", 332 | # "bucket_width": "int32", 333 | # "bucket_height": "int32", 334 | # "original_width": "int32", 335 | # "original_height": "int32", 336 | # "caption_str": "str", 337 | # "video_arr": "ndarray", 338 | "prompt_embeds": "ndarray", 339 | "prompt_attention_mask": "ndarray", 340 | "pooled_prompt_embeds": "ndarray", 341 | "latents": "ndarray", 342 | "latents_cond": "ndarray", 343 | # "latents_cond_2": "ndarray", 344 | # "latents_cond_only_first": "ndarray", 345 | # "latents_cond_only_last": "ndarray", 346 | } 347 | 348 | os.umask(0o000) 349 | writer = MDSWriter( 350 | out=save_path, 351 | columns=columns, 352 | compression=args.mds_shard_compression, 353 | size_limit=256 * (2**20), 354 | max_workers=64, 355 | ) 356 | 357 | def collate_fn(batch): 358 | idx = [x["idx"] for x in batch] 359 | item_key = [x["item_key"] for x in batch] 360 | item = [x["item"] for x in batch] 361 | video = [x["video_arr"] for x in batch] 362 | caption = [x["caption_str"] for x in batch] 363 | 364 | for i in range(len(caption)): 365 | caption[i] = caption[i].strip() 366 | for phrase in COMMON_LLM_START_PHRASES: 367 | if caption[i].startswith(phrase): 368 | caption[i] = caption[i].removeprefix(phrase).strip() 369 | 370 | return {"video_arr": video, "caption_str": caption, "idx": idx, "item_key": item_key, "item": item} 371 | 372 | dl = StreamingDataLoader( 373 | dataset, 374 | batch_size=1, 375 | collate_fn=collate_fn, 376 | shuffle=False, 377 | num_workers=8, 378 | prefetch_factor=2, 379 | pin_memory=True, 380 | ) 381 | 382 | for batch in tqdm(dl, dynamic_ncols=True, desc="Precomputing latents", disable=not accelerator.is_main_process): 383 | # print(accelerator.process_index, batch["idx"], batch["item_key"], batch["item"]) 384 | 385 | assert len(batch["video_arr"]) == 1 386 | 387 | video, caption = batch["video_arr"][0], batch["caption_str"][0] 388 | video = torch.from_numpy(video.copy()).permute(0, 3, 1, 2) # n_frames, c, h, w 389 | 390 | control_condition = torch.zeros_like(video) # create an empty video 391 | control_condition[0] = video[0] # keep the first frame 392 | control_condition[-1] = video[-1] # keep the last frame 393 | 394 | # control_condition_2 = torch.zeros_like(video) # create an empty video 395 | # control_condition_2[0] = video[0] # keep the first frame 396 | # control_condition_2[-1] = video[-1] # keep the last frame 397 | # control_condition_2[video.shape[0] // 2] = video[video.shape[0] // 2] # keep the middle frame 398 | 399 | # control_condition_only_first = torch.zeros_like(video) # create an empty video 400 | # control_condition_only_first[0] = video[0] # keep the first frame 401 | 402 | # control_condition_only_last = torch.zeros_like(video) # create an empty video 403 | # control_condition_only_last[-1] = video[-1] # keep the last frame 404 | 405 | video = torch.stack([video_transforms(x) for x in video], dim=0).unsqueeze(0) 406 | control_condition = torch.stack([video_transforms(x) for x in control_condition], dim=0).unsqueeze(0) 407 | # control_condition_2 = torch.stack([video_transforms(x) for x in control_condition_2], dim=0).unsqueeze(0) 408 | # control_condition_only_first = torch.stack([video_transforms(x) for x in control_condition_only_first], dim=0).unsqueeze(0) 409 | # control_condition_only_last = torch.stack([video_transforms(x) for x in control_condition_only_last], dim=0).unsqueeze(0) 410 | 411 | with torch.inference_mode(): # @TODO: add batch support ? 412 | latents = prepare_latents(models["vae"], video, device=accelerator.device, dtype=dtype)["latents"] 413 | latents_cond = prepare_latents( 414 | models["vae"], control_condition, device=accelerator.device, dtype=dtype 415 | )["latents"] 416 | # latents_cond_2 = prepare_latents(models["vae"], control_condition_2, device=accelerator.device, dtype=dtype)["latents"] 417 | # latents_cond_only_first = prepare_latents(models["vae"], control_condition_only_first, device=accelerator.device, dtype=dtype)["latents"] 418 | # latents_cond_only_last = prepare_latents(models["vae"], control_condition_only_last, device=accelerator.device, dtype=dtype)["latents"] 419 | 420 | conditions = prepare_conditions( 421 | tokenizer=models["tokenizer"], 422 | text_encoder=models["text_encoder"], 423 | tokenizer_2=models["tokenizer_2"], 424 | text_encoder_2=models["text_encoder_2"], 425 | prompt=caption, 426 | device=accelerator.device, 427 | dtype=dtype, 428 | ) 429 | prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = ( 430 | conditions["prompt_embeds"], 431 | conditions["prompt_attention_mask"], 432 | conditions["pooled_prompt_embeds"], 433 | ) 434 | 435 | # out_batch = {**batch} # only takes extra space 436 | out_batch = {} 437 | out_batch["latents"] = latents[0].float().cpu().numpy() 438 | out_batch["prompt_embeds"] = prompt_embeds[0].float().cpu().numpy() 439 | out_batch["prompt_attention_mask"] = prompt_attention_mask[0].float().cpu().numpy() 440 | out_batch["pooled_prompt_embeds"] = pooled_prompt_embeds[0].float().cpu().numpy() 441 | out_batch["latents_cond"] = latents_cond[0].float().cpu().numpy() 442 | # out_batch["latents_cond_2"] = latents_cond_2[0].float().cpu().numpy() 443 | # out_batch["latents_cond_only_first"] = latents_cond_only_first[0].float().cpu().numpy() 444 | # out_batch["latents_cond_only_last"] = latents_cond_only_last[0].float().cpu().numpy() 445 | 446 | assert ( 447 | out_batch.keys() == columns.keys() 448 | ), f"{out_batch.keys()} != {columns.keys()}, missing {set(out_batch.keys()) - set(columns.keys())}" 449 | 450 | os.umask(0o000) 451 | writer.write(out_batch) 452 | 453 | writer.finish() 454 | accelerator.wait_for_everyone() 455 | 456 | 457 | if __name__ == "__main__": 458 | parser = ArgumentParser() 459 | parser.add_argument("--mds_data_path", required=True, type=str, nargs="+") 460 | parser.add_argument("--pretrained_model_name_or_path", type=str, default="hunyuanvideo-community/HunyuanVideo") 461 | parser.add_argument("--output_dir", required=True, type=str) 462 | parser.add_argument("--mds_shard_compression", type=str, default=None) 463 | parser.add_argument("--recursive", action="store_true", default=False) 464 | 465 | args = parser.parse_args() 466 | main(args) 467 | --------------------------------------------------------------------------------