├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── attn_processor.py
├── config.py
├── configs
├── accelerate
│ ├── uncompiled_8.yaml
│ └── uncompiled_8_ds.yaml
├── config_defaults.yaml
└── deepspeed
│ └── zero_stage2_config.json
├── ema.py
├── hv_control_lora_inference.py
├── hv_train_control_lora.py
├── mds_dataloaders.py
├── notebooks
└── 000-inference.ipynb
├── optim.py
├── pyproject.toml
├── setup_env.sh
└── tools
├── hv_cache_dataset.py
├── hv_precompute_latents.py
└── hv_precompute_latents_dist.py
/.gitignore:
--------------------------------------------------------------------------------
1 | */output/*
2 | */wandb/*
3 | */outputs/*
4 | */input/*
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115 | .pdm.toml
116 | .pdm-python
117 | .pdm-build/
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.4.0
6 | hooks:
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 | - id: check-docstring-first
10 | - id: check-toml
11 | - id: check-case-conflict
12 |
13 | - repo: https://github.com/asottile/pyupgrade
14 | rev: v3.3.1
15 | hooks:
16 | - id: pyupgrade
17 | args: [--py38-plus]
18 | name: Upgrade code
19 |
20 | - repo: https://github.com/PyCQA/isort
21 | rev: 5.12.0
22 | hooks:
23 | - id: isort
24 | name: imports
25 | require_serial: false
26 | args: [--settings-path=pyproject.toml]
27 |
28 | - repo: https://github.com/psf/black
29 | rev: 23.1.0
30 | hooks:
31 | - id: black
32 | name: Format code
33 | args: [--config=pyproject.toml]
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **HunyuanVideo Keyframe Control Lora**
2 |
3 | ---
4 |
5 | This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our experiments on HunyuanVideo Keyframe Control Lora.
6 |
7 | ## Contents
8 |
9 | - [🔥🔥🔥 News!!](#-news)
10 | - [Abstract](#-abstract)
11 | - [Demo](#-demo)
12 | - [📜 Recommeded Settings](#-recommeded-settings)
13 | - [🛠️ Dependencies and Installation](#-dependencies-and-installation)
14 | - [🚀 Inference](#-inference)
15 | - [🚀 Training](#-training)
16 |
17 | ---
18 |
19 | ## 🔥🔥🔥 News!!
20 |
21 | - 27 February 2025: We release the training code of HunyuanVideo Keyframe Control Lora and [Blog](https://insiders.dashtoon.com/introducing-hunyuan-keyframe-lora-open-source-keyframe-based-video-generation/).
22 | - 24 February 2025: We release the inference code and model weights of HunyuanVideo Keyframe Control Lora . [Download](https://huggingface.co/dashtoon/hunyuan-video-keyframe-control-lora/tree/main).
23 |
24 | ## Abstract
25 |
26 | HunyuanVideo Keyframe Control Lora is an adapter for HunyuanVideo T2V model for keyframe-based video generation. Our architecture builds upon existing models, introducing key enhancements to optimize keyframe-based video generation:
27 |
28 | - We modify the input patch embedding projection layer to effectively incorporate keyframe information. By adjusting the convolutional input parameters, we enable the model to process image inputs within the Diffusion Transformer (DiT) framework.
29 | - We apply Low-Rank Adaptation (LoRA) across all linear layers and the convolutional input layer. This approach facilitates efficient fine-tuning by introducing low-rank matrices that approximate the weight updates, thereby preserving the base model's foundational capabilities while reducing the number of trainable parameters.
30 | - The model is conditioned on user-defined keyframes, allowing precise control over the generated video's start and end frames. This conditioning ensures that the generated content aligns seamlessly with the specified keyframes, enhancing the coherence and narrative flow of the video.
31 |
32 | ## 🎥 Demo
33 |
34 | **Click on the first column images to view the generated videos**
35 |
36 | | Generated Video | Image 1 | Image 2 |
37 | | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------- |
38 | |
|
|
|
39 | |
|
|
|
40 | |
|
|
|
41 | |
|
|
|
42 |
43 | ## 📜 Recommeded Settings
44 |
45 | 1. The model works best on human subjects. Single subject images work slightly better.
46 | 2. It is recommended to use the following image generation resolutions `720x1280`, `544x960`, `1280x720`, `960x544`.
47 | 3. It is recommended to set frames from 33 upto 97. Can go upto 121 frames as well (but not tested much).
48 | 4. Prompting helps a lot but works even without. The prompt can be as simple as just the name of the object you want to generate or can be detailed.
49 | 5. `num_inference_steps` is recommended to be 50, but for fast results you can use 30 as well. Anything less than 30 is not recommended.
50 |
51 | ## 🛠️ Dependencies and Installation
52 |
53 | Begin by cloning the repository:
54 |
55 | ```shell
56 | git clone https://github.com/dashtoon/hunyuan-video-keyframe-control-lora.git
57 | cd hunyuan-video-keyframe-control-lora
58 | ```
59 |
60 | ### Installation Guide for Linux
61 |
62 | We recommend CUDA versions 12.4
63 |
64 | Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
65 |
66 | ```shell
67 | bash setup_env.sh
68 | ```
69 |
70 | ## 🚀 Inference
71 |
72 | The model weights can be downloaded from [Huggingface](https://huggingface.co/dashtoon/hunyuan-video-keyframe-control-lora)
73 |
74 | You can run inference using the provided script. The script uses `flash_attn` but can also be modified to use `sage_attn`. Running the below command will output a video that is saved in `output.mp4`
75 |
76 | - An NVIDIA GPU with CUDA support is required.
77 | - The model is tested on a single 80G GPU.
78 | - **Minimum**: The minimum GPU memory required is ~60GB for 720px1280px129f and ~45G for 544px960px129f.
79 | - **Recommended**: We recommend using a GPU with 80GB of memory for better generation quality.
80 | - Tested operating system: Linux
81 |
82 | ```shell
83 | export BASE_MODEL="hunyuanvideo-community/HunyuanVideo"
84 | export LORA_PATH=""
85 | export IMAGE_1=""
86 | export IMAGE_2=""
87 | export PROMPT=""
88 | export HEIGHT=960
89 | export WIDTH=544
90 | export n_FRAMES=33
91 |
92 | python hv_control_lora_inference.py \
93 | --model $BASE_MODEL \
94 | --lora $LORA_PATH \
95 | --frame1 $IMAGE_1 --frame2 $IMAGE_2 --prompt "$PROMPT" --frames $n_FRAMES \
96 | --height $HEIGHT --width $WIDTH \
97 | --steps 50 \
98 | --guidance 6.0 \
99 | --seed 123143153 \
100 | --output output.mp4
101 | ```
102 |
103 | ## 🚀 Training
104 |
105 | ### Dataset Preparation
106 |
107 | It is recommended to have atleast 1 GPU with 80GB of VRAM. We use mosaic-ml streaming for caching our data. We expect our original data in the following format. Running the tree command, you should see:
108 |
109 | ```
110 | dataset
111 | ├── metadata.csv
112 | ├── videos
113 | ├── 00000.mp4
114 | ├── 00001.mp4
115 | ├── ...
116 | ```
117 |
118 | The csv can contain any number of columns, but due to limited support at the moment, we only make use of prompt and video columns. The CSV should look like this:
119 |
120 | ```
121 | caption,video_file,other_column1,other_column2
122 | A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.,videos/00000.mp4,...,...
123 | ```
124 |
125 | For the above format you would run the following command for starting to cache the dataset:
126 |
127 | ```shell
128 | python tools/hv_cache_dataset.py \
129 | --csv "dataset/metadata.csv" \
130 | --base_dir "dataset" \
131 | --video_column video_file \
132 | --caption_column "caption" \
133 | --output_dir "dataset/mds_cache" \
134 | --bucket_reso \
135 | "1280x720x33" "1280x720x65" "1280x720x97" "960x544x33" "960x544x65" "960x544x97" \
136 | "720x1280x33" "720x1280x65" "720x1280x97" "544x960x33" "544x960x65" "544x960x97" \
137 | --min_bucket_count 100 \
138 | --head_frame 0
139 | ```
140 |
141 | - `bucket_reso` : this specifies the bucket resolutions to train on in the format of WxHxF.
142 | - `head_frame`: the intial frame from where to start extracting from a video
143 |
144 | **NOTE:** It is recommened to first convert your video into separate scenes and ensure there is continuity between scenes. [This](https://github.com/aigc-apps/EasyAnimate/tree/main/easyanimate/video_caption) is a good starting point for video dataset preparation.
145 |
146 | The next commanded will start caching the LLM embeds and the VAE states.
147 |
148 | ```shell
149 | NUM_GPUS=8
150 | MIXED_PRECISION="bf16"
151 | accelerate launch --num_processes=$NUM_GPUS --mixed_precision=$MIXED_PRECISION --main_process_port=12345 \
152 | tools/hv_precompute_latents_dist.py \
153 | --pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo" \
154 | --mds_data_path "dataset/mds_cache" \
155 | --output_dir "dataset/mds_cache_latents" \
156 | --recursive
157 | ```
158 |
159 | Now you need to add the path to all the mds latent folders in `./configs/config_defaults.yaml` config file under `data.local` as a list. The latent_cache should be stored unfer `--output_dir` folder as `1280x720x33_00` folders. Where `1280` is the width of the video, `720` is the height of the video and `33` is the framerate of the video and `00` is the gpu id.
160 | Now we are ready to start training!
161 |
162 | ### Starting a traning run
163 |
164 | ```shell
165 | NUM_GPUS=8
166 | MIXED_PRECISION="bf16"
167 | EXPERIMENT_NAME="my_first_run"
168 | OUTPUT_DIR="outputs/"
169 | CONFIG_PATH="./configs/config_defaults.yaml"
170 | NUM_EPOCHS=1
171 |
172 | accelerate launch --num_processes=$NUM_GPUS --mixed_precision=$MIXED_PRECISION --main_process_port=12345 \
173 | hv_train_control_lora.py \
174 | --config_path $CONFIG_PATH \
175 | --experiment.run_id=$EXPERIMENT_NAME \
176 | --experiment.output_dirpath=$OUTPUT_DIR \
177 | --network.train_norm_layers=False \
178 | --network.lora_dropout=0.05 \
179 | --hparams.ema.use_ema=False \
180 | --hparams.num_train_epochs=1
181 | ```
182 |
183 | ## Acknowledgements
184 |
185 | - We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
186 |
187 | - We build on top of a body of great open-source libraries: transformers, accelerate, peft, diffusers, bitsandbytes, torchao, deepspeed, mosaicml-streaming -- to name a few.
188 |
189 | ## Star History
190 |
191 | [](https://www.star-history.com/#dashtoon/hunyuan-video-keyframe-control-lora&Timeline)
192 |
--------------------------------------------------------------------------------
/attn_processor.py:
--------------------------------------------------------------------------------
1 | import flash_attn
2 | import torch
3 | from accelerate.logging import get_logger
4 | from diffusers.models.embeddings import apply_rotary_emb
5 | from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | def get_cu_seqlens(attention_mask):
11 | """Calculate cu_seqlens_q, cu_seqlens_kv
12 |
13 | Args:
14 | attention_mask (torch.Tensor): boolean attention mask of shape: [B, 1, 1, N]
15 |
16 | Returns:
17 | torch.Tensor: the calculated cu_seqlens for flash attention
18 | """
19 | batch_size = attention_mask.shape[0]
20 | text_len = attention_mask.sum(dim=-1, dtype=torch.int)
21 | max_len = attention_mask.shape[-1]
22 |
23 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
24 |
25 | for i in range(batch_size):
26 | s = text_len[i]
27 | s1 = i * max_len + s
28 | s2 = (i + 1) * max_len
29 | cu_seqlens[2 * i + 1] = s1
30 | cu_seqlens[2 * i + 2] = s2
31 |
32 | return cu_seqlens
33 |
34 |
35 | class HunyuanVideoFlashAttnProcessor:
36 | def __init__(self):
37 | pass
38 |
39 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
40 | if attn.add_q_proj is None and encoder_hidden_states is not None:
41 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
42 |
43 | query = attn.to_q(hidden_states)
44 | key = attn.to_k(hidden_states)
45 | value = attn.to_v(hidden_states)
46 |
47 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
48 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
49 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
50 |
51 | if attn.norm_q is not None:
52 | query = attn.norm_q(query)
53 | if attn.norm_k is not None:
54 | key = attn.norm_k(key)
55 |
56 | if image_rotary_emb is not None:
57 | if attn.add_q_proj is None and encoder_hidden_states is not None:
58 | query = torch.cat(
59 | [
60 | apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
61 | query[:, :, -encoder_hidden_states.shape[1] :],
62 | ],
63 | dim=2,
64 | )
65 | key = torch.cat(
66 | [
67 | apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
68 | key[:, :, -encoder_hidden_states.shape[1] :],
69 | ],
70 | dim=2,
71 | )
72 | else:
73 | query = apply_rotary_emb(query, image_rotary_emb)
74 | key = apply_rotary_emb(key, image_rotary_emb)
75 |
76 | batch_size = hidden_states.shape[0]
77 | img_seq_len = hidden_states.shape[1]
78 | txt_seq_len = 0
79 |
80 | if attn.add_q_proj is not None and encoder_hidden_states is not None:
81 | encoder_query = attn.add_q_proj(encoder_hidden_states)
82 | encoder_key = attn.add_k_proj(encoder_hidden_states)
83 | encoder_value = attn.add_v_proj(encoder_hidden_states)
84 |
85 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
86 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
87 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
88 |
89 | if attn.norm_added_q is not None:
90 | encoder_query = attn.norm_added_q(encoder_query)
91 | if attn.norm_added_k is not None:
92 | encoder_key = attn.norm_added_k(encoder_key)
93 |
94 | query = torch.cat([query, encoder_query], dim=2)
95 | key = torch.cat([key, encoder_key], dim=2)
96 | value = torch.cat([value, encoder_value], dim=2)
97 |
98 | txt_seq_len = encoder_hidden_states.shape[1]
99 |
100 | max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len
101 | cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)
102 |
103 | query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3])
104 | key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3])
105 | value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3])
106 | hidden_states = flash_attn_varlen_func(
107 | query,
108 | key,
109 | value,
110 | cu_seqlens_q,
111 | cu_seqlens_kv,
112 | max_seqlen_q,
113 | max_seqlen_kv,
114 | softmax_scale=None,
115 | dropout_p=0.0,
116 | causal=False,
117 | )
118 |
119 | hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1)
120 | hidden_states = hidden_states.to(query.dtype)
121 |
122 | if encoder_hidden_states is not None:
123 | hidden_states, encoder_hidden_states = (
124 | hidden_states[:, : -encoder_hidden_states.shape[1]],
125 | hidden_states[:, -encoder_hidden_states.shape[1] :],
126 | )
127 |
128 | if getattr(attn, "to_out", None) is not None:
129 | hidden_states = attn.to_out[0](hidden_states)
130 | hidden_states = attn.to_out[1](hidden_states)
131 |
132 | if getattr(attn, "to_add_out", None) is not None:
133 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
134 |
135 | return hidden_states, encoder_hidden_states
136 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import typing
3 | import uuid
4 | from dataclasses import asdict, dataclass
5 |
6 | from peft import LoraConfig
7 | from pyrallis import field
8 |
9 |
10 | @dataclass
11 | class BaseConfig:
12 | def get(self, attribute_name, default=None):
13 | return getattr(self, attribute_name, default)
14 |
15 | def pop(self, attribute_name, default=None):
16 | if hasattr(self, attribute_name):
17 | value = getattr(self, attribute_name)
18 | delattr(self, attribute_name)
19 | return value
20 | else:
21 | return default
22 |
23 | def __str__(self):
24 | return json.dumps(asdict(self), indent=4)
25 |
26 |
27 | @dataclass
28 | class DataLoaderKwargs(BaseConfig):
29 | """Configuration for data loading parameters"""
30 |
31 | drop_last: bool = field(default=True) # Whether to drop the last incomplete batch
32 | num_workers: int = field(default=8) # Number of worker processes for data loading
33 | persistent_workers: bool = field(default=True) # Keep worker processes alive between epochs
34 | pin_memory: bool = field(default=True) # Pin memory for faster data transfer to GPU
35 | prefetch_factor: int = field(default=2) # Number of batches to prefetch per worker
36 |
37 |
38 | @dataclass
39 | class StreamingKwargs(BaseConfig):
40 | """Configuration for data streaming parameters"""
41 |
42 | cache_limit: str = field(default="5tb") # Maximum cache size limit
43 | download_timeout: int = field(default=12000) # Timeout in seconds for downloads
44 | num_canonical_nodes: typing.Optional[int] = field(default=None) # Number of canonical nodes to use
45 | shuffle: bool = field(default=True) # Whether to shuffle the data
46 | batching_method: str = field(default="per_stream") # Method used for batching data
47 |
48 |
49 | @dataclass
50 | class DataConfig(BaseConfig):
51 | """Configuration for data sources and processing"""
52 |
53 | remote: typing.Optional[typing.List[typing.Optional[str]]] = field(default=None) # Remote data source paths
54 | local: typing.Optional[typing.List[typing.Optional[str]]] = field(default=None) # Local data source paths
55 | batch_size: int = field(default=1) # Training batch size
56 | video_key: str = field(default="video") # Key for video data in dataset
57 | caption_key: str = field(default="caption") # Key for caption data in dataset
58 | latents_key: str = field(default="latents") # Key for latents in dataset
59 | prompt_embeds_key: str = field(default="prompt_embeds") # Key for prompt embeddings
60 | latents_cond_key: str = field(default="latents_cond") # Key for conditional latents
61 | prompt_attention_mask_key: str = field(default="prompt_attention_mask") # Key for prompt attention mask
62 | pooled_prompt_embeds_key: str = field(default="pooled_prompt_embeds") # Key for pooled prompt embeddings
63 | repeat: typing.Optional[typing.List] = field(default=None, is_mutable=True) # Number of times to repeat dataset
64 | choose: typing.Optional[typing.List] = field(default=None, is_mutable=True) # Indices to choose from dataset
65 | streaming_kwargs: StreamingKwargs = field(
66 | default_factory=StreamingKwargs, is_mutable=True
67 | ) # Streaming configuration
68 | dataloader_kwargs: DataLoaderKwargs = field(
69 | default_factory=DataLoaderKwargs, is_mutable=True
70 | ) # DataLoader configuration
71 |
72 |
73 | @dataclass
74 | class PretrainedModelConfig(BaseConfig):
75 | """Configuration for pretrained model loading"""
76 |
77 | pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo" # Path or name of pretrained model
78 | revision: typing.Optional[str] = field(default=None) # Specific model revision to use
79 | variant: typing.Optional[str] = field(default=None) # Specific model variant to use
80 |
81 |
82 | @dataclass
83 | class NetworkConfig(BaseConfig):
84 | """Configuration for network architecture"""
85 |
86 | lora_rank: int = field(default=16) # Rank for LoRA adaptation
87 | lora_alpha: int = field(default=16) # Alpha scaling for LoRA
88 | target_modules: typing.Optional[typing.List[str]] = field(default=None, is_mutable=True) # Target modules for LoRA
89 | lora_dropout: float = field(default=0.0) # Dropout probability for LoRA layers
90 | train_norm_layers: bool = field(default=False) # Whether to train normalization layers
91 | init_lora_weights: typing.Union[bool, str] = field(
92 | default=True
93 | ) # typing.Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "loftq"]
94 | lora_layers: typing.Optional[str] = field(default=None)
95 |
96 |
97 | @dataclass
98 | class ExperimentConfig(BaseConfig):
99 | """Configuration for experiment tracking"""
100 |
101 | output_dirpath: str = field(default="./outputs") # Directory path for outputs
102 | random_seed: int = field(default=42) # Random seed for reproducibility
103 | run_id: str = field(default_factory=lambda: uuid.uuid4().hex[:8], is_mutable=True) # Unique run identifier
104 | name: str = field(default="whatever") # Name of the experiment
105 | ic_debug: bool = field(default=False) # Whether to use ic for debugging
106 |
107 |
108 | @dataclass
109 | class FlowMatchConfig(BaseConfig):
110 | """Configuration for flow matching parameters"""
111 |
112 | discrete_flow_shift: float = field(default=7.0) # Shift for discrete flow
113 | timestep_sampling: str = field(default="sigma") # "uniform", "sigmoid", "logit_normal"
114 | weighting_scheme: str = field(default="none") # "sigma_sqrt", "cosmap", "none"
115 | sigmoid_scale: float = field(default=1.0) # Scale for sigmoid function
116 | logit_mean: float = field(default=0.0) # Mean for logit distribution
117 | logit_std: float = field(default=1.0) # Standard deviation for logit distribution
118 |
119 |
120 | @dataclass
121 | class EmaConfig(BaseConfig):
122 | """Ema configuration"""
123 |
124 | use_ema: bool = field(default=False)
125 | ema_decay: float = field(default=0.99)
126 | ema_foreach_disable: bool = field(default=False)
127 | ema_device: str = field(default="accelerator") # | typing.Literal["accelerator", "cpu"]
128 | ema_cpu_only: bool = field(default=False)
129 | ema_update_interval: typing.Optional[int] = field(default=None)
130 |
131 |
132 | @dataclass
133 | class TrainingHyperParametersConfig(BaseConfig):
134 | """Configuration for training hyperparameters"""
135 |
136 | mixed_precision: str = field(default="bf16") # Mixed precision training type
137 | gradient_checkpointing: bool = field(default=True) # Whether to use gradient checkpointing
138 | gradient_accumulation_steps: int = field(default=1) # Number of gradient accumulation steps
139 | learning_rate: float = field(default=1e-04) # Learning rate for training
140 | optimizer_type: str = field(default="torch.optim.AdamW") # Type of optimizer to use
141 | optimizer_args: typing.List[str] = field(default=[], is_mutable=True) # Additional optimizer arguments
142 | max_grad_norm: int = field(default=1.0) # Maximum gradient norm for clipping
143 | grad_clip_method: str = field(default="norm")
144 | lr_scheduler: str = field(default="constant") # Learning rate scheduler type
145 | lr_warmup_steps: int = field(default=0) # Number of warmup steps
146 | lr_scheduler_num_cycles: int = field(default=1) # Number of scheduler cycles
147 | lr_scheduler_power: float = field(default=0.9) # Power for scheduler
148 | guidance_scale: int = field(default=1.0) # Scale for guidance
149 | flow_match: FlowMatchConfig = field(default_factory=FlowMatchConfig, is_mutable=True) # Flow matching configuration
150 | num_train_epochs: typing.Optional[int] = field(default=None) # Number of training epochs
151 | max_train_steps: typing.Optional[int] = field(default=None) # Maximum number of training steps
152 | caption_dropout_p: float = field(default=0.0) # Dropout probability for captions
153 | ema: EmaConfig = field(default_factory=EmaConfig, is_mutable=True) # EMA configuration
154 | gradient_precision: str = field(
155 | default="accelerator"
156 | ) # gradient precision from LLAMA paper | typing.Literal["accelerator", "fp32"]
157 |
158 |
159 | @dataclass
160 | class CheckpointConfig(BaseConfig):
161 | """Configuration for model checkpointing"""
162 |
163 | save_every_n_steps: typing.Optional[int] = field(default=None) # Save checkpoint every N steps
164 | save_last_n_steps: typing.Optional[int] = field(default=None) # Keep last N checkpoints
165 | resume_from_checkpoint: typing.Optional[str] = field(default=None) # Path to checkpoint to resume from
166 |
167 |
168 | @dataclass
169 | class TorchDDPKwargs(BaseConfig):
170 | """Configuration for torch distributed parameters"""
171 |
172 | backend: str = field(default="nccl")
173 | find_unused_parameters: bool = field(default=False)
174 | gradient_as_bucket_view: bool = field(default=False)
175 | static_graph: bool = field(default=False)
176 |
177 |
178 | @dataclass
179 | class Config(BaseConfig):
180 | """Main configuration class combining all sub-configurations"""
181 |
182 | experiment: ExperimentConfig = field(default_factory=ExperimentConfig, is_mutable=True) # Experiment configuration
183 | data: DataConfig = field(default_factory=DataConfig, is_mutable=True) # Data configuration
184 | model: PretrainedModelConfig = field(default_factory=PretrainedModelConfig, is_mutable=True) # Model configuration
185 | network: NetworkConfig = field(default_factory=NetworkConfig, is_mutable=True) # Network configuration
186 | hparams: TrainingHyperParametersConfig = field(
187 | default_factory=TrainingHyperParametersConfig, is_mutable=True
188 | ) # Training hyperparameters
189 | checkpointing: CheckpointConfig = field(
190 | default_factory=CheckpointConfig, is_mutable=True
191 | ) # Checkpointing configuration
192 | ddp_kwargs: TorchDDPKwargs = field(default_factory=TorchDDPKwargs, is_mutable=True)
193 |
194 |
195 | if __name__ == "__main__":
196 | import pyrallis
197 |
198 | cfg = pyrallis.parse(config_class=Config)
199 | print(f"Training {cfg}")
200 |
--------------------------------------------------------------------------------
/configs/accelerate/uncompiled_8.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: "no"
5 | enable_cpu_affinity: false
6 | gpu_ids: all
7 | machine_rank: 0
8 | main_training_function: main
9 | mixed_precision: bf16
10 | num_machines: 1
11 | num_processes: 8
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 |
--------------------------------------------------------------------------------
/configs/accelerate/uncompiled_8_ds.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: ./configs/deepspeed/zero_stage2_config.json
5 | zero3_init_flag: false
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: "no"
8 | enable_cpu_affinity: false
9 | machine_rank: 0
10 | main_training_function: main
11 | num_machines: 1
12 | num_processes: 8
13 | rdzv_backend: static
14 | same_network: true
15 | tpu_env: []
16 | tpu_use_cluster: false
17 | tpu_use_sudo: false
18 | use_cpu: false
19 |
--------------------------------------------------------------------------------
/configs/config_defaults.yaml:
--------------------------------------------------------------------------------
1 | experiment:
2 | output_dirpath: "./i2v_outputs"
3 | random_seed: 4201
4 | name: "i2v"
5 | run_id: "my_first_experiment"
6 |
7 | model:
8 | pretrained_model_name_or_path: "hunyuanvideo-community/HunyuanVideo"
9 | revision: null
10 | variant: null
11 |
12 | data:
13 | local:
14 | batch_size: 1
15 | video_key: "video"
16 | caption_key: "caption"
17 | latents_key: "latents"
18 | prompt_embeds_key: "prompt_embeds"
19 | latents_cond_key: "latents_cond"
20 | prompt_attention_mask_key: "prompt_attention_mask"
21 | pooled_prompt_embeds_key: "pooled_prompt_embeds"
22 |
23 | dataloader_kwargs:
24 | drop_last: false
25 | num_workers: 8
26 | persistent_workers: true
27 | pin_memory: true
28 | prefetch_factor: 2
29 |
30 | streaming_kwargs:
31 | shuffle: true
32 | batching_method: "per_stream"
33 | num_canonical_nodes: 8
34 |
35 | network:
36 | lora_rank: 128
37 | lora_alpha: 128
38 | lora_dropout: 0.00
39 | init_lora_weights: true
40 | train_norm_layers: false
41 | lora_layers: "all-linear"
42 |
43 | hparams:
44 | caption_dropout_p: 0.00
45 | mixed_precision: "bf16"
46 | gradient_checkpointing: true
47 | gradient_accumulation_steps: 1
48 | learning_rate: 2e-4
49 | optimizer_type: "optimi-stableadamw"
50 | optimizer_args: ["weight_decay=1e-2", "eps=1e-8", "betas=(0.9, 0.95)"]
51 | max_grad_norm: 0.0
52 | lr_scheduler: "constant_with_warmup"
53 | lr_warmup_steps: 100
54 | guidance_scale: 1.0
55 | num_train_epochs: 1
56 | max_train_steps: null
57 | gradient_precision: "accelerator"
58 |
59 | flow_match:
60 | weighting_scheme: "none"
61 | timestep_sampling: "logit_normal"
62 | sigmoid_scale: 1.0
63 | logit_mean: 0.0
64 | logit_std: 1.0
65 | discrete_flow_shift: 7.0
66 |
67 | ema:
68 | use_ema: true
69 | ema_device: "accelerator"
70 | ema_decay: 0.99
71 |
72 | checkpointing:
73 | save_every_n_steps: 500
74 | save_last_n_steps: 1500
75 | resume_from_checkpoint: null
76 |
--------------------------------------------------------------------------------
/configs/deepspeed/zero_stage2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "contiguous_gradients": true,
5 | "overlap_comm": true,
6 | "reduce_scatter": true,
7 | "reduce_bucket_size": 5e8,
8 | "allgather_bucket_size": 5e8
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "gradient_accumulation_steps": 1.0,
14 | "train_micro_batch_size_per_gpu": 1.0,
15 | "gradient_clipping": 1.0,
16 | "wall_clock_breakdown": false,
17 | "zero_allow_untested_optimizer": true
18 | }
19 |
--------------------------------------------------------------------------------
/ema.py:
--------------------------------------------------------------------------------
1 | # copy-paste from: https://github.com/bghira/SimpleTuner/blob/main/helpers/training/ema.py
2 | import contextlib
3 | import copy
4 | import logging
5 | import os
6 | from typing import Any, Dict, Iterable, Optional, Union
7 |
8 | import torch
9 | import transformers
10 | from accelerate.logging import get_logger
11 | from diffusers.utils import is_transformers_available
12 | from diffusers.utils.deprecation_utils import deprecate
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def should_update_ema(args, step):
18 | if args.ema_update_interval is None:
19 | # If the EMA update interval is not set, always update the EMA.
20 | return True
21 | else:
22 | should_update = step % args.ema_update_interval == 0
23 | if should_update:
24 | logger.debug("Updating EMA weights...")
25 | return should_update
26 |
27 |
28 | class EMAModel:
29 | """
30 | Exponential Moving Average of models weights
31 | """
32 |
33 | def __init__(
34 | self,
35 | args,
36 | accelerator,
37 | parameters: Iterable[torch.nn.Parameter],
38 | decay: float = 0.9999,
39 | min_decay: float = 0.0,
40 | update_after_step: int = 0,
41 | use_ema_warmup: bool = False,
42 | inv_gamma: Union[float, int] = 1.0,
43 | power: Union[float, int] = 2 / 3,
44 | foreach: bool = True,
45 | model_cls: Optional[Any] = None,
46 | model_config: Dict[str, Any] = None,
47 | **kwargs,
48 | ):
49 | """
50 | Args:
51 | parameters (Iterable[torch.nn.Parameter]): The parameters to track.
52 | decay (float): The decay factor for the exponential moving average.
53 | min_decay (float): The minimum decay factor for the exponential moving average.
54 | update_after_step (int): The number of steps to wait before starting to update the EMA weights.
55 | use_ema_warmup (bool): Whether to use EMA warmup.
56 | inv_gamma (float):
57 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
58 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
59 | foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
60 | device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
61 | weights will be stored on CPU.
62 |
63 | @crowsonkb's notes on EMA Warmup:
64 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
65 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
66 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
67 | at 215.4k steps).
68 | """
69 |
70 | if isinstance(parameters, torch.nn.Module):
71 | deprecation_message = (
72 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
73 | "Please pass the parameters of the module instead."
74 | )
75 | deprecate(
76 | "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
77 | "1.0.0",
78 | deprecation_message,
79 | standard_warn=False,
80 | )
81 | parameters = parameters.parameters()
82 |
83 | # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
84 | use_ema_warmup = True
85 |
86 | if kwargs.get("max_value", None) is not None:
87 | deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
88 | deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
89 | decay = kwargs["max_value"]
90 |
91 | if kwargs.get("min_value", None) is not None:
92 | deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
93 | deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
94 | min_decay = kwargs["min_value"]
95 |
96 | parameters = list(parameters)
97 | self.shadow_params = [p.clone().detach() for p in parameters]
98 |
99 | if kwargs.get("device", None) is not None:
100 | deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
101 | deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
102 | self.to(device=kwargs["device"])
103 |
104 | self.temp_stored_params = None
105 |
106 | self.decay = decay
107 | self.min_decay = min_decay
108 | self.update_after_step = update_after_step
109 | self.use_ema_warmup = use_ema_warmup
110 | self.inv_gamma = inv_gamma
111 | self.power = power
112 | self.optimization_step = 0
113 | self.cur_decay_value = None # set in `step()`
114 | self.foreach = foreach
115 |
116 | self.model_cls = model_cls
117 | self.model_config = model_config
118 | self.args = args
119 | self.accelerator = accelerator
120 | self.training = True # To emulate nn.Module's training mode
121 |
122 | def save_state_dict(self, path: str) -> None:
123 | """
124 | Save the EMA model's state directly to a file.
125 |
126 | Args:
127 | path (str): The file path where the EMA state will be saved.
128 | """
129 | # if the folder containing the path does not exist, create it
130 | os.makedirs(os.path.dirname(path), exist_ok=True)
131 | # grab state dict
132 | state_dict = self.state_dict()
133 | # save it using torch.save
134 | torch.save(state_dict, path)
135 | logger.info(f"EMA model state saved to {path}")
136 |
137 | def load_state_dict(self, path: str) -> None:
138 | """
139 | Load the EMA model's state from a file and apply it to this instance.
140 |
141 | Args:
142 | path (str): The file path from where the EMA state will be loaded.
143 | """
144 | state_dict = torch.load(path, map_location="cpu", weights_only=True)
145 |
146 | # Load metadata
147 | self.decay = state_dict.get("decay", self.decay)
148 | self.min_decay = state_dict.get("min_decay", self.min_decay)
149 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
150 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
151 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
152 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
153 | self.power = state_dict.get("power", self.power)
154 |
155 | # Load shadow parameters
156 | shadow_params = []
157 | idx = 0
158 | while f"shadow_params.{idx}" in state_dict:
159 | shadow_params.append(state_dict[f"shadow_params.{idx}"])
160 | idx += 1
161 |
162 | if len(shadow_params) != len(self.shadow_params):
163 | raise ValueError(
164 | f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, "
165 | f"but found {len(shadow_params)} in the state dict."
166 | )
167 |
168 | for current_param, loaded_param in zip(self.shadow_params, shadow_params):
169 | current_param.data.copy_(loaded_param.data)
170 |
171 | logger.info(f"EMA model state loaded from {path}")
172 |
173 | @classmethod
174 | def from_pretrained(cls, path, model_cls) -> "EMAModel":
175 | _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
176 | model = model_cls.from_pretrained(path)
177 |
178 | ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
179 |
180 | ema_model.load_state_dict(ema_kwargs)
181 | return ema_model
182 |
183 | def save_pretrained(self, path, max_shard_size: str = "10GB"):
184 | if self.model_cls is None:
185 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
186 |
187 | if self.model_config is None:
188 | raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
189 |
190 | model = self.model_cls.from_config(self.model_config)
191 | state_dict = self.state_dict(exclude_params=True)
192 | state_dict.pop("shadow_params", None)
193 |
194 | model.register_to_config(**state_dict)
195 | self.copy_to(model.parameters())
196 | model.save_pretrained(path, max_shard_size=max_shard_size)
197 |
198 | def get_decay(self, optimization_step: int = None) -> float:
199 | """
200 | Compute the decay factor for the exponential moving average.
201 | """
202 | if optimization_step is None:
203 | optimization_step = self.optimization_step
204 |
205 | step = max(0, optimization_step - self.update_after_step - 1)
206 |
207 | if step <= 0:
208 | return 0.0
209 |
210 | if self.use_ema_warmup:
211 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
212 | else:
213 | cur_decay_value = (1 + step) / (10 + step)
214 |
215 | cur_decay_value = min(cur_decay_value, self.decay)
216 | # make sure decay is not smaller than min_decay
217 | cur_decay_value = max(cur_decay_value, self.min_decay)
218 | return cur_decay_value
219 |
220 | @torch.no_grad()
221 | def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None):
222 | if not should_update_ema(self.args, global_step):
223 | return
224 |
225 | if self.args.ema_device == "cpu" and not self.args.ema_cpu_only:
226 | # Move EMA to accelerator for faster update.
227 | self.to(device=self.accelerator.device, non_blocking=True)
228 | if isinstance(parameters, torch.nn.Module):
229 | deprecation_message = (
230 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
231 | "Please pass the parameters of the module instead."
232 | )
233 | deprecate(
234 | "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
235 | "1.0.0",
236 | deprecation_message,
237 | standard_warn=False,
238 | )
239 | parameters = parameters.parameters()
240 |
241 | parameters = list(parameters)
242 |
243 | if global_step is not None:
244 | # When we're updating the EMA periodically, we can't trust the counter.
245 | self.optimization_step = global_step
246 | else:
247 | self.optimization_step += 1
248 |
249 | # Compute the decay factor for the exponential moving average.
250 | decay = self.get_decay(self.optimization_step)
251 | self.cur_decay_value = decay
252 | one_minus_decay = 1 - decay
253 |
254 | context_manager = contextlib.nullcontext
255 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
256 | import deepspeed
257 |
258 | if self.foreach:
259 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
260 | context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
261 |
262 | with context_manager():
263 | params_grad = [param for param in parameters if param.requires_grad]
264 | s_params_grad = [
265 | s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
266 | ]
267 |
268 | if len(params_grad) < len(parameters):
269 | torch._foreach_copy_(
270 | [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
271 | [param for param in parameters if not param.requires_grad],
272 | non_blocking=True,
273 | )
274 |
275 | torch._foreach_sub_(
276 | s_params_grad,
277 | torch._foreach_sub(s_params_grad, params_grad),
278 | alpha=one_minus_decay,
279 | )
280 |
281 | else:
282 | for s_param, param in zip(self.shadow_params, parameters):
283 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
284 | context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
285 |
286 | with context_manager():
287 | if param.requires_grad:
288 | s_param.sub_(one_minus_decay * (s_param - param.to(s_param.device)))
289 | else:
290 | s_param.copy_(param)
291 | if self.args.ema_device == "cpu" and not self.args.ema_cpu_only:
292 | # Move back to CPU for safe-keeping.
293 | self.to(device=self.args.ema_device, non_blocking=True)
294 |
295 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
296 | """
297 | Copy current averaged parameters into given collection of parameters.
298 |
299 | Args:
300 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
301 | updated with the stored moving averages. If `None`, the parameters with which this
302 | `ExponentialMovingAverage` was initialized will be used.
303 | """
304 | parameters = list(parameters)
305 | if self.foreach:
306 | torch._foreach_copy_(
307 | [param.data for param in parameters],
308 | [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
309 | )
310 | else:
311 | for s_param, param in zip(self.shadow_params, parameters):
312 | param.data.copy_(s_param.to(param.device).data)
313 |
314 | def pin_memory(self) -> None:
315 | r"""
316 | Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
317 | offloading EMA params to the host.
318 | """
319 | if torch.backends.mps.is_available():
320 | logger.warning("Apple silicon does not support pinned memory. Skipping.")
321 | return
322 |
323 | if self.args.ema_cpu_only:
324 | return
325 |
326 | # This probably won't work, but we'll do it anyway.
327 | self.shadow_params = [p.pin_memory() for p in self.shadow_params]
328 |
329 | def to(self, *args, **kwargs):
330 | for param in self.shadow_params:
331 | param.data = param.data.to(*args, **kwargs)
332 | return self
333 |
334 | def cuda(self, device=None):
335 | return self.to(device="cuda" if device is None else f"cuda:{device}")
336 |
337 | def cpu(self):
338 | return self.to(device="cpu")
339 |
340 | def state_dict(self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False):
341 | r"""
342 | Returns a dictionary containing a whole state of the EMA model.
343 | """
344 | state_dict = {
345 | "decay": self.decay,
346 | "min_decay": self.min_decay,
347 | "optimization_step": self.optimization_step,
348 | "update_after_step": self.update_after_step,
349 | "use_ema_warmup": self.use_ema_warmup,
350 | "inv_gamma": self.inv_gamma,
351 | "power": self.power,
352 | }
353 | if exclude_params:
354 | return state_dict
355 | for idx, param in enumerate(self.shadow_params):
356 | state_dict[f"{prefix}shadow_params.{idx}"] = param if keep_vars else param.detach()
357 | return state_dict
358 |
359 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
360 | r"""
361 | Save the current parameters for restoring later.
362 | """
363 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
364 |
365 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
366 | r"""
367 | Restore the parameters stored with the `store` method.
368 | """
369 | if self.temp_stored_params is None:
370 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
371 | if self.foreach:
372 | torch._foreach_copy_(
373 | [param.data for param in parameters],
374 | [c_param.data for c_param in self.temp_stored_params],
375 | )
376 | else:
377 | for c_param, param in zip(self.temp_stored_params, parameters):
378 | param.data.copy_(c_param.data)
379 |
380 | # Better memory-wise.
381 | self.temp_stored_params = None
382 |
383 | def parameter_count(self) -> int:
384 | return sum(p.numel() for p in self.shadow_params)
385 |
386 | # Implementing nn.Module methods to emulate its behavior
387 |
388 | def named_children(self):
389 | # No child modules
390 | return iter([])
391 |
392 | def children(self):
393 | return iter([])
394 |
395 | def modules(self):
396 | yield self
397 |
398 | def named_modules(self, memo=None, prefix=""):
399 | yield prefix, self
400 |
401 | def parameters(self, recurse=True):
402 | return iter(self.shadow_params)
403 |
404 | def named_parameters(self, prefix="", recurse=True):
405 | for i, param in enumerate(self.shadow_params):
406 | name = f"{prefix}shadow_params.{i}"
407 | yield name, param
408 |
409 | def buffers(self, recurse=True):
410 | return iter([])
411 |
412 | def named_buffers(self, prefix="", recurse=True):
413 | return iter([])
414 |
415 | def train(self, mode=True):
416 | self.training = mode
417 | return self
418 |
419 | def eval(self):
420 | return self.train(False)
421 |
422 | def zero_grad(self):
423 | # No gradients to zero in EMA model
424 | pass
425 |
--------------------------------------------------------------------------------
/hv_control_lora_inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import os
4 | import time
5 | import uuid
6 | from datetime import datetime
7 | from typing import Any, Callable, Dict, List, Optional, Union
8 |
9 | import av
10 | import cv2
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | import torchvision.transforms.v2 as transforms
15 | from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline
16 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
17 | from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
18 | from diffusers.models.attention import Attention
19 | from diffusers.models.embeddings import apply_rotary_emb
20 | from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed
21 | from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps
22 | from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
23 | from PIL import Image
24 |
25 | # Try to import flash attention
26 | try:
27 | import flash_attn
28 | from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func
29 | except ImportError:
30 | flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None
31 |
32 | try:
33 | from sageattention import sageattn, sageattn_varlen
34 | except ImportError:
35 | sageattn, sageattn_varlen = None, None
36 |
37 |
38 | def get_cu_seqlens(attention_mask):
39 | """Calculate cu_seqlens_q, cu_seqlens_kv using attention_mask"""
40 | batch_size = attention_mask.shape[0]
41 | text_len = attention_mask.sum(dim=-1, dtype=torch.int)
42 | max_len = attention_mask.shape[-1]
43 |
44 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
45 |
46 | for i in range(batch_size):
47 | s = text_len[i]
48 | s1 = i * max_len + s
49 | s2 = (i + 1) * max_len
50 | cu_seqlens[2 * i + 1] = s1
51 | cu_seqlens[2 * i + 2] = s2
52 |
53 | return cu_seqlens
54 |
55 |
56 | class HunyuanVideoFlashAttnProcessor:
57 | def __init__(self, use_flash_attn=True, use_sageattn=False):
58 | self.use_flash_attn = use_flash_attn
59 | self.use_sageattn = use_sageattn
60 | if self.use_flash_attn:
61 | assert flash_attn is not None, "Flash attention not available"
62 | if self.use_sageattn:
63 | assert sageattn is not None, "Sage attention not available"
64 |
65 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):
66 | if attn.add_q_proj is None and encoder_hidden_states is not None:
67 | hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
68 |
69 | query = attn.to_q(hidden_states)
70 | key = attn.to_k(hidden_states)
71 | value = attn.to_v(hidden_states)
72 |
73 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
74 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
75 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
76 |
77 | if attn.norm_q is not None:
78 | query = attn.norm_q(query)
79 | if attn.norm_k is not None:
80 | key = attn.norm_k(key)
81 |
82 | if image_rotary_emb is not None:
83 | if attn.add_q_proj is None and encoder_hidden_states is not None:
84 | query = torch.cat(
85 | [
86 | apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
87 | query[:, :, -encoder_hidden_states.shape[1] :],
88 | ],
89 | dim=2,
90 | )
91 | key = torch.cat(
92 | [
93 | apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
94 | key[:, :, -encoder_hidden_states.shape[1] :],
95 | ],
96 | dim=2,
97 | )
98 | else:
99 | query = apply_rotary_emb(query, image_rotary_emb)
100 | key = apply_rotary_emb(key, image_rotary_emb)
101 |
102 | batch_size = hidden_states.shape[0]
103 | img_seq_len = hidden_states.shape[1]
104 | txt_seq_len = 0
105 |
106 | if attn.add_q_proj is not None and encoder_hidden_states is not None:
107 | encoder_query = attn.add_q_proj(encoder_hidden_states)
108 | encoder_key = attn.add_k_proj(encoder_hidden_states)
109 | encoder_value = attn.add_v_proj(encoder_hidden_states)
110 |
111 | encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
112 | encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
113 | encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
114 |
115 | if attn.norm_added_q is not None:
116 | encoder_query = attn.norm_added_q(encoder_query)
117 | if attn.norm_added_k is not None:
118 | encoder_key = attn.norm_added_k(encoder_key)
119 |
120 | query = torch.cat([query, encoder_query], dim=2)
121 | key = torch.cat([key, encoder_key], dim=2)
122 | value = torch.cat([value, encoder_value], dim=2)
123 |
124 | txt_seq_len = encoder_hidden_states.shape[1]
125 |
126 | max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len
127 | cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)
128 |
129 | query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3])
130 | key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3])
131 | value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3])
132 |
133 | if self.use_flash_attn:
134 | hidden_states = flash_attn_varlen_func(
135 | query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
136 | )
137 | elif self.use_sageattn:
138 | hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
139 | else:
140 | raise NotImplementedError("Please set use_flash_attn=True or use_sageattn=True")
141 |
142 | hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1)
143 | hidden_states = hidden_states.to(query.dtype)
144 |
145 | if encoder_hidden_states is not None:
146 | hidden_states, encoder_hidden_states = (
147 | hidden_states[:, : -encoder_hidden_states.shape[1]],
148 | hidden_states[:, -encoder_hidden_states.shape[1] :],
149 | )
150 |
151 | if getattr(attn, "to_out", None) is not None:
152 | hidden_states = attn.to_out[0](hidden_states)
153 | hidden_states = attn.to_out[1](hidden_states)
154 |
155 | if getattr(attn, "to_add_out", None) is not None:
156 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
157 |
158 | return hidden_states, encoder_hidden_states
159 |
160 |
161 | def save_video(video, output_path="output.mp4"):
162 | """Save frames as a video file"""
163 | width, height = video[0].size
164 |
165 | container = av.open(output_path, mode="w")
166 |
167 | # Create video stream
168 | codec = "libx264"
169 | pixel_format = "yuv420p"
170 | stream = container.add_stream(codec, rate=24)
171 | stream.width = width
172 | stream.height = height
173 | stream.pix_fmt = pixel_format
174 | stream.bit_rate = 4000000 # 4Mbit/s
175 |
176 | for frame_array in video:
177 | frame = av.VideoFrame.from_image(frame_array)
178 | packets = stream.encode(frame)
179 | for packet in packets:
180 | container.mux(packet)
181 |
182 | # Flush remaining packets
183 | for packet in stream.encode():
184 | container.mux(packet)
185 |
186 | container.close()
187 |
188 |
189 | def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
190 | """Resize the image to the bucket resolution."""
191 | is_pil_image = isinstance(image, Image.Image)
192 | if is_pil_image:
193 | image_width, image_height = image.size
194 | else:
195 | image_height, image_width = image.shape[:2]
196 |
197 | if bucket_reso == (image_width, image_height):
198 | return np.array(image) if is_pil_image else image
199 |
200 | bucket_width, bucket_height = bucket_reso
201 |
202 | scale_width = bucket_width / image_width
203 | scale_height = bucket_height / image_height
204 | scale = max(scale_width, scale_height)
205 | image_width = int(image_width * scale + 0.5)
206 | image_height = int(image_height * scale + 0.5)
207 |
208 | if scale > 1:
209 | image = Image.fromarray(image) if not is_pil_image else image
210 | image = image.resize((image_width, image_height), Image.LANCZOS)
211 | image = np.array(image)
212 | else:
213 | image = np.array(image) if is_pil_image else image
214 | image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
215 |
216 | # Crop the image to the bucket resolution
217 | crop_left = (image_width - bucket_width) // 2
218 | crop_top = (image_height - bucket_height) // 2
219 | image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
220 |
221 | return image
222 |
223 |
224 | @torch.inference_mode()
225 | def call_pipe(
226 | pipe,
227 | prompt: Union[str, List[str]] = None,
228 | prompt_2: Union[str, List[str]] = None,
229 | height: int = 720,
230 | width: int = 1280,
231 | num_frames: int = 129,
232 | num_inference_steps: int = 50,
233 | sigmas: List[float] = None,
234 | guidance_scale: float = 6.0,
235 | num_videos_per_prompt: Optional[int] = 1,
236 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
237 | latents: Optional[torch.Tensor] = None,
238 | prompt_embeds: Optional[torch.Tensor] = None,
239 | pooled_prompt_embeds: Optional[torch.Tensor] = None,
240 | prompt_attention_mask: Optional[torch.Tensor] = None,
241 | output_type: Optional[str] = "pil",
242 | return_dict: bool = True,
243 | attention_kwargs: Optional[Dict[str, Any]] = None,
244 | callback_on_step_end: Optional[
245 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
246 | ] = None,
247 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
248 | prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
249 | max_sequence_length: int = 256,
250 | image_latents: Optional[torch.Tensor] = None,
251 | ):
252 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
253 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
254 |
255 | # 1. Check inputs
256 | pipe.check_inputs(
257 | prompt,
258 | prompt_2,
259 | height,
260 | width,
261 | prompt_embeds,
262 | callback_on_step_end_tensor_inputs,
263 | prompt_template,
264 | )
265 |
266 | pipe._guidance_scale = guidance_scale
267 | pipe._attention_kwargs = attention_kwargs
268 | pipe._current_timestep = None
269 | pipe._interrupt = False
270 |
271 | device = pipe._execution_device
272 |
273 | # 2. Define call parameters
274 | if prompt is not None and isinstance(prompt, str):
275 | batch_size = 1
276 | elif prompt is not None and isinstance(prompt, list):
277 | batch_size = len(prompt)
278 | else:
279 | batch_size = prompt_embeds.shape[0]
280 |
281 | # 3. Encode input prompt
282 | prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
283 | prompt=prompt,
284 | prompt_2=prompt_2,
285 | prompt_template=prompt_template,
286 | num_videos_per_prompt=num_videos_per_prompt,
287 | prompt_embeds=prompt_embeds,
288 | pooled_prompt_embeds=pooled_prompt_embeds,
289 | prompt_attention_mask=prompt_attention_mask,
290 | device=device,
291 | max_sequence_length=max_sequence_length,
292 | )
293 |
294 | transformer_dtype = pipe.transformer.dtype
295 | prompt_embeds = prompt_embeds.to(transformer_dtype)
296 | prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
297 | if pooled_prompt_embeds is not None:
298 | pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
299 |
300 | # 4. Prepare timesteps
301 | sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
302 | timesteps, num_inference_steps = retrieve_timesteps(
303 | pipe.scheduler,
304 | num_inference_steps,
305 | device,
306 | sigmas=sigmas,
307 | )
308 |
309 | # 5. Prepare latent variables
310 | num_channels_latents = pipe.transformer.config.in_channels
311 | num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1
312 | latents = pipe.prepare_latents(
313 | batch_size * num_videos_per_prompt,
314 | num_channels_latents,
315 | height,
316 | width,
317 | num_latent_frames,
318 | torch.float32,
319 | device,
320 | generator,
321 | latents,
322 | )
323 |
324 | # 6. Prepare guidance condition
325 | guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
326 |
327 | # 7. Denoising loop
328 | num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order
329 | pipe._num_timesteps = len(timesteps)
330 |
331 | with pipe.progress_bar(total=num_inference_steps) as progress_bar:
332 | for i, t in enumerate(timesteps):
333 | if pipe.interrupt:
334 | continue
335 |
336 | pipe._current_timestep = t
337 | latent_model_input = latents.to(transformer_dtype)
338 | timestep = t.expand(latents.shape[0]).to(latents.dtype)
339 |
340 | noise_pred = pipe.transformer(
341 | hidden_states=torch.cat([latent_model_input, image_latents], dim=1),
342 | timestep=timestep,
343 | encoder_hidden_states=prompt_embeds,
344 | encoder_attention_mask=prompt_attention_mask,
345 | pooled_projections=pooled_prompt_embeds,
346 | guidance=guidance,
347 | attention_kwargs=attention_kwargs,
348 | return_dict=False,
349 | )[0]
350 |
351 | # Compute the previous noisy sample x_t -> x_t-1
352 | latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
353 |
354 | if callback_on_step_end is not None:
355 | callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
356 | callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)
357 | latents = callback_outputs.pop("latents", latents)
358 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
359 |
360 | # Update progress bar
361 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):
362 | progress_bar.update()
363 |
364 | pipe._current_timestep = None
365 |
366 | if output_type != "latent":
367 | latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
368 | video = pipe.vae.decode(latents, return_dict=False)[0]
369 | video = pipe.video_processor.postprocess_video(video, output_type=output_type)
370 | else:
371 | video = latents
372 |
373 | # Offload all models
374 | pipe.maybe_free_model_hooks()
375 |
376 | return (video,) if not return_dict else HunyuanVideoPipelineOutput(frames=video)
377 |
378 |
379 | def setup_pipeline(model_path, lora_path=None):
380 | """Set up the HunyuanVideo pipeline with optional LoRA weights"""
381 | pipe = HunyuanVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
382 | pipe.to("cuda")
383 | pipe.vae.enable_slicing()
384 | pipe.vae.enable_tiling()
385 |
386 | # Apply flash attention to all transformer blocks
387 | for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks:
388 | block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=True, use_sageattn=False)
389 |
390 | # Modify input channels for conditioning
391 | with torch.no_grad():
392 | initial_input_channels = pipe.transformer.config.in_channels
393 | new_img_in = HunyuanVideoPatchEmbed(
394 | patch_size=(
395 | pipe.transformer.config.patch_size_t,
396 | pipe.transformer.config.patch_size,
397 | pipe.transformer.config.patch_size,
398 | ),
399 | in_chans=pipe.transformer.config.in_channels * 2,
400 | embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,
401 | )
402 | new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)
403 | new_img_in.proj.weight.zero_()
404 | new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)
405 |
406 | if pipe.transformer.x_embedder.proj.bias is not None:
407 | new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)
408 |
409 | pipe.transformer.x_embedder = new_img_in
410 | pipe.transformer.x_embedder.requires_grad_(False)
411 |
412 | # Load LoRA weights if provided
413 | if lora_path and os.path.exists(lora_path):
414 | lora_state_dict = pipe.lora_state_dict(lora_path)
415 |
416 | # Load transformer LoRA weights
417 | transformer_lora_state_dict = {
418 | k.replace("transformer.", ""): v
419 | for k, v in lora_state_dict.items()
420 | if k.startswith("transformer.") and "lora" in k
421 | }
422 | pipe.load_lora_into_transformer(
423 | transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe
424 | )
425 | pipe.set_adapters(["i2v"], adapter_weights=[1.0])
426 | pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"])
427 | pipe.unload_lora_weights()
428 |
429 | # Load norm layers if present
430 | NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
431 | transformer_norm_layers_state_dict = {
432 | k.replace("transformer.", ""): v
433 | for k, v in lora_state_dict.items()
434 | if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)
435 | }
436 |
437 | if transformer_norm_layers_state_dict:
438 | print("[INFO] Loading normalization layers from state dict...")
439 | transformer_state_dict = pipe.transformer.state_dict()
440 | transformer_keys = set(transformer_state_dict.keys())
441 | state_dict_keys = set(transformer_norm_layers_state_dict.keys())
442 | extra_keys = list(state_dict_keys - transformer_keys)
443 |
444 | if extra_keys:
445 | print(f"[WARNING] Ignoring unsupported keys: {extra_keys}")
446 | for key in extra_keys:
447 | transformer_norm_layers_state_dict.pop(key)
448 |
449 | pipe.transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False)
450 | else:
451 | print("[INFO] No normalization layers found in state dict")
452 |
453 | return pipe
454 |
455 |
456 | def prepare_conditioning(pipe, frame1_path, frame2_path, n_frames, height, width):
457 | """Prepare conditioning frames for the model"""
458 | video_transforms = transforms.Compose(
459 | [
460 | transforms.Lambda(lambda x: x / 255.0),
461 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
462 | ]
463 | )
464 |
465 | # Load and resize conditioning frames
466 | cond_frame1 = Image.open(frame1_path).convert("RGB")
467 | cond_frame2 = Image.open(frame2_path).convert("RGB")
468 |
469 | cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height))
470 | cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height))
471 |
472 | # Create conditioning video tensor
473 | cond_video = np.zeros(shape=(n_frames, height, width, 3))
474 | cond_video[0], cond_video[-1] = np.array(cond_frame1), np.array(cond_frame2)
475 |
476 | cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)
477 | cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)
478 |
479 | # Encode to latent space
480 | with torch.inference_mode():
481 | image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype)
482 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
483 | cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()
484 | cond_latents = cond_latents * pipe.vae.config.scaling_factor
485 | cond_latents = cond_latents.to(dtype=pipe.dtype)
486 |
487 | return cond_latents
488 |
489 |
490 | def main():
491 | parser = argparse.ArgumentParser(description="Run HunyuanVideo inference with control frames")
492 | parser.add_argument(
493 | "--model", type=str, default="hunyuanvideo-community/HunyuanVideo", help="Path to HunyuanVideo model"
494 | )
495 | parser.add_argument("--lora", type=str, required=True, help="Path to LoRA weights for image-to-video control")
496 | parser.add_argument("--frame1", type=str, required=True, help="Path to first control frame")
497 | parser.add_argument("--frame2", type=str, required=True, help="Path to second control frame")
498 | parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation")
499 | parser.add_argument("--height", type=int, default=720, help="Output video height")
500 | parser.add_argument("--width", type=int, default=1280, help="Output video width")
501 | parser.add_argument("--frames", type=int, default=77, help="Number of frames to generate")
502 | parser.add_argument("--steps", type=int, default=50, help="Number of inference steps")
503 | parser.add_argument("--guidance", type=float, default=6.0, help="Guidance scale")
504 | parser.add_argument("--seed", type=int, default=None, help="Random seed for generation")
505 | parser.add_argument("--output", type=str, default=None, help="Output video path (default: auto-generated)")
506 |
507 | args = parser.parse_args()
508 |
509 | # Validate inputs
510 | for path in [args.frame1, args.frame2]:
511 | if not os.path.exists(path):
512 | raise FileNotFoundError(f"Input frame not found: {path}")
513 |
514 | if args.lora and not os.path.exists(args.lora):
515 | raise FileNotFoundError(f"LoRA weights not found: {args.lora}")
516 |
517 | # Set random seed
518 | seed = args.seed if args.seed is not None else int(time.time()) % 10000
519 | generator = torch.Generator(device="cuda").manual_seed(seed)
520 |
521 | # Setup pipeline
522 | print(f"Loading model from {args.model}")
523 | pipe = setup_pipeline(args.model, args.lora)
524 |
525 | # Prepare conditioning
526 | print("Preparing conditioning frames...")
527 | cond_latents = prepare_conditioning(pipe, args.frame1, args.frame2, args.frames, args.height, args.width)
528 |
529 | # Generate video
530 | print(f"Generating video with prompt: '{args.prompt}'")
531 | video = call_pipe(
532 | pipe,
533 | prompt=args.prompt,
534 | num_frames=args.frames,
535 | num_inference_steps=args.steps,
536 | image_latents=cond_latents,
537 | width=args.width,
538 | height=args.height,
539 | guidance_scale=args.guidance,
540 | generator=generator,
541 | ).frames[0]
542 |
543 | # Save output
544 | if args.output:
545 | output_path = args.output
546 | else:
547 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d_%H%M%S")
548 | output_path = f"hv-CL-{args.height}x{args.width}x{args.frames}-{time_flag}.mp4"
549 |
550 | print(f"Saving video to {output_path}")
551 | save_video(video, output_path)
552 | print("Done!")
553 |
554 |
555 | if __name__ == "__main__":
556 | main()
557 |
--------------------------------------------------------------------------------
/hv_train_control_lora.py:
--------------------------------------------------------------------------------
1 | from icecream import ic, install
2 |
3 | install()
4 | ic.configureOutput(includeContext=True)
5 |
6 | import ast
7 | import gc
8 | import importlib
9 | import json
10 | import logging
11 | import math
12 | import os
13 | import random
14 | import shutil
15 | import sys
16 | import time
17 | import typing
18 | import warnings
19 | from copy import deepcopy
20 | from dataclasses import asdict
21 | from datetime import datetime, timedelta
22 | from pathlib import Path
23 | from pprint import pformat
24 |
25 | import diffusers
26 | import numpy as np
27 | import pyrallis
28 | import torch
29 | import torch.optim.adamw
30 | import transformers
31 | from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType, InitProcessGroupKwargs
32 | from accelerate.logging import get_logger
33 | from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed
34 | from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
35 | from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed
36 | from diffusers.optimization import get_scheduler
37 | from diffusers.training_utils import (
38 | cast_training_params,
39 | compute_density_for_timestep_sampling,
40 | compute_loss_weighting_for_sd3,
41 | )
42 | from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft
43 | from diffusers.utils.torch_utils import is_compiled_module
44 | from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
45 | from tqdm.auto import tqdm
46 |
47 | from attn_processor import HunyuanVideoFlashAttnProcessor # isort: skip
48 | from config import Config # isort:skip
49 | from mds_dataloaders import build_mds_dataloader # isort: skip
50 | from optim import get_optimizer, max_gradient # isort: skip
51 | from ema import EMAModel # isort: skip
52 |
53 | NORM_LAYER_PREFIXES = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
54 |
55 | logging.basicConfig(
56 | level=logging.INFO,
57 | format="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s",
58 | force=True,
59 | handlers=[logging.StreamHandler()],
60 | )
61 | warnings.filterwarnings("ignore") # ignore warning
62 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
63 |
64 | if torch.cuda.is_available():
65 | torch.backends.cuda.matmul.allow_tf32 = True
66 | torch.backends.cudnn.allow_tf32 = True
67 |
68 | logger = get_logger(__name__)
69 |
70 |
71 | def bytes_to_gigabytes(x: int) -> float:
72 | if x is not None:
73 | return x / 1024**3
74 |
75 |
76 | def free_memory() -> None:
77 | if torch.cuda.is_available():
78 | gc.collect()
79 | torch.cuda.empty_cache()
80 | torch.cuda.ipc_collect()
81 |
82 |
83 | def get_memory_statistics(precision: int = 3) -> typing.Dict[str, typing.Any]:
84 | memory_allocated = None
85 | memory_reserved = None
86 | max_memory_allocated = None
87 | max_memory_reserved = None
88 |
89 | if torch.cuda.is_available():
90 | device = torch.cuda.current_device()
91 | memory_allocated = torch.cuda.memory_allocated(device)
92 | memory_reserved = torch.cuda.memory_reserved(device)
93 | max_memory_allocated = torch.cuda.max_memory_allocated(device)
94 | max_memory_reserved = torch.cuda.max_memory_reserved(device)
95 |
96 | elif torch.backends.mps.is_available():
97 | memory_allocated = torch.mps.current_allocated_memory()
98 |
99 | else:
100 | logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.")
101 |
102 | return {
103 | "memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision),
104 | "memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision),
105 | "max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision),
106 | "max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision),
107 | }
108 |
109 |
110 | def get_nb_trainable_parameters(mod: torch.nn.Module):
111 | r"""
112 | Returns the number of trainable parameters and number of all parameters in the model.
113 | """
114 | # note: same as PeftModel.get_nb_trainable_parameters
115 | trainable_params = 0
116 | all_param = 0
117 | for _, param in mod.named_parameters():
118 | num_params = param.numel()
119 | # if using DS Zero 3 and the weights are initialized empty
120 | if num_params == 0 and hasattr(param, "ds_numel"):
121 | num_params = param.ds_numel
122 |
123 | # Due to the design of 4bit linear layers from bitsandbytes
124 | # one needs to multiply the number of parameters by 2 to get
125 | # the correct number of parameters
126 | if param.__class__.__name__ == "Params4bit":
127 | num_params = num_params * 2
128 |
129 | all_param += num_params
130 | if param.requires_grad:
131 | trainable_params += num_params
132 |
133 | return trainable_params, all_param
134 |
135 |
136 | def get_noisy_model_input_and_timesteps(
137 | cfg: Config, latents, noise, noise_scheduler, device, weight_dtype, scheduler_sigmas, generator=None
138 | ):
139 | batch_size = latents.size(0)
140 | if cfg.hparams.flow_match.timestep_sampling == "uniform":
141 | sigmas = torch.rand((batch_size,), device=device, generator=generator)
142 | elif cfg.hparams.flow_match.timestep_sampling == "sigmoid":
143 | # imported from cloneofsimo's minRF trainer: https://github.com/cloneofsimo/minRF
144 | # also used by: https://github.com/XLabs-AI/x-flux/tree/main
145 | # and: https://github.com/kohya-ss/sd-scripts/commit/8a0f12dde812994ec3facdcdb7c08b362dbceb0f
146 | sigmas = torch.sigmoid(
147 | cfg.hparams.flow_match.sigmoid_scale * torch.randn((batch_size,), device=device, generator=generator)
148 | )
149 | elif cfg.hparams.flow_match.timestep_sampling == "logit_normal":
150 | sigmas = torch.normal(
151 | cfg.hparams.flow_match.logit_mean,
152 | cfg.hparams.flow_match.logit_std,
153 | size=(batch_size,),
154 | device=device,
155 | generator=generator,
156 | )
157 | sigmas = torch.sigmoid(cfg.hparams.flow_match.sigmoid_scale * sigmas)
158 |
159 | if cfg.hparams.flow_match.discrete_flow_shift is not None and cfg.hparams.flow_match.discrete_flow_shift > 0:
160 | sigmas = (sigmas * cfg.hparams.flow_match.discrete_flow_shift) / (
161 | 1 + (cfg.hparams.flow_match.discrete_flow_shift - 1) * sigmas
162 | )
163 |
164 | timesteps, sigmas = (sigmas * 1000.0).long(), sigmas.view(-1, 1, 1, 1, 1)
165 | noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
166 | return noisy_model_input, timesteps
167 |
168 |
169 | @pyrallis.wrap()
170 | def main(cfg: Config):
171 | if cfg.experiment.ic_debug:
172 | ic.enable()
173 | else:
174 | ic.disable()
175 |
176 | # ======================================================
177 | # 1. configs & runtime variables
178 | # ======================================================
179 | output_dirpath = Path(cfg.experiment.output_dirpath) / cfg.experiment.run_id
180 | logging_dirpath = output_dirpath / "logs"
181 |
182 | accelerator_project_config = ProjectConfiguration(project_dir=output_dirpath, logging_dir=logging_dirpath)
183 | ddp_kwargs = DistributedDataParallelKwargs(
184 | find_unused_parameters=cfg.ddp_kwargs.find_unused_parameters,
185 | gradient_as_bucket_view=cfg.ddp_kwargs.gradient_as_bucket_view,
186 | static_graph=cfg.ddp_kwargs.static_graph,
187 | )
188 | init_kwargs = InitProcessGroupKwargs(backend=cfg.ddp_kwargs.backend, timeout=timedelta(seconds=5400))
189 |
190 | accelerator = Accelerator(
191 | gradient_accumulation_steps=cfg.hparams.gradient_accumulation_steps,
192 | mixed_precision=cfg.hparams.mixed_precision,
193 | log_with=None,
194 | project_config=accelerator_project_config,
195 | kwargs_handlers=[ddp_kwargs, init_kwargs],
196 | )
197 |
198 | print(accelerator.state)
199 |
200 | accelerator.print("\nENVIRONMENT\n")
201 | accelerator.print(f" Python .......................... {sys.version}")
202 | accelerator.print(f" torch.__version__ ............... {torch.__version__}")
203 | accelerator.print(f" torch.version.cuda .............. {torch.version.cuda}")
204 | accelerator.print(f" torch.backends.cudnn.version() .. {torch.backends.cudnn.version()}\n")
205 | accelerator.print("\n")
206 | accelerator.print(f">> Run ID : {cfg.experiment.run_id!r}")
207 |
208 | if accelerator.is_main_process:
209 | transformers.utils.logging.set_verbosity_warning()
210 | diffusers.utils.logging.set_verbosity_info()
211 | else:
212 | transformers.utils.logging.set_verbosity_error()
213 | diffusers.utils.logging.set_verbosity_error()
214 |
215 | if cfg.experiment.random_seed is not None:
216 | set_seed(cfg.experiment.random_seed)
217 |
218 | if accelerator.num_processes > 1:
219 | logger.info("DDP VARS: ")
220 | logger.info(f" WORLD_SIZE: {os.getenv('WORLD_SIZE', 'N/A')}")
221 | logger.info(f" LOCAL_WORLD_SIZE: {os.getenv('LOCAL_WORLD_SIZE', 'N/A')}")
222 | logger.info(f" RANK: {os.getenv('RANK', 'N/A')}")
223 | logger.info(f" MASTER_ADDR: {os.getenv('MASTER_ADDR', 'N/A')}")
224 | logger.info(f" MASTER_PORT: {os.getenv('MASTER_PORT', 'N/A')}")
225 |
226 | if accelerator.is_main_process:
227 | output_dirpath.mkdir(parents=True, exist_ok=True)
228 | if not accelerator.is_main_process:
229 | ic.disable()
230 |
231 | if accelerator.is_main_process:
232 | logger.info(f"Saving config to {output_dirpath / 'config.yaml'}")
233 | yaml_cfg = pyrallis.dump(cfg)
234 | with open(output_dirpath / "config.yaml", "w") as f:
235 | f.write(yaml_cfg)
236 |
237 | logger.info(f"config = \n{pyrallis.dump(cfg)}")
238 |
239 | # ======================================================
240 | # 2. build model
241 | # ======================================================
242 |
243 | noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=cfg.hparams.flow_match.discrete_flow_shift)
244 |
245 | load_dtype = torch.bfloat16
246 | logger.info(f"Load transformer model from {cfg.model.pretrained_model_name_or_path!r}")
247 | t0 = time.time()
248 | transformer = HunyuanVideoTransformer3DModel.from_pretrained(
249 | cfg.model.pretrained_model_name_or_path,
250 | subfolder="transformer",
251 | torch_dtype=load_dtype,
252 | revision=cfg.model.revision,
253 | variant=cfg.model.variant,
254 | )
255 | logger.info(f"Loaded transformer model from {cfg.model.pretrained_model_name_or_path!r} in {time.time() - t0:.2f}s")
256 |
257 | for blk in transformer.transformer_blocks + transformer.single_transformer_blocks:
258 | blk.attn.processor = HunyuanVideoFlashAttnProcessor()
259 | logger.info("transformer patch with flash_attn done ok!")
260 |
261 | with torch.no_grad():
262 | logger.info("expand transformer x_embedder input channels")
263 | t0 = time.time()
264 | initial_input_channels = transformer.config.in_channels
265 | new_img_in = HunyuanVideoPatchEmbed(
266 | patch_size=(transformer.config.patch_size_t, transformer.config.patch_size, transformer.config.patch_size),
267 | in_chans=transformer.config.in_channels * 2,
268 | embed_dim=transformer.config.num_attention_heads * transformer.config.attention_head_dim,
269 | )
270 | new_img_in.proj.weight.zero_()
271 | new_img_in.proj.weight[:, :initial_input_channels].copy_(transformer.x_embedder.proj.weight)
272 | if transformer.x_embedder.proj.bias is not None:
273 | new_img_in.proj.bias.copy_(transformer.x_embedder.proj.bias)
274 | transformer.x_embedder = new_img_in
275 | assert torch.all(transformer.x_embedder.proj.weight[:, initial_input_channels:] == 0)
276 | transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
277 | logger.info(f"expanded transformer x_embedder input channels in {time.time() - t0:.2f}s")
278 | accelerator.wait_for_everyone()
279 |
280 | weight_dtype = torch.float32
281 | if accelerator.state.deepspeed_plugin:
282 | # DeepSpeed is handling precision, use what's in the DeepSpeed config
283 | if (
284 | "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
285 | and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
286 | ):
287 | weight_dtype = torch.float16
288 | if (
289 | "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
290 | and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
291 | ):
292 | weight_dtype = torch.bfloat16
293 | else:
294 | if accelerator.mixed_precision == "fp16":
295 | weight_dtype = torch.float16
296 | elif accelerator.mixed_precision == "bf16":
297 | weight_dtype = torch.bfloat16
298 | logger.info(f"configured weight dtype: {weight_dtype!r}")
299 |
300 | if cfg.hparams.gradient_checkpointing:
301 | transformer.enable_gradient_checkpointing()
302 |
303 | transformer.requires_grad_(False)
304 |
305 | transformer.to(accelerator.device, dtype=weight_dtype)
306 | logger.info(f"dit dtype: {next(transformer.parameters()).dtype!r}")
307 |
308 | if cfg.network.lora_layers is not None:
309 | if cfg.network.lora_layers != "all-linear":
310 | target_modules = [layer.strip() for layer in cfg.network.lora_layers.split(",")]
311 | if "x_embedder" not in target_modules:
312 | target_modules.append("x_embedder.proj")
313 | elif cfg.network.lora_layers == "all-linear":
314 | target_modules = set()
315 | for name, module in transformer.named_modules():
316 | if isinstance(module, torch.nn.Linear):
317 | target_modules.add(name)
318 | target_modules = list(target_modules)
319 | if "x_embedder" not in target_modules:
320 | target_modules.append("x_embedder.proj")
321 | target_modules = [t for t in target_modules if "norm" not in t]
322 | else:
323 | assert cfg.network.target_modules is not None, "either `lora_layers` or `target_modules` must be specified"
324 | target_modules = cfg.network.target_modules
325 |
326 | logger.info(f"using LoRA traning mode: ")
327 | logger.info(f"rank .......................................... {cfg.network.lora_rank!r}")
328 | logger.info(f"alpha ......................................... {cfg.network.lora_alpha!r}")
329 | logger.info(f"target_modules ................................ {json.dumps(target_modules, indent=4)}")
330 |
331 | transformer_lora_config = LoraConfig(
332 | r=cfg.network.lora_rank,
333 | lora_alpha=cfg.network.lora_alpha,
334 | lora_dropout=cfg.network.lora_dropout,
335 | target_modules=target_modules,
336 | init_lora_weights=cfg.network.init_lora_weights,
337 | )
338 | transformer.add_adapter(transformer_lora_config)
339 | accelerator.wait_for_everyone()
340 |
341 | trainable_params, all_param = get_nb_trainable_parameters(transformer)
342 | logger.info(
343 | f"trainable params: {trainable_params:,d} || "
344 | f"all params: {all_param:,d} || "
345 | f"trainable%: {100 * trainable_params / all_param:.4f}"
346 | )
347 |
348 | if cfg.network.train_norm_layers:
349 | train_norm_layers = []
350 | logger.info(f"train norm layers, setting requires_grad to True for layers matching {NORM_LAYER_PREFIXES!r}")
351 | for name, param in transformer.named_parameters():
352 | if any(k in name for k in NORM_LAYER_PREFIXES):
353 | param.requires_grad_(True)
354 | train_norm_layers.append(name)
355 | logger.info(f"train norm layers ............................. {json.dumps(train_norm_layers, indent=4)}")
356 |
357 | if cfg.hparams.mixed_precision == "fp16":
358 | logger.warning("full fp16 training is unstable, casting params to fp32")
359 | cast_training_params([transformer])
360 |
361 | if cfg.hparams.gradient_checkpointing:
362 | transformer.enable_gradient_checkpointing()
363 |
364 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
365 |
366 | ema_model = None
367 | if accelerator.is_main_process and cfg.hparams.ema.use_ema:
368 | logger.info("Using EMA. Creating EMAModel.")
369 | ema_model_cls, ema_model_config = transformer.__class__, transformer.config
370 | ema_model = EMAModel(
371 | cfg.hparams.ema,
372 | accelerator,
373 | parameters=transformer_lora_parameters,
374 | model_cls=ema_model_cls,
375 | model_config=ema_model_config,
376 | decay=cfg.hparams.ema.ema_decay,
377 | foreach=not cfg.hparams.ema.ema_foreach_disable,
378 | )
379 | logger.info(f"EMA model creation completed with {ema_model.parameter_count():,} parameters")
380 |
381 | accelerator.wait_for_everyone()
382 |
383 | use_deepspeed_optimizer = (
384 | accelerator.state.deepspeed_plugin is not None
385 | and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
386 | )
387 | use_deepspeed_scheduler = (
388 | accelerator.state.deepspeed_plugin is not None
389 | and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
390 | )
391 |
392 | optimizer = get_optimizer(
393 | transformer_lora_parameters,
394 | optimizer_name=cfg.hparams.optimizer_type,
395 | learning_rate=cfg.hparams.learning_rate,
396 | optimizer_args_str=cfg.hparams.optimizer_args,
397 | use_deepspeed=use_deepspeed_optimizer,
398 | )
399 |
400 | # ======================================================
401 | # 3. build dataset and dataloaders
402 | # ======================================================
403 |
404 | train_dataloader = build_mds_dataloader(
405 | remote=cfg.data.remote,
406 | local=cfg.data.local,
407 | batch_size=cfg.data.batch_size,
408 | video_key=cfg.data.video_key,
409 | caption_key=cfg.data.caption_key,
410 | latents_key=cfg.data.latents_key,
411 | latents_cond_key=cfg.data.latents_cond_key,
412 | prompt_embeds_key=cfg.data.prompt_embeds_key,
413 | prompt_attention_mask_key=cfg.data.prompt_attention_mask_key,
414 | pooled_prompt_embeds_key=cfg.data.pooled_prompt_embeds_key,
415 | streaming_kwargs=asdict(cfg.data.streaming_kwargs),
416 | dataloader_kwargs=asdict(cfg.data.dataloader_kwargs),
417 | latent_dtype=weight_dtype,
418 | )
419 |
420 | # =======================================================
421 | # 4. distributed training preparation with accelerator
422 | # =======================================================
423 |
424 | def unwrap_model(model):
425 | model = accelerator.unwrap_model(model)
426 | model = model._orig_mod if is_compiled_module(model) else model
427 | return model
428 |
429 | def save_model_hook(models, weights, output_dir):
430 | if not accelerator.is_main_process:
431 | return
432 |
433 | if cfg.hparams.ema.use_ema and ema_model is not None:
434 | primary_model = unwrap_model(transformer)
435 | ema_model_path = os.path.join(output_dir, "ema_model.pt")
436 | logger.info(f"Saving EMA model state to {ema_model_path!r}")
437 | try:
438 | ema_model.save_state_dict(ema_model_path)
439 | except Exception as e:
440 | logger.error(f"Error saving EMA model: {e!r}")
441 |
442 | # we'll temporarily overwrite the LoRA parameters with the EMA parameters to save it.
443 | logger.info("Saving EMA model to disk.")
444 | trainable_parameters = [p for p in primary_model.parameters() if p.requires_grad]
445 | ema_model.store(trainable_parameters)
446 | ema_model.copy_to(trainable_parameters)
447 | transformer_lora_layers = get_peft_model_state_dict(primary_model)
448 | HunyuanVideoPipeline.save_lora_weights(
449 | os.path.join(output_dir, "ema"),
450 | transformer_lora_layers=transformer_lora_layers,
451 | weight_name=f"{cfg.experiment.name}.sft",
452 | )
453 | ema_model.restore(trainable_parameters)
454 |
455 | transformer_lora_layers_to_save = None
456 | for model in models:
457 | if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
458 | model = unwrap_model(model)
459 | transformer_lora_layers_to_save = get_peft_model_state_dict(model)
460 |
461 | if cfg.network.train_norm_layers:
462 | transformer_norm_layers_to_save = {
463 | f"transformer.{name}": param
464 | for name, param in model.named_parameters()
465 | if any(k in name for k in NORM_LAYER_PREFIXES)
466 | }
467 | transformer_lora_layers_to_save = {
468 | **transformer_lora_layers_to_save,
469 | **transformer_norm_layers_to_save,
470 | }
471 | else:
472 | raise ValueError(f"unexpected save model: {model.__class__}")
473 |
474 | if weights:
475 | weights.pop()
476 |
477 | HunyuanVideoPipeline.save_lora_weights(
478 | output_dir,
479 | transformer_lora_layers=transformer_lora_layers_to_save,
480 | weight_name=f"{cfg.experiment.name}.sft",
481 | )
482 |
483 | if hasattr(train_dataloader, "state_dict"):
484 | torch.save(train_dataloader.state_dict(), os.path.join(output_dir, "train_dataloader_state.pt"))
485 |
486 | def load_model_hook(models, input_dir):
487 | if hasattr(train_dataloader, "load_state_dict"):
488 | logger.info(f"Loading train dataloader state from Path: {input_dir!r}")
489 | train_dataloader.load_state_dict(torch.load(os.path.join(input_dir, "train_dataloader_state.pt")))
490 |
491 | if cfg.hparams.ema.use_ema and ema_model is not None:
492 | logger.info(f"Loading EMA model from Path: {input_dir!r}")
493 | try:
494 | ema_model.load_state_dict(os.path.join(input_dir, "ema_model.pt"))
495 | except Exception as e:
496 | logger.error(f"Could not load EMA model: {e!r}")
497 |
498 | transformer_ = None
499 | if not accelerator.distributed_type == DistributedType.DEEPSPEED:
500 | while len(models) > 0:
501 | model = models.pop()
502 |
503 | if isinstance(model, type(unwrap_model(transformer))):
504 | transformer_ = model
505 | else:
506 | raise ValueError(f"unexpected save model: {model.__class__}")
507 | else:
508 | transformer_ = HunyuanVideoTransformer3DModel.from_pretrained(
509 | cfg.model.pretrained_model_name_or_path, subfolder="transformer"
510 | )
511 | transformer_.to(accelerator.device, weight_dtype)
512 |
513 | with torch.no_grad():
514 | initial_input_channels = transformer.config.in_channels
515 | new_img_in = HunyuanVideoPatchEmbed(
516 | patch_size=(
517 | transformer.config.patch_size_t,
518 | transformer.config.patch_size,
519 | transformer.config.patch_size,
520 | ),
521 | in_chans=transformer.config.in_channels * 2,
522 | embed_dim=transformer.config.num_attention_heads
523 | * transformer.config.num_attention_heads.attention_head_dim,
524 | )
525 | new_img_in.proj.weight.zero_()
526 | new_img_in.proj.weight[:, :initial_input_channels].copy_(transformer.x_embedder.proj.weight)
527 | if transformer.x_embedder.proj.bias is not None:
528 | new_img_in.proj.bias.copy_(transformer.x_embedder.proj.bias)
529 | transformer.x_embedder = new_img_in
530 | transformer.register_to_config(
531 | in_channels=initial_input_channels * 2, out_channels=initial_input_channels
532 | )
533 |
534 | transformer_.add_adapter(transformer_lora_config)
535 |
536 | lora_weight_name = os.path.join(input_dir, f"{cfg.experiment.name}.sft")
537 | logger.info(f"Loading LoRA weights from Path: {lora_weight_name!r}")
538 | lora_state_dict = HunyuanVideoPipeline.lora_state_dict(lora_weight_name)
539 | transformer_lora_state_dict = {
540 | f'{k.replace("transformer.", "")}': v
541 | for k, v in lora_state_dict.items()
542 | if k.startswith("transformer.") and "lora" in k
543 | }
544 | # transformer_lora_state_dict = convert_unet_state_dict_to_peft(transformer_lora_state_dict)
545 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_lora_state_dict, adapter_name="default")
546 |
547 | if incompatible_keys is not None:
548 | # check only for unexpected keys
549 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
550 | if unexpected_keys:
551 | logger.warning(
552 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
553 | f" {unexpected_keys}. "
554 | )
555 |
556 | if cfg.network.train_norm_layers:
557 | transformer_norm_layers_state_dict = {
558 | f'{k.replace("transformer.", "")}': v
559 | for k, v in lora_state_dict.items()
560 | if k.startswith("transformer.") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)
561 | }
562 | for key in list(transformer_norm_layers_state_dict.keys()):
563 | if key.split(".")[0] == "transformer":
564 | transformer_norm_layers_state_dict[
565 | key[len(f"transformer.") :]
566 | ] = transformer_norm_layers_state_dict.pop(key)
567 |
568 | transformer_state_dict = transformer.state_dict()
569 | transformer_keys = set(transformer_state_dict.keys())
570 | state_dict_keys = set(transformer_norm_layers_state_dict.keys())
571 | extra_keys = list(state_dict_keys - transformer_keys)
572 |
573 | if extra_keys:
574 | logger.warning(
575 | f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
576 | )
577 |
578 | for key in extra_keys:
579 | transformer_norm_layers_state_dict.pop(key)
580 |
581 | # We can't load with strict=True because the current state_dict does not contain all the transformer keys
582 | incompatible_keys = transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False)
583 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
584 |
585 | # We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
586 | if unexpected_keys:
587 | if any(norm_key in k for k in unexpected_keys for norm_key in NORM_LAYER_PREFIXES):
588 | raise ValueError(
589 | f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
590 | )
591 |
592 | if cfg.hparams.mixed_precision == "fp16":
593 | cast_training_params([transformer_])
594 |
595 | logger.info(f"Completed loading checkpoint from Path: {input_dir!r}")
596 |
597 | accelerator.register_save_state_pre_hook(save_model_hook)
598 | accelerator.register_load_state_pre_hook(load_model_hook)
599 |
600 | # The epoch_size attribute of StreamingDataset is the number of samples per epoch of training.
601 | # The __len__() method returns the epoch_size divided by the number of devices – it is the number of samples seen per device, per epoch.
602 | # The size() method returns the number of unique samples in the underlying dataset.
603 | # Due to upsampling/downsampling, size() may not be the same as epoch_size.
604 | if cfg.hparams.max_train_steps is None:
605 | len_train_dataloader_after_sharding = len(train_dataloader)
606 | num_update_steps_per_epoch = math.ceil(
607 | len_train_dataloader_after_sharding / cfg.hparams.gradient_accumulation_steps
608 | )
609 | num_training_steps_for_scheduler = (
610 | cfg.hparams.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
611 | )
612 | else:
613 | num_training_steps_for_scheduler = cfg.hparams.max_train_steps * accelerator.num_processes
614 |
615 | if use_deepspeed_scheduler:
616 | from accelerate.utils import DummyScheduler
617 |
618 | lr_scheduler = DummyScheduler(
619 | name=cfg.hparams.lr_scheduler,
620 | optimizer=optimizer,
621 | total_num_steps=num_training_steps_for_scheduler,
622 | num_warmup_steps=cfg.hparams.lr_warmup_steps * accelerator.num_processes,
623 | )
624 | else:
625 | lr_scheduler = get_scheduler(
626 | name=cfg.hparams.lr_scheduler,
627 | optimizer=optimizer,
628 | num_training_steps=num_training_steps_for_scheduler,
629 | num_warmup_steps=cfg.hparams.lr_warmup_steps * accelerator.num_processes,
630 | num_cycles=cfg.hparams.lr_scheduler_num_cycles,
631 | power=cfg.hparams.lr_scheduler_power,
632 | )
633 |
634 | # not need to wrap dataloader because mosaicml-streaming handles it internally
635 | # the config should be passed via deepspeed json
636 | if accelerator.state.deepspeed_plugin is not None:
637 | d = transformer.config.num_attention_heads * transformer.config.attention_head_dim
638 | accelerator.state.deepspeed_plugin.deepspeed_config["zero_optimization"]["reduce_bucket_size"] = d
639 | accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = cfg.data.batch_size
640 | accelerator.state.deepspeed_plugin.deepspeed_config[
641 | "gradient_accumulation_steps"
642 | ] = cfg.hparams.gradient_accumulation_steps
643 |
644 | # Prepare everything with our `accelerator`.
645 | # passing dataloader is required to resolve deepspeed 'auto' params, but we do
646 | transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler)
647 |
648 | if cfg.hparams.ema.use_ema and ema_model is not None:
649 | if cfg.hparams.ema.ema_device == "accelerator":
650 | logger.info("Moving EMA model weights to accelerator...")
651 |
652 | ema_model.to((accelerator.device if cfg.hparams.ema.ema_device == "accelerator" else "cpu"), dtype=weight_dtype)
653 |
654 | if cfg.hparams.ema.ema_device == "cpu" and not cfg.hparams.ema.ema_cpu_only:
655 | logger.info("Pinning EMA model weights to CPU...")
656 | try:
657 | ema_model.pin_memory()
658 | except Exception as e:
659 | logger.error(f"Failed to pin EMA model to CPU: {e}")
660 |
661 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.hparams.gradient_accumulation_steps)
662 | if cfg.hparams.max_train_steps is None:
663 | cfg.hparams.max_train_steps = cfg.hparams.num_train_epochs * num_update_steps_per_epoch
664 | if num_training_steps_for_scheduler != cfg.hparams.max_train_steps * accelerator.num_processes:
665 | logger.warning(
666 | f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
667 | f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
668 | f"This inconsistency may result in the learning rate scheduler not functioning properly."
669 | )
670 |
671 | # Afterwards we recalculate our number of training epochs
672 | cfg.hparams.num_train_epochs = math.ceil(cfg.hparams.max_train_steps / num_update_steps_per_epoch)
673 |
674 | total_batch_size = cfg.data.batch_size * accelerator.num_processes * cfg.hparams.gradient_accumulation_steps
675 | num_trainable_parameters = sum(p.numel() for p in transformer_lora_parameters)
676 |
677 | # fmt: off
678 | logger.info("***** Running training *****")
679 | logger.info(f" Num trainable parameters ........................................... {num_trainable_parameters}")
680 | logger.info(f" Num examples ....................................................... {train_dataloader.dataset.size}")
681 | logger.info(f" Num batches each epoch ............................................. {len(train_dataloader)}")
682 | logger.info(f" Num epochs ......................................................... {cfg.hparams.num_train_epochs}")
683 | logger.info(f" Instantaneous batch size per device ................................ {cfg.data.batch_size}")
684 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) ... {total_batch_size}")
685 | logger.info(f" Gradient accumulation steps ........................................ {cfg.hparams.gradient_accumulation_steps}")
686 | logger.info(f" Total optimization steps ........................................... {cfg.hparams.max_train_steps}")
687 | # fmt: on
688 |
689 | global_step, first_epoch = 0, 0
690 |
691 | if not cfg.checkpointing.resume_from_checkpoint:
692 | initial_global_step = 0
693 | else:
694 | if cfg.checkpointing.resume_from_checkpoint != "latest":
695 | path = cfg.checkpointing.resume_from_checkpoint
696 | else:
697 | # Get the mos recent checkpoint
698 | dirs = os.listdir(cfg.experiment.output_dir)
699 | dirs = [d for d in dirs if d.startswith("checkpoint")]
700 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
701 | path = dirs[-1] if len(dirs) > 0 else None
702 | path = os.path.join(cfg.experiment.output_dir, path)
703 |
704 | if path is None:
705 | accelerator.print(
706 | f"Checkpoint {cfg.checkpointing.resume_from_checkpoint!r} does not exist. Starting a new training run."
707 | )
708 | cfg.checkpointing.resume_from_checkpoint = None
709 | initial_global_step = 0
710 | else:
711 | accelerator.print(f"Resuming from checkpoint {path!r}")
712 | accelerator.load_state(path)
713 |
714 | global_step = int(path.split("checkpoint-step")[1])
715 | initial_global_step = global_step
716 | first_epoch = global_step // num_update_steps_per_epoch
717 |
718 | logger.info(f"Override: global_step={initial_global_step} | first_epoch={first_epoch}")
719 |
720 | memory_statistics = get_memory_statistics()
721 | logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
722 |
723 | # =======================================================
724 | # 5. training loop
725 | # =======================================================
726 | accelerator.wait_for_everyone()
727 | progress_bar = tqdm(
728 | range(0, cfg.hparams.max_train_steps),
729 | initial=initial_global_step,
730 | desc="Steps",
731 | dynamic_ncols=True,
732 | disable=not accelerator.is_main_process,
733 | )
734 |
735 | model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
736 | generator = torch.Generator(device=accelerator.device)
737 | scheduler_sigmas = noise_scheduler.sigmas.clone().to(device=accelerator.device, dtype=weight_dtype)
738 | if cfg.experiment.random_seed is not None:
739 | generator = generator.manual_seed(cfg.experiment.random_seed)
740 |
741 | for epoch in range(first_epoch, cfg.hparams.num_train_epochs):
742 | logger.info(f"epoch {epoch+1}/{ cfg.hparams.num_train_epochs}")
743 | transformer.train()
744 |
745 | for step, batch in enumerate(train_dataloader):
746 | models_to_accumulate = [transformer]
747 |
748 | with accelerator.accumulate(models_to_accumulate):
749 | latents, cond_latents = batch[cfg.data.latents_key], batch[cfg.data.latents_cond_key]
750 |
751 | latents = latents.to(accelerator.device, dtype=weight_dtype).contiguous()
752 | cond_latents = cond_latents.to(accelerator.device, dtype=weight_dtype).contiguous()
753 |
754 | batch_size = latents.size(0)
755 |
756 | prompt_embeds = batch[cfg.data.prompt_embeds_key].to(accelerator.device, dtype=weight_dtype)
757 | pooled_prompt_embeds = batch[cfg.data.pooled_prompt_embeds_key].to(
758 | accelerator.device, dtype=weight_dtype
759 | )
760 | prompt_attention_mask = batch[cfg.data.prompt_attention_mask_key].to(
761 | accelerator.device, dtype=torch.bool
762 | )
763 |
764 | if random.random() < cfg.hparams.caption_dropout_p:
765 | prompt_embeds.fill_(0)
766 | pooled_prompt_embeds.fill_(0)
767 | prompt_attention_mask.fill_(False)
768 |
769 | noise = torch.randn(latents.shape, device=accelerator.device, dtype=weight_dtype, generator=generator)
770 |
771 | noisy_model_input, timesteps = get_noisy_model_input_and_timesteps(
772 | cfg=cfg,
773 | latents=latents,
774 | noise=noise,
775 | noise_scheduler=noise_scheduler,
776 | device=accelerator.device,
777 | weight_dtype=weight_dtype,
778 | generator=generator,
779 | scheduler_sigmas=scheduler_sigmas,
780 | )
781 | noisy_model_input = noisy_model_input.to(weight_dtype)
782 |
783 | weighting = compute_loss_weighting_for_sd3(
784 | cfg.hparams.flow_match.weighting_scheme, sigmas=scheduler_sigmas
785 | )
786 | while len(weighting.shape) < latents.ndim:
787 | weighting = weighting.unsqueeze(-1)
788 | guidance_vec = (
789 | torch.full((batch_size,), float(cfg.hparams.guidance_scale), device=accelerator.device) * 1000.0
790 | )
791 |
792 | ic(noisy_model_input.shape, cond_latents.shape)
793 | ic(
794 | step,
795 | guidance_vec,
796 | weighting,
797 | timesteps,
798 | prompt_embeds.shape,
799 | prompt_attention_mask.shape,
800 | pooled_prompt_embeds.shape,
801 | )
802 |
803 | denoised_latents = transformer(
804 | hidden_states=torch.cat([noisy_model_input, cond_latents], dim=1),
805 | timestep=timesteps,
806 | encoder_hidden_states=prompt_embeds,
807 | pooled_projections=pooled_prompt_embeds,
808 | encoder_attention_mask=prompt_attention_mask,
809 | guidance=guidance_vec,
810 | return_dict=False,
811 | )[0]
812 |
813 | target = noise - latents
814 | loss = torch.nn.functional.mse_loss(denoised_latents.to(weight_dtype), target, reduction="none")
815 |
816 | if weighting is not None:
817 | loss = loss * weighting
818 |
819 | loss = loss.mean()
820 |
821 | assert torch.isnan(loss) == False, "NaN loss detected"
822 |
823 | accelerator.backward(loss)
824 |
825 | if cfg.hparams.gradient_precision == "fp32":
826 | for param in transformer_lora_parameters:
827 | if param.grad is not None:
828 | param.grad.data = param.grad.data.to(torch.float32)
829 |
830 | grad_norm = max_gradient(transformer_lora_parameters)
831 | if accelerator.sync_gradients:
832 | if accelerator.distributed_type == DistributedType.DEEPSPEED:
833 | grad_norm = transformer.get_global_grad_norm()
834 |
835 | elif cfg.hparams.max_grad_norm > 0:
836 | if cfg.hparams.grad_clip_method == "norm":
837 | grad_norm = accelerator.clip_grad_norm_(
838 | transformer_lora_parameters, cfg.hparams.max_grad_norm
839 | )
840 | elif cfg.hparams.grad_clip_method == "value":
841 | grad_norm = accelerator.clip_grad_value_(
842 | transformer_lora_parameters, cfg.hparams.max_grad_norm
843 | )
844 |
845 | if torch.is_tensor(grad_norm):
846 | grad_norm = grad_norm.item()
847 |
848 | optimizer.step()
849 | optimizer.zero_grad(set_to_none=True)
850 | lr_scheduler.step()
851 |
852 | if accelerator.sync_gradients:
853 | progress_bar.update(1)
854 | global_step += 1
855 |
856 | if cfg.hparams.ema.use_ema and ema_model is not None:
857 | ema_model.step(parameters=transformer_lora_parameters, global_step=global_step)
858 |
859 | if accelerator.is_main_process:
860 | if global_step % cfg.checkpointing.save_every_n_steps == 0:
861 | save_path = os.path.join(output_dirpath, f"checkpoint-step{global_step:08d}")
862 | accelerator.save_state(save_path)
863 | logger.info(f"Saved state to {save_path!r}")
864 |
865 | if cfg.checkpointing.save_last_n_steps is not None:
866 | remove_step_no = global_step - cfg.checkpointing.save_last_n_steps - 1
867 | remove_step_no = remove_step_no - (remove_step_no % cfg.checkpointing.save_every_n_steps)
868 | if remove_step_no < 0:
869 | remove_step_no = None
870 | if remove_step_no is not None:
871 | remove_ckpt_name = os.path.join(output_dirpath, f"checkpoint-step{remove_step_no:08d}")
872 | if os.path.exists(remove_ckpt_name):
873 | logger.info(f"removing old checkpoint: {remove_ckpt_name!r}")
874 | shutil.rmtree(remove_ckpt_name)
875 |
876 | logs = {}
877 | logs["loss"] = accelerator.reduce(loss.detach().clone(), reduction="mean").item()
878 | logs["grad_norm"] = grad_norm
879 | logs["lr"] = lr_scheduler.get_last_lr()[0]
880 | if ema_model is not None:
881 | logs["ema_decay"] = ema_model.get_decay()
882 | progress_bar.set_postfix(**logs)
883 | accelerator.log(logs, step=global_step)
884 |
885 | if global_step >= cfg.hparams.max_train_steps:
886 | logger.info(f"max training steps={cfg.hparams.max_train_steps!r} reached.")
887 | break
888 |
889 | memory_statistics = get_memory_statistics()
890 | logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
891 |
892 | accelerator.wait_for_everyone()
893 | if accelerator.is_main_process:
894 | if cfg.hparams.ema.use_ema and ema_model is not None:
895 | ema_model.copy_to(transformer_lora_parameters)
896 |
897 | transformer = unwrap_model(transformer)
898 | transformer_lora_layers = get_peft_model_state_dict(transformer)
899 | HunyuanVideoPipeline.save_lora_weights(
900 | output_dirpath,
901 | transformer_lora_layers=transformer_lora_layers,
902 | safe_serialization=True,
903 | weight_name=f"{cfg.experiment.name}.sft",
904 | )
905 | accelerator.wait_for_everyone()
906 |
907 | memory_statistics = get_memory_statistics()
908 | logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
909 | accelerator.end_training()
910 |
911 |
912 | if __name__ == "__main__":
913 | main()
914 |
--------------------------------------------------------------------------------
/mds_dataloaders.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from io import BytesIO
3 | from pathlib import Path
4 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
5 |
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from streaming import Stream, StreamingDataLoader, StreamingDataset
10 | from torch.utils.data import DataLoader
11 |
12 |
13 | def make_streams(remote, local=None, proportion=None, repeat=None, choose=None):
14 | """Helper function to create a list of Stream objects from a set of remotes and stream weights.
15 |
16 | Args:
17 | remote (Union[str, Sequence[str]]): The remote path or paths to stream from.
18 | local (Union[str, Sequence[str]], optional): The local path or paths to cache the data. If not provided, the
19 | default local path is used. Default: ``None``.
20 | proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``.
21 | repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``.
22 | choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``.
23 |
24 | Returns:
25 | List[Stream]: A list of Stream objects.
26 | """
27 | remote, local = _make_remote_and_local_sequences(remote, local)
28 | proportion, repeat, choose = _make_weighting_sequences(remote, proportion, repeat, choose)
29 |
30 | streams = []
31 | for i, (r, l) in enumerate(zip(remote, local)):
32 | streams.append(Stream(remote=r, local=l, proportion=proportion[i], repeat=repeat[i], choose=choose[i]))
33 | return streams
34 |
35 |
36 | def _make_remote_and_local_sequences(remote, local=None):
37 | if isinstance(remote, str):
38 | remote = [remote]
39 | if isinstance(local, str):
40 | local = [local]
41 | if not local:
42 | local = [_make_default_local_path(r) for r in remote]
43 |
44 | if isinstance(remote, Sequence) and isinstance(local, Sequence):
45 | if len(remote) != len(local):
46 | ValueError(
47 | f"remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}"
48 | )
49 | else:
50 | ValueError(f"remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.")
51 | return remote, local
52 |
53 |
54 | def _make_default_local_path(remote_path):
55 | return str(Path(*["/tmp"] + list(Path(remote_path).parts[1:])))
56 |
57 |
58 | def _make_weighting_sequences(remote, proportion=None, repeat=None, choose=None):
59 | weights = {"proportion": proportion, "repeat": repeat, "choose": choose}
60 | for name, weight in weights.items():
61 | if weight is not None and len(remote) != len(weight):
62 | ValueError(f"{name} must be the same length as remote, got lengths {len(remote)} and {len(weight)}")
63 | proportion = weights["proportion"] if weights["proportion"] is not None else [None] * len(remote)
64 | repeat = weights["repeat"] if weights["repeat"] is not None else [None] * len(remote)
65 | choose = weights["choose"] if weights["choose"] is not None else [None] * len(remote)
66 | return proportion, repeat, choose
67 |
68 |
69 | class StreamingVideoCaptionLatentsDataset(StreamingDataset):
70 | def __init__(
71 | self,
72 | streams: Sequence[Stream],
73 | video_key: str = "video",
74 | caption_key: str = "caption",
75 | latents_key: str = "latents",
76 | latents_cond_key: str = "latents_cond",
77 | prompt_embeds_key: str = "prompt_embeds",
78 | prompt_attention_mask_key: str = "prompt_attention_mask",
79 | pooled_prompt_embeds_key: str = "pooled_prompt_embeds",
80 | latent_dtype: torch.dtype = torch.bfloat16,
81 | batch_size: int = None,
82 | **streaming_kwargs,
83 | ):
84 | streaming_kwargs.setdefault("shuffle_block_size", 1 << 18)
85 | streaming_kwargs.setdefault("shuffle_algo", "py1s")
86 | super().__init__(streams=streams, batch_size=batch_size, **streaming_kwargs)
87 |
88 | self.video_key = video_key
89 | self.caption_key = caption_key
90 | self.latents_key = latents_key
91 | self.prompt_embeds_key = prompt_embeds_key
92 | self.latents_cond_key = latents_cond_key
93 | self.prompt_attention_mask_key = prompt_attention_mask_key
94 | self.pooled_prompt_embeds_key = pooled_prompt_embeds_key
95 | self.latent_dtype = latent_dtype
96 |
97 | def __getitem__(self, index):
98 | sample = super().__getitem__(index)
99 |
100 | out = {}
101 |
102 | latents = torch.from_numpy(sample[self.latents_key].copy()).to(dtype=self.latent_dtype)
103 | latents_cond = torch.from_numpy(sample[self.latents_cond_key].copy()).to(dtype=self.latent_dtype)
104 |
105 | prompt_embeds = torch.from_numpy(sample[self.prompt_embeds_key].copy()).to(dtype=self.latent_dtype)
106 | pooled_prompt_embeds = torch.from_numpy(sample[self.pooled_prompt_embeds_key].copy()).to(
107 | dtype=self.latent_dtype
108 | )
109 | prompt_attention_mask = torch.from_numpy(sample[self.prompt_attention_mask_key].copy()).to(dtype=torch.bool)
110 |
111 | out[self.latents_key] = latents
112 | out[self.latents_cond_key] = latents_cond
113 | out[self.prompt_embeds_key] = prompt_embeds
114 | out[self.pooled_prompt_embeds_key] = pooled_prompt_embeds
115 | out[self.prompt_attention_mask_key] = prompt_attention_mask
116 |
117 | return out
118 |
119 |
120 | def build_mds_dataloader(
121 | remote: Union[str, List],
122 | local: Union[str, List],
123 | batch_size: int,
124 | video_key: str = "video",
125 | caption_key: str = "caption",
126 | latents_key: str = "latents",
127 | latents_cond_key: str = "latents_cond",
128 | prompt_embeds_key: str = "prompt_embeds",
129 | prompt_attention_mask_key: str = "prompt_attention_mask",
130 | pooled_prompt_embeds_key: str = "pooled_prompt_embeds",
131 | latent_dtype: torch.dtype = torch.bfloat16,
132 | proportion: Optional[list] = None,
133 | repeat: Optional[list] = None,
134 | choose: Optional[list] = None,
135 | streaming_kwargs: Optional[Dict] = None,
136 | dataloader_kwargs: Optional[Dict] = None,
137 | ):
138 | if streaming_kwargs is None:
139 | streaming_kwargs = {}
140 | if dataloader_kwargs is None:
141 | dataloader_kwargs = {}
142 |
143 | # streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose)
144 | if isinstance(local, str):
145 | local = [local]
146 | streams = [Stream(local=l) for l in local]
147 |
148 | dataset = StreamingVideoCaptionLatentsDataset(
149 | streams=streams,
150 | video_key=video_key,
151 | caption_key=caption_key,
152 | latents_key=latents_key,
153 | latents_cond_key=latents_cond_key,
154 | prompt_embeds_key=prompt_embeds_key,
155 | prompt_attention_mask_key=prompt_attention_mask_key,
156 | pooled_prompt_embeds_key=pooled_prompt_embeds_key,
157 | latent_dtype=latent_dtype,
158 | batch_size=batch_size,
159 | **streaming_kwargs,
160 | )
161 |
162 | dataloader = StreamingDataLoader(
163 | dataset=dataset,
164 | batch_size=batch_size,
165 | sampler=None,
166 | **dataloader_kwargs,
167 | )
168 |
169 | return dataloader
170 |
--------------------------------------------------------------------------------
/notebooks/000-inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%env CUDA_VISIBLE_DEVICES=0\n",
10 | "%env HF_HUB_ENABLE_HF_TRANSFER=1"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "import sys\n",
21 | "import time\n",
22 | "import uuid\n",
23 | "from datetime import datetime\n",
24 | "from typing import *\n",
25 | "\n",
26 | "import av\n",
27 | "import cv2\n",
28 | "import diffusers\n",
29 | "import numpy as np\n",
30 | "import safetensors.torch\n",
31 | "import torch\n",
32 | "import torch.nn.functional as F\n",
33 | "import torchvision.transforms.v2 as transforms\n",
34 | "from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline\n",
35 | "from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\n",
36 | "from diffusers.loaders import HunyuanVideoLoraLoaderMixin\n",
37 | "from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel\n",
38 | "from diffusers.models.attention import Attention\n",
39 | "from diffusers.models.embeddings import apply_rotary_emb\n",
40 | "from diffusers.models.transformers.transformer_hunyuan_video import *\n",
41 | "from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel\n",
42 | "from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps\n",
43 | "from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput\n",
44 | "from diffusers.pipelines.pipeline_utils import DiffusionPipeline\n",
45 | "from diffusers.schedulers import FlowMatchEulerDiscreteScheduler\n",
46 | "from diffusers.utils import export_to_video, is_torch_xla_available, logging, replace_example_docstring\n",
47 | "from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft\n",
48 | "from diffusers.utils.torch_utils import randn_tensor\n",
49 | "from diffusers.video_processor import VideoProcessor\n",
50 | "from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\n",
51 | "from PIL import Image"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "print(\"\\nENVIRONMENT\\n\")\n",
61 | "print(f\" Python {sys.version}\\n\")\n",
62 | "print(f\" torch.__version__ = {torch.__version__}\")\n",
63 | "print(f\" torch.version.cuda = {torch.version.cuda}\")\n",
64 | "print(f\" torch.backends.cudnn.version() = {torch.backends.cudnn.version()}\")\n",
65 | "print(f\" diffusers.__version__ = {diffusers.__version__}\")"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "try:\n",
75 | " import flash_attn\n",
76 | " from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func\n",
77 | "except ImportError:\n",
78 | " flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None\n",
79 | "\n",
80 | "try:\n",
81 | " from sageattention import sageattn, sageattn_varlen\n",
82 | "except ImportError:\n",
83 | " sageattn, sageattn_varlen = None, None\n",
84 | "\n",
85 | "\n",
86 | "def get_cu_seqlens(attention_mask):\n",
87 | " \"\"\"Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len\n",
88 | "\n",
89 | " Args:\n",
90 | " text_mask (torch.Tensor): the mask of text\n",
91 | " img_len (int): the length of image\n",
92 | "\n",
93 | " Returns:\n",
94 | " torch.Tensor: the calculated cu_seqlens for flash attention\n",
95 | " \"\"\"\n",
96 | " batch_size = attention_mask.shape[0]\n",
97 | " text_len = attention_mask.sum(dim=-1, dtype=torch.int) # .flatten()\n",
98 | " max_len = attention_mask.shape[-1]\n",
99 | "\n",
100 | " cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=\"cuda\")\n",
101 | "\n",
102 | " for i in range(batch_size):\n",
103 | " s = text_len[i]\n",
104 | " s1 = i * max_len + s\n",
105 | " s2 = (i + 1) * max_len\n",
106 | " cu_seqlens[2 * i + 1] = s1\n",
107 | " cu_seqlens[2 * i + 2] = s2\n",
108 | "\n",
109 | " return cu_seqlens\n",
110 | "\n",
111 | "\n",
112 | "class HunyuanVideoFlashAttnProcessor:\n",
113 | " def __init__(self, use_flash_attn=True, use_sageattn=False):\n",
114 | " self.use_flash_attn = use_flash_attn\n",
115 | " self.use_sageattn = use_sageattn\n",
116 | " if self.use_flash_attn:\n",
117 | " assert flash_attn is not None\n",
118 | " if self.use_sageattn:\n",
119 | " assert sageattn is not None\n",
120 | "\n",
121 | " def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):\n",
122 | " if attn.add_q_proj is None and encoder_hidden_states is not None:\n",
123 | " hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)\n",
124 | "\n",
125 | " query = attn.to_q(hidden_states)\n",
126 | " key = attn.to_k(hidden_states)\n",
127 | " value = attn.to_v(hidden_states)\n",
128 | "\n",
129 | " query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
130 | " key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
131 | " value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
132 | "\n",
133 | " if attn.norm_q is not None:\n",
134 | " query = attn.norm_q(query)\n",
135 | " if attn.norm_k is not None:\n",
136 | " key = attn.norm_k(key)\n",
137 | "\n",
138 | " if image_rotary_emb is not None:\n",
139 | " if attn.add_q_proj is None and encoder_hidden_states is not None:\n",
140 | " query = torch.cat([apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), query[:, :, -encoder_hidden_states.shape[1] :]], dim=2)\n",
141 | " key = torch.cat([apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), key[:, :, -encoder_hidden_states.shape[1] :]], dim=2)\n",
142 | " else:\n",
143 | " query = apply_rotary_emb(query, image_rotary_emb)\n",
144 | " key = apply_rotary_emb(key, image_rotary_emb)\n",
145 | "\n",
146 | " batch_size = hidden_states.shape[0]\n",
147 | " img_seq_len = hidden_states.shape[1]\n",
148 | " txt_seq_len = 0\n",
149 | "\n",
150 | " if attn.add_q_proj is not None and encoder_hidden_states is not None:\n",
151 | " encoder_query = attn.add_q_proj(encoder_hidden_states)\n",
152 | " encoder_key = attn.add_k_proj(encoder_hidden_states)\n",
153 | " encoder_value = attn.add_v_proj(encoder_hidden_states)\n",
154 | "\n",
155 | " encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
156 | " encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
157 | " encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)\n",
158 | "\n",
159 | " if attn.norm_added_q is not None:\n",
160 | " encoder_query = attn.norm_added_q(encoder_query)\n",
161 | " if attn.norm_added_k is not None:\n",
162 | " encoder_key = attn.norm_added_k(encoder_key)\n",
163 | "\n",
164 | " query = torch.cat([query, encoder_query], dim=2)\n",
165 | " key = torch.cat([key, encoder_key], dim=2)\n",
166 | " value = torch.cat([value, encoder_value], dim=2)\n",
167 | "\n",
168 | " txt_seq_len = encoder_hidden_states.shape[1]\n",
169 | "\n",
170 | " max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len\n",
171 | " cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask)\n",
172 | "\n",
173 | " query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3])\n",
174 | " key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3])\n",
175 | " value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3])\n",
176 | "\n",
177 | " if self.use_flash_attn:\n",
178 | " hidden_states = flash_attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)\n",
179 | " elif self.use_sageattn:\n",
180 | " hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)\n",
181 | " else:\n",
182 | " raise NotImplementedError(\"Please set use_flash_attn=True or use_sageattn=True\")\n",
183 | "\n",
184 | " hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1)\n",
185 | " hidden_states = hidden_states.to(query.dtype)\n",
186 | "\n",
187 | " if encoder_hidden_states is not None:\n",
188 | " hidden_states, encoder_hidden_states = (hidden_states[:, : -encoder_hidden_states.shape[1]], hidden_states[:, -encoder_hidden_states.shape[1] :])\n",
189 | "\n",
190 | " if getattr(attn, \"to_out\", None) is not None:\n",
191 | " hidden_states = attn.to_out[0](hidden_states)\n",
192 | " hidden_states = attn.to_out[1](hidden_states)\n",
193 | "\n",
194 | " if getattr(attn, \"to_add_out\", None) is not None:\n",
195 | " encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n",
196 | "\n",
197 | " return hidden_states, encoder_hidden_states"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": null,
203 | "metadata": {},
204 | "outputs": [],
205 | "source": [
206 | "def save_video(video, output_path=\"output.mp4\"):\n",
207 | " width, height = video[0].size\n",
208 | "\n",
209 | " container = av.open(output_path, mode=\"w\")\n",
210 | "\n",
211 | " # create video stream\n",
212 | " codec = \"libx264\"\n",
213 | " pixel_format = \"yuv420p\"\n",
214 | " stream = container.add_stream(codec, rate=24)\n",
215 | " stream.width = width\n",
216 | " stream.height = height\n",
217 | " stream.pix_fmt = pixel_format\n",
218 | " stream.bit_rate = 4000000 # 4Mbit/s\n",
219 | "\n",
220 | " for frame_array in video:\n",
221 | " frame = av.VideoFrame.from_image(frame_array)\n",
222 | " packets = stream.encode(frame)\n",
223 | " for packet in packets:\n",
224 | " container.mux(packet)\n",
225 | "\n",
226 | " for packet in stream.encode():\n",
227 | " container.mux(packet)\n",
228 | "\n",
229 | " container.close()"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "pipe = HunyuanVideoPipeline.from_pretrained(\"hunyuanvideo-community/HunyuanVideo\", torch_dtype=torch.bfloat16)\n",
239 | "pipe.to(\"cuda\")\n",
240 | "pipe.vae.enable_slicing()\n",
241 | "pipe.vae.enable_tiling()\n",
242 | "\n",
243 | "for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks:\n",
244 | " block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=True, use_sageattn=False)"
245 | ]
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "with torch.no_grad():\n",
254 | " initial_input_channels = pipe.transformer.config.in_channels\n",
255 | " new_img_in = HunyuanVideoPatchEmbed(\n",
256 | " patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size),\n",
257 | " in_chans=pipe.transformer.config.in_channels * 2,\n",
258 | " embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim,\n",
259 | " )\n",
260 | " new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype)\n",
261 | " new_img_in.proj.weight.zero_()\n",
262 | " new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight)\n",
263 | "\n",
264 | " if pipe.transformer.x_embedder.proj.bias is not None:\n",
265 | " new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias)\n",
266 | "\n",
267 | " pipe.transformer.x_embedder = new_img_in\n",
268 | " pipe.transformer.x_embedder.requires_grad_(False)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": null,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "LORA_PATH = \"/mnt/data/ayushman/projects/output/i2v_outputs/exp006-resume-from-9500/i2v.sft\"\n",
278 | "\n",
279 | "# ------ load lora\n",
280 | "lora_state_dict = pipe.lora_state_dict(LORA_PATH)\n",
281 | "transformer_lora_state_dict = {f'{k.replace(\"transformer.\", \"\")}': v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\") and \"lora\" in k}\n",
282 | "pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name=\"i2v\", _pipeline=pipe)\n",
283 | "pipe.set_adapters([\"i2v\"], adapter_weights=[1.0])\n",
284 | "pipe.fuse_lora(components=[\"transformer\"], lora_scale=1.0, adapter_names=[\"i2v\"])\n",
285 | "pipe.unload_lora_weights()\n",
286 | "\n",
287 | "# -------- load norm layers\n",
288 | "NORM_LAYER_PREFIXES = [\"norm_q\", \"norm_k\", \"norm_added_q\", \"norm_added_k\"]\n",
289 | "transformer_norm_layers_state_dict = {\n",
290 | " f'{k.replace(\"transformer.\", \"\")}': v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)\n",
291 | "}\n",
292 | "if len(transformer_norm_layers_state_dict) == 0:\n",
293 | " print(\"[INFO] No normalization layers found in state dict. Skipping loading normalization layers....\")\n",
294 | "else:\n",
295 | " print(\"[INFO] Extracting normalization layers from state dict and loading them into the transformer....\")\n",
296 | " for key in list(transformer_norm_layers_state_dict.keys()):\n",
297 | " if key.split(\".\")[0] == \"transformer\":\n",
298 | " transformer_norm_layers_state_dict[key[len(f\"transformer.\") :]] = transformer_norm_layers_state_dict.pop(key)\n",
299 | " transformer_state_dict = pipe.transformer.state_dict()\n",
300 | " transformer_keys = set(transformer_state_dict.keys())\n",
301 | " state_dict_keys = set(transformer_norm_layers_state_dict.keys())\n",
302 | " extra_keys = list(state_dict_keys - transformer_keys)\n",
303 | " if extra_keys:\n",
304 | " print(f\"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\\n{extra_keys}.\")\n",
305 | " for key in extra_keys:\n",
306 | " transformer_norm_layers_state_dict.pop(key)\n",
307 | " incompatible_keys = pipe.transformer.load_state_dict(transformer_norm_layers_state_dict, strict=False)\n",
308 | " unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n",
309 | " if any(norm_key in k for k in unexpected_keys for norm_key in NORM_LAYER_PREFIXES):\n",
310 | " print(f\"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer.\")"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {},
317 | "outputs": [],
318 | "source": [
319 | "video_transforms = transforms.Compose(\n",
320 | " [\n",
321 | " transforms.Lambda(lambda x: x / 255.0),\n",
322 | " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),\n",
323 | " ]\n",
324 | ")\n",
325 | "\n",
326 | "\n",
327 | "def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:\n",
328 | " \"\"\"\n",
329 | " Resize the image to the bucket resolution.\n",
330 | " \"\"\"\n",
331 | " is_pil_image = isinstance(image, Image.Image)\n",
332 | " if is_pil_image:\n",
333 | " image_width, image_height = image.size\n",
334 | " else:\n",
335 | " image_height, image_width = image.shape[:2]\n",
336 | "\n",
337 | " if bucket_reso == (image_width, image_height):\n",
338 | " return np.array(image) if is_pil_image else image\n",
339 | "\n",
340 | " bucket_width, bucket_height = bucket_reso\n",
341 | "\n",
342 | " scale_width = bucket_width / image_width\n",
343 | " scale_height = bucket_height / image_height\n",
344 | " scale = max(scale_width, scale_height)\n",
345 | " image_width = int(image_width * scale + 0.5)\n",
346 | " image_height = int(image_height * scale + 0.5)\n",
347 | "\n",
348 | " if scale > 1:\n",
349 | " image = Image.fromarray(image) if not is_pil_image else image\n",
350 | " image = image.resize((image_width, image_height), Image.LANCZOS)\n",
351 | " image = np.array(image)\n",
352 | " else:\n",
353 | " image = np.array(image) if is_pil_image else image\n",
354 | " image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)\n",
355 | "\n",
356 | " # crop the image to the bucket resolution\n",
357 | " crop_left = (image_width - bucket_width) // 2\n",
358 | " crop_top = (image_height - bucket_height) // 2\n",
359 | " image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]\n",
360 | "\n",
361 | " return image"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": null,
367 | "metadata": {},
368 | "outputs": [],
369 | "source": [
370 | "# n_frames, height, width = 77, 960, 544\n",
371 | "# prompt = 'a woman moving her head'\n",
372 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/dvqdv.png\"\n",
373 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/F11024.png\"\n",
374 | "\n",
375 | "n_frames, height, width = 77, 720, 1280\n",
376 | "prompt = \"A man and a woman sit at a cozy café table in the warm midday light. She wears a vibrant green dress, and he wears a simple black T-shirt. Their table is adorned with fresh salads, small plates of appetizers, and a sparkling beverage in elegant glassware. They chat casually, occasionally glancing at their phones or laughing together. Large windows in the background reveal a bustling street, where people pass by under soft lanterns. The camera begins with a medium-wide shot, capturing the relaxed ambiance and lively chatter of nearby patrons. It then transitions to a closer view, highlighting the woman's bright smile and the man's attentive expression. Their voices blend with the gentle hum of the café, as they enjoy a pleasant lunch and each other's company.\"\n",
377 | "p1 = \"/mnt/data/ayushman/projects/input/img_samples/009-F1.png\"\n",
378 | "p2 = \"/mnt/data/ayushman/projects/input/img_samples/009-F2.png\"\n",
379 | "\n",
380 | "\n",
381 | "# n_frames, height, width = 73, 960, 544\n",
382 | "# prompt = 'a man'\n",
383 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/3.3.PNG\"\n",
384 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/4.1 (1).PNG\"\n",
385 | "\n",
386 | "\n",
387 | "# n_frames, height, width = 73, 544, 960\n",
388 | "# prompt = 'realistic style, a man walking out of a cave'\n",
389 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/3015505-hd_1920_1080_24fps-Scene-001-01.jpg\"\n",
390 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/3015505-hd_1920_1080_24fps-Scene-001-02.jpg\"\n",
391 | "\n",
392 | "# n_frames, height, width = 73, 544, 960\n",
393 | "# prompt = 'a whimsical milk cartoon dancing'\n",
394 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-03.jpg\"\n",
395 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-01.jpg\"\n",
396 | "\n",
397 | "# n_frames, height, width = 73, 544, 960\n",
398 | "# prompt = 'a whimsical milk cartoon dancing'\n",
399 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-03.jpg\"\n",
400 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/158387-816637360_small-Scene-001-01.jpg\"\n",
401 | "\n",
402 | "# n_frames, height, width = 73, 960, 544\n",
403 | "# prompt = \"a woman\"\n",
404 | "# p1 = \"/mnt/data/ayushman/projects/input/img_samples/F14866.png\"\n",
405 | "# p2 = \"/mnt/data/ayushman/projects/input/img_samples/F14920.png\"\n",
406 | "\n",
407 | "assert os.path.exists(p1)\n",
408 | "assert os.path.exists(p2)\n",
409 | "\n",
410 | "cond_frame1 = Image.open(p1).convert(\"RGB\")\n",
411 | "cond_frame2 = Image.open(p2).convert(\"RGB\")\n",
412 | "# show_images([cond_frame1, cond_frame2], titles=[\"cond_frame1\", \"cond_frame2\"], imsize=4)\n",
413 | "# plt.show()"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "execution_count": null,
419 | "metadata": {},
420 | "outputs": [],
421 | "source": [
422 | "cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(width, height))\n",
423 | "cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(width, height))\n",
424 | "\n",
425 | "cond_video = np.zeros(shape=(n_frames, height, width, 3))\n",
426 | "cond_video[0], cond_video[-1] = np.array(cond_frame1), np.array(cond_frame2)\n",
427 | "\n",
428 | "cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2)\n",
429 | "cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0)\n",
430 | "\n",
431 | "with torch.inference_mode():\n",
432 | " image_or_video = cond_video.to(device=\"cuda\", dtype=pipe.dtype)\n",
433 | " image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]\n",
434 | " cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample()\n",
435 | " cond_latents = cond_latents * pipe.vae.config.scaling_factor\n",
436 | " cond_latents = cond_latents.to(dtype=pipe.dtype)\n",
437 | " assert not torch.any(torch.isnan(cond_latents))"
438 | ]
439 | },
440 | {
441 | "cell_type": "code",
442 | "execution_count": null,
443 | "metadata": {},
444 | "outputs": [],
445 | "source": [
446 | "@torch.inference_mode()\n",
447 | "def call_pipe(\n",
448 | " pipe,\n",
449 | " prompt: Union[str, List[str]] = None,\n",
450 | " prompt_2: Union[str, List[str]] = None,\n",
451 | " height: int = 720,\n",
452 | " width: int = 1280,\n",
453 | " num_frames: int = 129,\n",
454 | " num_inference_steps: int = 50,\n",
455 | " sigmas: List[float] = None,\n",
456 | " guidance_scale: float = 6.0,\n",
457 | " num_videos_per_prompt: Optional[int] = 1,\n",
458 | " generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n",
459 | " latents: Optional[torch.Tensor] = None,\n",
460 | " prompt_embeds: Optional[torch.Tensor] = None,\n",
461 | " pooled_prompt_embeds: Optional[torch.Tensor] = None,\n",
462 | " prompt_attention_mask: Optional[torch.Tensor] = None,\n",
463 | " output_type: Optional[str] = \"pil\",\n",
464 | " return_dict: bool = True,\n",
465 | " attention_kwargs: Optional[Dict[str, Any]] = None,\n",
466 | " callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,\n",
467 | " callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n",
468 | " prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,\n",
469 | " max_sequence_length: int = 256,\n",
470 | " image_latents: Optional[torch.Tensor] = None,\n",
471 | "):\n",
472 | "\n",
473 | " if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n",
474 | " callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n",
475 | "\n",
476 | " # 1. Check inputs. Raise error if not correct\n",
477 | " pipe.check_inputs(\n",
478 | " prompt,\n",
479 | " prompt_2,\n",
480 | " height,\n",
481 | " width,\n",
482 | " prompt_embeds,\n",
483 | " callback_on_step_end_tensor_inputs,\n",
484 | " prompt_template,\n",
485 | " )\n",
486 | "\n",
487 | " pipe._guidance_scale = guidance_scale\n",
488 | " pipe._attention_kwargs = attention_kwargs\n",
489 | " pipe._current_timestep = None\n",
490 | " pipe._interrupt = False\n",
491 | "\n",
492 | " device = pipe._execution_device\n",
493 | "\n",
494 | " # 2. Define call parameters\n",
495 | " if prompt is not None and isinstance(prompt, str):\n",
496 | " batch_size = 1\n",
497 | " elif prompt is not None and isinstance(prompt, list):\n",
498 | " batch_size = len(prompt)\n",
499 | " else:\n",
500 | " batch_size = prompt_embeds.shape[0]\n",
501 | "\n",
502 | " # 3. Encode input prompt\n",
503 | " prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(\n",
504 | " prompt=prompt,\n",
505 | " prompt_2=prompt_2,\n",
506 | " prompt_template=prompt_template,\n",
507 | " num_videos_per_prompt=num_videos_per_prompt,\n",
508 | " prompt_embeds=prompt_embeds,\n",
509 | " pooled_prompt_embeds=pooled_prompt_embeds,\n",
510 | " prompt_attention_mask=prompt_attention_mask,\n",
511 | " device=device,\n",
512 | " max_sequence_length=max_sequence_length,\n",
513 | " )\n",
514 | "\n",
515 | " transformer_dtype = pipe.transformer.dtype\n",
516 | " prompt_embeds = prompt_embeds.to(transformer_dtype)\n",
517 | " prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)\n",
518 | " if pooled_prompt_embeds is not None:\n",
519 | " pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)\n",
520 | "\n",
521 | " # 4. Prepare timesteps\n",
522 | " sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas\n",
523 | " timesteps, num_inference_steps = retrieve_timesteps(\n",
524 | " pipe.scheduler,\n",
525 | " num_inference_steps,\n",
526 | " device,\n",
527 | " sigmas=sigmas,\n",
528 | " )\n",
529 | "\n",
530 | " # 5. Prepare latent variables\n",
531 | " num_channels_latents = pipe.transformer.config.in_channels\n",
532 | " num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1\n",
533 | " latents = pipe.prepare_latents(\n",
534 | " batch_size * num_videos_per_prompt,\n",
535 | " num_channels_latents,\n",
536 | " height,\n",
537 | " width,\n",
538 | " num_latent_frames,\n",
539 | " torch.float32,\n",
540 | " device,\n",
541 | " generator,\n",
542 | " latents,\n",
543 | " )\n",
544 | "\n",
545 | " # 6. Prepare guidance condition\n",
546 | " guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0\n",
547 | "\n",
548 | " # 7. Denoising loop\n",
549 | " num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order\n",
550 | " pipe._num_timesteps = len(timesteps)\n",
551 | "\n",
552 | " with pipe.progress_bar(total=num_inference_steps) as progress_bar:\n",
553 | " for i, t in enumerate(timesteps):\n",
554 | " if pipe.interrupt:\n",
555 | " continue\n",
556 | "\n",
557 | " pipe._current_timestep = t\n",
558 | " latent_model_input = latents.to(transformer_dtype)\n",
559 | " timestep = t.expand(latents.shape[0]).to(latents.dtype)\n",
560 | "\n",
561 | " noise_pred = pipe.transformer(\n",
562 | " hidden_states=torch.cat([latent_model_input, image_latents], dim=1),\n",
563 | " timestep=timestep,\n",
564 | " encoder_hidden_states=prompt_embeds,\n",
565 | " encoder_attention_mask=prompt_attention_mask,\n",
566 | " pooled_projections=pooled_prompt_embeds,\n",
567 | " guidance=guidance,\n",
568 | " attention_kwargs=attention_kwargs,\n",
569 | " return_dict=False,\n",
570 | " )[0]\n",
571 | "\n",
572 | " # compute the previous noisy sample x_t -> x_t-1\n",
573 | " latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n",
574 | "\n",
575 | " if callback_on_step_end is not None:\n",
576 | " callback_kwargs = {}\n",
577 | " for k in callback_on_step_end_tensor_inputs:\n",
578 | " callback_kwargs[k] = locals()[k]\n",
579 | " callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs)\n",
580 | "\n",
581 | " latents = callback_outputs.pop(\"latents\", latents)\n",
582 | " prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n",
583 | "\n",
584 | " # call the callback, if provided\n",
585 | " if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0):\n",
586 | " progress_bar.update()\n",
587 | " pipe._current_timestep = None\n",
588 | "\n",
589 | " if not output_type == \"latent\":\n",
590 | " latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor\n",
591 | " video = pipe.vae.decode(latents, return_dict=False)[0]\n",
592 | " video = pipe.video_processor.postprocess_video(video, output_type=output_type)\n",
593 | " else:\n",
594 | " video = latents\n",
595 | "\n",
596 | " # Offload all models\n",
597 | " pipe.maybe_free_model_hooks()\n",
598 | "\n",
599 | " if not return_dict:\n",
600 | " return (video,)\n",
601 | "\n",
602 | " return HunyuanVideoPipelineOutput(frames=video)"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": null,
608 | "metadata": {},
609 | "outputs": [],
610 | "source": [
611 | "video = call_pipe(\n",
612 | " pipe,\n",
613 | " prompt=prompt,\n",
614 | " num_frames=n_frames,\n",
615 | " num_inference_steps=50,\n",
616 | " image_latents=cond_latents,\n",
617 | " width=width,\n",
618 | " height=height,\n",
619 | " guidance_scale=6.0,\n",
620 | " generator=torch.Generator(device=\"cuda\").manual_seed(0),\n",
621 | ").frames[0]"
622 | ]
623 | },
624 | {
625 | "cell_type": "code",
626 | "execution_count": null,
627 | "metadata": {},
628 | "outputs": [],
629 | "source": [
630 | "time_flag = datetime.fromtimestamp(time.time()).strftime(\"%Y%m%d_%H%M%S\")\n",
631 | "fp = f\"/mnt/data/ayushman/projects/output/samples/hv-CL-{height}x{width}x{n_frames}-{time_flag}.mp4\"\n",
632 | "print(fp)\n",
633 | "save_video(video, fp)"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": []
642 | }
643 | ],
644 | "metadata": {
645 | "kernelspec": {
646 | "display_name": "hunyuan_control_env",
647 | "language": "python",
648 | "name": "python3"
649 | },
650 | "language_info": {
651 | "codemirror_mode": {
652 | "name": "ipython",
653 | "version": 3
654 | },
655 | "file_extension": ".py",
656 | "mimetype": "text/x-python",
657 | "name": "python",
658 | "nbconvert_exporter": "python",
659 | "pygments_lexer": "ipython3",
660 | "version": "3.10.16"
661 | }
662 | },
663 | "nbformat": 4,
664 | "nbformat_minor": 2
665 | }
666 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from typing import Any, Dict
3 |
4 | import torch
5 | from accelerate.logging import get_logger
6 |
7 | try:
8 | from optimi import AdamW as OptimiAdamW
9 | from optimi import StableAdamW as OptimiStableAdamW
10 | except ImportError:
11 | OptimiAdamW, OptimiStableAdamW = None, None
12 |
13 | try:
14 | from bitsandbytes.optim import AdamW8bit, Lion8bit
15 | except ImportError:
16 | AdamW8bit, Lion8bit = None, None
17 |
18 | try:
19 | from came_pytorch import CAME
20 | except ImportError:
21 | CAME = None
22 |
23 | import ast
24 |
25 | logger = get_logger(__name__)
26 |
27 |
28 | OPTIMIZER_FUNC_TO_NAME = {
29 | "adam": torch.optim.Adam,
30 | "adamw": torch.optim.AdamW,
31 | "optimi-adamw": OptimiAdamW,
32 | "optimi-stableadamw": OptimiStableAdamW,
33 | "bnb-adamw8bit": AdamW8bit,
34 | "bnb-lion8bit": Lion8bit,
35 | "came": CAME,
36 | }
37 |
38 |
39 | def get_optimizer(
40 | params_to_optimize,
41 | optimizer_name: str = "adam",
42 | learning_rate: float = 1e-3,
43 | optimizer_args_str: str | None = None,
44 | use_deepspeed: bool = False,
45 | # use_cpu_offload_optimizer: bool = False,
46 | # offload_gradients: bool = False,
47 | ) -> torch.optim.Optimizer:
48 | optimizer_kwargs = {}
49 |
50 | if optimizer_args_str is not None and len(optimizer_args_str) > 0:
51 | for arg in optimizer_args_str:
52 | key, value = arg.split("=")
53 | value = ast.literal_eval(value)
54 | optimizer_kwargs[key] = value
55 |
56 | optimizer_name = optimizer_name.lower()
57 |
58 | if use_deepspeed:
59 | from accelerate.utils import DummyOptim
60 |
61 | return DummyOptim(params_to_optimize, lr=learning_rate, **optimizer_kwargs)
62 |
63 | assert optimizer_name in OPTIMIZER_FUNC_TO_NAME, f"Unknown optimizer: {optimizer_name!r}"
64 |
65 | optimizer_class = OPTIMIZER_FUNC_TO_NAME[optimizer_name]
66 | assert optimizer_class is not None
67 |
68 | optimizer = optimizer_class(params_to_optimize, lr=learning_rate, **optimizer_kwargs)
69 |
70 | logger.info(f"Use {optimizer.__class__.__name__!r} | {optimizer_kwargs!r}")
71 | return optimizer
72 |
73 |
74 | def gradient_norm(parameters):
75 | norm = 0
76 | for param in parameters:
77 | if param.grad is None:
78 | continue
79 | local_norm = param.grad.detach().data.norm(2)
80 | norm += local_norm.item() ** 2
81 | norm = norm**0.5
82 | return norm
83 |
84 |
85 | def max_gradient(parameters):
86 | max_grad_value = float("-inf")
87 | for param in parameters:
88 | if param.grad is None:
89 | continue
90 | local_max_grad = param.grad.detach().data.abs().max()
91 | max_grad_value = max(max_grad_value, local_max_grad.item())
92 | return max_grad_value
93 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | profile = "black"
3 | line_length = 120
4 | force_grid_wrap = 0
5 | use_parentheses = true
6 | ensure_newline_before_comments = true
7 |
8 | [tool.black]
9 | line-length = 120
10 | exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist|_notebooks/.*)"
11 |
--------------------------------------------------------------------------------
/setup_env.sh:
--------------------------------------------------------------------------------
1 | #!/user/bin/env bash
2 | conda create -n hunyuan_control_env python=3.10 -y && conda activate hunyuan_control_env
3 | conda install -c nvidia/label/cuda-12.4.0 cuda-toolkit cuda -y
4 |
5 | export CUDA_HOME=$(dirname $(dirname $(which nvcc)))
6 | export PATH=$CUDA_HOME/bin:$PATH
7 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
8 |
9 | python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
10 | python -m pip install ninja
11 | python -m pip install --verbose --upgrade git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3
12 |
13 | wget "https://raw.githubusercontent.com/kohya-ss/musubi-tuner/refs/heads/main/requirements.txt" -O requirements.txt
14 | python -m pip install -r requirements.txt && rm requirements.txt
15 |
16 | python -m pip install accelerate==1.2.1 transformers==4.46.3 bitsandbytes==0.45.2 decord==0.6.0 deepspeed==0.16.3 opencv-python==4.10.0.84 pandas==2.2.3 peft==0.14.0 mosaicml-streaming==0.11.0 pyrallis==0.3.1 torch-optimi==0.2.1
17 | python -m pip install huggingface-hub hf_transfer
18 | python -m pip install --upgrade git+https://github.com/huggingface/diffusers@81440fd47493b9f9e817411ca0499d0bf06fde95
19 | python -m pip install icecream pre-commit
20 |
--------------------------------------------------------------------------------
/tools/hv_cache_dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing as mp
3 | import os
4 | import pprint
5 | import traceback
6 | from argparse import ArgumentParser
7 | from ast import literal_eval
8 | from pathlib import Path
9 | from typing import Tuple, Union
10 |
11 | import cv2
12 | import decord
13 | import numpy as np
14 | import pandas as pd
15 | from PIL import Image
16 | from streaming import MDSWriter
17 | from tqdm.auto import tqdm
18 |
19 | decord.bridge.set_bridge("native")
20 |
21 |
22 | def parse_args():
23 | parser = ArgumentParser()
24 | parser.add_argument("--csv", required=True, type=str)
25 | parser.add_argument("--base_dir", default=None, type=str)
26 | parser.add_argument("--video_column", default="video", type=str)
27 | parser.add_argument("--caption_column", default="caption", type=str)
28 | parser.add_argument("--output_dir", required=True, type=str)
29 | parser.add_argument("--bucket_reso", default=None, type=str, nargs="+")
30 | parser.add_argument("--debug", action="store_true", help="debug mode")
31 | parser.add_argument("--min_bucket_count", type=int, default=1)
32 | parser.add_argument("--head_frame", type=int, default=0)
33 |
34 | args = parser.parse_args()
35 | print("args:", pprint.pformat(args, sort_dicts=True, compact=True))
36 | return args
37 |
38 |
39 | def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
40 | """
41 | Resize the image to the bucket resolution.
42 | """
43 | is_pil_image = isinstance(image, Image.Image)
44 | if is_pil_image:
45 | image_width, image_height = image.size
46 | else:
47 | image_height, image_width = image.shape[:2]
48 |
49 | if bucket_reso == (image_width, image_height):
50 | return np.array(image) if is_pil_image else image
51 |
52 | bucket_width, bucket_height = bucket_reso
53 | if bucket_width == image_width or bucket_height == image_height:
54 | image = np.array(image) if is_pil_image else image
55 | else:
56 | # resize the image to the bucket resolution to match the short side
57 | scale_width = bucket_width / image_width
58 | scale_height = bucket_height / image_height
59 | scale = max(scale_width, scale_height)
60 | image_width = int(image_width * scale + 0.5)
61 | image_height = int(image_height * scale + 0.5)
62 |
63 | if scale > 1:
64 | image = Image.fromarray(image) if not is_pil_image else image
65 | image = image.resize((image_width, image_height), Image.LANCZOS)
66 | image = np.array(image)
67 | else:
68 | image = np.array(image) if is_pil_image else image
69 | image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
70 |
71 | # crop the image to the bucket resolution
72 | crop_left = (image_width - bucket_width) // 2
73 | crop_top = (image_height - bucket_height) // 2
74 | image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
75 | return image
76 |
77 |
78 | def divisible_by(num: int, divisor: int) -> int:
79 | return num - num % divisor
80 |
81 |
82 | class BucketSelector:
83 | RESOLUTION_STEPS_HUNYUAN = 16
84 |
85 | def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False):
86 | self.resolution = resolution
87 | self.bucket_area = resolution[0] * resolution[1]
88 | self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
89 |
90 | if not enable_bucket:
91 | # only define one bucket
92 | self.bucket_resolutions = [resolution]
93 | self.no_upscale = False
94 | else:
95 | # prepare bucket resolution
96 | self.no_upscale = no_upscale
97 | sqrt_size = int(math.sqrt(self.bucket_area))
98 | min_size = divisible_by(sqrt_size // 2, self.reso_steps)
99 | self.bucket_resolutions = []
100 | for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
101 | h = divisible_by(self.bucket_area // w, self.reso_steps)
102 | self.bucket_resolutions.append((w, h))
103 | self.bucket_resolutions.append((h, w))
104 |
105 | self.bucket_resolutions = list(set(self.bucket_resolutions))
106 | self.bucket_resolutions.sort()
107 |
108 | # calculate aspect ratio to find the nearest resolution
109 | self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
110 |
111 | def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
112 | """
113 | return the bucket resolution for the given image size, (width, height)
114 | """
115 | area = image_size[0] * image_size[1]
116 | if self.no_upscale and area <= self.bucket_area:
117 | w, h = image_size
118 | w = divisible_by(w, self.reso_steps)
119 | h = divisible_by(h, self.reso_steps)
120 | return w, h
121 |
122 | aspect_ratio = image_size[0] / image_size[1]
123 | ar_errors = self.aspect_ratios - aspect_ratio
124 | bucket_id = np.abs(ar_errors).argmin()
125 | return self.bucket_resolutions[bucket_id]
126 |
127 |
128 | def load_video(video_path, bucket_selector=None, start_frame=None, end_frame=None):
129 | vr = decord.VideoReader(uri=video_path)
130 | video_num_frames = len(vr)
131 | _start_frame, _end_frame = 0, video_num_frames
132 |
133 | if start_frame is not None:
134 | _start_frame = start_frame
135 | if end_frame is not None:
136 | _end_frame = min(_end_frame, end_frame)
137 |
138 | frames = vr.get_batch(range(_start_frame, _end_frame)).asnumpy()
139 | f, h, w, _ = frames.shape
140 | bucket_reso = bucket_selector.get_bucket_resolution(image_size=(w, h))
141 | frames = [resize_image_to_bucket(frame, bucket_reso=bucket_reso) for frame in frames]
142 | return frames
143 |
144 |
145 | class BucketBatchManager:
146 | def __init__(self, bucketed_item_info, min_bucket_count=0):
147 | self.buckets = bucketed_item_info
148 | self.bucket_resos = list(self.buckets.keys())
149 | self.bucket_resos.sort()
150 |
151 | for bucket_reso in self.bucket_resos:
152 | bucket = self.buckets[bucket_reso]
153 | if len(bucket) < min_bucket_count:
154 | print(
155 | f"bucket {bucket_reso!r} (n={len(bucket)!r}) has less than {min_bucket_count!r} items, remove it..."
156 | )
157 | del self.buckets[bucket_reso]
158 |
159 | self.bucket_resos = list(self.buckets.keys())
160 | self.bucket_resos.sort()
161 |
162 | def show_bucket_info(self):
163 | for bucket_reso in self.bucket_resos:
164 | bucket = self.buckets[bucket_reso]
165 | print(f"bucket: {bucket_reso}, count: {len(bucket)}")
166 |
167 |
168 | def convert_and_make_shards(args, base_df, bucket_batch_manager, bucket_reso, b_idx):
169 | columns = {
170 | "idx": "int32",
171 | "item_key": "str",
172 | "item": "str",
173 | "frame_count": "int32",
174 | "bucket_width": "int32",
175 | "bucket_height": "int32",
176 | "original_width": "int32",
177 | "original_height": "int32",
178 | "caption_str": "str",
179 | "video_arr": "ndarray",
180 | }
181 | print(f"Starting converter processs for bucket {bucket_reso!r}...")
182 | output_path = os.path.join(args.output_dir, "x".join(list(map(str, bucket_reso))))
183 | Path(output_path).mkdir(parents=True, exist_ok=True)
184 |
185 | bucket = bucket_batch_manager.buckets[bucket_reso]
186 |
187 | writer = MDSWriter(out=output_path, columns=columns, size_limit=256 * (2**20), max_workers=os.cpu_count())
188 |
189 | for item_info in tqdm(bucket, dynamic_ncols=True, position=b_idx, leave=False, desc=f"bucket {bucket_reso}"):
190 | item_key = item_info["item_key"]
191 | frame_count = item_info["frame_count"]
192 | frame_crop_pos = item_info["frame_crop_pos"]
193 | idx = item_info["idx"]
194 | bucket_reso_wh = (item_info["bucket_width"], item_info["bucket_height"])
195 |
196 | row = base_df.iloc[idx]
197 | video_path, caption = row[args.video_column], row[args.caption_column]
198 |
199 | try:
200 | vr = decord.VideoReader(uri=video_path)
201 | video = vr.get_batch(range(frame_crop_pos, frame_crop_pos + frame_count)).asnumpy()
202 | original_width, original_height = video.shape[2], video.shape[1]
203 | video = [resize_image_to_bucket(frame, bucket_reso=bucket_reso_wh) for frame in video]
204 | video = np.stack(video, axis=0)
205 | except Exception as e:
206 | print(f"Failed to load video {video_path!r} : {e!r}")
207 | print(traceback.format_exc())
208 | continue
209 |
210 | sample = {}
211 | sample["idx"] = idx
212 | sample["item_key"] = str(item_key)
213 | sample["item"] = str(Path(video_path).name)
214 | sample["frame_count"] = frame_count
215 | sample["bucket_width"] = bucket_reso_wh[0]
216 | sample["bucket_height"] = bucket_reso_wh[1]
217 | sample["original_width"] = original_height
218 | sample["original_height"] = original_width
219 | sample["caption_str"] = caption
220 | sample["video_arr"] = video
221 |
222 | writer.write(sample)
223 | writer.finish()
224 | print(f"Converter process finished for bucket {bucket_reso!r} !!!")
225 |
226 |
227 | def main(args):
228 | if str(args.csv).endswith(".csv"):
229 | df = pd.read_csv(args.csv)
230 | elif str(args.csv).endswith(".json"):
231 | df = pd.read_json(args.csv)
232 | elif str(args.csv).endswith(".parquet"):
233 | df = pd.read_parquet(args.csv)
234 | elif str(args.csv).endswith(".jsonl"):
235 | df = pd.read_json(args.csv, lines=True, orient="records")
236 | else:
237 | raise ValueError(f"Invalid csv path: {args.csv!r}")
238 | if args.base_dir is not None:
239 | df[args.video_column] = df[args.video_column].apply(lambda x: os.path.join(args.base_dir, x))
240 |
241 | if args.debug:
242 | df = df.sample(n=10).reset_index(drop=True, inplace=False)
243 |
244 | print("Total number of samples: ", len(df))
245 |
246 | bucket_selectors = []
247 | for res in args.bucket_reso:
248 | w, h, f = res.split("x")
249 | bs = BucketSelector(resolution=(int(w), int(h), int(f)), enable_bucket=True, no_upscale=False)
250 | bucket_selectors.append(bs)
251 |
252 | batches = {}
253 | for idx, row in tqdm(df.iterrows(), total=len(df), dynamic_ncols=True, desc="Generating buckets"):
254 | video_path = row[args.video_column]
255 |
256 | vr = decord.VideoReader(uri=video_path)
257 | frame_count = len(vr)
258 | video = vr.get_batch(range(0, 1)).asnumpy()
259 | frame_height, frame_width = video.shape[1], video.shape[2]
260 | frame_size = (frame_width, frame_height)
261 |
262 | for bs in bucket_selectors:
263 | target_frame = bs.resolution[-1]
264 | bucket_reso = bs.get_bucket_resolution(frame_size)
265 |
266 | if frame_count >= (target_frame + args.head_frame):
267 | crop_pos_and_frames = [args.head_frame, target_frame]
268 | body, ext = os.path.splitext(Path(video_path).name)
269 | item_key = f"{body}_{crop_pos_and_frames[0]:05d}-{target_frame:05d}{ext}"
270 | batch_key = (*bucket_reso, target_frame)
271 | item_info = {
272 | "item_key": item_key,
273 | "batch_key": batch_key,
274 | "frame_count": target_frame,
275 | "frame_crop_pos": crop_pos_and_frames[0],
276 | "idx": idx,
277 | "bucket_height": bucket_reso[1],
278 | "bucket_width": bucket_reso[0],
279 | }
280 | batch = batches.get(batch_key, [])
281 | batch.append(item_info)
282 | batches[batch_key] = batch
283 |
284 | bucket_manager = BucketBatchManager(batches, min_bucket_count=args.min_bucket_count)
285 | bucket_manager.show_bucket_info()
286 |
287 | bucket_resos = bucket_manager.bucket_resos
288 |
289 | exporters = []
290 | for bucket_idx, bucket_reso in enumerate(bucket_resos):
291 | op = mp.Process(target=convert_and_make_shards, args=(args, df, bucket_manager, bucket_reso, bucket_idx))
292 | op.start()
293 | exporters.append(op)
294 |
295 | for op in exporters:
296 | op.join()
297 | op.close()
298 |
299 |
300 | if __name__ == "__main__":
301 | main(parse_args())
302 |
--------------------------------------------------------------------------------
/tools/hv_precompute_latents.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from io import BytesIO
4 | from pathlib import Path
5 | from typing import Any, Dict, List, Optional, Tuple, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torchvision.transforms.v2 as transforms
10 | from accelerate.logging import get_logger
11 | from diffusers import (
12 | AutoencoderKLHunyuanVideo,
13 | FlowMatchEulerDiscreteScheduler,
14 | HunyuanVideoPipeline,
15 | HunyuanVideoTransformer3DModel,
16 | )
17 | from PIL import Image
18 | from streaming import MDSWriter, Stream, StreamingDataset
19 | from torch.utils.data import DataLoader
20 | from tqdm.auto import tqdm
21 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
22 |
23 | torch.backends.cuda.matmul.allow_tf32 = True
24 |
25 |
26 | def load_condition_models(
27 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
28 | text_encoder_dtype: torch.dtype = torch.float16,
29 | text_encoder_2_dtype: torch.dtype = torch.float16,
30 | revision: Optional[str] = None,
31 | cache_dir: Optional[str] = None,
32 | device: Optional[torch.device] = None,
33 | **kwargs,
34 | ) -> Dict[str, nn.Module]:
35 | tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
36 | text_encoder = LlamaModel.from_pretrained(
37 | model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
38 | )
39 | tokenizer_2 = CLIPTokenizer.from_pretrained(
40 | model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
41 | )
42 | text_encoder_2 = CLIPTextModel.from_pretrained(
43 | model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
44 | )
45 | if device is not None:
46 | text_encoder.to(device)
47 | text_encoder_2.to(device)
48 | return {
49 | "tokenizer": tokenizer,
50 | "text_encoder": text_encoder,
51 | "tokenizer_2": tokenizer_2,
52 | "text_encoder_2": text_encoder_2,
53 | }
54 |
55 |
56 | def load_latent_models(
57 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
58 | vae_dtype: torch.dtype = torch.float16,
59 | revision: Optional[str] = None,
60 | cache_dir: Optional[str] = None,
61 | device: Optional[torch.device] = None,
62 | **kwargs,
63 | ) -> Dict[str, nn.Module]:
64 | vae = AutoencoderKLHunyuanVideo.from_pretrained(
65 | model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
66 | )
67 | if device is not None:
68 | vae.to(device)
69 | vae.enable_slicing()
70 | vae.enable_tiling()
71 | return {
72 | "vae": vae,
73 | }
74 |
75 |
76 | def load_diffusion_models(
77 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
78 | transformer_dtype: torch.dtype = torch.bfloat16,
79 | shift: float = 1.0,
80 | revision: Optional[str] = None,
81 | cache_dir: Optional[str] = None,
82 | **kwargs,
83 | ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
84 | transformer = HunyuanVideoTransformer3DModel.from_pretrained(
85 | model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
86 | )
87 | scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
88 | return {"transformer": transformer, "scheduler": scheduler}
89 |
90 |
91 | def prepare_conditions(
92 | tokenizer: LlamaTokenizer,
93 | text_encoder: LlamaModel,
94 | tokenizer_2: CLIPTokenizer,
95 | text_encoder_2: CLIPTextModel,
96 | prompt: Union[str, List[str]],
97 | device: Optional[torch.device] = None,
98 | dtype: Optional[torch.dtype] = None,
99 | max_sequence_length: int = 256,
100 | prompt_template: Dict[str, Any] = {
101 | "template": (
102 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
103 | "1. The main content and theme of the video."
104 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
105 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
106 | "4. background environment, light, style and atmosphere."
107 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
108 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
109 | ),
110 | "crop_start": 95,
111 | },
112 | **kwargs,
113 | ) -> torch.Tensor:
114 | device = device or text_encoder.device
115 | dtype = dtype or text_encoder.dtype
116 |
117 | if isinstance(prompt, str):
118 | prompt = [prompt]
119 |
120 | conditions = {}
121 | conditions.update(
122 | _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
123 | )
124 | conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
125 |
126 | return conditions
127 |
128 |
129 | def prepare_latents(
130 | vae: AutoencoderKLHunyuanVideo,
131 | image_or_video: torch.Tensor,
132 | device: Optional[torch.device] = None,
133 | dtype: Optional[torch.dtype] = None,
134 | generator: Optional[torch.Generator] = None,
135 | **kwargs,
136 | ) -> torch.Tensor:
137 | device = device or vae.device
138 | dtype = dtype or vae.dtype
139 |
140 | if image_or_video.ndim == 4:
141 | image_or_video = image_or_video.unsqueeze(2)
142 | assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
143 |
144 | image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
145 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
146 |
147 | latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
148 | latents = latents * vae.config.scaling_factor
149 | latents = latents.to(dtype=dtype)
150 | return {"latents": latents}
151 |
152 |
153 | def _get_llama_prompt_embeds(
154 | tokenizer: LlamaTokenizer,
155 | text_encoder: LlamaModel,
156 | prompt: List[str],
157 | prompt_template: Dict[str, Any],
158 | device: torch.device,
159 | dtype: torch.dtype,
160 | max_sequence_length: int = 256,
161 | num_hidden_layers_to_skip: int = 2,
162 | ) -> Tuple[torch.Tensor, torch.Tensor]:
163 | batch_size = len(prompt)
164 | prompt = [prompt_template["template"].format(p) for p in prompt]
165 |
166 | crop_start = prompt_template.get("crop_start", None)
167 | if crop_start is None:
168 | prompt_template_input = tokenizer(
169 | prompt_template["template"],
170 | padding="max_length",
171 | return_tensors="pt",
172 | return_length=False,
173 | return_overflowing_tokens=False,
174 | return_attention_mask=False,
175 | )
176 | crop_start = prompt_template_input["input_ids"].shape[-1]
177 | # Remove <|eot_id|> token and placeholder {}
178 | crop_start -= 2
179 |
180 | max_sequence_length += crop_start
181 | text_inputs = tokenizer(
182 | prompt,
183 | max_length=max_sequence_length,
184 | padding="max_length",
185 | truncation=True,
186 | return_tensors="pt",
187 | return_length=False,
188 | return_overflowing_tokens=False,
189 | return_attention_mask=True,
190 | )
191 | text_input_ids = text_inputs.input_ids.to(device=device)
192 | prompt_attention_mask = text_inputs.attention_mask.to(device=device)
193 |
194 | prompt_embeds = text_encoder(
195 | input_ids=text_input_ids,
196 | attention_mask=prompt_attention_mask,
197 | output_hidden_states=True,
198 | ).hidden_states[-(num_hidden_layers_to_skip + 1)]
199 | prompt_embeds = prompt_embeds.to(dtype=dtype)
200 |
201 | if crop_start is not None and crop_start > 0:
202 | prompt_embeds = prompt_embeds[:, crop_start:]
203 | prompt_attention_mask = prompt_attention_mask[:, crop_start:]
204 |
205 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
206 |
207 | return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
208 |
209 |
210 | def _get_clip_prompt_embeds(
211 | tokenizer_2: CLIPTokenizer,
212 | text_encoder_2: CLIPTextModel,
213 | prompt: Union[str, List[str]],
214 | device: torch.device,
215 | dtype: torch.dtype,
216 | max_sequence_length: int = 77,
217 | ) -> torch.Tensor:
218 | text_inputs = tokenizer_2(
219 | prompt,
220 | padding="max_length",
221 | max_length=max_sequence_length,
222 | truncation=True,
223 | return_tensors="pt",
224 | )
225 |
226 | prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
227 | prompt_embeds = prompt_embeds.to(dtype=dtype)
228 |
229 | return {"pooled_prompt_embeds": prompt_embeds}
230 |
231 |
232 | def main(args):
233 | if "cuda" in args.device.lower():
234 | assert torch.cuda.is_available()
235 | device = torch.device(args.device)
236 | dtype = getattr(torch, args.dtype)
237 |
238 | video_transforms = transforms.Compose(
239 | [
240 | transforms.Lambda(lambda x: x / 255.0),
241 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
242 | ]
243 | )
244 |
245 | stream = [Stream(local=args.mds_data_path)]
246 | dataset = StreamingDataset(streams=stream, batch_size=1)
247 |
248 | save_path = Path(args.output_dir) / Path(args.mds_data_path).name
249 | save_path.mkdir(parents=True, exist_ok=True)
250 | save_path = save_path.as_posix()
251 |
252 | print(f"Saving to => {save_path!r}")
253 |
254 | columns = {
255 | "idx": "int32",
256 | "item_key": "str",
257 | "item": "str",
258 | "frame_count": "int32",
259 | "bucket_width": "int32",
260 | "bucket_height": "int32",
261 | "original_width": "int32",
262 | "original_height": "int32",
263 | "caption_str": "str",
264 | "video_arr": "ndarray",
265 | "prompt_embeds": "ndarray",
266 | "prompt_attention_mask": "ndarray",
267 | "pooled_prompt_embeds": "ndarray",
268 | "latents": "ndarray",
269 | "latents_cond": "ndarray",
270 | }
271 |
272 | writer = MDSWriter(
273 | out=save_path,
274 | columns=columns,
275 | compression=args.mds_shard_compression,
276 | size_limit=256 * (2**20),
277 | max_workers=64,
278 | )
279 |
280 | models = load_condition_models(
281 | args.pretrained_model_name_or_path, text_encoder_dtype=dtype, text_encoder_2_dtype=torch.bfloat16, device=device
282 | )
283 | models.update(load_latent_models(args.pretrained_model_name_or_path, vae_dtype=dtype, device=device))
284 |
285 | for batch in tqdm(dataset, dynamic_ncols=True, desc="Precomputing latents"):
286 | video, caption = batch["video_arr"], batch["caption_str"]
287 | video = torch.from_numpy(video.copy()).permute(0, 3, 1, 2) # n_frames, c, h, w
288 |
289 | control_condition = torch.zeros_like(video) # create an empty video
290 | control_condition[0] = video[0] # keep the first frame
291 | control_condition[-1] = video[-1] # keep the last frame
292 |
293 | video = torch.stack([video_transforms(x) for x in video], dim=0).unsqueeze(0)
294 | control_condition = torch.stack([video_transforms(x) for x in control_condition], dim=0).unsqueeze(0)
295 |
296 | with torch.inference_mode(): # @TODO: add batch support ?
297 | latents = prepare_latents(models["vae"], video, device=device, dtype=dtype)["latents"]
298 | latents_cond = prepare_latents(models["vae"], control_condition, device=device, dtype=dtype)["latents"]
299 | conditions = prepare_conditions(
300 | tokenizer=models["tokenizer"],
301 | text_encoder=models["text_encoder"],
302 | tokenizer_2=models["tokenizer_2"],
303 | text_encoder_2=models["text_encoder_2"],
304 | prompt=caption,
305 | device=device,
306 | dtype=dtype,
307 | )
308 | prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = (
309 | conditions["prompt_embeds"],
310 | conditions["prompt_attention_mask"],
311 | conditions["pooled_prompt_embeds"],
312 | )
313 |
314 | out_batch = {**batch}
315 | out_batch["latents"] = latents[0].float().cpu().numpy()
316 | out_batch["prompt_embeds"] = prompt_embeds[0].float().cpu().numpy()
317 | out_batch["prompt_attention_mask"] = prompt_attention_mask[0].float().cpu().numpy()
318 | out_batch["pooled_prompt_embeds"] = pooled_prompt_embeds[0].float().cpu().numpy()
319 | out_batch["latents_cond"] = latents_cond[0].float().cpu().numpy()
320 |
321 | writer.write(out_batch)
322 |
323 | writer.finish()
324 |
325 |
326 | if __name__ == "__main__":
327 | parser = ArgumentParser()
328 | parser.add_argument("--mds_data_path", required=True, type=str)
329 | parser.add_argument(
330 | "--pretrained_model_name_or_path", required=True, type=str, default="hunyuanvideo-community/HunyuanVideo"
331 | )
332 | parser.add_argument("--output_dir", required=True, type=str)
333 | parser.add_argument("--device", type=str, default="cuda")
334 | parser.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16")
335 | parser.add_argument("--mds_shard_compression", type=str, default=None)
336 |
337 | args = parser.parse_args()
338 | main(args)
339 |
--------------------------------------------------------------------------------
/tools/hv_precompute_latents_dist.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from argparse import ArgumentParser
4 | from datetime import datetime, timedelta
5 | from io import BytesIO
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional, Tuple, Union
8 |
9 | import accelerate
10 | import torch
11 | import torch.nn as nn
12 | import torchvision.transforms as transforms
13 | from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
14 | from accelerate.logging import get_logger
15 | from accelerate.utils import set_seed
16 | from diffusers import (
17 | AutoencoderKLHunyuanVideo,
18 | FlowMatchEulerDiscreteScheduler,
19 | HunyuanVideoPipeline,
20 | HunyuanVideoTransformer3DModel,
21 | )
22 | from PIL import Image
23 | from streaming import MDSWriter, Stream, StreamingDataLoader, StreamingDataset
24 | from torch.utils.data import DataLoader
25 | from tqdm.auto import tqdm
26 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
27 |
28 | torch.backends.cuda.matmul.allow_tf32 = True
29 | logger = get_logger("cache_latents")
30 | logging.basicConfig(
31 | level=logging.INFO,
32 | format="[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s",
33 | force=True,
34 | handlers=[logging.StreamHandler()],
35 | )
36 |
37 | _COMMON_BEGINNING_PHRASES = (
38 | "This video",
39 | "The video",
40 | "This clip",
41 | "The clip",
42 | "The animation",
43 | "This image",
44 | "The image",
45 | "This picture",
46 | "The picture",
47 | )
48 | _COMMON_CONTINUATION_WORDS = ("shows", "depicts", "features", "captures", "highlights", "introduces", "presents")
49 |
50 | COMMON_LLM_START_PHRASES = (
51 | "In the video,",
52 | "In this video,",
53 | "In this video clip,",
54 | "In the clip,",
55 | "Caption:",
56 | *(
57 | f"{beginning} {continuation}"
58 | for beginning in _COMMON_BEGINNING_PHRASES
59 | for continuation in _COMMON_CONTINUATION_WORDS
60 | ),
61 | )
62 |
63 |
64 | def load_condition_models(
65 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
66 | text_encoder_dtype: torch.dtype = torch.float16,
67 | text_encoder_2_dtype: torch.dtype = torch.float16,
68 | revision: Optional[str] = None,
69 | cache_dir: Optional[str] = None,
70 | device: Optional[torch.device] = None,
71 | **kwargs,
72 | ) -> Dict[str, nn.Module]:
73 | tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
74 | text_encoder = LlamaModel.from_pretrained(
75 | model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
76 | )
77 | tokenizer_2 = CLIPTokenizer.from_pretrained(
78 | model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
79 | )
80 | text_encoder_2 = CLIPTextModel.from_pretrained(
81 | model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
82 | )
83 | if device is not None:
84 | text_encoder.to(device)
85 | text_encoder_2.to(device)
86 | return {
87 | "tokenizer": tokenizer,
88 | "text_encoder": text_encoder,
89 | "tokenizer_2": tokenizer_2,
90 | "text_encoder_2": text_encoder_2,
91 | }
92 |
93 |
94 | def load_latent_models(
95 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
96 | vae_dtype: torch.dtype = torch.float16,
97 | revision: Optional[str] = None,
98 | cache_dir: Optional[str] = None,
99 | device: Optional[torch.device] = None,
100 | **kwargs,
101 | ) -> Dict[str, nn.Module]:
102 | vae = AutoencoderKLHunyuanVideo.from_pretrained(
103 | model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
104 | )
105 | if device is not None:
106 | vae.to(device)
107 | vae.enable_slicing()
108 | vae.enable_tiling()
109 | return {
110 | "vae": vae,
111 | }
112 |
113 |
114 | def load_diffusion_models(
115 | model_id: str = "hunyuanvideo-community/HunyuanVideo",
116 | transformer_dtype: torch.dtype = torch.bfloat16,
117 | shift: float = 1.0,
118 | revision: Optional[str] = None,
119 | cache_dir: Optional[str] = None,
120 | **kwargs,
121 | ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
122 | transformer = HunyuanVideoTransformer3DModel.from_pretrained(
123 | model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
124 | )
125 | scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
126 | return {"transformer": transformer, "scheduler": scheduler}
127 |
128 |
129 | def prepare_conditions(
130 | tokenizer: LlamaTokenizer,
131 | text_encoder: LlamaModel,
132 | tokenizer_2: CLIPTokenizer,
133 | text_encoder_2: CLIPTextModel,
134 | prompt: Union[str, List[str]],
135 | device: Optional[torch.device] = None,
136 | dtype: Optional[torch.dtype] = None,
137 | max_sequence_length: int = 256,
138 | prompt_template: Dict[str, Any] = {
139 | "template": (
140 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
141 | "1. The main content and theme of the video."
142 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
143 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
144 | "4. background environment, light, style and atmosphere."
145 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
146 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
147 | ),
148 | "crop_start": 95,
149 | },
150 | **kwargs,
151 | ) -> torch.Tensor:
152 | if isinstance(prompt, str):
153 | prompt = [prompt]
154 |
155 | conditions = {}
156 | conditions.update(
157 | _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
158 | )
159 | conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
160 |
161 | return conditions
162 |
163 |
164 | def prepare_latents(
165 | vae: AutoencoderKLHunyuanVideo,
166 | image_or_video: torch.Tensor,
167 | device: Optional[torch.device] = None,
168 | dtype: Optional[torch.dtype] = None,
169 | generator: Optional[torch.Generator] = None,
170 | **kwargs,
171 | ) -> torch.Tensor:
172 | if image_or_video.ndim == 4:
173 | image_or_video = image_or_video.unsqueeze(2)
174 | assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
175 |
176 | image_or_video = image_or_video.to(device=device, dtype=dtype)
177 | image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
178 |
179 | latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
180 | latents = latents * vae.config.scaling_factor
181 | latents = latents.to(dtype=dtype)
182 | return {"latents": latents}
183 |
184 |
185 | def _get_llama_prompt_embeds(
186 | tokenizer: LlamaTokenizer,
187 | text_encoder: LlamaModel,
188 | prompt: List[str],
189 | prompt_template: Dict[str, Any],
190 | device: torch.device,
191 | dtype: torch.dtype,
192 | max_sequence_length: int = 256,
193 | num_hidden_layers_to_skip: int = 2,
194 | ) -> Tuple[torch.Tensor, torch.Tensor]:
195 | batch_size = len(prompt)
196 | prompt = [prompt_template["template"].format(p) for p in prompt]
197 |
198 | crop_start = prompt_template.get("crop_start", None)
199 | if crop_start is None:
200 | prompt_template_input = tokenizer(
201 | prompt_template["template"],
202 | padding="max_length",
203 | return_tensors="pt",
204 | return_length=False,
205 | return_overflowing_tokens=False,
206 | return_attention_mask=False,
207 | )
208 | crop_start = prompt_template_input["input_ids"].shape[-1]
209 | # Remove <|eot_id|> token and placeholder {}
210 | crop_start -= 2
211 |
212 | max_sequence_length += crop_start
213 | text_inputs = tokenizer(
214 | prompt,
215 | max_length=max_sequence_length,
216 | padding="max_length",
217 | truncation=True,
218 | return_tensors="pt",
219 | return_length=False,
220 | return_overflowing_tokens=False,
221 | return_attention_mask=True,
222 | )
223 | text_input_ids = text_inputs.input_ids.to(device=device)
224 | prompt_attention_mask = text_inputs.attention_mask.to(device=device)
225 |
226 | prompt_embeds = text_encoder(
227 | input_ids=text_input_ids,
228 | attention_mask=prompt_attention_mask,
229 | output_hidden_states=True,
230 | ).hidden_states[-(num_hidden_layers_to_skip + 1)]
231 | prompt_embeds = prompt_embeds.to(dtype=dtype)
232 |
233 | if crop_start is not None and crop_start > 0:
234 | prompt_embeds = prompt_embeds[:, crop_start:]
235 | prompt_attention_mask = prompt_attention_mask[:, crop_start:]
236 |
237 | prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
238 |
239 | return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
240 |
241 |
242 | def _get_clip_prompt_embeds(
243 | tokenizer_2: CLIPTokenizer,
244 | text_encoder_2: CLIPTextModel,
245 | prompt: Union[str, List[str]],
246 | device: torch.device,
247 | dtype: torch.dtype,
248 | max_sequence_length: int = 77,
249 | ) -> torch.Tensor:
250 | text_inputs = tokenizer_2(
251 | prompt,
252 | padding="max_length",
253 | max_length=max_sequence_length,
254 | truncation=True,
255 | return_tensors="pt",
256 | )
257 |
258 | prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
259 | prompt_embeds = prompt_embeds.to(dtype=dtype)
260 |
261 | return {"pooled_prompt_embeds": prompt_embeds}
262 |
263 |
264 | def main(args):
265 | accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(hours=36))])
266 |
267 | video_transforms = transforms.Compose(
268 | [
269 | transforms.Lambda(lambda x: x / 255.0),
270 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
271 | ]
272 | )
273 |
274 | mds_streams = []
275 | if args.recursive:
276 | for fp in args.mds_data_path:
277 | for dir in os.listdir(fp):
278 | dir_path = os.path.join(fp, dir)
279 | mds_streams.append((Stream(local=dir_path), dir_path))
280 | else:
281 | for fp in args.mds_data_path:
282 | mds_streams.append((Stream(local=fp), fp))
283 |
284 | accelerator.print(f"## mds_streams: {len(mds_streams)}")
285 |
286 | dtype = torch.float32
287 | if accelerator.mixed_precision == "fp16":
288 | dtype = torch.float16
289 | elif accelerator.mixed_precision == "bf16":
290 | dtype = torch.bfloat16
291 |
292 | logger.info(f"set dtype to {dtype!r}")
293 |
294 | models = load_condition_models(
295 | args.pretrained_model_name_or_path,
296 | text_encoder_dtype=dtype,
297 | text_encoder_2_dtype=torch.bfloat16,
298 | device=accelerator.device,
299 | )
300 | models.update(load_latent_models(args.pretrained_model_name_or_path, vae_dtype=dtype, device=accelerator.device))
301 |
302 | models["vae"].eval().requires_grad_(False)
303 | models["text_encoder"].eval().requires_grad_(False)
304 | models["text_encoder_2"].eval().requires_grad_(False)
305 |
306 | accelerator.wait_for_everyone()
307 |
308 | for varname in ["RANK", "LOCAL_WORLD_SIZE", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]:
309 | assert os.environ.get(varname) is not None, f"{varname} is not set"
310 | logger.info(f"{varname!r}: {os.environ.get(varname)}")
311 |
312 | for stream, data_path in mds_streams:
313 | logger.info(f"## Processing {data_path!r}")
314 |
315 | dataset = StreamingDataset(
316 | streams=[stream],
317 | batch_size=1,
318 | num_canonical_nodes=(int(os.environ["WORLD_SIZE"]) // 8),
319 | )
320 |
321 | save_path = Path(args.output_dir) / f"{Path(data_path).name}_{accelerator.local_process_index:02d}"
322 | save_path.mkdir(parents=True, exist_ok=True)
323 | save_path = save_path.as_posix()
324 |
325 | logger.info(f"Saving to => {save_path!r}")
326 |
327 | columns = {
328 | # "idx": "int32",
329 | # "item_key": "str",
330 | # "item": "str",
331 | # "frame_count": "int32",
332 | # "bucket_width": "int32",
333 | # "bucket_height": "int32",
334 | # "original_width": "int32",
335 | # "original_height": "int32",
336 | # "caption_str": "str",
337 | # "video_arr": "ndarray",
338 | "prompt_embeds": "ndarray",
339 | "prompt_attention_mask": "ndarray",
340 | "pooled_prompt_embeds": "ndarray",
341 | "latents": "ndarray",
342 | "latents_cond": "ndarray",
343 | # "latents_cond_2": "ndarray",
344 | # "latents_cond_only_first": "ndarray",
345 | # "latents_cond_only_last": "ndarray",
346 | }
347 |
348 | os.umask(0o000)
349 | writer = MDSWriter(
350 | out=save_path,
351 | columns=columns,
352 | compression=args.mds_shard_compression,
353 | size_limit=256 * (2**20),
354 | max_workers=64,
355 | )
356 |
357 | def collate_fn(batch):
358 | idx = [x["idx"] for x in batch]
359 | item_key = [x["item_key"] for x in batch]
360 | item = [x["item"] for x in batch]
361 | video = [x["video_arr"] for x in batch]
362 | caption = [x["caption_str"] for x in batch]
363 |
364 | for i in range(len(caption)):
365 | caption[i] = caption[i].strip()
366 | for phrase in COMMON_LLM_START_PHRASES:
367 | if caption[i].startswith(phrase):
368 | caption[i] = caption[i].removeprefix(phrase).strip()
369 |
370 | return {"video_arr": video, "caption_str": caption, "idx": idx, "item_key": item_key, "item": item}
371 |
372 | dl = StreamingDataLoader(
373 | dataset,
374 | batch_size=1,
375 | collate_fn=collate_fn,
376 | shuffle=False,
377 | num_workers=8,
378 | prefetch_factor=2,
379 | pin_memory=True,
380 | )
381 |
382 | for batch in tqdm(dl, dynamic_ncols=True, desc="Precomputing latents", disable=not accelerator.is_main_process):
383 | # print(accelerator.process_index, batch["idx"], batch["item_key"], batch["item"])
384 |
385 | assert len(batch["video_arr"]) == 1
386 |
387 | video, caption = batch["video_arr"][0], batch["caption_str"][0]
388 | video = torch.from_numpy(video.copy()).permute(0, 3, 1, 2) # n_frames, c, h, w
389 |
390 | control_condition = torch.zeros_like(video) # create an empty video
391 | control_condition[0] = video[0] # keep the first frame
392 | control_condition[-1] = video[-1] # keep the last frame
393 |
394 | # control_condition_2 = torch.zeros_like(video) # create an empty video
395 | # control_condition_2[0] = video[0] # keep the first frame
396 | # control_condition_2[-1] = video[-1] # keep the last frame
397 | # control_condition_2[video.shape[0] // 2] = video[video.shape[0] // 2] # keep the middle frame
398 |
399 | # control_condition_only_first = torch.zeros_like(video) # create an empty video
400 | # control_condition_only_first[0] = video[0] # keep the first frame
401 |
402 | # control_condition_only_last = torch.zeros_like(video) # create an empty video
403 | # control_condition_only_last[-1] = video[-1] # keep the last frame
404 |
405 | video = torch.stack([video_transforms(x) for x in video], dim=0).unsqueeze(0)
406 | control_condition = torch.stack([video_transforms(x) for x in control_condition], dim=0).unsqueeze(0)
407 | # control_condition_2 = torch.stack([video_transforms(x) for x in control_condition_2], dim=0).unsqueeze(0)
408 | # control_condition_only_first = torch.stack([video_transforms(x) for x in control_condition_only_first], dim=0).unsqueeze(0)
409 | # control_condition_only_last = torch.stack([video_transforms(x) for x in control_condition_only_last], dim=0).unsqueeze(0)
410 |
411 | with torch.inference_mode(): # @TODO: add batch support ?
412 | latents = prepare_latents(models["vae"], video, device=accelerator.device, dtype=dtype)["latents"]
413 | latents_cond = prepare_latents(
414 | models["vae"], control_condition, device=accelerator.device, dtype=dtype
415 | )["latents"]
416 | # latents_cond_2 = prepare_latents(models["vae"], control_condition_2, device=accelerator.device, dtype=dtype)["latents"]
417 | # latents_cond_only_first = prepare_latents(models["vae"], control_condition_only_first, device=accelerator.device, dtype=dtype)["latents"]
418 | # latents_cond_only_last = prepare_latents(models["vae"], control_condition_only_last, device=accelerator.device, dtype=dtype)["latents"]
419 |
420 | conditions = prepare_conditions(
421 | tokenizer=models["tokenizer"],
422 | text_encoder=models["text_encoder"],
423 | tokenizer_2=models["tokenizer_2"],
424 | text_encoder_2=models["text_encoder_2"],
425 | prompt=caption,
426 | device=accelerator.device,
427 | dtype=dtype,
428 | )
429 | prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = (
430 | conditions["prompt_embeds"],
431 | conditions["prompt_attention_mask"],
432 | conditions["pooled_prompt_embeds"],
433 | )
434 |
435 | # out_batch = {**batch} # only takes extra space
436 | out_batch = {}
437 | out_batch["latents"] = latents[0].float().cpu().numpy()
438 | out_batch["prompt_embeds"] = prompt_embeds[0].float().cpu().numpy()
439 | out_batch["prompt_attention_mask"] = prompt_attention_mask[0].float().cpu().numpy()
440 | out_batch["pooled_prompt_embeds"] = pooled_prompt_embeds[0].float().cpu().numpy()
441 | out_batch["latents_cond"] = latents_cond[0].float().cpu().numpy()
442 | # out_batch["latents_cond_2"] = latents_cond_2[0].float().cpu().numpy()
443 | # out_batch["latents_cond_only_first"] = latents_cond_only_first[0].float().cpu().numpy()
444 | # out_batch["latents_cond_only_last"] = latents_cond_only_last[0].float().cpu().numpy()
445 |
446 | assert (
447 | out_batch.keys() == columns.keys()
448 | ), f"{out_batch.keys()} != {columns.keys()}, missing {set(out_batch.keys()) - set(columns.keys())}"
449 |
450 | os.umask(0o000)
451 | writer.write(out_batch)
452 |
453 | writer.finish()
454 | accelerator.wait_for_everyone()
455 |
456 |
457 | if __name__ == "__main__":
458 | parser = ArgumentParser()
459 | parser.add_argument("--mds_data_path", required=True, type=str, nargs="+")
460 | parser.add_argument("--pretrained_model_name_or_path", type=str, default="hunyuanvideo-community/HunyuanVideo")
461 | parser.add_argument("--output_dir", required=True, type=str)
462 | parser.add_argument("--mds_shard_compression", type=str, default=None)
463 | parser.add_argument("--recursive", action="store_true", default=False)
464 |
465 | args = parser.parse_args()
466 | main(args)
467 |
--------------------------------------------------------------------------------