├── LICENSE
├── Notice.txt
├── README.md
├── README_zh.md
├── assets
├── audio
│ ├── 2.WAV
│ ├── 3.WAV
│ └── 4.WAV
├── image
│ ├── 1.png
│ ├── 2.png
│ ├── 3.png
│ ├── 4.png
│ ├── src1.png
│ ├── src2.png
│ ├── src3.png
│ └── src4.png
├── material
│ ├── demo.png
│ ├── logo.png
│ ├── method.png
│ └── teaser.png
└── test.csv
├── hymm_gradio
├── flask_audio.py
├── gradio_audio.py
└── tool_for_end2end.py
├── hymm_sp
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── config.cpython-310.pyc
│ ├── constants.cpython-310.pyc
│ ├── helpers.cpython-310.pyc
│ ├── inference.cpython-310.pyc
│ └── sample_inference_audio.cpython-310.pyc
├── config.py
├── constants.py
├── data_kits
│ ├── __pycache__
│ │ ├── audio_dataset.cpython-310.pyc
│ │ ├── audio_preprocessor.cpython-310.pyc
│ │ ├── data_tools.cpython-310.pyc
│ │ └── ffmpeg_utils.cpython-310.pyc
│ ├── audio_dataset.py
│ ├── audio_preprocessor.py
│ ├── data_tools.py
│ └── face_align
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── align.cpython-310.pyc
│ │ └── detface.cpython-310.pyc
│ │ ├── align.py
│ │ └── detface.py
├── diffusion
│ ├── __init__.py
│ ├── __pycache__
│ │ └── __init__.cpython-310.pyc
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ └── pipeline_hunyuan_video_audio.cpython-310.pyc
│ │ └── pipeline_hunyuan_video_audio.py
│ └── schedulers
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── scheduling_flow_match_discrete.cpython-310.pyc
│ │ └── scheduling_flow_match_discrete.py
├── helpers.py
├── inference.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── activation_layers.cpython-310.pyc
│ │ ├── attn_layers.cpython-310.pyc
│ │ ├── audio_adapters.cpython-310.pyc
│ │ ├── embed_layers.cpython-310.pyc
│ │ ├── fp8_optimization.cpython-310.pyc
│ │ ├── mlp_layers.cpython-310.pyc
│ │ ├── models_audio.cpython-310.pyc
│ │ ├── modulate_layers.cpython-310.pyc
│ │ ├── norm_layers.cpython-310.pyc
│ │ ├── parallel_states.cpython-310.pyc
│ │ ├── posemb_layers.cpython-310.pyc
│ │ └── token_refiner.cpython-310.pyc
│ ├── activation_layers.py
│ ├── attn_layers.py
│ ├── audio_adapters.py
│ ├── embed_layers.py
│ ├── fp8_optimization.py
│ ├── mlp_layers.py
│ ├── models_audio.py
│ ├── modulate_layers.py
│ ├── norm_layers.py
│ ├── parallel_states.py
│ ├── posemb_layers.py
│ └── token_refiner.py
├── sample_batch.py
├── sample_gpu_poor.py
├── sample_inference_audio.py
├── text_encoder
│ ├── __init__.py
│ └── __pycache__
│ │ └── __init__.cpython-310.pyc
└── vae
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── autoencoder_kl_causal_3d.cpython-310.pyc
│ ├── unet_causal_3d_blocks.cpython-310.pyc
│ └── vae.cpython-310.pyc
│ ├── autoencoder_kl_causal_3d.py
│ ├── unet_causal_3d_blocks.py
│ └── vae.py
├── requirements.txt
├── scripts
├── run_gradio.sh
├── run_sample_batch_sp.sh
├── run_single_audio.sh
└── run_single_poor.sh
└── weights
└── README.md
/LICENSE:
--------------------------------------------------------------------------------
1 | TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2 | TencentHunyuanVideo-Avatar Release Date: May 28, 2025
3 | THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4 | By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5 | 1. DEFINITIONS.
6 | a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7 | b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8 | c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9 | d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10 | e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11 | f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12 | g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13 | h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14 | i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
15 | j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, TencentHunyuanVideo-Avatar released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar].
16 | k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17 | l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18 | m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19 | n. “including” shall mean including but not limited to.
20 | 2. GRANT OF RIGHTS.
21 | We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22 | 3. DISTRIBUTION.
23 | You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24 | a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25 | b. You must cause any modified files to carry prominent notices stating that You changed the files;
26 | c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27 | d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28 | You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29 | 4. ADDITIONAL COMMERCIAL TERMS.
30 | If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31 | 5. RULES OF USE.
32 | a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33 | b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34 | c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35 | 6. INTELLECTUAL PROPERTY.
36 | a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37 | b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38 | c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39 | d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40 | 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41 | a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42 | b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43 | c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44 | 8. SURVIVAL AND TERMINATION.
45 | a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46 | b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47 | 9. GOVERNING LAW AND JURISDICTION.
48 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49 | b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50 |
51 | EXHIBIT A
52 | ACCEPTABLE USE POLICY
53 |
54 | Tencent reserves the right to update this Acceptable Use Policy from time to time.
55 | Last modified: November 5, 2024
56 |
57 | Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58 | 1. Outside the Territory;
59 | 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60 | 3. To harm Yourself or others;
61 | 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62 | 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63 | 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64 | 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65 | 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66 | 9. To intentionally defame, disparage or otherwise harass others;
67 | 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68 | 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69 | 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70 | 13. To impersonate another individual without consent, authorization, or legal right;
71 | 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72 | 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73 | 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74 | 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75 | 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76 | 19. For military purposes;
77 | 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | # **HunyuanVideo-Avatar** 🌅
8 |
9 |
14 |
15 |
16 |

17 |
18 |
19 |
20 |

21 |
22 |
23 | 
24 |
25 | > [**HunyuanVideo-Avatar: High-Fidelity Audio-Driven Human Animation for Multiple Characters**](https://arxiv.org/pdf/2505.20156)
26 |
27 | ## 🔥🔥🔥 News!!
28 | * Jun 06, 2025: 🔥 HunyuanVideo-Avatar supports **Single GPU** with only **10GB VRAM**, with **TeaCache** included, **HUGE THANKS** to [Wan2GP](https://github.com/deepbeepmeep/Wan2GP)
29 | * May 28, 2025: 🔥 HunyuanVideo-Avatar is available in Cloud-Native-Build (CNB) [HunyuanVideo-Avatar](https://cnb.cool/tencent/hunyuan/HunyuanVideo-Avatar).
30 | * May 28, 2025: 👋 We release the inference code and model weights of HunyuanVideo-Avatar. [Download](weights/README.md).
31 |
32 |
33 | ## 📑 Open-source Plan
34 |
35 | - HunyuanVideo-Avatar
36 | - [x] Inference
37 | - [x] Checkpoints
38 | - [ ] ComfyUI
39 |
40 | ## Contents
41 | - [**HunyuanVideo-Avatar** 🌅](#HunyuanVideo-Avatar-)
42 | - [🔥🔥🔥 News!!](#-news)
43 | - [📑 Open-source Plan](#-open-source-plan)
44 | - [Contents](#contents)
45 | - [**Abstract**](#abstract)
46 | - [**HunyuanVideo-Avatar Overall Architecture**](#HunyuanVideo-Avatar-overall-architecture)
47 | - [🎉 **HunyuanVideo-Avatar Key Features**](#-HunyuanVideo-Avatar-key-features)
48 | - [**Multimodal Video customization**](#multimodal-video-customization)
49 | - [**Various Applications**](#various-applications)
50 | - [📈 Comparisons](#-comparisons)
51 | - [📜 Requirements](#-requirements)
52 | - [🛠️ Dependencies and Installation](#️-dependencies-and-installation)
53 | - [Installation Guide for Linux](#installation-guide-for-linux)
54 | - [🧱 Download Pretrained Models](#-download-pretrained-models)
55 | - [🚀 Parallel Inference on Multiple GPUs](#-parallel-inference-on-multiple-gpus)
56 | - [🔑 Single-gpu Inference](#-single-gpu-inference)
57 | - [Run with very low VRAM](#run-with-very-low-vram)
58 | - [Run a Gradio Server](#run-a-gradio-server)
59 | - [🔗 BibTeX](#-bibtex)
60 | - [Acknowledgements](#acknowledgements)
61 | ---
62 |
63 | ## **Abstract**
64 |
65 | Recent years have witnessed significant progress in audio-driven human animation. However, critical challenges remain in (i) generating highly dynamic videos while preserving character consistency, (ii) achieving precise emotion alignment between characters and audio, and (iii) enabling multi-character audio-driven animation. To address these challenges, we propose HunyuanVideo-Avatar, a multimodal diffusion transformer (MM-DiT)-based model capable of simultaneously generating dynamic, emotion-controllable, and multi-character dialogue videos. Concretely, HunyuanVideo-Avatar introduces three key innovations: (i) A character image injection module is designed to replace the conventional addition-based character conditioning scheme, eliminating the inherent condition mismatch between training and inference. This ensures the dynamic motion and strong character consistency; (ii) An Audio Emotion Module (AEM) is introduced to extract and transfer the emotional cues from an emotion reference image to the target generated video, enabling fine-grained and accurate emotion style control; (iii) A Face-Aware Audio Adapter (FAA) is proposed to isolate the audio-driven character with latent-level face mask, enabling independent audio injection via cross-attention for multi-character scenarios. These innovations empower HunyuanVideo-Avatar to surpass state-of-the-art methods on benchmark datasets and a newly proposed wild dataset, generating realistic avatars in dynamic, immersive scenarios. The source code and model weights will be released publicly.
66 |
67 | ## **HunyuanVideo-Avatar Overall Architecture**
68 |
69 | 
70 |
71 | We propose **HunyuanVideo-Avatar**, a multi-modal diffusion transformer(MM-DiT)-based model capable of generating **dynamic**, **emotion-controllable**, and **multi-character dialogue** videos.
72 |
73 | ## 🎉 **HunyuanVideo-Avatar Key Features**
74 |
75 | 
76 |
77 | ### **High-Dynamic and Emotion-Controllable Video Generation**
78 |
79 | HunyuanVideo-Avatar supports animating any input **avatar images** to **high-dynamic** and **emotion-controllable** videos with simple **audio conditions**. Specifically, it takes as input **multi-style** avatar images at **arbitrary scales and resolutions**. The system supports multi-style avatars encompassing photorealistic, cartoon, 3D-rendered, and anthropomorphic characters. Multi-scale generation spanning portrait, upper-body and full-body. It generates videos with high-dynamic foreground and background, achieving superior realistic and naturalness. In addition, the system supports controlling facial emotions of the characters conditioned on input audio.
80 |
81 | ### **Various Applications**
82 |
83 | HunyuanVideo-Avatar supports various downstream tasks and applications. For instance, the system generates talking avatar videos, which could be applied to e-commerce, online streaming, social media video production, etc. In addition, its multi-character animation feature enlarges the application such as video content creation, editing, etc.
84 |
85 | ## 📜 Requirements
86 |
87 | * An NVIDIA GPU with CUDA support is required.
88 | * The model is tested on a machine with 8GPUs.
89 | * **Minimum**: The minimum GPU memory required is 24GB for 704px768px129f but very slow.
90 | * **Recommended**: We recommend using a GPU with 96GB of memory for better generation quality.
91 | * **Tips**: If OOM occurs when using GPU with 80GB of memory, try to reduce the image resolution.
92 | * Tested operating system: Linux
93 |
94 |
95 | ## 🛠️ Dependencies and Installation
96 |
97 | Begin by cloning the repository:
98 | ```shell
99 | git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar.git
100 | cd HunyuanVideo-Avatar
101 | ```
102 |
103 | ### Installation Guide for Linux
104 |
105 | We recommend CUDA versions 12.4 or 11.8 for the manual installation.
106 |
107 | Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
108 |
109 | ```shell
110 | # 1. Create conda environment
111 | conda create -n HunyuanVideo-Avatar python==3.10.9
112 |
113 | # 2. Activate the environment
114 | conda activate HunyuanVideo-Avatar
115 |
116 | # 3. Install PyTorch and other dependencies using conda
117 | # For CUDA 11.8
118 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
119 | # For CUDA 12.4
120 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.4 -c pytorch -c nvidia
121 |
122 | # 4. Install pip dependencies
123 | python -m pip install -r requirements.txt
124 | # 5. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
125 | python -m pip install ninja
126 | python -m pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3
127 | ```
128 |
129 | In case of running into float point exception(core dump) on the specific GPU type, you may try the following solutions:
130 |
131 | ```shell
132 | # Option 1: Making sure you have installed CUDA 12.4, CUBLAS>=12.4.5.8, and CUDNN>=9.00 (or simply using our CUDA 12 docker image).
133 | pip install nvidia-cublas-cu12==12.4.5.8
134 | export LD_LIBRARY_PATH=/opt/conda/lib/python3.8/site-packages/nvidia/cublas/lib/
135 |
136 | # Option 2: Forcing to explicitly use the CUDA 11.8 compiled version of Pytorch and all the other packages
137 | pip uninstall -r requirements.txt # uninstall all packages
138 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu118
139 | pip install -r requirements.txt
140 | pip install ninja
141 | pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3
142 | ```
143 |
144 | Additionally, you can also use HunyuanVideo Docker image. Use the following command to pull and run the docker image.
145 |
146 | ```shell
147 | # For CUDA 12.4 (updated to avoid float point exception)
148 | docker pull hunyuanvideo/hunyuanvideo:cuda_12
149 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
150 | pip install gradio==3.39.0 diffusers==0.33.0 transformers==4.41.2
151 |
152 | # For CUDA 11.8
153 | docker pull hunyuanvideo/hunyuanvideo:cuda_11
154 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11
155 | pip install gradio==3.39.0 diffusers==0.33.0 transformers==4.41.2
156 | ```
157 |
158 |
159 | ## 🧱 Download Pretrained Models
160 |
161 | The details of download pretrained models are shown [here](weights/README.md).
162 |
163 | ## 🚀 Parallel Inference on Multiple GPUs
164 |
165 | For example, to generate a video with 8 GPUs, you can use the following command:
166 |
167 | ```bash
168 | cd HunyuanVideo-Avatar
169 |
170 | JOBS_DIR=$(dirname $(dirname "$0"))
171 | export PYTHONPATH=./
172 | export MODEL_BASE="./weights"
173 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
174 |
175 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
176 | --input 'assets/test.csv' \
177 | --ckpt ${checkpoint_path} \
178 | --sample-n-frames 129 \
179 | --seed 128 \
180 | --image-size 704 \
181 | --cfg-scale 7.5 \
182 | --infer-steps 50 \
183 | --use-deepcache 1 \
184 | --flow-shift-eval-video 5.0 \
185 | --save-path ${OUTPUT_BASEPATH}
186 | ```
187 |
188 | ## 🔑 Single-gpu Inference
189 |
190 | For example, to generate a video with 1 GPU, you can use the following command:
191 |
192 | ```bash
193 | cd HunyuanVideo-Avatar
194 |
195 | JOBS_DIR=$(dirname $(dirname "$0"))
196 | export PYTHONPATH=./
197 |
198 | export MODEL_BASE=./weights
199 | OUTPUT_BASEPATH=./results-single
200 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt
201 |
202 | export DISABLE_SP=1
203 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \
204 | --input 'assets/test.csv' \
205 | --ckpt ${checkpoint_path} \
206 | --sample-n-frames 129 \
207 | --seed 128 \
208 | --image-size 704 \
209 | --cfg-scale 7.5 \
210 | --infer-steps 50 \
211 | --use-deepcache 1 \
212 | --flow-shift-eval-video 5.0 \
213 | --save-path ${OUTPUT_BASEPATH} \
214 | --use-fp8 \
215 | --infer-min
216 | ```
217 |
218 | ### Run with very low VRAM
219 |
220 | ```bash
221 | cd HunyuanVideo-Avatar
222 |
223 | JOBS_DIR=$(dirname $(dirname "$0"))
224 | export PYTHONPATH=./
225 |
226 | export MODEL_BASE=./weights
227 | OUTPUT_BASEPATH=./results-poor
228 |
229 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt
230 |
231 | export CPU_OFFLOAD=1
232 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \
233 | --input 'assets/test.csv' \
234 | --ckpt ${checkpoint_path} \
235 | --sample-n-frames 129 \
236 | --seed 128 \
237 | --image-size 704 \
238 | --cfg-scale 7.5 \
239 | --infer-steps 50 \
240 | --use-deepcache 1 \
241 | --flow-shift-eval-video 5.0 \
242 | --save-path ${OUTPUT_BASEPATH} \
243 | --use-fp8 \
244 | --cpu-offload \
245 | --infer-min
246 | ```
247 |
248 | ### Run with 10GB VRAM GPU (TeaCache supported)
249 |
250 | Thanks to [Wan2GP](https://github.com/deepbeepmeep/Wan2GP), HunyuanVideo-Avatar now supports single GPU mode with even lower VRAM (10GB) without quality degradation. Check out this [great repo](https://github.com/deepbeepmeep/Wan2GP/tree/main/hyvideo).
251 |
252 |
253 | ## Run a Gradio Server
254 | ```bash
255 | cd HunyuanVideo-Avatar
256 |
257 | bash ./scripts/run_gradio.sh
258 |
259 | ```
260 |
261 | ## 🔗 BibTeX
262 |
263 | If you find [HunyuanVideo-Avatar](https://arxiv.org/pdf/2505.20156) useful for your research and applications, please cite using this BibTeX:
264 |
265 | ```BibTeX
266 | @misc{hu2025HunyuanVideo-Avatar,
267 | title={HunyuanVideo-Avatar: High-Fidelity Audio-Driven Human Animation for Multiple Characters},
268 | author={Yi Chen and Sen Liang and Zixiang Zhou and Ziyao Huang and Yifeng Ma and Junshu Tang and Qin Lin and Yuan Zhou and Qinglin Lu},
269 | year={2025},
270 | eprint={2505.20156},
271 | archivePrefix={arXiv},
272 | primaryClass={cs.CV},
273 | url={https://arxiv.org/pdf/2505.20156},
274 | }
275 | ```
276 |
277 | ## Acknowledgements
278 |
279 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
280 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/README_zh.md
--------------------------------------------------------------------------------
/assets/audio/2.WAV:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/2.WAV
--------------------------------------------------------------------------------
/assets/audio/3.WAV:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/3.WAV
--------------------------------------------------------------------------------
/assets/audio/4.WAV:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/4.WAV
--------------------------------------------------------------------------------
/assets/image/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/1.png
--------------------------------------------------------------------------------
/assets/image/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/2.png
--------------------------------------------------------------------------------
/assets/image/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/3.png
--------------------------------------------------------------------------------
/assets/image/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/4.png
--------------------------------------------------------------------------------
/assets/image/src1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src1.png
--------------------------------------------------------------------------------
/assets/image/src2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src2.png
--------------------------------------------------------------------------------
/assets/image/src3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src3.png
--------------------------------------------------------------------------------
/assets/image/src4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src4.png
--------------------------------------------------------------------------------
/assets/material/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/demo.png
--------------------------------------------------------------------------------
/assets/material/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/logo.png
--------------------------------------------------------------------------------
/assets/material/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/method.png
--------------------------------------------------------------------------------
/assets/material/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/teaser.png
--------------------------------------------------------------------------------
/assets/test.csv:
--------------------------------------------------------------------------------
1 | videoid,image,audio,prompt,fps
2 | 8,assets/image/1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forested area.,25
3 | 9,assets/image/2.png,assets/audio/2.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
4 | 10,assets/image/3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25
5 | 11,assets/image/4.png,assets/audio/2.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
6 | 12,assets/image/src1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
7 | 13,assets/image/src2.png,assets/audio/2.WAV,A person in a green jacket stands in a forest at dusk.,25
8 | 14,assets/image/src3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25
9 | 15,assets/image/src4.png,assets/audio/2.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
10 | 16,assets/image/1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forested area.,25
11 | 17,assets/image/2.png,assets/audio/3.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
12 | 18,assets/image/3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25
13 | 19,assets/image/4.png,assets/audio/3.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
14 | 20,assets/image/src1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
15 | 21,assets/image/src2.png,assets/audio/3.WAV,A person in a green jacket stands in a forest at dusk.,25
16 | 22,assets/image/src3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25
17 | 23,assets/image/src4.png,assets/audio/3.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
18 | 24,assets/image/1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forested area.,25
19 | 25,assets/image/2.png,assets/audio/4.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25
20 | 26,assets/image/3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25
21 | 27,assets/image/4.png,assets/audio/4.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25
22 | 28,assets/image/src1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25
23 | 29,assets/image/src2.png,assets/audio/4.WAV,A person in a green jacket stands in a forest at dusk.,25
24 | 30,assets/image/src3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25
25 | 31,assets/image/src4.png,assets/audio/4.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25
26 |
--------------------------------------------------------------------------------
/hymm_gradio/flask_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import warnings
5 | import threading
6 | import traceback
7 | import uvicorn
8 | from fastapi import FastAPI, Body
9 | from pathlib import Path
10 | from datetime import datetime
11 | import torch.distributed as dist
12 | from hymm_gradio.tool_for_end2end import *
13 | from hymm_sp.config import parse_args
14 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler
15 |
16 | from hymm_sp.modules.parallel_states import (
17 | initialize_distributed,
18 | nccl_info,
19 | )
20 |
21 | from transformers import WhisperModel
22 | from transformers import AutoFeatureExtractor
23 | from hymm_sp.data_kits.face_align import AlignImage
24 |
25 |
26 | warnings.filterwarnings("ignore")
27 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE')
28 | app = FastAPI()
29 | rlock = threading.RLock()
30 |
31 |
32 |
33 | @app.api_route('/predict2', methods=['GET', 'POST'])
34 | def predict(data=Body(...)):
35 | is_acquire = False
36 | error_info = ""
37 | try:
38 | is_acquire = rlock.acquire(blocking=False)
39 | if is_acquire:
40 | res = predict_wrap(data)
41 | return res
42 | except Exception as e:
43 | error_info = traceback.format_exc()
44 | print(error_info)
45 | finally:
46 | if is_acquire:
47 | rlock.release()
48 | return {"errCode": -1, "info": "broken"}
49 |
50 | def predict_wrap(input_dict={}):
51 | if nccl_info.sp_size > 1:
52 | device = torch.device(f"cuda:{torch.distributed.get_rank()}")
53 | rank = local_rank = torch.distributed.get_rank()
54 | print(f"sp_size={nccl_info.sp_size}, rank {rank} local_rank {local_rank}")
55 | try:
56 | print(f"----- rank = {rank}")
57 | if rank == 0:
58 | input_dict = process_input_dict(input_dict)
59 |
60 | print('------- start to predict -------')
61 | # Parse input arguments
62 | image_path = input_dict["image_path"]
63 | driving_audio_path = input_dict["audio_path"]
64 |
65 | prompt = input_dict["prompt"]
66 |
67 | save_fps = input_dict.get("save_fps", 25)
68 |
69 |
70 | ret_dict = None
71 | if image_path is None or driving_audio_path is None:
72 | ret_dict = {
73 | "errCode": -3,
74 | "content": [
75 | {
76 | "buffer": None
77 | },
78 | ],
79 | "info": "input content is not valid",
80 | }
81 |
82 | print(f"errCode: -3, input content is not valid!")
83 | return ret_dict
84 |
85 | # Preprocess input batch
86 | torch.cuda.synchronize()
87 |
88 | a = datetime.now()
89 |
90 | try:
91 | model_kwargs_tmp = data_preprocess_server(
92 | args, image_path, driving_audio_path, prompt, feature_extractor
93 | )
94 | except:
95 | ret_dict = {
96 | "errCode": -2,
97 | "content": [
98 | {
99 | "buffer": None
100 | },
101 | ],
102 | "info": "failed to preprocess input data"
103 | }
104 | print(f"errCode: -2, preprocess failed!")
105 | return ret_dict
106 |
107 | text_prompt = model_kwargs_tmp["text_prompt"]
108 | audio_path = model_kwargs_tmp["audio_path"]
109 | image_path = model_kwargs_tmp["image_path"]
110 | fps = model_kwargs_tmp["fps"]
111 | audio_prompts = model_kwargs_tmp["audio_prompts"]
112 | audio_len = model_kwargs_tmp["audio_len"]
113 | motion_bucket_id_exps = model_kwargs_tmp["motion_bucket_id_exps"]
114 | motion_bucket_id_heads = model_kwargs_tmp["motion_bucket_id_heads"]
115 | pixel_value_ref = model_kwargs_tmp["pixel_value_ref"]
116 | pixel_value_ref_llava = model_kwargs_tmp["pixel_value_ref_llava"]
117 |
118 |
119 |
120 | torch.cuda.synchronize()
121 | b = datetime.now()
122 | preprocess_time = (b - a).total_seconds()
123 | print("="*100)
124 | print("preprocess time :", preprocess_time)
125 | print("="*100)
126 |
127 | else:
128 | text_prompt = None
129 | audio_path = None
130 | image_path = None
131 | fps = None
132 | audio_prompts = None
133 | audio_len = None
134 | motion_bucket_id_exps = None
135 | motion_bucket_id_heads = None
136 | pixel_value_ref = None
137 | pixel_value_ref_llava = None
138 |
139 | except:
140 | traceback.print_exc()
141 | if rank == 0:
142 | ret_dict = {
143 | "errCode": -1, # Failed to generate video
144 | "content":[
145 | {
146 | "buffer": None
147 | }
148 | ],
149 | "info": "failed to preprocess",
150 | }
151 | return ret_dict
152 |
153 | try:
154 | broadcast_params = [
155 | text_prompt,
156 | audio_path,
157 | image_path,
158 | fps,
159 | audio_prompts,
160 | audio_len,
161 | motion_bucket_id_exps,
162 | motion_bucket_id_heads,
163 | pixel_value_ref,
164 | pixel_value_ref_llava,
165 | ]
166 | dist.broadcast_object_list(broadcast_params, src=0)
167 | outputs = generate_image_parallel(*broadcast_params)
168 |
169 | if rank == 0:
170 | samples = outputs["samples"]
171 | sample = samples[0].unsqueeze(0)
172 |
173 | sample = sample[:, :, :audio_len[0]]
174 |
175 | video = sample[0].permute(1, 2, 3, 0).clamp(0, 1).numpy()
176 | video = (video * 255.).astype(np.uint8)
177 |
178 | output_dict = {
179 | "err_code": 0,
180 | "err_msg": "succeed",
181 | "video": video,
182 | "audio": input_dict.get("audio_path", None),
183 | "save_fps": save_fps,
184 | }
185 |
186 | ret_dict = process_output_dict(output_dict)
187 | return ret_dict
188 |
189 | except:
190 | traceback.print_exc()
191 | if rank == 0:
192 | ret_dict = {
193 | "errCode": -1, # Failed to generate video
194 | "content":[
195 | {
196 | "buffer": None
197 | }
198 | ],
199 | "info": "failed to generate video",
200 | }
201 | return ret_dict
202 |
203 | return None
204 |
205 | def generate_image_parallel(text_prompt,
206 | audio_path,
207 | image_path,
208 | fps,
209 | audio_prompts,
210 | audio_len,
211 | motion_bucket_id_exps,
212 | motion_bucket_id_heads,
213 | pixel_value_ref,
214 | pixel_value_ref_llava
215 | ):
216 | if nccl_info.sp_size > 1:
217 | device = torch.device(f"cuda:{torch.distributed.get_rank()}")
218 |
219 | batch = {
220 | "text_prompt": text_prompt,
221 | "audio_path": audio_path,
222 | "image_path": image_path,
223 | "fps": fps,
224 | "audio_prompts": audio_prompts,
225 | "audio_len": audio_len,
226 | "motion_bucket_id_exps": motion_bucket_id_exps,
227 | "motion_bucket_id_heads": motion_bucket_id_heads,
228 | "pixel_value_ref": pixel_value_ref,
229 | "pixel_value_ref_llava": pixel_value_ref_llava
230 | }
231 |
232 | samples = hunyuan_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance)
233 | return samples
234 |
235 | def worker_loop():
236 | while True:
237 | predict_wrap()
238 |
239 |
240 | if __name__ == "__main__":
241 | audio_args = parse_args()
242 | initialize_distributed(audio_args.seed)
243 | hunyuan_sampler = HunyuanVideoSampler.from_pretrained(
244 | audio_args.ckpt, args=audio_args)
245 | args = hunyuan_sampler.args
246 |
247 | rank = local_rank = 0
248 | device = torch.device("cuda")
249 | if nccl_info.sp_size > 1:
250 | device = torch.device(f"cuda:{torch.distributed.get_rank()}")
251 | rank = local_rank = torch.distributed.get_rank()
252 |
253 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/")
254 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32)
255 | wav2vec.requires_grad_(False)
256 |
257 |
258 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/'
259 | det_path = os.path.join(BASE_DIR, 'detface.pt')
260 | align_instance = AlignImage("cuda", det_path=det_path)
261 |
262 |
263 |
264 | if rank == 0:
265 | uvicorn.run(app, host="0.0.0.0", port=80)
266 | else:
267 | worker_loop()
268 |
269 |
--------------------------------------------------------------------------------
/hymm_gradio/gradio_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import json
5 | import datetime
6 | import requests
7 | import gradio as gr
8 | from tool_for_end2end import *
9 |
10 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
11 | DATADIR = './temp'
12 | _HEADER_ = '''
13 |
14 |
Tencent HunyuanVideo-Avatar Demo
15 |
16 |
17 | '''
18 | # flask url
19 | URL = "http://127.0.0.1:80/predict2"
20 |
21 | def post_and_get(audio_input, id_image, prompt):
22 | now = datetime.datetime.now().isoformat()
23 | imgdir = os.path.join(DATADIR, 'reference')
24 | videodir = os.path.join(DATADIR, 'video')
25 | imgfile = os.path.join(imgdir, now + '.png')
26 | output_video_path = os.path.join(videodir, now + '.mp4')
27 |
28 |
29 | os.makedirs(imgdir, exist_ok=True)
30 | os.makedirs(videodir, exist_ok=True)
31 | cv2.imwrite(imgfile, id_image[:,:,::-1])
32 |
33 | proxies = {
34 | "http": None,
35 | "https": None,
36 | }
37 |
38 | files = {
39 | "image_buffer": encode_image_to_base64(imgfile),
40 | "audio_buffer": encode_wav_to_base64(audio_input),
41 | "text": prompt,
42 | "save_fps": 25,
43 | }
44 | r = requests.get(URL, data = json.dumps(files), proxies=proxies)
45 | ret_dict = json.loads(r.text)
46 | print(ret_dict["info"])
47 | save_video_base64_to_local(
48 | video_path=None,
49 | base64_buffer=ret_dict["content"][0]["buffer"],
50 | output_video_path=output_video_path)
51 |
52 |
53 | return output_video_path
54 |
55 | def create_demo():
56 |
57 | with gr.Blocks() as demo:
58 | gr.Markdown(_HEADER_)
59 | with gr.Tab('语音数字人驱动'):
60 | with gr.Row():
61 | with gr.Column(scale=1):
62 | with gr.Group():
63 | prompt = gr.Textbox(label="Prompt", value="a man is speaking.")
64 |
65 | audio_input = gr.Audio(sources=["upload"],
66 | type="filepath",
67 | label="Upload Audio",
68 | elem_classes="media-upload",
69 | scale=1)
70 | id_image = gr.Image(label="Input reference image", height=480)
71 |
72 | with gr.Column(scale=2):
73 | with gr.Group():
74 | output_image = gr.Video(label="Generated Video")
75 |
76 |
77 | with gr.Column(scale=1):
78 | generate_btn = gr.Button("Generate")
79 |
80 | generate_btn.click(fn=post_and_get,
81 | inputs=[audio_input, id_image, prompt],
82 | outputs=[output_image],
83 | )
84 |
85 | return demo
86 |
87 | if __name__ == "__main__":
88 | allowed_paths = ['/']
89 | demo = create_demo()
90 | demo.launch(server_name='0.0.0.0', server_port=8080, share=True, allowed_paths=allowed_paths)
91 |
--------------------------------------------------------------------------------
/hymm_gradio/tool_for_end2end.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import math
4 | import uuid
5 | import base64
6 | import imageio
7 | import torch
8 | import torchvision
9 | from PIL import Image
10 | import numpy as np
11 | from copy import deepcopy
12 | from einops import rearrange
13 | import torchvision.transforms as transforms
14 | from torchvision.transforms import ToPILImage
15 | from hymm_sp.data_kits.audio_dataset import get_audio_feature
16 |
17 | TEMP_DIR = "./temp"
18 | if not os.path.exists(TEMP_DIR):
19 | os.makedirs(TEMP_DIR, exist_ok=True)
20 |
21 |
22 | def data_preprocess_server(args, image_path, audio_path, prompts, feature_extractor):
23 | llava_transform = transforms.Compose(
24 | [
25 | transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
26 | transforms.ToTensor(),
27 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
28 | ]
29 | )
30 |
31 | """ 生成prompt """
32 | if prompts is None:
33 | prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed."
34 | else:
35 | prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + prompts
36 |
37 | fps = 25
38 |
39 | img_size = args.image_size
40 | ref_image = Image.open(image_path).convert('RGB')
41 |
42 | # Resize reference image
43 | w, h = ref_image.size
44 | scale = img_size / min(w, h)
45 | new_w = round(w * scale / 64) * 64
46 | new_h = round(h * scale / 64) * 64
47 |
48 | if img_size == 704:
49 | img_size_long = 1216
50 | if new_w * new_h > img_size * img_size_long:
51 | scale = math.sqrt(img_size * img_size_long / w / h)
52 | new_w = round(w * scale / 64) * 64
53 | new_h = round(h * scale / 64) * 64
54 |
55 | ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
56 |
57 | ref_image = np.array(ref_image)
58 | ref_image = torch.from_numpy(ref_image)
59 |
60 | audio_input, audio_len = get_audio_feature(feature_extractor, audio_path)
61 | audio_prompts = audio_input[0]
62 |
63 | motion_bucket_id_heads = np.array([25] * 4)
64 | motion_bucket_id_exps = np.array([30] * 4)
65 | motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
66 | motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
67 | fps = torch.from_numpy(np.array(fps))
68 |
69 | to_pil = ToPILImage()
70 | pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
71 |
72 | pixel_value_ref_llava = [llava_transform(to_pil(image)) for image in pixel_value_ref]
73 | pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
74 |
75 | batch = {
76 | "text_prompt": [prompts],
77 | "audio_path": [audio_path],
78 | "image_path": [image_path],
79 | "fps": fps.unsqueeze(0).to(dtype=torch.float16),
80 | "audio_prompts": audio_prompts.unsqueeze(0).to(dtype=torch.float16),
81 | "audio_len": [audio_len],
82 | "motion_bucket_id_exps": motion_bucket_id_exps.unsqueeze(0),
83 | "motion_bucket_id_heads": motion_bucket_id_heads.unsqueeze(0),
84 | "pixel_value_ref": pixel_value_ref.unsqueeze(0).to(dtype=torch.float16),
85 | "pixel_value_ref_llava": pixel_value_ref_llava.unsqueeze(0).to(dtype=torch.float16)
86 | }
87 |
88 | return batch
89 |
90 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
91 | videos = rearrange(videos, "b c t h w -> t b c h w")
92 | outputs = []
93 | for x in videos:
94 | x = torchvision.utils.make_grid(x, nrow=n_rows)
95 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
96 | if rescale:
97 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
98 | x = torch.clamp(x,0,1)
99 | x = (x * 255).numpy().astype(np.uint8)
100 | outputs.append(x)
101 |
102 | os.makedirs(os.path.dirname(path), exist_ok=True)
103 | imageio.mimsave(path, outputs, fps=fps, quality=quality)
104 |
105 | def encode_image_to_base64(image_path):
106 | try:
107 | with open(image_path, 'rb') as image_file:
108 | image_data = image_file.read()
109 | encoded_data = base64.b64encode(image_data).decode('utf-8')
110 | print(f"Image file '{image_path}' has been successfully encoded to Base64.")
111 | return encoded_data
112 |
113 | except Exception as e:
114 | print(f"Error encoding image: {e}")
115 | return None
116 |
117 | def encode_video_to_base64(video_path):
118 | try:
119 | with open(video_path, 'rb') as video_file:
120 | video_data = video_file.read()
121 | encoded_data = base64.b64encode(video_data).decode('utf-8')
122 | print(f"Video file '{video_path}' has been successfully encoded to Base64.")
123 | return encoded_data
124 |
125 | except Exception as e:
126 | print(f"Error encoding video: {e}")
127 | return None
128 |
129 | def encode_wav_to_base64(wav_path):
130 | try:
131 | with open(wav_path, 'rb') as audio_file:
132 | audio_data = audio_file.read()
133 | encoded_data = base64.b64encode(audio_data).decode('utf-8')
134 | print(f"Audio file '{wav_path}' has been successfully encoded to Base64.")
135 | return encoded_data
136 |
137 | except Exception as e:
138 | print(f"Error encoding audio: {e}")
139 | return None
140 |
141 | def encode_pkl_to_base64(pkl_path):
142 | try:
143 | with open(pkl_path, 'rb') as pkl_file:
144 | pkl_data = pkl_file.read()
145 |
146 | encoded_data = base64.b64encode(pkl_data).decode('utf-8')
147 |
148 | print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.")
149 | return encoded_data
150 |
151 | except Exception as e:
152 | print(f"Error encoding pickle: {e}")
153 | return None
154 |
155 | def decode_base64_to_image(base64_buffer_str):
156 | try:
157 | image_data = base64.b64decode(base64_buffer_str)
158 | image = Image.open(io.BytesIO(image_data))
159 | image_array = np.array(image)
160 | print(f"Image Base64 string has beed succesfully decoded to image.")
161 | return image_array
162 | except Exception as e:
163 | print(f"Error encdecodingoding image: {e}")
164 | return None
165 |
166 | def decode_base64_to_video(base64_buffer_str):
167 | try:
168 | video_data = base64.b64decode(base64_buffer_str)
169 | video_bytes = io.BytesIO(video_data)
170 | video_bytes.seek(0)
171 | video_reader = imageio.get_reader(video_bytes, 'ffmpeg')
172 | video_frames = [frame for frame in video_reader]
173 | return video_frames
174 | except Exception as e:
175 | print(f"Error decoding video: {e}")
176 | return None
177 |
178 |
179 | def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None):
180 | if video_path is not None and base64_buffer is None:
181 | video_buffer_base64 = encode_video_to_base64(video_path)
182 | elif video_path is None and base64_buffer is not None:
183 | video_buffer_base64 = deepcopy(base64_buffer)
184 | else:
185 | print("Please pass either 'video_path' or 'base64_buffer'")
186 | return None
187 |
188 | if video_buffer_base64 is not None:
189 | video_data = base64.b64decode(video_buffer_base64)
190 | if output_video_path is None:
191 | uuid_string = str(uuid.uuid4())
192 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
193 | else:
194 | temp_video_path = output_video_path
195 | with open(temp_video_path, 'wb') as video_file:
196 | video_file.write(video_data)
197 | return temp_video_path
198 | else:
199 | return None
200 |
201 | def save_audio_base64_to_local(audio_path=None, base64_buffer=None):
202 | if audio_path is not None and base64_buffer is None:
203 | audio_buffer_base64 = encode_wav_to_base64(audio_path)
204 | elif audio_path is None and base64_buffer is not None:
205 | audio_buffer_base64 = deepcopy(base64_buffer)
206 | else:
207 | print("Please pass either 'audio_path' or 'base64_buffer'")
208 | return None
209 |
210 | if audio_buffer_base64 is not None:
211 | audio_data = base64.b64decode(audio_buffer_base64)
212 | uuid_string = str(uuid.uuid4())
213 | temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav'
214 | with open(temp_audio_path, 'wb') as audio_file:
215 | audio_file.write(audio_data)
216 | return temp_audio_path
217 | else:
218 | return None
219 |
220 | def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None):
221 | if pkl_path is not None and base64_buffer is None:
222 | pkl_buffer_base64 = encode_pkl_to_base64(pkl_path)
223 | elif pkl_path is None and base64_buffer is not None:
224 | pkl_buffer_base64 = deepcopy(base64_buffer)
225 | else:
226 | print("Please pass either 'pkl_path' or 'base64_buffer'")
227 | return None
228 |
229 | if pkl_buffer_base64 is not None:
230 | pkl_data = base64.b64decode(pkl_buffer_base64)
231 | uuid_string = str(uuid.uuid4())
232 | temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl'
233 | with open(temp_pkl_path, 'wb') as pkl_file:
234 | pkl_file.write(pkl_data)
235 | return temp_pkl_path
236 | else:
237 | return None
238 |
239 | def remove_temp_fles(input_dict):
240 | for key, val in input_dict.items():
241 | if "_path" in key and val is not None and os.path.exists(val):
242 | os.remove(val)
243 | print(f"Remove temporary {key} from {val}")
244 |
245 | def process_output_dict(output_dict):
246 |
247 | uuid_string = str(uuid.uuid4())
248 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4'
249 | imageio.mimsave(temp_video_path, output_dict["video"], fps=output_dict.get("save_fps", 25))
250 |
251 | # Add audio
252 | if output_dict["audio"] is not None and os.path.exists(output_dict["audio"]):
253 | output_path = temp_video_path
254 | audio_path = output_dict["audio"]
255 | save_path = temp_video_path.replace(".mp4", "_audio.mp4")
256 | print('='*100)
257 | print(f"output_path = {output_path}\n audio_path = {audio_path}\n save_path = {save_path}")
258 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{save_path}' -y -loglevel quiet; rm '{output_path}'")
259 | else:
260 | save_path = temp_video_path
261 |
262 | video_base64_buffer = encode_video_to_base64(save_path)
263 |
264 | encoded_output_dict = {
265 | "errCode": output_dict["err_code"],
266 | "content": [
267 | {
268 | "buffer": video_base64_buffer
269 | },
270 | ],
271 | "info":output_dict["err_msg"],
272 | }
273 |
274 |
275 |
276 | return encoded_output_dict
277 |
278 |
279 | def save_image_base64_to_local(image_path=None, base64_buffer=None):
280 | # Encode image to base64 buffer
281 | if image_path is not None and base64_buffer is None:
282 | image_buffer_base64 = encode_image_to_base64(image_path)
283 | elif image_path is None and base64_buffer is not None:
284 | image_buffer_base64 = deepcopy(base64_buffer)
285 | else:
286 | print("Please pass either 'image_path' or 'base64_buffer'")
287 | return None
288 |
289 | # Decode base64 buffer and save to local disk
290 | if image_buffer_base64 is not None:
291 | image_data = base64.b64decode(image_buffer_base64)
292 | uuid_string = str(uuid.uuid4())
293 | temp_image_path = f'{TEMP_DIR}/{uuid_string}.png'
294 | with open(temp_image_path, 'wb') as image_file:
295 | image_file.write(image_data)
296 | return temp_image_path
297 | else:
298 | return None
299 |
300 | def process_input_dict(input_dict):
301 |
302 | decoded_input_dict = {}
303 |
304 | decoded_input_dict["save_fps"] = input_dict.get("save_fps", 25)
305 |
306 | image_base64_buffer = input_dict.get("image_buffer", None)
307 | if image_base64_buffer is not None:
308 | decoded_input_dict["image_path"] = save_image_base64_to_local(
309 | image_path=None,
310 | base64_buffer=image_base64_buffer)
311 | else:
312 | decoded_input_dict["image_path"] = None
313 |
314 | audio_base64_buffer = input_dict.get("audio_buffer", None)
315 | if audio_base64_buffer is not None:
316 | decoded_input_dict["audio_path"] = save_audio_base64_to_local(
317 | audio_path=None,
318 | base64_buffer=audio_base64_buffer)
319 | else:
320 | decoded_input_dict["audio_path"] = None
321 |
322 | decoded_input_dict["prompt"] = input_dict.get("text", None)
323 |
324 | return decoded_input_dict
--------------------------------------------------------------------------------
/hymm_sp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__init__.py
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/constants.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/constants.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/helpers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/helpers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/inference.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/inference.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from hymm_sp.constants import *
3 | import re
4 | import collections.abc
5 |
6 | def as_tuple(x):
7 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
8 | return tuple(x)
9 | if x is None or isinstance(x, (int, float, str)):
10 | return (x,)
11 | else:
12 | raise ValueError(f"Unknown type {type(x)}")
13 |
14 | def parse_args(namespace=None):
15 | parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
16 | parser = add_extra_args(parser)
17 | args = parser.parse_args(namespace=namespace)
18 | args = sanity_check_args(args)
19 | return args
20 |
21 | def add_extra_args(parser: argparse.ArgumentParser):
22 | parser = add_network_args(parser)
23 | parser = add_extra_models_args(parser)
24 | parser = add_denoise_schedule_args(parser)
25 | parser = add_evaluation_args(parser)
26 | return parser
27 |
28 | def add_network_args(parser: argparse.ArgumentParser):
29 | group = parser.add_argument_group(title="Network")
30 | group.add_argument("--model", type=str, default="HYVideo-T/2",
31 | help="Model architecture to use. It it also used to determine the experiment directory.")
32 | group.add_argument("--latent-channels", type=str, default=None,
33 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
34 | "it still needs to match the latent channels of the VAE model.")
35 | group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
36 | return parser
37 |
38 | def add_extra_models_args(parser: argparse.ArgumentParser):
39 | group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
40 |
41 | # VAE
42 | group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
43 | group.add_argument("--vae-precision", type=str, default="fp16",
44 | help="Precision mode for the VAE model.")
45 | group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
46 | group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
47 | help="Name of the text encoder model.")
48 | group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
49 | help="Precision mode for the text encoder model.")
50 | group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
51 | group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
52 | group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
53 | help="Name of the tokenizer model.")
54 | group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
55 | help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
56 | "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
57 | group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
58 | help="Video prompt template for the decoder-only text encoder model.")
59 | group.add_argument("--hidden-state-skip-layer", type=int, default=2,
60 | help="Skip layer for hidden states.")
61 | group.add_argument("--apply-final-norm", action="store_true",
62 | help="Apply final normalization to the used text encoder hidden states.")
63 |
64 | # - CLIP
65 | group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
66 | help="Name of the second text encoder model.")
67 | group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
68 | help="Precision mode for the second text encoder model.")
69 | group.add_argument("--text-states-dim-2", type=int, default=768,
70 | help="Dimension of the second text encoder hidden states.")
71 | group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
72 | help="Name of the second tokenizer model.")
73 | group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
74 | group.set_defaults(use_attention_mask=True)
75 | group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
76 | help="A projection layer for bridging the text encoder hidden states and the diffusion model "
77 | "conditions.")
78 | return parser
79 |
80 |
81 | def add_denoise_schedule_args(parser: argparse.ArgumentParser):
82 | group = parser.add_argument_group(title="Denoise schedule")
83 | group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
84 | group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
85 | group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
86 | group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
87 | "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
88 | group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
89 | return parser
90 |
91 | def add_evaluation_args(parser: argparse.ArgumentParser):
92 | group = parser.add_argument_group(title="Validation Loss Evaluation")
93 | parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
94 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
95 | parser.add_argument("--reproduce", action="store_true",
96 | help="Enable reproducibility by setting random seeds and deterministic algorithms.")
97 | parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
98 | parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
99 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
100 | parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
101 | parser.add_argument("--infer-min", action="store_true", help="infer 5s.")
102 | group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
103 | group.add_argument("--video-size", type=int, nargs='+', default=512,
104 | help="Video size for training. If a single value is provided, it will be used for both width "
105 | "and height. If two values are provided, they will be used for width and height "
106 | "respectively.")
107 | group.add_argument("--sample-n-frames", type=int, default=1,
108 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
109 | group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
110 | group.add_argument("--val-disable-autocast", action="store_true",
111 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
112 | group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
113 | group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
114 | group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
115 | group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
116 | group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
117 | group.add_argument("--image-size", type=int, default=704)
118 | group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
119 | group.add_argument("--image-path", type=str, default="", help="")
120 | group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
121 | group.add_argument("--input", type=str, default=None, help="test data.")
122 | group.add_argument("--item-name", type=str, default=None, help="")
123 | group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
124 | group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
125 | group.add_argument("--use-deepcache", type=int, default=1)
126 | return parser
127 |
128 | def sanity_check_args(args):
129 | # VAE channels
130 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
131 | if not re.match(vae_pattern, args.vae):
132 | raise ValueError(
133 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
134 | )
135 | vae_channels = int(args.vae.split("-")[1][:-1])
136 | if args.latent_channels is None:
137 | args.latent_channels = vae_channels
138 | if vae_channels != args.latent_channels:
139 | raise ValueError(
140 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
141 | )
142 | return args
143 |
--------------------------------------------------------------------------------
/hymm_sp/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | __all__ = [
5 | "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE",
6 | "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH",
7 | "TEXT_PROJECTION",
8 | ]
9 |
10 | # =================== Constant Values =====================
11 |
12 | PRECISION_TO_TYPE = {
13 | 'fp32': torch.float32,
14 | 'fp16': torch.float16,
15 | 'bf16': torch.bfloat16,
16 | }
17 |
18 | PROMPT_TEMPLATE_ENCODE_VIDEO = (
19 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
20 | "1. The main content and theme of the video."
21 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
22 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
23 | "4. background environment, light, style and atmosphere."
24 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
25 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
26 | )
27 |
28 | PROMPT_TEMPLATE = {
29 | "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95},
30 | }
31 |
32 | # ======================= Model ======================
33 | PRECISIONS = {"fp32", "fp16", "bf16"}
34 |
35 | # =================== Model Path =====================
36 | MODEL_BASE = os.getenv("MODEL_BASE")
37 | MODEL_BASE=f"{MODEL_BASE}/ckpts"
38 |
39 | # 3D VAE
40 | VAE_PATH = {
41 | "884-16c-hy0801": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae",
42 | }
43 |
44 | # Text Encoder
45 | TEXT_ENCODER_PATH = {
46 | "clipL": f"{MODEL_BASE}/text_encoder_2",
47 | "llava-llama-3-8b": f"{MODEL_BASE}/llava_llama_image",
48 | }
49 |
50 | # Tokenizer
51 | TOKENIZER_PATH = {
52 | "clipL": f"{MODEL_BASE}/text_encoder_2",
53 | "llava-llama-3-8b":f"{MODEL_BASE}/llava_llama_image",
54 | }
55 |
56 | TEXT_PROJECTION = {
57 | "linear", # Default, an nn.Linear() layer
58 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT
59 | }
60 |
--------------------------------------------------------------------------------
/hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/audio_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import math
4 | import json
5 | import torch
6 | import random
7 | import librosa
8 | import traceback
9 | import torchvision
10 | import numpy as np
11 | import pandas as pd
12 | from PIL import Image
13 | from einops import rearrange
14 | from torch.utils.data import Dataset
15 | from decord import VideoReader, cpu
16 | from transformers import CLIPImageProcessor
17 | import torchvision.transforms as transforms
18 | from torchvision.transforms import ToPILImage
19 |
20 |
21 |
22 | def get_audio_feature(feature_extractor, audio_path):
23 | audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
24 | assert sampling_rate == 16000
25 |
26 | audio_features = []
27 | window = 750*640
28 | for i in range(0, len(audio_input), window):
29 | audio_feature = feature_extractor(audio_input[i:i+window],
30 | sampling_rate=sampling_rate,
31 | return_tensors="pt",
32 | ).input_features
33 | audio_features.append(audio_feature)
34 |
35 | audio_features = torch.cat(audio_features, dim=-1)
36 | return audio_features, len(audio_input) // 640
37 |
38 |
39 | class VideoAudioTextLoaderVal(Dataset):
40 | def __init__(
41 | self,
42 | image_size: int,
43 | meta_file: str,
44 | **kwargs,
45 | ):
46 | super().__init__()
47 | self.meta_file = meta_file
48 | self.image_size = image_size
49 | self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
50 | self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
51 | self.feature_extractor = kwargs.get("feature_extractor", None)
52 | self.meta_files = []
53 |
54 | csv_data = pd.read_csv(meta_file)
55 | for idx in range(len(csv_data)):
56 | self.meta_files.append(
57 | {
58 | "videoid": str(csv_data["videoid"][idx]),
59 | "image_path": str(csv_data["image"][idx]),
60 | "audio_path": str(csv_data["audio"][idx]),
61 | "prompt": str(csv_data["prompt"][idx]),
62 | "fps": float(csv_data["fps"][idx])
63 | }
64 | )
65 |
66 | self.llava_transform = transforms.Compose(
67 | [
68 | transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
69 | transforms.ToTensor(),
70 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
71 | ]
72 | )
73 | self.clip_image_processor = CLIPImageProcessor()
74 |
75 | self.device = torch.device("cuda")
76 | self.weight_dtype = torch.float16
77 |
78 |
79 | def __len__(self):
80 | return len(self.meta_files)
81 |
82 | @staticmethod
83 | def get_text_tokens(text_encoder, description, dtype_encode="video"):
84 | text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
85 | text_ids = text_inputs["input_ids"].squeeze(0)
86 | text_mask = text_inputs["attention_mask"].squeeze(0)
87 | return text_ids, text_mask
88 |
89 | def get_batch_data(self, idx):
90 | meta_file = self.meta_files[idx]
91 | videoid = meta_file["videoid"]
92 | image_path = meta_file["image_path"]
93 | audio_path = meta_file["audio_path"]
94 | prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
95 | fps = meta_file["fps"]
96 |
97 | img_size = self.image_size
98 | ref_image = Image.open(image_path).convert('RGB')
99 |
100 | # Resize reference image
101 | w, h = ref_image.size
102 | scale = img_size / min(w, h)
103 | new_w = round(w * scale / 64) * 64
104 | new_h = round(h * scale / 64) * 64
105 |
106 | if img_size == 704:
107 | img_size_long = 1216
108 | if new_w * new_h > img_size * img_size_long:
109 | import math
110 | scale = math.sqrt(img_size * img_size_long / w / h)
111 | new_w = round(w * scale / 64) * 64
112 | new_h = round(h * scale / 64) * 64
113 |
114 | ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
115 |
116 | ref_image = np.array(ref_image)
117 | ref_image = torch.from_numpy(ref_image)
118 |
119 | audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
120 | audio_prompts = audio_input[0]
121 |
122 | motion_bucket_id_heads = np.array([25] * 4)
123 | motion_bucket_id_exps = np.array([30] * 4)
124 | motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
125 | motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
126 | fps = torch.from_numpy(np.array(fps))
127 |
128 | to_pil = ToPILImage()
129 | pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
130 |
131 | pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
132 | pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
133 | pixel_value_ref_clip = self.clip_image_processor(
134 | images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
135 | return_tensors="pt"
136 | ).pixel_values[0]
137 | pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
138 |
139 | # Encode text prompts
140 |
141 | text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
142 | text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
143 |
144 | # Output batch
145 | batch = {
146 | "text_prompt": prompt, #
147 | "videoid": videoid,
148 | "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
149 | "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
150 | "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
151 | "audio_prompts": audio_prompts.to(dtype=torch.float16),
152 | "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
153 | "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
154 | "fps": fps.to(dtype=torch.float16),
155 | "text_ids": text_ids.clone(), # 对应llava_text_encoder
156 | "text_mask": text_mask.clone(), # 对应llava_text_encoder
157 | "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
158 | "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
159 | "audio_len": audio_len,
160 | "image_path": image_path,
161 | "audio_path": audio_path,
162 | }
163 | return batch
164 |
165 | def __getitem__(self, idx):
166 | return self.get_batch_data(idx)
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/hymm_sp/data_kits/audio_preprocessor.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import cv2
4 | import json
5 | import time
6 | import decord
7 | import einops
8 | import librosa
9 | import torch
10 | import random
11 | import argparse
12 | import traceback
13 | import numpy as np
14 | from tqdm import tqdm
15 | from PIL import Image
16 | from einops import rearrange
17 |
18 |
19 |
20 | def get_facemask(ref_image, align_instance, area=1.25):
21 | # ref_image: (b f c h w)
22 | bsz, f, c, h, w = ref_image.shape
23 | images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
24 | face_masks = []
25 | for image in images:
26 | image_pil = Image.fromarray(image).convert("RGB")
27 | _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
28 | try:
29 | bboxSrc = bboxes_list[0]
30 | except:
31 | bboxSrc = [0, 0, w, h]
32 | x1, y1, ww, hh = bboxSrc
33 | x2, y2 = x1 + ww, y1 + hh
34 | ww, hh = (x2-x1) * area, (y2-y1) * area
35 | center = [(x2+x1)//2, (y2+y1)//2]
36 | x1 = max(center[0] - ww//2, 0)
37 | y1 = max(center[1] - hh//2, 0)
38 | x2 = min(center[0] + ww//2, w)
39 | y2 = min(center[1] + hh//2, h)
40 |
41 | face_mask = np.zeros_like(np.array(image_pil))
42 | face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
43 | face_masks.append(torch.from_numpy(face_mask[...,:1]))
44 | face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c)
45 | face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
46 | face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
47 | return face_masks
48 |
49 |
50 | def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
51 | if fps == 25:
52 | start_ts = [0]
53 | step_ts = [1]
54 | elif fps == 12.5:
55 | start_ts = [0]
56 | step_ts = [2]
57 | num_frames = min(num_frames, 400)
58 | audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
59 | audio_feats = torch.stack(audio_feats, dim=2)
60 | audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
61 |
62 | audio_prompts = []
63 | for bb in range(1):
64 | audio_feats_list = []
65 | for f in range(num_frames):
66 | cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
67 | audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
68 | audio_feats_list.append(audio_clip)
69 | audio_feats_list = torch.stack(audio_feats_list, 1)
70 | audio_prompts.append(audio_feats_list)
71 | audio_prompts = torch.cat(audio_prompts)
72 | return audio_prompts
--------------------------------------------------------------------------------
/hymm_sp/data_kits/data_tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import imageio
6 | import torchvision
7 | from einops import rearrange
8 |
9 |
10 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
11 | videos = rearrange(videos, "b c t h w -> t b c h w")
12 | outputs = []
13 | for x in videos:
14 | x = torchvision.utils.make_grid(x, nrow=n_rows)
15 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
16 | if rescale:
17 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
18 | x = torch.clamp(x,0,1)
19 | x = (x * 255).numpy().astype(np.uint8)
20 | outputs.append(x)
21 |
22 | os.makedirs(os.path.dirname(path), exist_ok=True)
23 | imageio.mimsave(path, outputs, fps=fps, quality=quality)
24 |
25 | def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
26 | crop_h, crop_w = crop_img.shape[:2]
27 | target_w, target_h = size
28 | scale_h, scale_w = target_h / crop_h, target_w / crop_w
29 | if scale_w > scale_h:
30 | resize_h = int(target_h*resize_ratio)
31 | resize_w = int(crop_w / crop_h * resize_h)
32 | else:
33 | resize_w = int(target_w*resize_ratio)
34 | resize_h = int(crop_h / crop_w * resize_w)
35 | crop_img = cv2.resize(crop_img, (resize_w, resize_h))
36 | pad_left = (target_w - resize_w) // 2
37 | pad_top = (target_h - resize_h) // 2
38 | pad_right = target_w - resize_w - pad_left
39 | pad_bottom = target_h - resize_h - pad_top
40 | crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
41 | return crop_img
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/__init__.py:
--------------------------------------------------------------------------------
1 | from .align import AlignImage
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/align.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from .detface import DetFace
5 |
6 | class AlignImage(object):
7 | def __init__(self, device='cuda', det_path=''):
8 | self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device)
9 |
10 | @torch.no_grad()
11 | def __call__(self, im, maxface=False):
12 | bboxes, kpss, scores = self.facedet.detect(im)
13 | face_num = bboxes.shape[0]
14 |
15 | five_pts_list = []
16 | scores_list = []
17 | bboxes_list = []
18 | for i in range(face_num):
19 | five_pts_list.append(kpss[i].reshape(5,2))
20 | scores_list.append(scores[i])
21 | bboxes_list.append(bboxes[i])
22 |
23 | if maxface and face_num>1:
24 | max_idx = 0
25 | max_area = (bboxes[0, 2])*(bboxes[0, 3])
26 | for i in range(1, face_num):
27 | area = (bboxes[i,2])*(bboxes[i,3])
28 | if area>max_area:
29 | max_idx = i
30 | five_pts_list = [five_pts_list[max_idx]]
31 | scores_list = [scores_list[max_idx]]
32 | bboxes_list = [bboxes_list[max_idx]]
33 |
34 | return five_pts_list, scores_list, bboxes_list
--------------------------------------------------------------------------------
/hymm_sp/data_kits/face_align/detface.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torchvision
7 |
8 |
9 | def xyxy2xywh(x):
10 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
11 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
12 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
13 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
14 | y[:, 2] = x[:, 2] - x[:, 0] # width
15 | y[:, 3] = x[:, 3] - x[:, 1] # height
16 | return y
17 |
18 |
19 | def xywh2xyxy(x):
20 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
21 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
22 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
23 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
24 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
25 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
26 | return y
27 |
28 |
29 | def box_iou(box1, box2):
30 | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
31 | """
32 | Return intersection-over-union (Jaccard index) of boxes.
33 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
34 | Arguments:
35 | box1 (Tensor[N, 4])
36 | box2 (Tensor[M, 4])
37 | Returns:
38 | iou (Tensor[N, M]): the NxM matrix containing the pairwise
39 | IoU values for every element in boxes1 and boxes2
40 | """
41 |
42 | def box_area(box):
43 | # box = 4xn
44 | return (box[2] - box[0]) * (box[3] - box[1])
45 |
46 | area1 = box_area(box1.T)
47 | area2 = box_area(box2.T)
48 |
49 | # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
50 | inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
51 | torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
52 | # iou = inter / (area1 + area2 - inter)
53 | return inter / (area1[:, None] + area2 - inter)
54 |
55 |
56 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
57 | # Rescale coords (xyxy) from img1_shape to img0_shape
58 | if ratio_pad is None: # calculate from img0_shape
59 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
60 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
61 | else:
62 | gain = ratio_pad[0][0]
63 | pad = ratio_pad[1]
64 |
65 | coords[:, [0, 2]] -= pad[0] # x padding
66 | coords[:, [1, 3]] -= pad[1] # y padding
67 | coords[:, :4] /= gain
68 | clip_coords(coords, img0_shape)
69 | return coords
70 |
71 |
72 | def clip_coords(boxes, img_shape):
73 | # Clip bounding xyxy bounding boxes to image shape (height, width)
74 | boxes[:, 0].clamp_(0, img_shape[1]) # x1
75 | boxes[:, 1].clamp_(0, img_shape[0]) # y1
76 | boxes[:, 2].clamp_(0, img_shape[1]) # x2
77 | boxes[:, 3].clamp_(0, img_shape[0]) # y2
78 |
79 |
80 | def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
81 | # Rescale coords (xyxy) from img1_shape to img0_shape
82 | if ratio_pad is None: # calculate from img0_shape
83 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
84 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
85 | else:
86 | gain = ratio_pad[0][0]
87 | pad = ratio_pad[1]
88 |
89 | coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
90 | coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
91 | coords[:, :10] /= gain
92 | #clip_coords(coords, img0_shape)
93 | coords[:, 0].clamp_(0, img0_shape[1]) # x1
94 | coords[:, 1].clamp_(0, img0_shape[0]) # y1
95 | coords[:, 2].clamp_(0, img0_shape[1]) # x2
96 | coords[:, 3].clamp_(0, img0_shape[0]) # y2
97 | coords[:, 4].clamp_(0, img0_shape[1]) # x3
98 | coords[:, 5].clamp_(0, img0_shape[0]) # y3
99 | coords[:, 6].clamp_(0, img0_shape[1]) # x4
100 | coords[:, 7].clamp_(0, img0_shape[0]) # y4
101 | coords[:, 8].clamp_(0, img0_shape[1]) # x5
102 | coords[:, 9].clamp_(0, img0_shape[0]) # y5
103 | return coords
104 |
105 |
106 | def show_results(img, xywh, conf, landmarks, class_num):
107 | h,w,c = img.shape
108 | tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
109 | x1 = int(xywh[0] * w - 0.5 * xywh[2] * w)
110 | y1 = int(xywh[1] * h - 0.5 * xywh[3] * h)
111 | x2 = int(xywh[0] * w + 0.5 * xywh[2] * w)
112 | y2 = int(xywh[1] * h + 0.5 * xywh[3] * h)
113 | cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
114 |
115 | clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
116 |
117 | for i in range(5):
118 | point_x = int(landmarks[2 * i] * w)
119 | point_y = int(landmarks[2 * i + 1] * h)
120 | cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
121 |
122 | tf = max(tl - 1, 1) # font thickness
123 | label = str(conf)[:5]
124 | cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
125 | return img
126 |
127 |
128 | def make_divisible(x, divisor):
129 | # Returns x evenly divisible by divisor
130 | return (x // divisor) * divisor
131 |
132 |
133 | def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()):
134 | """Performs Non-Maximum Suppression (NMS) on inference results
135 | Returns:
136 | detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
137 | """
138 |
139 | nc = prediction.shape[2] - 15 # number of classes
140 | xc = prediction[..., 4] > conf_thres # candidates
141 |
142 | # Settings
143 | min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
144 | # time_limit = 10.0 # seconds to quit after
145 | redundant = True # require redundant detections
146 | multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
147 | merge = False # use merge-NMS
148 |
149 | # t = time.time()
150 | output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
151 | for xi, x in enumerate(prediction): # image index, image inference
152 | # Apply constraints
153 | # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
154 | x = x[xc[xi]] # confidence
155 |
156 | # Cat apriori labels if autolabelling
157 | if labels and len(labels[xi]):
158 | l = labels[xi]
159 | v = torch.zeros((len(l), nc + 15), device=x.device)
160 | v[:, :4] = l[:, 1:5] # box
161 | v[:, 4] = 1.0 # conf
162 | v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
163 | x = torch.cat((x, v), 0)
164 |
165 | # If none remain process next image
166 | if not x.shape[0]:
167 | continue
168 |
169 | # Compute conf
170 | x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
171 |
172 | # Box (center x, center y, width, height) to (x1, y1, x2, y2)
173 | box = xywh2xyxy(x[:, :4])
174 |
175 | # Detections matrix nx6 (xyxy, conf, landmarks, cls)
176 | if multi_label:
177 | i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
178 | x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
179 | else: # best class only
180 | conf, j = x[:, 15:].max(1, keepdim=True)
181 | x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
182 |
183 | # Filter by class
184 | if classes is not None:
185 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
186 |
187 | # If none remain process next image
188 | n = x.shape[0] # number of boxes
189 | if not n:
190 | continue
191 |
192 | # Batched NMS
193 | c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
194 | boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
195 | i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
196 | #if i.shape[0] > max_det: # limit detections
197 | # i = i[:max_det]
198 | if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
199 | # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
200 | iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
201 | weights = iou * scores[None] # box weights
202 | x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
203 | if redundant:
204 | i = i[iou.sum(1) > 1] # require redundancy
205 |
206 | output[xi] = x[i]
207 | # if (time.time() - t) > time_limit:
208 | # break # time limit exceeded
209 |
210 | return output
211 |
212 |
213 | class DetFace():
214 | def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'):
215 | assert os.path.exists(pt_path)
216 |
217 | self.inpSize = 416
218 | self.conf_thres = confThreshold
219 | self.iou_thres = nmsThreshold
220 | self.test_device = torch.device(device if torch.cuda.is_available() else "cpu")
221 | self.model = torch.jit.load(pt_path).to(self.test_device)
222 | self.last_w = 416
223 | self.last_h = 416
224 | self.grids = None
225 |
226 | @torch.no_grad()
227 | def detect(self, srcimg):
228 | # t0=time.time()
229 |
230 | h0, w0 = srcimg.shape[:2] # orig hw
231 | r = self.inpSize / min(h0, w0) # resize image to img_size
232 | h1 = int(h0*r+31)//32*32
233 | w1 = int(w0*r+31)//32*32
234 |
235 | img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR)
236 |
237 | # Convert
238 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB
239 |
240 | # Run inference
241 | img = torch.from_numpy(img).to(self.test_device).permute(2,0,1)
242 | img = img.float()/255 # uint8 to fp16/32 0-1
243 | if img.ndimension() == 3:
244 | img = img.unsqueeze(0)
245 |
246 | # Inference
247 | if h1 != self.last_h or w1 != self.last_w or self.grids is None:
248 | grids = []
249 | for scale in [8,16,32]:
250 | ny = h1//scale
251 | nx = w1//scale
252 | yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
253 | grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
254 | grids.append(grid.to(self.test_device))
255 | self.grids = grids
256 | self.last_w = w1
257 | self.last_h = h1
258 |
259 | pred = self.model(img, self.grids).cpu()
260 |
261 | # Apply NMS
262 | det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0]
263 | # Process detections
264 | # det = pred[0]
265 | bboxes = np.zeros((det.shape[0], 4))
266 | kpss = np.zeros((det.shape[0], 5, 2))
267 | scores = np.zeros((det.shape[0]))
268 | # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh
269 | # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks
270 | det = det.cpu().numpy()
271 |
272 | for j in range(det.shape[0]):
273 | # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy()
274 | bboxes[j, 0] = det[j, 0] * w0/w1
275 | bboxes[j, 1] = det[j, 1] * h0/h1
276 | bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0]
277 | bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1]
278 | scores[j] = det[j, 4]
279 | # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy()
280 | kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]])
281 | # class_num = det[j, 15].cpu().numpy()
282 | # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num)
283 | return bboxes, kpss, scores
284 |
--------------------------------------------------------------------------------
/hymm_sp/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import HunyuanVideoAudioPipeline
2 | from .schedulers import FlowMatchDiscreteScheduler
3 |
4 |
5 | def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None,
6 | device=None, progress_bar_config=None):
7 | """ Load the denoising scheduler for inference. """
8 | if scheduler is None:
9 | scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift_eval_video, reverse=args.flow_reverse, solver=args.flow_solver, )
10 |
11 | # Only enable progress bar for rank 0
12 | progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0}
13 |
14 | pipeline = HunyuanVideoAudioPipeline(vae=vae,
15 | text_encoder=text_encoder,
16 | text_encoder_2=text_encoder_2,
17 | transformer=model,
18 | scheduler=scheduler,
19 | # safety_checker=None,
20 | # feature_extractor=None,
21 | # requires_safety_checker=False,
22 | progress_bar_config=progress_bar_config,
23 | args=args,
24 | )
25 | if args.cpu_offload: # avoid oom
26 | pass
27 | else:
28 | pipeline = pipeline.to(device)
29 |
30 | return pipeline
--------------------------------------------------------------------------------
/hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/diffusion/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline
2 |
--------------------------------------------------------------------------------
/hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/diffusion/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
--------------------------------------------------------------------------------
/hymm_sp/diffusion/schedulers/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/schedulers/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | #
16 | # Modified from diffusers==0.29.2
17 | #
18 | # ==============================================================================
19 |
20 | from dataclasses import dataclass
21 | from typing import Optional, Tuple, Union
22 |
23 | import torch
24 |
25 | from diffusers.configuration_utils import ConfigMixin, register_to_config
26 | from diffusers.utils import BaseOutput, logging
27 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
28 |
29 |
30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31 |
32 |
33 | @dataclass
34 | class FlowMatchDiscreteSchedulerOutput(BaseOutput):
35 | """
36 | Output class for the scheduler's `step` function output.
37 |
38 | Args:
39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41 | denoising loop.
42 | """
43 |
44 | prev_sample: torch.FloatTensor
45 |
46 |
47 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
48 | """
49 | Euler scheduler.
50 |
51 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
52 | methods the library implements for all schedulers such as loading and saving.
53 |
54 | Args:
55 | num_train_timesteps (`int`, defaults to 1000):
56 | The number of diffusion steps to train the model.
57 | timestep_spacing (`str`, defaults to `"linspace"`):
58 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
59 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
60 | shift (`float`, defaults to 1.0):
61 | The shift value for the timestep schedule.
62 | reverse (`bool`, defaults to `True`):
63 | Whether to reverse the timestep schedule.
64 | """
65 |
66 | _compatibles = []
67 | order = 1
68 |
69 | @register_to_config
70 | def __init__(
71 | self,
72 | num_train_timesteps: int = 1000,
73 | shift: float = 1.0,
74 | reverse: bool = True,
75 | solver: str = "euler",
76 | n_tokens: Optional[int] = None,
77 | ):
78 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
79 |
80 | if not reverse:
81 | sigmas = sigmas.flip(0)
82 |
83 | self.sigmas = sigmas
84 | # the value fed to model
85 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
86 |
87 | self._step_index = None
88 | self._begin_index = None
89 |
90 | self.supported_solver = ["euler"]
91 | if solver not in self.supported_solver:
92 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
93 |
94 | @property
95 | def step_index(self):
96 | """
97 | The index counter for current timestep. It will increase 1 after each scheduler step.
98 | """
99 | return self._step_index
100 |
101 | @property
102 | def begin_index(self):
103 | """
104 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
105 | """
106 | return self._begin_index
107 |
108 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
109 | def set_begin_index(self, begin_index: int = 0):
110 | """
111 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
112 |
113 | Args:
114 | begin_index (`int`):
115 | The begin index for the scheduler.
116 | """
117 | self._begin_index = begin_index
118 |
119 | def _sigma_to_t(self, sigma):
120 | return sigma * self.config.num_train_timesteps
121 |
122 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
123 | n_tokens: int = None):
124 | """
125 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
126 |
127 | Args:
128 | num_inference_steps (`int`):
129 | The number of diffusion steps used when generating samples with a pre-trained model.
130 | device (`str` or `torch.device`, *optional*):
131 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132 | n_tokens (`int`, *optional*):
133 | Number of tokens in the input sequence.
134 | """
135 | self.num_inference_steps = num_inference_steps
136 |
137 | sigmas = torch.linspace(1, 0, num_inference_steps + 1)
138 | sigmas = self.sd3_time_shift(sigmas)
139 |
140 | if not self.config.reverse:
141 | sigmas = 1 - sigmas
142 |
143 | self.sigmas = sigmas
144 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
145 |
146 | # Reset step index
147 | self._step_index = None
148 |
149 | def index_for_timestep(self, timestep, schedule_timesteps=None):
150 | if schedule_timesteps is None:
151 | schedule_timesteps = self.timesteps
152 |
153 | indices = (schedule_timesteps == timestep).nonzero()
154 |
155 | # The sigma index that is taken for the **very** first `step`
156 | # is always the second index (or the last index if there is only 1)
157 | # This way we can ensure we don't accidentally skip a sigma in
158 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
159 | pos = 1 if len(indices) > 1 else 0
160 |
161 | return indices[pos].item()
162 |
163 | def _init_step_index(self, timestep):
164 | if self.begin_index is None:
165 | if isinstance(timestep, torch.Tensor):
166 | timestep = timestep.to(self.timesteps.device)
167 | self._step_index = self.index_for_timestep(timestep)
168 | else:
169 | self._step_index = self._begin_index
170 |
171 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
172 | return sample
173 |
174 | def sd3_time_shift(self, t: torch.Tensor):
175 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
176 |
177 | def step(
178 | self,
179 | model_output: torch.FloatTensor,
180 | timestep: Union[float, torch.FloatTensor],
181 | sample: torch.FloatTensor,
182 | return_dict: bool = True,
183 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
184 | """
185 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
186 | process from the learned model outputs (most often the predicted noise).
187 |
188 | Args:
189 | model_output (`torch.FloatTensor`):
190 | The direct output from learned diffusion model.
191 | timestep (`float`):
192 | The current discrete timestep in the diffusion chain.
193 | sample (`torch.FloatTensor`):
194 | A current instance of a sample created by the diffusion process.
195 | return_dict (`bool`):
196 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
197 | tuple.
198 |
199 | Returns:
200 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
201 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
202 | returned, otherwise a tuple is returned where the first element is the sample tensor.
203 | """
204 |
205 | if (
206 | isinstance(timestep, int)
207 | or isinstance(timestep, torch.IntTensor)
208 | or isinstance(timestep, torch.LongTensor)
209 | ):
210 | raise ValueError(
211 | (
212 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
213 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
214 | " one of the `scheduler.timesteps` as a timestep."
215 | ),
216 | )
217 |
218 | if self.step_index is None:
219 | self._init_step_index(timestep)
220 |
221 | # Upcast to avoid precision issues when computing prev_sample
222 | sample = sample.to(torch.float32)
223 |
224 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
225 |
226 | if self.config.solver == "euler":
227 | prev_sample = sample + model_output.float() * dt
228 | else:
229 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
230 |
231 | # upon completion increase step index by one
232 | self._step_index += 1
233 |
234 | if not return_dict:
235 | return (prev_sample,)
236 |
237 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
238 |
239 | def __len__(self):
240 | return self.config.num_train_timesteps
241 |
--------------------------------------------------------------------------------
/hymm_sp/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, List
3 | from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd
4 |
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | def _ntuple(n):
10 | def parse(x):
11 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
12 | x = tuple(x)
13 | if len(x) == 1:
14 | x = tuple(repeat(x[0], n))
15 | return x
16 | return tuple(repeat(x, n))
17 | return parse
18 |
19 | to_1tuple = _ntuple(1)
20 | to_2tuple = _ntuple(2)
21 | to_3tuple = _ntuple(3)
22 | to_4tuple = _ntuple(4)
23 |
24 | def get_rope_freq_from_size(latents_size, ndim, target_ndim, args,
25 | rope_theta_rescale_factor: Union[float, List[float]]=1.0,
26 | rope_interpolation_factor: Union[float, List[float]]=1.0,
27 | concat_dict={}):
28 |
29 | if isinstance(args.patch_size, int):
30 | assert all(s % args.patch_size == 0 for s in latents_size), \
31 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
32 | f"but got {latents_size}."
33 | rope_sizes = [s // args.patch_size for s in latents_size]
34 | elif isinstance(args.patch_size, list):
35 | assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
36 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \
37 | f"but got {latents_size}."
38 | rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]
39 |
40 | if len(rope_sizes) != target_ndim:
41 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
42 | head_dim = args.hidden_size // args.num_heads
43 | rope_dim_list = args.rope_dim_list
44 | if rope_dim_list is None:
45 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
46 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
47 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
48 | rope_sizes,
49 | theta=args.rope_theta,
50 | use_real=True,
51 | theta_rescale_factor=rope_theta_rescale_factor,
52 | interpolation_factor=rope_interpolation_factor,
53 | concat_dict=concat_dict)
54 | return freqs_cos, freqs_sin
55 |
56 | def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False,
57 | theta_rescale_factor: Union[float, List[float]]=1.0,
58 | interpolation_factor: Union[float, List[float]]=1.0,
59 | concat_dict={}
60 | ):
61 |
62 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
63 | if len(concat_dict)<1:
64 | pass
65 | else:
66 | if concat_dict['mode']=='timecat':
67 | bias = grid[:,:1].clone()
68 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
69 | grid = torch.cat([bias, grid], dim=1)
70 |
71 | elif concat_dict['mode']=='timecat-w':
72 | bias = grid[:,:1].clone()
73 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
74 | bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
75 | grid = torch.cat([bias, grid], dim=1)
76 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
77 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
78 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
79 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
80 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
81 |
82 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
83 | interpolation_factor = [interpolation_factor] * len(rope_dim_list)
84 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
85 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
86 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
87 |
88 | # use 1/ndim of dimensions to encode grid_axis
89 | embs = []
90 | for i in range(len(rope_dim_list)):
91 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
92 | theta_rescale_factor=theta_rescale_factor[i],
93 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
94 |
95 | embs.append(emb)
96 |
97 | if use_real:
98 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
99 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
100 | return cos, sin
101 | else:
102 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
103 | return emb
--------------------------------------------------------------------------------
/hymm_sp/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from loguru import logger
4 | from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
5 | from hymm_sp.vae import load_vae
6 | from hymm_sp.modules import load_model
7 | from hymm_sp.text_encoder import TextEncoder
8 | import torch.distributed
9 | from hymm_sp.modules.parallel_states import (
10 | nccl_info,
11 | )
12 | from hymm_sp.modules.fp8_optimization import convert_fp8_linear
13 |
14 |
15 | class Inference(object):
16 | def __init__(self,
17 | args,
18 | vae,
19 | vae_kwargs,
20 | text_encoder,
21 | model,
22 | text_encoder_2=None,
23 | pipeline=None,
24 | cpu_offload=False,
25 | device=None,
26 | logger=None):
27 | self.vae = vae
28 | self.vae_kwargs = vae_kwargs
29 |
30 | self.text_encoder = text_encoder
31 | self.text_encoder_2 = text_encoder_2
32 |
33 | self.model = model
34 | self.pipeline = pipeline
35 | self.cpu_offload = cpu_offload
36 |
37 | self.args = args
38 | self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
39 | if nccl_info.sp_size > 1:
40 | self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
41 |
42 | self.logger = logger
43 |
44 | @classmethod
45 | def from_pretrained(cls,
46 | pretrained_model_path,
47 | args,
48 | device=None,
49 | **kwargs):
50 | """
51 | Initialize the Inference pipeline.
52 |
53 | Args:
54 | pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
55 | device (int): The device for inference. Default is 0.
56 | logger (logging.Logger): The logger for the inference pipeline. Default is None.
57 | """
58 | # ========================================================================
59 | logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
60 |
61 | # ======================== Get the args path =============================
62 |
63 | # Set device and disable gradient
64 | if device is None:
65 | device = "cuda" if torch.cuda.is_available() else "cpu"
66 | torch.set_grad_enabled(False)
67 | logger.info("Building model...")
68 | factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
69 | in_channels = args.latent_channels
70 | out_channels = args.latent_channels
71 | print("="*25, f"build model", "="*25)
72 | model = load_model(
73 | args,
74 | in_channels=in_channels,
75 | out_channels=out_channels,
76 | factor_kwargs=factor_kwargs
77 | )
78 | if args.use_fp8:
79 | convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision])
80 | if args.cpu_offload:
81 | print(f'='*20, f'load transformer to cpu')
82 | model = model.to('cpu')
83 | torch.cuda.empty_cache()
84 | else:
85 | model = model.to(device)
86 | model = Inference.load_state_dict(args, model, pretrained_model_path)
87 | model.eval()
88 |
89 | # ============================= Build extra models ========================
90 | # VAE
91 | print("="*25, f"load vae", "="*25)
92 | vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device)
93 | vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
94 |
95 | # Text encoder
96 | if args.prompt_template_video is not None:
97 | crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
98 | else:
99 | crop_start = 0
100 | max_length = args.text_len + crop_start
101 |
102 | # prompt_template_video
103 | prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
104 | print("="*25, f"load llava", "="*25)
105 | text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
106 | max_length = max_length,
107 | text_encoder_precision = args.text_encoder_precision,
108 | tokenizer_type = args.tokenizer,
109 | use_attention_mask = args.use_attention_mask,
110 | prompt_template_video = prompt_template_video,
111 | hidden_state_skip_layer = args.hidden_state_skip_layer,
112 | apply_final_norm = args.apply_final_norm,
113 | reproduce = args.reproduce,
114 | logger = logger,
115 | device = 'cpu' if args.cpu_offload else device ,
116 | )
117 | text_encoder_2 = None
118 | if args.text_encoder_2 is not None:
119 | text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
120 | max_length=args.text_len_2,
121 | text_encoder_precision=args.text_encoder_precision_2,
122 | tokenizer_type=args.tokenizer_2,
123 | use_attention_mask=args.use_attention_mask,
124 | reproduce=args.reproduce,
125 | logger=logger,
126 | device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu'
127 | )
128 |
129 | return cls(args=args,
130 | vae=vae,
131 | vae_kwargs=vae_kwargs,
132 | text_encoder=text_encoder,
133 | model=model,
134 | text_encoder_2=text_encoder_2,
135 | device=device,
136 | logger=logger)
137 |
138 | @staticmethod
139 | def load_state_dict(args, model, ckpt_path):
140 | load_key = args.load_key
141 | ckpt_path = Path(ckpt_path)
142 | if ckpt_path.is_dir():
143 | ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
144 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
145 | if load_key in state_dict:
146 | state_dict = state_dict[load_key]
147 | elif load_key == ".":
148 | pass
149 | else:
150 | raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
151 | model.load_state_dict(state_dict, strict=False)
152 | return model
153 |
154 | def get_exp_dir_and_ckpt_id(self):
155 | if self.ckpt is None:
156 | raise ValueError("The checkpoint path is not provided.")
157 |
158 | ckpt = Path(self.ckpt)
159 | if ckpt.parents[1].name == "checkpoints":
160 | # It should be a standard checkpoint path. We use the parent directory as the default save directory.
161 | exp_dir = ckpt.parents[2]
162 | else:
163 | raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
164 | f"It seems that the checkpoint path is not standard. Please explicitly provide the "
165 | f"save path by --save-path.")
166 | return exp_dir, ckpt.parent.name
167 |
168 | @staticmethod
169 | def parse_size(size):
170 | if isinstance(size, int):
171 | size = [size]
172 | if not isinstance(size, (list, tuple)):
173 | raise ValueError(f"Size must be an integer or (height, width), got {size}.")
174 | if len(size) == 1:
175 | size = [size[0], size[0]]
176 | if len(size) != 2:
177 | raise ValueError(f"Size must be an integer or (height, width), got {size}.")
178 | return size
179 |
--------------------------------------------------------------------------------
/hymm_sp/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .models_audio import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2 |
3 | def load_model(args, in_channels, out_channels, factor_kwargs):
4 | model = HYVideoDiffusionTransformer(
5 | args,
6 | in_channels=in_channels,
7 | out_channels=out_channels,
8 | **HUNYUAN_VIDEO_CONFIG[args.model],
9 | **factor_kwargs,
10 | )
11 | return model
12 |
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/activation_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/activation_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/attn_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/attn_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/audio_adapters.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/audio_adapters.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/embed_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/embed_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/fp8_optimization.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/fp8_optimization.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/mlp_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/mlp_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/models_audio.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/models_audio.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/modulate_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/modulate_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/norm_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/norm_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/parallel_states.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/parallel_states.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/posemb_layers.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/posemb_layers.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/__pycache__/token_refiner.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/token_refiner.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/modules/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def get_activation_layer(act_type):
5 | """get activation layer
6 |
7 | Args:
8 | act_type (str): the activation type
9 |
10 | Returns:
11 | torch.nn.functional: the activation layer
12 | """
13 | if act_type == "gelu":
14 | return lambda: nn.GELU()
15 | elif act_type == "gelu_tanh":
16 | # Approximate `tanh` requires torch >= 1.13
17 | return lambda: nn.GELU(approximate="tanh")
18 | elif act_type == "relu":
19 | return nn.ReLU
20 | elif act_type == "silu":
21 | return nn.SiLU
22 | else:
23 | raise ValueError(f"Unknown activation type: {act_type}")
--------------------------------------------------------------------------------
/hymm_sp/modules/audio_adapters.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides the implementation of an Audio Projection Model, which is designed for
3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4 | that can be used for various downstream applications, such as audio analysis or synthesis.
5 |
6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7 | provides a foundation for building custom models. This implementation includes multiple linear
8 | layers with ReLU activation functions and a LayerNorm for normalization.
9 |
10 | Key Features:
11 | - Audio embedding input with flexible sequence length and block structure.
12 | - Multiple linear layers for feature transformation.
13 | - ReLU activation for non-linear transformation.
14 | - LayerNorm for stabilizing and speeding up training.
15 | - Rearrangement of input embeddings to match the model's expected input shape.
16 | - Customizable number of blocks, channels, and context tokens for adaptability.
17 |
18 | The module is structured to be easily integrated into larger systems or used as a standalone
19 | component for audio feature extraction and processing.
20 |
21 | Classes:
22 | - AudioProjModel: A class representing the audio projection model with configurable parameters.
23 |
24 | Functions:
25 | - (none)
26 |
27 | Dependencies:
28 | - torch: For tensor operations and neural network components.
29 | - diffusers: For the ModelMixin base class.
30 | - einops: For tensor rearrangement operations.
31 |
32 | """
33 |
34 | import torch
35 | from diffusers import ModelMixin
36 | from einops import rearrange
37 |
38 | import math
39 | import torch.nn as nn
40 | from .parallel_states import (
41 | initialize_sequence_parallel_state,
42 | nccl_info,
43 | get_sequence_parallel_state,
44 | parallel_attention,
45 | all_gather,
46 | all_to_all_4D,
47 | )
48 |
49 | class AudioProjNet2(ModelMixin):
50 | """Audio Projection Model
51 |
52 | This class defines an audio projection model that takes audio embeddings as input
53 | and produces context tokens as output. The model is based on the ModelMixin class
54 | and consists of multiple linear layers and activation functions. It can be used
55 | for various audio processing tasks.
56 |
57 | Attributes:
58 | seq_len (int): The length of the audio sequence.
59 | blocks (int): The number of blocks in the audio projection model.
60 | channels (int): The number of channels in the audio projection model.
61 | intermediate_dim (int): The intermediate dimension of the model.
62 | context_tokens (int): The number of context tokens in the output.
63 | output_dim (int): The output dimension of the context tokens.
64 |
65 | Methods:
66 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
67 | Initializes the AudioProjModel with the given parameters.
68 | forward(self, audio_embeds):
69 | Defines the forward pass for the AudioProjModel.
70 | Parameters:
71 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
72 | Returns:
73 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
74 |
75 | """
76 |
77 | def __init__(
78 | self,
79 | seq_len=5,
80 | blocks=12, # add a new parameter blocks
81 | channels=768, # add a new parameter channels
82 | intermediate_dim=512,
83 | output_dim=768,
84 | context_tokens=4,
85 | ):
86 | super().__init__()
87 |
88 | self.seq_len = seq_len
89 | self.blocks = blocks
90 | self.channels = channels
91 | self.input_dim = (
92 | seq_len * blocks * channels
93 | )
94 | self.intermediate_dim = intermediate_dim
95 | self.context_tokens = context_tokens
96 | self.output_dim = output_dim
97 |
98 | # define multiple linear layers
99 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
100 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
101 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
102 |
103 | self.norm = nn.LayerNorm(output_dim)
104 |
105 |
106 | def forward(self, audio_embeds):
107 |
108 | video_length = audio_embeds.shape[1]
109 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
110 | batch_size, window_size, blocks, channels = audio_embeds.shape
111 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
112 |
113 | audio_embeds = torch.relu(self.proj1(audio_embeds))
114 | audio_embeds = torch.relu(self.proj2(audio_embeds))
115 |
116 | context_tokens = self.proj3(audio_embeds).reshape(
117 | batch_size, self.context_tokens, self.output_dim
118 | )
119 | context_tokens = self.norm(context_tokens)
120 | out_all = rearrange(
121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length
122 | )
123 |
124 | return out_all
125 |
126 |
127 | def reshape_tensor(x, heads):
128 | bs, length, width = x.shape
129 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
130 | x = x.view(bs, length, heads, -1)
131 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
132 | x = x.transpose(1, 2)
133 | # (bs, n_heads, length, dim_per_head)
134 | x = x.reshape(bs, heads, length, -1)
135 | return x
136 |
137 |
138 | class PerceiverAttentionCA(nn.Module):
139 | def __init__(self, *, dim=3072, dim_head=1024, heads=33):
140 | super().__init__()
141 | self.scale = dim_head ** -0.5
142 | self.dim_head = dim_head
143 | self.heads = heads
144 | inner_dim = dim_head #* heads
145 |
146 | self.norm1 = nn.LayerNorm(dim)
147 | self.norm2 = nn.LayerNorm(dim)
148 |
149 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
150 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
151 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
152 |
153 | import torch.nn.init as init
154 | init.zeros_(self.to_out.weight)
155 | if self.to_out.bias is not None:
156 | init.zeros_(self.to_out.bias)
157 |
158 | def forward(self, x, latents):
159 | """
160 | Args:
161 | x (torch.Tensor): image features
162 | shape (b, t, aa, D)
163 | latent (torch.Tensor): latent features
164 | shape (b, t, hw, D)
165 | """
166 | x = self.norm1(x)
167 | latents = self.norm2(latents)
168 | # print("latents shape: ", latents.shape)
169 | # print("x shape: ", x.shape)
170 | q = self.to_q(latents)
171 | k, v = self.to_kv(x).chunk(2, dim=-1)
172 |
173 |
174 | # attention
175 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
176 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
177 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
178 | out = weight @ v
179 |
180 | # out = out.permute(0, 2, 1, 3)
181 | return self.to_out(out)
182 | #def forward(self, x, latents):
183 | # """
184 | # Args:
185 | # x (torch.Tensor): image features
186 | # shape (b, t, aa, D)
187 | # latent (torch.Tensor): latent features
188 | # shape (b, t, hw, D)
189 | # """
190 | # if get_sequence_parallel_state():
191 | # sp_size = nccl_info.sp_size
192 | # sp_rank = nccl_info.rank_within_group
193 | # print("rank:", latents.shape, sp_size, sp_rank)
194 | # latents = torch.chunk(latents, sp_size, dim=1)[sp_rank]
195 |
196 | # x = self.norm1(x)
197 | # latents = self.norm2(latents)
198 | # # print("latents shape: ", latents.shape)
199 | # # print("x shape: ", x.shape)
200 | # q = self.to_q(latents)
201 | # k, v = self.to_kv(x).chunk(2, dim=-1)
202 |
203 | # # print("q, k, v: ", q.shape, k.shape, v.shape)
204 |
205 | # # attention
206 | # #scale = 1 / math.sqrt(math.sqrt(self.dim_head))
207 | # #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
208 | # #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
209 | # #out = weight @ v
210 | # def shrink_head(encoder_state, dim):
211 | # local_heads = encoder_state.shape[dim] // nccl_info.sp_size
212 | # return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
213 |
214 | # if get_sequence_parallel_state():
215 | # # batch_size, seq_len, attn_heads, head_dim
216 | # q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
217 | # k = shrink_head(k ,dim=2)
218 | # v = shrink_head(v ,dim=2)
219 | # qkv = torch.stack([query, key, value], dim=2)
220 | # attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None)
221 | # # out = out.permute(0, 2, 1, 3)
222 | # #b, s, a, d = attn.shape
223 | # #attn = attn.reshape(b, s, -1)
224 | #
225 | # out = self.to_out(attn)
226 | # if get_sequence_parallel_state():
227 | # out = all_gather(out, dim=1)
228 | # return out
229 |
--------------------------------------------------------------------------------
/hymm_sp/modules/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from hymm_sp.helpers import to_2tuple
5 |
6 |
7 | class PatchEmbed(nn.Module):
8 | """ 2D Image to Patch Embedding
9 |
10 | Image to Patch Embedding using Conv2d
11 |
12 | A convolution based approach to patchifying a 2D image w/ embedding projection.
13 |
14 | Based on the impl in https://github.com/google-research/vision_transformer
15 |
16 | Hacked together by / Copyright 2020 Ross Wightman
17 |
18 | Remove the _assert function in forward function to be compatible with multi-resolution images.
19 | """
20 | def __init__(
21 | self,
22 | patch_size=16,
23 | in_chans=3,
24 | embed_dim=768,
25 | norm_layer=None,
26 | flatten=True,
27 | bias=True,
28 | dtype=None,
29 | device=None
30 | ):
31 | factory_kwargs = {'dtype': dtype, 'device': device}
32 | super().__init__()
33 | patch_size = to_2tuple(patch_size)
34 | self.patch_size = patch_size
35 | self.flatten = flatten
36 |
37 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias,
38 | **factory_kwargs)
39 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
40 | if bias:
41 | nn.init.zeros_(self.proj.bias)
42 |
43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
44 |
45 | def forward(self, x):
46 | x = self.proj(x)
47 | shape = x.shape
48 | if self.flatten:
49 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
50 | x = self.norm(x)
51 | return x, shape
52 |
53 |
54 | class TextProjection(nn.Module):
55 | """
56 | Projects text embeddings. Also handles dropout for classifier-free guidance.
57 |
58 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
59 | """
60 |
61 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
62 | factory_kwargs = {'dtype': dtype, 'device': device}
63 | super().__init__()
64 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
65 | self.act_1 = act_layer()
66 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
67 |
68 | def forward(self, caption):
69 | hidden_states = self.linear_1(caption)
70 | hidden_states = self.act_1(hidden_states)
71 | hidden_states = self.linear_2(hidden_states)
72 | return hidden_states
73 |
74 |
75 | def timestep_embedding(t, dim, max_period=10000):
76 | """
77 | Create sinusoidal timestep embeddings.
78 |
79 | Args:
80 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
81 | dim (int): the dimension of the output.
82 | max_period (int): controls the minimum frequency of the embeddings.
83 |
84 | Returns:
85 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
86 |
87 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
88 | """
89 | half = dim // 2
90 | freqs = torch.exp(
91 | -math.log(max_period)
92 | * torch.arange(start=0, end=half, dtype=torch.float32)
93 | / half
94 | ).to(device=t.device)
95 | args = t[:, None].float() * freqs[None]
96 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
97 | if dim % 2:
98 | embedding = torch.cat(
99 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
100 | )
101 | return embedding
102 |
103 |
104 | class TimestepEmbedder(nn.Module):
105 | """
106 | Embeds scalar timesteps into vector representations.
107 | """
108 | def __init__(self,
109 | hidden_size,
110 | act_layer,
111 | frequency_embedding_size=256,
112 | max_period=10000,
113 | out_size=None,
114 | dtype=None,
115 | device=None
116 | ):
117 | factory_kwargs = {'dtype': dtype, 'device': device}
118 | super().__init__()
119 | self.frequency_embedding_size = frequency_embedding_size
120 | self.max_period = max_period
121 | if out_size is None:
122 | out_size = hidden_size
123 |
124 | self.mlp = nn.Sequential(
125 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
126 | act_layer(),
127 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
128 | )
129 | nn.init.normal_(self.mlp[0].weight, std=0.02)
130 | nn.init.normal_(self.mlp[2].weight, std=0.02)
131 |
132 | def forward(self, t):
133 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
134 | t_emb = self.mlp(t_freq)
135 | return t_emb
--------------------------------------------------------------------------------
/hymm_sp/modules/fp8_optimization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8 | _bits = torch.tensor(bits)
9 | _mantissa_bit = torch.tensor(mantissa_bit)
10 | _sign_bits = torch.tensor(sign_bits)
11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12 | E = _bits - _sign_bits - M
13 | bias = 2 ** (E - 1) - 1
14 | mantissa = 1
15 | for i in range(mantissa_bit - 1):
16 | mantissa += 1 / (2 ** (i+1))
17 | maxval = mantissa * 2 ** (2**E - 1 - bias)
18 | return maxval
19 |
20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21 | """
22 | Default is E4M3.
23 | """
24 | bits = torch.tensor(bits)
25 | mantissa_bit = torch.tensor(mantissa_bit)
26 | sign_bits = torch.tensor(sign_bits)
27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28 | E = bits - sign_bits - M
29 | bias = 2 ** (E - 1) - 1
30 | mantissa = 1
31 | for i in range(mantissa_bit - 1):
32 | mantissa += 1 / (2 ** (i+1))
33 | maxval = mantissa * 2 ** (2**E - 1 - bias)
34 | minval = - maxval
35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36 | input_clamp = torch.min(torch.max(x, minval), maxval)
37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39 | # dequant
40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales
41 | return qdq_out, log_scales
42 |
43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44 | for i in range(len(x.shape) - 1):
45 | scale = scale.unsqueeze(-1)
46 | new_x = x / scale
47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48 | return quant_dequant_x, scale, log_scales
49 |
50 | def fp8_activation_dequant(qdq_out, scale, dtype):
51 | qdq_out = qdq_out.type(dtype)
52 | quant_dequant_x = qdq_out * scale.to(dtype)
53 | return quant_dequant_x
54 |
55 | def fp8_linear_forward(cls, original_dtype, input):
56 | weight_dtype = cls.weight.dtype
57 | #####
58 | if cls.weight.dtype != torch.float8_e4m3fn:
59 | maxval = get_fp_maxval()
60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62 | linear_weight = linear_weight.to(torch.float8_e4m3fn)
63 | weight_dtype = linear_weight.dtype
64 | else:
65 | scale = cls.fp8_scale.to(cls.weight.device)
66 | linear_weight = cls.weight
67 | #####
68 |
69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70 | if True or len(input.shape) == 3:
71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72 | if cls.bias != None:
73 | output = F.linear(input, cls_dequant, cls.bias)
74 | else:
75 | output = F.linear(input, cls_dequant)
76 | return output
77 | else:
78 | return cls.original_forward(input.to(original_dtype))
79 | else:
80 | return cls.original_forward(input)
81 |
82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83 | setattr(module, "fp8_matmul_enabled", True)
84 |
85 | # loading fp8 mapping file
86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87 | if os.path.exists(fp8_map_path):
88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)['module']
89 | else:
90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91 |
92 | fp8_layers = []
93 | for key, layer in module.named_modules():
94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95 | fp8_layers.append(key)
96 | original_forward = layer.forward
97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99 | setattr(layer, "original_forward", original_forward)
100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
--------------------------------------------------------------------------------
/hymm_sp/modules/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .modulate_layers import modulate
10 | from hymm_sp.helpers import to_2tuple
11 |
12 |
13 | class MLP(nn.Module):
14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
15 | """
16 | def __init__(self,
17 | in_channels,
18 | hidden_channels=None,
19 | out_features=None,
20 | act_layer=nn.GELU,
21 | norm_layer=None,
22 | bias=True,
23 | drop=0.,
24 | use_conv=False,
25 | device=None,
26 | dtype=None
27 | ):
28 | factory_kwargs = {'device': device, 'dtype': dtype}
29 | super().__init__()
30 | out_features = out_features or in_channels
31 | hidden_channels = hidden_channels or in_channels
32 | bias = to_2tuple(bias)
33 | drop_probs = to_2tuple(drop)
34 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
35 |
36 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
37 | self.act = act_layer()
38 | self.drop1 = nn.Dropout(drop_probs[0])
39 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
40 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
41 | self.drop2 = nn.Dropout(drop_probs[1])
42 |
43 | def forward(self, x):
44 | x = self.fc1(x)
45 | x = self.act(x)
46 | x = self.drop1(x)
47 | x = self.norm(x)
48 | x = self.fc2(x)
49 | x = self.drop2(x)
50 | return x
51 |
52 |
53 | class MLPEmbedder(nn.Module):
54 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
55 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
56 | factory_kwargs = {'device': device, 'dtype': dtype}
57 | super().__init__()
58 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
59 | self.silu = nn.SiLU()
60 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
61 |
62 | def forward(self, x: torch.Tensor) -> torch.Tensor:
63 | return self.out_layer(self.silu(self.in_layer(x)))
64 |
65 |
66 | class FinalLayer(nn.Module):
67 | """The final layer of DiT."""
68 |
69 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
70 | factory_kwargs = {'device': device, 'dtype': dtype}
71 | super().__init__()
72 |
73 | # Just use LayerNorm for the final layer
74 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
75 | if isinstance(patch_size, int):
76 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs)
77 | else:
78 | self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=True)
79 | nn.init.zeros_(self.linear.weight)
80 | nn.init.zeros_(self.linear.bias)
81 |
82 | # Here we don't distinguish between the modulate types. Just use the simple one.
83 | self.adaLN_modulation = nn.Sequential(
84 | act_layer(),
85 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
86 | )
87 | # Zero-initialize the modulation
88 | nn.init.zeros_(self.adaLN_modulation[1].weight)
89 | nn.init.zeros_(self.adaLN_modulation[1].bias)
90 |
91 | def forward(self, x, c):
92 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
93 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
94 | x = self.linear(x)
95 | return x
96 |
--------------------------------------------------------------------------------
/hymm_sp/modules/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ModulateDiT(nn.Module):
8 | """Modulation layer for DiT."""
9 | def __init__(
10 | self,
11 | hidden_size: int,
12 | factor: int,
13 | act_layer: Callable,
14 | dtype=None,
15 | device=None,
16 | ):
17 | factory_kwargs = {"dtype": dtype, "device": device}
18 | super().__init__()
19 | self.act = act_layer()
20 | self.linear = nn.Linear(
21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22 | )
23 | # Zero-initialize the modulation
24 | nn.init.zeros_(self.linear.weight)
25 | nn.init.zeros_(self.linear.bias)
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | return self.linear(self.act(x))
29 |
30 |
31 | def modulate(x, shift=None, scale=None):
32 | """modulate by shift and scale
33 |
34 | Args:
35 | x (torch.Tensor): input tensor.
36 | shift (torch.Tensor, optional): shift tensor. Defaults to None.
37 | scale (torch.Tensor, optional): scale tensor. Defaults to None.
38 |
39 | Returns:
40 | torch.Tensor: the output tensor after modulate.
41 | """
42 | if scale is None and shift is None:
43 | return x
44 | elif shift is None:
45 | return x * (1 + scale.unsqueeze(1))
46 | elif scale is None:
47 | return x + shift.unsqueeze(1)
48 | else:
49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50 |
51 |
52 | def apply_gate(x, gate=None, tanh=False):
53 | """AI is creating summary for apply_gate
54 |
55 | Args:
56 | x (torch.Tensor): input tensor.
57 | gate (torch.Tensor, optional): gate tensor. Defaults to None.
58 | tanh (bool, optional): whether to use tanh function. Defaults to False.
59 |
60 | Returns:
61 | torch.Tensor: the output tensor after apply gate.
62 | """
63 | if gate is None:
64 | return x
65 | if tanh:
66 | return x * gate.unsqueeze(1).tanh()
67 | else:
68 | return x * gate.unsqueeze(1)
69 |
70 |
71 | def ckpt_wrapper(module):
72 | def ckpt_forward(*inputs):
73 | outputs = module(*inputs)
74 | return outputs
75 |
76 | return ckpt_forward
--------------------------------------------------------------------------------
/hymm_sp/modules/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(
7 | self,
8 | dim: int,
9 | elementwise_affine=True,
10 | eps: float = 1e-6,
11 | device=None,
12 | dtype=None,
13 | ):
14 | """
15 | Initialize the RMSNorm normalization layer.
16 |
17 | Args:
18 | dim (int): The dimension of the input tensor.
19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20 |
21 | Attributes:
22 | eps (float): A small value added to the denominator for numerical stability.
23 | weight (nn.Parameter): Learnable scaling parameter.
24 |
25 | """
26 | factory_kwargs = {"device": device, "dtype": dtype}
27 | super().__init__()
28 | self.eps = eps
29 | if elementwise_affine:
30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31 |
32 | def _norm(self, x):
33 | """
34 | Apply the RMSNorm normalization to the input tensor.
35 |
36 | Args:
37 | x (torch.Tensor): The input tensor.
38 |
39 | Returns:
40 | torch.Tensor: The normalized tensor.
41 |
42 | """
43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44 |
45 | def forward(self, x):
46 | """
47 | Forward pass through the RMSNorm layer.
48 |
49 | Args:
50 | x (torch.Tensor): The input tensor.
51 |
52 | Returns:
53 | torch.Tensor: The output tensor after applying RMSNorm.
54 |
55 | """
56 | output = self._norm(x.float()).type_as(x)
57 | if hasattr(self, "weight"):
58 | output = output * self.weight
59 | return output
60 |
61 |
62 | def get_norm_layer(norm_layer):
63 | """
64 | Get the normalization layer.
65 |
66 | Args:
67 | norm_layer (str): The type of normalization layer.
68 |
69 | Returns:
70 | norm_layer (nn.Module): The normalization layer.
71 | """
72 | if norm_layer == "layer":
73 | return nn.LayerNorm
74 | elif norm_layer == "rms":
75 | return RMSNorm
76 | else:
77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
--------------------------------------------------------------------------------
/hymm_sp/modules/parallel_states.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import datetime
4 | import torch.distributed as dist
5 | from typing import Any, Tuple
6 | from torch import Tensor
7 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
8 |
9 |
10 | class COMM_INFO:
11 | def __init__(self):
12 | self.group = None
13 | self.sp_size = 1
14 | self.global_rank = 0
15 | self.rank_within_group = 0
16 | self.group_id = 0
17 |
18 |
19 | nccl_info = COMM_INFO()
20 | _SEQUENCE_PARALLEL_STATE = False
21 |
22 |
23 | def get_cu_seqlens(text_mask, img_len):
24 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
25 |
26 | Args:
27 | text_mask (torch.Tensor): the mask of text
28 | img_len (int): the length of image
29 |
30 | Returns:
31 | torch.Tensor: the calculated cu_seqlens for flash attention
32 | """
33 | batch_size = text_mask.shape[0]
34 | text_len = text_mask.sum(dim=1)
35 | max_len = text_mask.shape[1] + img_len
36 |
37 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
38 |
39 | for i in range(batch_size):
40 | s = text_len[i] + img_len
41 | s1 = i * max_len + s
42 | s2 = (i + 1) * max_len
43 | cu_seqlens[2 * i + 1] = s1
44 | cu_seqlens[2 * i + 2] = s2
45 |
46 | return cu_seqlens
47 |
48 | def initialize_sequence_parallel_state(sequence_parallel_size):
49 | global _SEQUENCE_PARALLEL_STATE
50 | if sequence_parallel_size > 1:
51 | _SEQUENCE_PARALLEL_STATE = True
52 | initialize_sequence_parallel_group(sequence_parallel_size)
53 | else:
54 | nccl_info.sp_size = 1
55 | nccl_info.global_rank = int(os.getenv("RANK", "0"))
56 | nccl_info.rank_within_group = 0
57 | nccl_info.group_id = int(os.getenv("RANK", "0"))
58 |
59 | def get_sequence_parallel_state():
60 | return _SEQUENCE_PARALLEL_STATE
61 |
62 | def initialize_sequence_parallel_group(sequence_parallel_size):
63 | """Initialize the sequence parallel group."""
64 | rank = int(os.getenv("RANK", "0"))
65 | world_size = int(os.getenv("WORLD_SIZE", "1"))
66 | assert (
67 | world_size % sequence_parallel_size == 0
68 | ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
69 | world_size, sequence_parallel_size)
70 | nccl_info.sp_size = sequence_parallel_size
71 | nccl_info.global_rank = rank
72 | num_sequence_parallel_groups: int = world_size // sequence_parallel_size
73 | for i in range(num_sequence_parallel_groups):
74 | ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
75 | group = dist.new_group(ranks)
76 | if rank in ranks:
77 | nccl_info.group = group
78 | nccl_info.rank_within_group = rank - i * sequence_parallel_size
79 | nccl_info.group_id = i
80 |
81 | def initialize_distributed(seed):
82 | local_rank = int(os.getenv("RANK", 0))
83 | world_size = int(os.getenv("WORLD_SIZE", 1))
84 | torch.cuda.set_device(local_rank)
85 | dist.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=2**31-1), world_size=world_size, rank=local_rank)
86 | torch.manual_seed(seed)
87 | torch.cuda.manual_seed_all(seed)
88 | initialize_sequence_parallel_state(world_size)
89 |
90 | def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
91 | """
92 | all-to-all for QKV
93 |
94 | Args:
95 | input (torch.tensor): a tensor sharded along dim scatter dim
96 | scatter_idx (int): default 1
97 | gather_idx (int): default 2
98 | group : torch process group
99 |
100 | Returns:
101 | torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
102 | """
103 | assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
104 |
105 | seq_world_size = dist.get_world_size(group)
106 | if scatter_idx == 2 and gather_idx == 1:
107 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
108 | bs, shard_seqlen, hc, hs = input.shape
109 | seqlen = shard_seqlen * seq_world_size
110 | shard_hc = hc // seq_world_size
111 |
112 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
113 | # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
114 | input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous())
115 |
116 | output = torch.empty_like(input_t)
117 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
118 | # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
119 | if seq_world_size > 1:
120 | dist.all_to_all_single(output, input_t, group=group)
121 | torch.cuda.synchronize()
122 | else:
123 | output = input_t
124 | # if scattering the seq-dim, transpose the heads back to the original dimension
125 | output = output.reshape(seqlen, bs, shard_hc, hs)
126 |
127 | # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
128 | output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
129 |
130 | return output
131 |
132 | elif scatter_idx == 1 and gather_idx == 2:
133 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
134 | bs, seqlen, shard_hc, hs = input.shape
135 | hc = shard_hc * seq_world_size
136 | shard_seqlen = seqlen // seq_world_size
137 | seq_world_size = dist.get_world_size(group)
138 |
139 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
140 | # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
141 | input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc,
142 | hs).transpose(0,
143 | 3).transpose(0,
144 | 1).contiguous().reshape(seq_world_size, shard_hc,
145 | shard_seqlen, bs, hs))
146 |
147 | output = torch.empty_like(input_t)
148 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
149 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
150 | if seq_world_size > 1:
151 | dist.all_to_all_single(output, input_t, group=group)
152 | torch.cuda.synchronize()
153 | else:
154 | output = input_t
155 |
156 | # if scattering the seq-dim, transpose the heads back to the original dimension
157 | output = output.reshape(hc, shard_seqlen, bs, hs)
158 |
159 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
160 | output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
161 |
162 | return output
163 | else:
164 | raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
165 |
166 |
167 | class SeqAllToAll4D(torch.autograd.Function):
168 | @staticmethod
169 | def forward(
170 | ctx: Any,
171 | group: dist.ProcessGroup,
172 | input: Tensor,
173 | scatter_idx: int,
174 | gather_idx: int,
175 | ) -> Tensor:
176 | ctx.group = group
177 | ctx.scatter_idx = scatter_idx
178 | ctx.gather_idx = gather_idx
179 |
180 | return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
181 |
182 | @staticmethod
183 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
184 | return (
185 | None,
186 | SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
187 | None,
188 | None,
189 | )
190 |
191 |
192 | def all_to_all_4D(
193 | input_: torch.Tensor,
194 | scatter_dim: int = 2,
195 | gather_dim: int = 1,
196 | ):
197 | return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
198 |
199 |
200 | def _all_to_all(
201 | input_: torch.Tensor,
202 | world_size: int,
203 | group: dist.ProcessGroup,
204 | scatter_dim: int,
205 | gather_dim: int,
206 | ):
207 | input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
208 | output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
209 | dist.all_to_all(output_list, input_list, group=group)
210 | return torch.cat(output_list, dim=gather_dim).contiguous()
211 |
212 |
213 | class _AllToAll(torch.autograd.Function):
214 | """All-to-all communication.
215 |
216 | Args:
217 | input_: input matrix
218 | process_group: communication group
219 | scatter_dim: scatter dimension
220 | gather_dim: gather dimension
221 | """
222 |
223 | @staticmethod
224 | def forward(ctx, input_, process_group, scatter_dim, gather_dim):
225 | ctx.process_group = process_group
226 | ctx.scatter_dim = scatter_dim
227 | ctx.gather_dim = gather_dim
228 | ctx.world_size = dist.get_world_size(process_group)
229 | output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
230 | return output
231 |
232 | @staticmethod
233 | def backward(ctx, grad_output):
234 | grad_output = _all_to_all(
235 | grad_output,
236 | ctx.world_size,
237 | ctx.process_group,
238 | ctx.gather_dim,
239 | ctx.scatter_dim,
240 | )
241 | return (
242 | grad_output,
243 | None,
244 | None,
245 | None,
246 | )
247 |
248 | def all_to_all(
249 | input_: torch.Tensor,
250 | scatter_dim: int = 2,
251 | gather_dim: int = 1,
252 | ):
253 | return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
254 |
255 |
256 | class _AllGather(torch.autograd.Function):
257 | """All-gather communication with autograd support.
258 |
259 | Args:
260 | input_: input tensor
261 | dim: dimension along which to concatenate
262 | """
263 |
264 | @staticmethod
265 | def forward(ctx, input_, dim):
266 | ctx.dim = dim
267 | world_size = nccl_info.sp_size
268 | group = nccl_info.group
269 | input_size = list(input_.size())
270 |
271 | ctx.input_size = input_size[dim]
272 |
273 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
274 | input_ = input_.contiguous()
275 | dist.all_gather(tensor_list, input_, group=group)
276 |
277 | output = torch.cat(tensor_list, dim=dim)
278 | return output
279 |
280 | @staticmethod
281 | def backward(ctx, grad_output):
282 | world_size = nccl_info.sp_size
283 | rank = nccl_info.rank_within_group
284 | dim = ctx.dim
285 | input_size = ctx.input_size
286 |
287 | sizes = [input_size] * world_size
288 |
289 | grad_input_list = torch.split(grad_output, sizes, dim=dim)
290 | grad_input = grad_input_list[rank]
291 |
292 | return grad_input, None
293 |
294 |
295 | def all_gather(input_: torch.Tensor, dim: int = 1):
296 | """Performs an all-gather operation on the input tensor along the specified dimension.
297 |
298 | Args:
299 | input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
300 | dim (int, optional): Dimension along which to concatenate. Defaults to 1.
301 |
302 | Returns:
303 | torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
304 | """
305 | return _AllGather.apply(input_, dim)
306 |
307 | def parallel_attention(q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,):
308 | """
309 | img_q_len,img_kv_len: 32256
310 | text_mask: 2x256
311 | query: [2, 32256, 24, 128])
312 | encoder_query: [2, 256, 24, 128]
313 | """
314 | query, encoder_query = q
315 | key, encoder_key = k
316 | value, encoder_value = v
317 | rank = torch.distributed.get_rank()
318 | if get_sequence_parallel_state():
319 | query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
320 | key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
321 | value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
322 | def shrink_head(encoder_state, dim):
323 | local_heads = encoder_state.shape[dim] // nccl_info.sp_size
324 | return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
325 | encoder_query = shrink_head(encoder_query, dim=2)
326 | encoder_key = shrink_head(encoder_key, dim=2)
327 | encoder_value = shrink_head(encoder_value, dim=2)
328 |
329 | sequence_length = query.size(1) # 32256
330 | encoder_sequence_length = encoder_query.size(1) # 256
331 |
332 | query = torch.cat([query, encoder_query], dim=1)
333 | key = torch.cat([key, encoder_key], dim=1)
334 | value = torch.cat([value, encoder_value], dim=1)
335 | bsz = query.shape[0]
336 | head = query.shape[-2]
337 | head_dim = query.shape[-1]
338 | query, key, value = [
339 | x.view(x.shape[0] * x.shape[1], *x.shape[2:])
340 | for x in [query, key, value]
341 | ]
342 | hidden_states = flash_attn_varlen_func(
343 | query,
344 | key,
345 | value,
346 | cu_seqlens_q,
347 | cu_seqlens_kv,
348 | max_seqlen_q,
349 | max_seqlen_kv,
350 | )
351 | # B, S, 3, H, D
352 | hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous()
353 |
354 | hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length),
355 | dim=1)
356 | if get_sequence_parallel_state():
357 | hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
358 | encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
359 | hidden_states = hidden_states.to(query.dtype)
360 | encoder_hidden_states = encoder_hidden_states.to(query.dtype)
361 |
362 | attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
363 |
364 | b, s, _, _= attn.shape
365 | attn = attn.reshape(b, s, -1)
366 | return attn, None
--------------------------------------------------------------------------------
/hymm_sp/modules/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple, List
3 |
4 |
5 | def _to_tuple(x, dim=2):
6 | if isinstance(x, int):
7 | return (x,) * dim
8 | elif len(x) == dim:
9 | return x
10 | else:
11 | raise ValueError(f"Expected length {dim} or int, but got {x}")
12 |
13 |
14 | def get_meshgrid_nd(start, *args, dim=2):
15 | """
16 | Get n-D meshgrid with start, stop and num.
17 |
18 | Args:
19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22 | n-tuples.
23 | *args: See above.
24 | dim (int): Dimension of the meshgrid. Defaults to 2.
25 |
26 | Returns:
27 | grid (np.ndarray): [dim, ...]
28 | """
29 | if len(args) == 0:
30 | # start is grid_size
31 | num = _to_tuple(start, dim=dim)
32 | start = (0,) * dim
33 | stop = num
34 | elif len(args) == 1:
35 | # start is start, args[0] is stop, step is 1
36 | start = _to_tuple(start, dim=dim)
37 | stop = _to_tuple(args[0], dim=dim)
38 | num = [stop[i] - start[i] for i in range(dim)]
39 | elif len(args) == 2:
40 | # start is start, args[0] is stop, args[1] is num
41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44 | else:
45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46 |
47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48 | axis_grid = []
49 | for i in range(dim):
50 | a, b, n = start[i], stop[i], num[i]
51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52 | axis_grid.append(g)
53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55 |
56 | return grid
57 |
58 |
59 | #################################################################################
60 | # Rotary Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63 |
64 | def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000., use_real=False,
65 | theta_rescale_factor: Union[float, List[float]]=1.0,
66 | interpolation_factor: Union[float, List[float]]=1.0):
67 | """
68 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
69 |
70 | Args:
71 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
72 | sum(rope_dim_list) should equal to head_dim of attention layer.
73 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
74 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
75 | *args: See above.
76 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
77 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
78 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
79 | part and an imaginary part separately.
80 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
81 |
82 | Returns:
83 | pos_embed (torch.Tensor): [HW, D/2]
84 | """
85 |
86 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
87 |
88 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
89 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
90 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
91 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
92 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
93 |
94 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
95 | interpolation_factor = [interpolation_factor] * len(rope_dim_list)
96 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
97 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
98 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
99 |
100 | # use 1/ndim of dimensions to encode grid_axis
101 | embs = []
102 | for i in range(len(rope_dim_list)):
103 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
104 | theta_rescale_factor=theta_rescale_factor[i],
105 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
106 | embs.append(emb)
107 |
108 | if use_real:
109 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
110 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
111 | return cos, sin
112 | else:
113 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
114 | return emb
115 |
116 |
117 | def get_1d_rotary_pos_embed(dim: int,
118 | pos: Union[torch.FloatTensor, int],
119 | theta: float = 10000.0,
120 | use_real: bool = False,
121 | theta_rescale_factor: float = 1.0,
122 | interpolation_factor: float = 1.0,
123 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
124 | """
125 | Precompute the frequency tensor for complex exponential (cis) with given dimensions.
126 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
127 |
128 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
129 | and the end index 'end'. The 'theta' parameter scales the frequencies.
130 | The returned tensor contains complex values in complex64 data type.
131 |
132 | Args:
133 | dim (int): Dimension of the frequency tensor.
134 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
135 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
136 | use_real (bool, optional): If True, return real part and imaginary part separately.
137 | Otherwise, return complex numbers.
138 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
139 |
140 | Returns:
141 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
142 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
143 | """
144 | if isinstance(pos, int):
145 | pos = torch.arange(pos).float()
146 |
147 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
148 | # has some connection to NTK literature
149 | if theta_rescale_factor != 1.0:
150 | theta *= theta_rescale_factor ** (dim / (dim - 2))
151 |
152 | freqs = 1.0 / (
153 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
154 | ) # [D/2]
155 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
156 | if use_real:
157 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
158 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
159 | return freqs_cos, freqs_sin
160 | else:
161 | freqs_cis = torch.polar(
162 | torch.ones_like(freqs), freqs
163 | ) # complex64 # [S, D/2]
164 | return freqs_cis
165 |
--------------------------------------------------------------------------------
/hymm_sp/modules/token_refiner.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from einops import rearrange
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .activation_layers import get_activation_layer
8 | from .attn_layers import attention
9 | from .norm_layers import get_norm_layer
10 | from .embed_layers import TimestepEmbedder, TextProjection
11 | from .attn_layers import attention
12 | from .mlp_layers import MLP
13 | from .modulate_layers import apply_gate
14 |
15 |
16 | class IndividualTokenRefinerBlock(nn.Module):
17 | def __init__(
18 | self,
19 | hidden_size,
20 | num_heads,
21 | mlp_ratio: str = 4.0,
22 | mlp_drop_rate: float = 0.0,
23 | act_type: str = "silu",
24 | qk_norm: bool = False,
25 | qk_norm_type: str = "layer",
26 | qkv_bias: bool = True,
27 | dtype: Optional[torch.dtype] = None,
28 | device: Optional[torch.device] = None,
29 | ):
30 | factory_kwargs = {'device': device, 'dtype': dtype}
31 | super().__init__()
32 | self.num_heads = num_heads
33 | head_dim = hidden_size // num_heads
34 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
35 |
36 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37 | self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38 | qk_norm_layer = get_norm_layer(qk_norm_type)
39 | self.self_attn_q_norm = (
40 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
41 | if qk_norm
42 | else nn.Identity()
43 | )
44 | self.self_attn_k_norm = (
45 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
46 | if qk_norm
47 | else nn.Identity()
48 | )
49 | self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
50 |
51 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
52 | act_layer = get_activation_layer(act_type)
53 | self.mlp = MLP(
54 | in_channels=hidden_size,
55 | hidden_channels=mlp_hidden_dim,
56 | act_layer=act_layer,
57 | drop=mlp_drop_rate,
58 | **factory_kwargs,
59 | )
60 |
61 | self.adaLN_modulation = nn.Sequential(
62 | act_layer(),
63 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
64 | )
65 | # Zero-initialize the modulation
66 | nn.init.zeros_(self.adaLN_modulation[1].weight)
67 | nn.init.zeros_(self.adaLN_modulation[1].bias)
68 |
69 | def forward(
70 | self,
71 | x: torch.Tensor,
72 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations
73 | attn_mask: torch.Tensor = None,
74 | ):
75 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
76 |
77 | norm_x = self.norm1(x)
78 | qkv = self.self_attn_qkv(norm_x)
79 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
80 | # Apply QK-Norm if needed
81 | q = self.self_attn_q_norm(q).to(v)
82 | k = self.self_attn_k_norm(k).to(v)
83 |
84 | # Self-Attention
85 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
86 |
87 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
88 |
89 | # FFN Layer
90 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
91 |
92 | return x
93 |
94 |
95 | class IndividualTokenRefiner(nn.Module):
96 | def __init__(
97 | self,
98 | hidden_size,
99 | num_heads,
100 | depth,
101 | mlp_ratio: float = 4.0,
102 | mlp_drop_rate: float = 0.0,
103 | act_type: str = "silu",
104 | qk_norm: bool = False,
105 | qk_norm_type: str = "layer",
106 | qkv_bias: bool = True,
107 | dtype: Optional[torch.dtype] = None,
108 | device: Optional[torch.device] = None,
109 | ):
110 | factory_kwargs = {'device': device, 'dtype': dtype}
111 | super().__init__()
112 | self.blocks = nn.ModuleList([
113 | IndividualTokenRefinerBlock(
114 | hidden_size=hidden_size,
115 | num_heads=num_heads,
116 | mlp_ratio=mlp_ratio,
117 | mlp_drop_rate=mlp_drop_rate,
118 | act_type=act_type,
119 | qk_norm=qk_norm,
120 | qk_norm_type=qk_norm_type,
121 | qkv_bias=qkv_bias,
122 | **factory_kwargs,
123 | ) for _ in range(depth)
124 | ])
125 |
126 | def forward(
127 | self,
128 | x: torch.Tensor,
129 | c: torch.LongTensor,
130 | mask: Optional[torch.Tensor] = None,
131 | ):
132 | self_attn_mask = None
133 | if mask is not None:
134 | batch_size = mask.shape[0]
135 | seq_len = mask.shape[1]
136 | mask = mask.to(x.device)
137 | # batch_size x 1 x seq_len x seq_len
138 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
139 | # batch_size x 1 x seq_len x seq_len
140 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
141 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
142 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
143 | # avoids self-attention weight being NaN for padding tokens
144 | self_attn_mask[:, :, :, 0] = True
145 |
146 | for block in self.blocks:
147 | x = block(x, c, self_attn_mask)
148 | return x
149 |
150 |
151 | class SingleTokenRefiner(nn.Module):
152 | def __init__(
153 | self,
154 | in_channels,
155 | hidden_size,
156 | num_heads,
157 | depth,
158 | mlp_ratio: float = 4.0,
159 | mlp_drop_rate: float = 0.0,
160 | act_type: str = "silu",
161 | qk_norm: bool = False,
162 | qk_norm_type: str = "layer",
163 | qkv_bias: bool = True,
164 | dtype: Optional[torch.dtype] = None,
165 | device: Optional[torch.device] = None,
166 | ):
167 | factory_kwargs = {'device': device, 'dtype': dtype}
168 | super().__init__()
169 |
170 | self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
171 |
172 | act_layer = get_activation_layer(act_type)
173 | # Build timestep embedding layer
174 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
175 | # Build context embedding layer
176 | self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
177 |
178 | self.individual_token_refiner = IndividualTokenRefiner(
179 | hidden_size=hidden_size,
180 | num_heads=num_heads,
181 | depth=depth,
182 | mlp_ratio=mlp_ratio,
183 | mlp_drop_rate=mlp_drop_rate,
184 | act_type=act_type,
185 | qk_norm=qk_norm,
186 | qk_norm_type=qk_norm_type,
187 | qkv_bias=qkv_bias,
188 | **factory_kwargs
189 | )
190 |
191 | def forward(
192 | self,
193 | x: torch.Tensor,
194 | t: torch.LongTensor,
195 | mask: Optional[torch.LongTensor] = None,
196 | ):
197 | timestep_aware_representations = self.t_embedder(t)
198 |
199 | if mask is None:
200 | context_aware_representations = x.mean(dim=1)
201 | else:
202 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
203 | context_aware_representations = (
204 | (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
205 | )
206 | context_aware_representations = self.c_embedder(context_aware_representations)
207 | c = timestep_aware_representations + context_aware_representations
208 |
209 | x = self.input_embedder(x)
210 |
211 | x = self.individual_token_refiner(x, c, mask)
212 |
213 | return x
--------------------------------------------------------------------------------
/hymm_sp/sample_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from pathlib import Path
5 | from loguru import logger
6 | from einops import rearrange
7 | import imageio
8 | import torch.distributed
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torch.utils.data import DataLoader
11 | from hymm_sp.config import parse_args
12 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler
13 | from hymm_sp.data_kits.audio_dataset import VideoAudioTextLoaderVal
14 | from hymm_sp.data_kits.data_tools import save_videos_grid
15 | from hymm_sp.data_kits.face_align import AlignImage
16 | from hymm_sp.modules.parallel_states import (
17 | initialize_distributed,
18 | nccl_info,
19 | )
20 |
21 | from transformers import WhisperModel
22 | from transformers import AutoFeatureExtractor
23 |
24 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE')
25 |
26 |
27 | def main():
28 | args = parse_args()
29 | models_root_path = Path(args.ckpt)
30 | print("*"*20)
31 | initialize_distributed(args.seed)
32 | if not models_root_path.exists():
33 | raise ValueError(f"`models_root` not exists: {models_root_path}")
34 | print("+"*20)
35 | # Create save folder to save the samples
36 | save_path = args.save_path
37 | if not os.path.exists(args.save_path):
38 | os.makedirs(save_path, exist_ok=True)
39 |
40 | # Load models
41 | rank = 0
42 | vae_dtype = torch.float16
43 | device = torch.device("cuda")
44 | if nccl_info.sp_size > 1:
45 | device = torch.device(f"cuda:{torch.distributed.get_rank()}")
46 | rank = torch.distributed.get_rank()
47 |
48 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device)
49 | # Get the updated args
50 | args = hunyuan_video_sampler.args
51 |
52 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32)
53 | wav2vec.requires_grad_(False)
54 |
55 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/'
56 | det_path = os.path.join(BASE_DIR, 'detface.pt')
57 | align_instance = AlignImage("cuda", det_path=det_path)
58 |
59 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/")
60 |
61 | kwargs = {
62 | "text_encoder": hunyuan_video_sampler.text_encoder,
63 | "text_encoder_2": hunyuan_video_sampler.text_encoder_2,
64 | "feature_extractor": feature_extractor,
65 | }
66 | video_dataset = VideoAudioTextLoaderVal(
67 | image_size=args.image_size,
68 | meta_file=args.input,
69 | **kwargs,
70 | )
71 |
72 | sampler = DistributedSampler(video_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False)
73 | json_loader = DataLoader(video_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False)
74 |
75 | for batch_index, batch in enumerate(json_loader, start=1):
76 |
77 | fps = batch["fps"]
78 | videoid = batch['videoid'][0]
79 | audio_path = str(batch["audio_path"][0])
80 | save_path = args.save_path
81 | output_path = f"{save_path}/{videoid}.mp4"
82 | output_audio_path = f"{save_path}/{videoid}_audio.mp4"
83 |
84 | samples = hunyuan_video_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance)
85 |
86 | sample = samples['samples'][0].unsqueeze(0) # denoised latent, (bs, 16, t//4, h//8, w//8)
87 | sample = sample[:, :, :batch["audio_len"][0]]
88 |
89 | video = rearrange(sample[0], "c f h w -> f h w c")
90 | video = (video * 255.).data.cpu().numpy().astype(np.uint8) # (f h w c)
91 |
92 | torch.cuda.empty_cache()
93 |
94 | final_frames = []
95 | for frame in video:
96 | final_frames.append(frame)
97 | final_frames = np.stack(final_frames, axis=0)
98 |
99 | if rank == 0:
100 | imageio.mimsave(output_path, final_frames, fps=fps.item())
101 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{output_audio_path}' -y -loglevel quiet; rm '{output_path}'")
102 |
103 |
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
--------------------------------------------------------------------------------
/hymm_sp/sample_gpu_poor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pathlib import Path
4 | from loguru import logger
5 | import imageio
6 | import torch
7 | from einops import rearrange
8 | import torch.distributed
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torch.utils.data import DataLoader
11 | from hymm_sp.config import parse_args
12 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler
13 | from hymm_sp.data_kits.audio_dataset import VideoAudioTextLoaderVal
14 | from hymm_sp.data_kits.face_align import AlignImage
15 |
16 | from transformers import WhisperModel
17 | from transformers import AutoFeatureExtractor
18 |
19 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE')
20 |
21 |
22 | def main():
23 | args = parse_args()
24 | models_root_path = Path(args.ckpt)
25 |
26 | if not models_root_path.exists():
27 | raise ValueError(f"`models_root` not exists: {models_root_path}")
28 |
29 | # Create save folder to save the samples
30 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
31 | if not os.path.exists(args.save_path):
32 | os.makedirs(save_path, exist_ok=True)
33 |
34 | # Load models
35 | rank = 0
36 | vae_dtype = torch.float16
37 | device = torch.device("cuda")
38 |
39 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device)
40 | # Get the updated args
41 | args = hunyuan_video_sampler.args
42 | if args.cpu_offload:
43 | from diffusers.hooks import apply_group_offloading
44 | onload_device = torch.device("cuda")
45 | apply_group_offloading(hunyuan_video_sampler.pipeline.transformer, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=1)
46 |
47 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32)
48 | wav2vec.requires_grad_(False)
49 |
50 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/'
51 | det_path = os.path.join(BASE_DIR, 'detface.pt')
52 | align_instance = AlignImage("cuda", det_path=det_path)
53 |
54 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/")
55 |
56 | kwargs = {
57 | "text_encoder": hunyuan_video_sampler.text_encoder,
58 | "text_encoder_2": hunyuan_video_sampler.text_encoder_2,
59 | "feature_extractor": feature_extractor,
60 | }
61 | video_dataset = VideoAudioTextLoaderVal(
62 | image_size=args.image_size,
63 | meta_file=args.input,
64 | **kwargs,
65 | )
66 |
67 | sampler = DistributedSampler(video_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False)
68 | json_loader = DataLoader(video_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False)
69 |
70 | for batch_index, batch in enumerate(json_loader, start=1):
71 |
72 | fps = batch["fps"]
73 | videoid = batch['videoid'][0]
74 | audio_path = str(batch["audio_path"][0])
75 | save_path = args.save_path
76 | output_path = f"{save_path}/{videoid}.mp4"
77 | output_audio_path = f"{save_path}/{videoid}_audio.mp4"
78 |
79 | if args.infer_min:
80 | batch["audio_len"][0] = 129
81 |
82 | samples = hunyuan_video_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance)
83 |
84 | sample = samples['samples'][0].unsqueeze(0) # denoised latent, (bs, 16, t//4, h//8, w//8)
85 | sample = sample[:, :, :batch["audio_len"][0]]
86 |
87 | video = rearrange(sample[0], "c f h w -> f h w c")
88 | video = (video * 255.).data.cpu().numpy().astype(np.uint8) # (f h w c)
89 |
90 | torch.cuda.empty_cache()
91 |
92 | final_frames = []
93 | for frame in video:
94 | final_frames.append(frame)
95 | final_frames = np.stack(final_frames, axis=0)
96 |
97 | if rank == 0:
98 | imageio.mimsave(output_path, final_frames, fps=fps.item())
99 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{output_audio_path}' -y -loglevel quiet; rm '{output_path}'")
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/hymm_sp/sample_inference_audio.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import torch
4 | import random
5 | from loguru import logger
6 | from einops import rearrange
7 | from hymm_sp.diffusion import load_diffusion_pipeline
8 | from hymm_sp.helpers import get_nd_rotary_pos_embed_new
9 | from hymm_sp.inference import Inference
10 | from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
11 | from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask
12 |
13 | def align_to(value, alignment):
14 | return int(math.ceil(value / alignment) * alignment)
15 |
16 | class HunyuanVideoSampler(Inference):
17 | def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
18 | device=0, logger=None):
19 | super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
20 | pipeline=pipeline, device=device, logger=logger)
21 |
22 | self.args = args
23 | self.pipeline = load_diffusion_pipeline(
24 | args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
25 | device=self.device)
26 | print('load hunyuan model successful... ')
27 |
28 | def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
29 | target_ndim = 3
30 | ndim = 5 - 2
31 | if '884' in self.args.vae:
32 | latents_size = [(video_length-1)//4+1 , height//8, width//8]
33 | else:
34 | latents_size = [video_length , height//8, width//8]
35 |
36 | if isinstance(self.model.patch_size, int):
37 | assert all(s % self.model.patch_size == 0 for s in latents_size), \
38 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
39 | f"but got {latents_size}."
40 | rope_sizes = [s // self.model.patch_size for s in latents_size]
41 | elif isinstance(self.model.patch_size, list):
42 | assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
43 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
44 | f"but got {latents_size}."
45 | rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
46 |
47 | if len(rope_sizes) != target_ndim:
48 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
49 | head_dim = self.model.hidden_size // self.model.num_heads
50 | rope_dim_list = self.model.rope_dim_list
51 | if rope_dim_list is None:
52 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
53 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
54 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
55 | rope_sizes,
56 | theta=self.args.rope_theta,
57 | use_real=True,
58 | theta_rescale_factor=1,
59 | concat_dict=concat_dict)
60 | return freqs_cos, freqs_sin
61 |
62 | @torch.no_grad()
63 | def predict(self,
64 | args, batch, wav2vec, feature_extractor, align_instance,
65 | **kwargs):
66 | """
67 | Predict the image from the given text.
68 |
69 | Args:
70 | prompt (str or List[str]): The input text.
71 | kwargs:
72 | size (int): The (height, width) of the output image/video. Default is (256, 256).
73 | video_length (int): The frame number of the output video. Default is 1.
74 | seed (int or List[str]): The random seed for the generation. Default is a random integer.
75 | negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
76 | infer_steps (int): The number of inference steps. Default is 100.
77 | guidance_scale (float): The guidance scale for the generation. Default is 6.0.
78 | num_videos_per_prompt (int): The number of videos per prompt. Default is 1.
79 | verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
80 | output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
81 | Default is 'pil'.
82 | """
83 |
84 | out_dict = dict()
85 |
86 | prompt = batch['text_prompt'][0]
87 | image_path = str(batch["image_path"][0])
88 | audio_path = str(batch["audio_path"][0])
89 | neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
90 | # videoid = batch['videoid'][0]
91 | fps = batch["fps"].to(self.device)
92 | audio_prompts = batch["audio_prompts"].to(self.device)
93 | weight_dtype = audio_prompts.dtype
94 |
95 | audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts]
96 | audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype)
97 | if audio_prompts.shape[1] <= 129:
98 | audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1)
99 | else:
100 | audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
101 |
102 | wav2vec.to("cpu")
103 | torch.cuda.empty_cache()
104 |
105 | uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
106 | motion_exp = batch["motion_bucket_id_exps"].to(self.device)
107 | motion_pose = batch["motion_bucket_id_heads"].to(self.device)
108 |
109 | pixel_value_ref = batch['pixel_value_ref'].to(self.device) # (b f c h w) 取值范围[0,255]
110 | face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0)
111 |
112 | pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
113 | uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref)
114 | pixel_value_ref = pixel_value_ref / 127.5 - 1.
115 | uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1
116 |
117 | pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
118 | uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w")
119 |
120 | pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device)
121 | pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w")
122 | uncond_pixel_value_llava = pixel_value_llava.clone()
123 |
124 | # ========== Encode reference latents ==========
125 | vae_dtype = self.vae.dtype
126 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
127 |
128 | if args.cpu_offload:
129 | self.vae.to('cuda')
130 |
131 | self.vae.enable_tiling()
132 | ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
133 | uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample()
134 | self.vae.disable_tiling()
135 | if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
136 | ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
137 | uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
138 | else:
139 | ref_latents.mul_(self.vae.config.scaling_factor)
140 | uncond_ref_latents.mul_(self.vae.config.scaling_factor)
141 |
142 | if args.cpu_offload:
143 | self.vae.to('cpu')
144 | torch.cuda.empty_cache()
145 |
146 | face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2),
147 | (ref_latents.shape[-2],
148 | ref_latents.shape[-1]),
149 | mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
150 |
151 |
152 | size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1])
153 | target_length = 129
154 | target_height = align_to(size[0], 16)
155 | target_width = align_to(size[1], 16)
156 | concat_dict = {'mode': 'timecat', 'bias': -1}
157 | # concat_dict = {}
158 | freqs_cos, freqs_sin = self.get_rotary_pos_embed(
159 | target_length,
160 | target_height,
161 | target_width,
162 | concat_dict)
163 | n_tokens = freqs_cos.shape[0]
164 |
165 | generator = torch.Generator(device=self.device).manual_seed(args.seed)
166 |
167 | debug_str = f"""
168 | prompt: {prompt}
169 | image_path: {image_path}
170 | audio_path: {audio_path}
171 | negative_prompt: {neg_prompt}
172 | seed: {args.seed}
173 | fps: {fps.item()}
174 | infer_steps: {args.infer_steps}
175 | target_height: {target_height}
176 | target_width: {target_width}
177 | target_length: {target_length}
178 | guidance_scale: {args.cfg_scale}
179 | """
180 | self.logger.info(debug_str)
181 | pipeline_kwargs = {
182 | "cpu_offload": args.cpu_offload
183 | }
184 | start_time = time.time()
185 | samples = self.pipeline(prompt=prompt,
186 | height=target_height,
187 | width=target_width,
188 | frame=target_length,
189 | num_inference_steps=args.infer_steps,
190 | guidance_scale=args.cfg_scale, # cfg scale
191 |
192 | negative_prompt=neg_prompt,
193 | num_images_per_prompt=args.num_images,
194 | generator=generator,
195 | prompt_embeds=None,
196 |
197 | ref_latents=ref_latents, # [1, 16, 1, h//8, w//8]
198 | uncond_ref_latents=uncond_ref_latents,
199 | pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336]
200 | uncond_pixel_value_llava=uncond_pixel_value_llava,
201 | face_masks=face_masks, # [b f h w]
202 | audio_prompts=audio_prompts,
203 | uncond_audio_prompts=uncond_audio_prompts,
204 | motion_exp=motion_exp,
205 | motion_pose=motion_pose,
206 | fps=fps,
207 |
208 | num_videos_per_prompt=1,
209 | attention_mask=None,
210 | negative_prompt_embeds=None,
211 | negative_attention_mask=None,
212 | output_type="pil",
213 | freqs_cis=(freqs_cos, freqs_sin),
214 | n_tokens=n_tokens,
215 | data_type='video',
216 | is_progress_bar=True,
217 | vae_ver=self.args.vae,
218 | enable_tiling=self.args.vae_tiling,
219 | **pipeline_kwargs
220 | )[0]
221 | if samples is None:
222 | return None
223 | out_dict['samples'] = samples
224 | gen_time = time.time() - start_time
225 | logger.info(f"Success, time: {gen_time}")
226 |
227 | wav2vec.to(self.device)
228 |
229 | return out_dict
230 |
231 |
--------------------------------------------------------------------------------
/hymm_sp/text_encoder/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/text_encoder/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/vae/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pathlib import Path
3 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
4 | from ..constants import VAE_PATH, PRECISION_TO_TYPE
5 |
6 | def load_vae(vae_type,
7 | vae_precision=None,
8 | sample_size=None,
9 | vae_path=None,
10 | logger=None,
11 | device=None
12 | ):
13 | if vae_path is None:
14 | vae_path = VAE_PATH[vae_type]
15 | vae_compress_spec, _, _ = vae_type.split("-")
16 | length = len(vae_compress_spec)
17 | if length == 3:
18 | if logger is not None:
19 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
20 | config = AutoencoderKLCausal3D.load_config(vae_path)
21 | if sample_size:
22 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
23 | else:
24 | vae = AutoencoderKLCausal3D.from_config(config)
25 | ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
26 | if "state_dict" in ckpt:
27 | ckpt = ckpt["state_dict"]
28 | # vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
29 | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items()}
30 | vae.load_state_dict(vae_ckpt)
31 |
32 | spatial_compression_ratio = vae.config.spatial_compression_ratio
33 | time_compression_ratio = vae.config.time_compression_ratio
34 | else:
35 | raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")
36 |
37 | if vae_precision is not None:
38 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
39 |
40 | vae.requires_grad_(False)
41 |
42 | if logger is not None:
43 | logger.info(f"VAE to dtype: {vae.dtype}")
44 |
45 | if device is not None:
46 | vae = vae.to(device)
47 |
48 | # Set vae to eval mode, even though it's dropout rate is 0.
49 | vae.eval()
50 |
51 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio
52 |
--------------------------------------------------------------------------------
/hymm_sp/vae/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/hymm_sp/vae/__pycache__/vae.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/vae.cpython-310.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.9.0.80
2 | diffusers==0.33.0
3 | transformers==4.45.1
4 | accelerate==1.1.1
5 | pandas==2.0.3
6 | numpy==1.24.4
7 | einops==0.7.0
8 | tqdm==4.66.2
9 | loguru==0.7.2
10 | imageio==2.34.0
11 | imageio-ffmpeg==0.5.1
12 | safetensors==0.4.3
13 | gradio==4.42.0
14 | fastapi==0.115.12
15 | uvicorn==0.34.2
16 | decord==0.6.0
17 | librosa==0.11.0
18 | scikit-video==1.1.11
19 | ffmpeg
--------------------------------------------------------------------------------
/scripts/run_gradio.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | JOBS_DIR=$(dirname $(dirname "$0"))
3 | export PYTHONPATH=./
4 |
5 | export MODEL_BASE=./weights
6 |
7 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
8 |
9 |
10 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_gradio/flask_audio.py \
11 | --input 'assets/test.csv' \
12 | --ckpt ${checkpoint_path} \
13 | --sample-n-frames 129 \
14 | --seed 128 \
15 | --image-size 704 \
16 | --cfg-scale 7.5 \
17 | --infer-steps 50 \
18 | --use-deepcache 1 \
19 | --flow-shift-eval-video 5.0 &
20 |
21 |
22 | python3 hymm_gradio/gradio_audio.py
23 |
--------------------------------------------------------------------------------
/scripts/run_sample_batch_sp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | JOBS_DIR=$(dirname $(dirname "$0"))
3 | export PYTHONPATH=./
4 |
5 | export MODEL_BASE=./weights
6 | OUTPUT_BASEPATH=./results
7 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
8 |
9 |
10 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
11 | --input 'assets/test.csv' \
12 | --ckpt ${checkpoint_path} \
13 | --sample-n-frames 129 \
14 | --seed 128 \
15 | --image-size 704 \
16 | --cfg-scale 7.5 \
17 | --infer-steps 50 \
18 | --use-deepcache 1 \
19 | --flow-shift-eval-video 5.0 \
20 | --save-path ${OUTPUT_BASEPATH}
21 |
--------------------------------------------------------------------------------
/scripts/run_single_audio.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | JOBS_DIR=$(dirname $(dirname "$0"))
3 | export PYTHONPATH=./
4 |
5 | export MODEL_BASE=./weights
6 | OUTPUT_BASEPATH=./results-single
7 |
8 | # checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
9 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt
10 |
11 |
12 | export DISABLE_SP=1
13 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \
14 | --input 'assets/test.csv' \
15 | --ckpt ${checkpoint_path} \
16 | --sample-n-frames 129 \
17 | --seed 128 \
18 | --image-size 704 \
19 | --cfg-scale 7.5 \
20 | --infer-steps 50 \
21 | --use-deepcache 1 \
22 | --flow-shift-eval-video 5.0 \
23 | --save-path ${OUTPUT_BASEPATH} \
24 | --use-fp8 \
25 | --infer-min
26 |
--------------------------------------------------------------------------------
/scripts/run_single_poor.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | JOBS_DIR=$(dirname $(dirname "$0"))
3 | export PYTHONPATH=./
4 |
5 | export MODEL_BASE=./weights
6 | OUTPUT_BASEPATH=./results-poor
7 |
8 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt
9 |
10 | export CPU_OFFLOAD=1
11 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \
12 | --input 'assets/test.csv' \
13 | --ckpt ${checkpoint_path} \
14 | --sample-n-frames 129 \
15 | --seed 128 \
16 | --image-size 704 \
17 | --cfg-scale 7.5 \
18 | --infer-steps 50 \
19 | --use-deepcache 1 \
20 | --flow-shift-eval-video 5.0 \
21 | --save-path ${OUTPUT_BASEPATH} \
22 | --use-fp8 \
23 | --cpu-offload \
24 | --infer-min
25 |
--------------------------------------------------------------------------------
/weights/README.md:
--------------------------------------------------------------------------------
1 | # Download Pretrained Models
2 |
3 | All models are stored in `HunyuanVideo-Avatar/weights` by default, and the file structure is as follows
4 | ```shell
5 | HunyuanVideo-Avatar
6 | ├──weights
7 | │ ├──ckpts
8 | │ │ ├──README.md
9 | │ │ ├──hunyuan-video-t2v-720p
10 | │ │ │ ├──transformers
11 | │ │ │ │ ├──mp_rank_00_model_states.pt
12 | │ │ │ │ ├──mp_rank_00_model_states_fp8.pt
13 | │ │ │ │ ├──mp_rank_00_model_states_fp8_map.pt
14 | │ │ │ ├──vae
15 | │ │ │ │ ├──pytorch_model.pt
16 | │ │ │ │ ├──config.json
17 | │ │ ├──llava_llama_image
18 | │ │ │ ├──model-00001-of-00004.safatensors
19 | │ │ │ ├──model-00002-of-00004.safatensors
20 | │ │ │ ├──model-00003-of-00004.safatensors
21 | │ │ │ ├──model-00004-of-00004.safatensors
22 | │ │ │ ├──...
23 | │ │ ├──text_encoder_2
24 | │ │ ├──whisper-tiny
25 | │ │ ├──det_align
26 | │ │ ├──...
27 | ```
28 |
29 | ## Download HunyuanVideo-Avatar model
30 | To download the HunyuanCustom model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
31 |
32 | ```shell
33 | python -m pip install "huggingface_hub[cli]"
34 | ```
35 |
36 | Then download the model using the following commands:
37 |
38 | ```shell
39 | # Switch to the directory named 'HunyuanVideo-Avatar/weights'
40 | cd HunyuanVideo-Avatar/weights
41 | # Use the huggingface-cli tool to download HunyuanVideo-Avatar model in HunyuanVideo-Avatar/weights dir.
42 | # The download time may vary from 10 minutes to 1 hour depending on network conditions.
43 | huggingface-cli download tencent/HunyuanVideo-Avatar --local-dir ./
44 | ```
45 |
--------------------------------------------------------------------------------