├── LICENSE ├── Notice.txt ├── README.md ├── README_zh.md ├── assets ├── audio │ ├── 2.WAV │ ├── 3.WAV │ └── 4.WAV ├── image │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── src1.png │ ├── src2.png │ ├── src3.png │ └── src4.png ├── material │ ├── demo.png │ ├── logo.png │ ├── method.png │ └── teaser.png └── test.csv ├── hymm_gradio ├── flask_audio.py ├── gradio_audio.py └── tool_for_end2end.py ├── hymm_sp ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── config.cpython-310.pyc │ ├── constants.cpython-310.pyc │ ├── helpers.cpython-310.pyc │ ├── inference.cpython-310.pyc │ └── sample_inference_audio.cpython-310.pyc ├── config.py ├── constants.py ├── data_kits │ ├── __pycache__ │ │ ├── audio_dataset.cpython-310.pyc │ │ ├── audio_preprocessor.cpython-310.pyc │ │ ├── data_tools.cpython-310.pyc │ │ └── ffmpeg_utils.cpython-310.pyc │ ├── audio_dataset.py │ ├── audio_preprocessor.py │ ├── data_tools.py │ └── face_align │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── align.cpython-310.pyc │ │ └── detface.cpython-310.pyc │ │ ├── align.py │ │ └── detface.py ├── diffusion │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-310.pyc │ ├── pipelines │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── pipeline_hunyuan_video_audio.cpython-310.pyc │ │ └── pipeline_hunyuan_video_audio.py │ └── schedulers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── scheduling_flow_match_discrete.cpython-310.pyc │ │ └── scheduling_flow_match_discrete.py ├── helpers.py ├── inference.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── activation_layers.cpython-310.pyc │ │ ├── attn_layers.cpython-310.pyc │ │ ├── audio_adapters.cpython-310.pyc │ │ ├── embed_layers.cpython-310.pyc │ │ ├── fp8_optimization.cpython-310.pyc │ │ ├── mlp_layers.cpython-310.pyc │ │ ├── models_audio.cpython-310.pyc │ │ ├── modulate_layers.cpython-310.pyc │ │ ├── norm_layers.cpython-310.pyc │ │ ├── parallel_states.cpython-310.pyc │ │ ├── posemb_layers.cpython-310.pyc │ │ └── token_refiner.cpython-310.pyc │ ├── activation_layers.py │ ├── attn_layers.py │ ├── audio_adapters.py │ ├── embed_layers.py │ ├── fp8_optimization.py │ ├── mlp_layers.py │ ├── models_audio.py │ ├── modulate_layers.py │ ├── norm_layers.py │ ├── parallel_states.py │ ├── posemb_layers.py │ └── token_refiner.py ├── sample_batch.py ├── sample_gpu_poor.py ├── sample_inference_audio.py ├── text_encoder │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-310.pyc └── vae │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── autoencoder_kl_causal_3d.cpython-310.pyc │ ├── unet_causal_3d_blocks.cpython-310.pyc │ └── vae.cpython-310.pyc │ ├── autoencoder_kl_causal_3d.py │ ├── unet_causal_3d_blocks.py │ └── vae.py ├── requirements.txt ├── scripts ├── run_gradio.sh ├── run_sample_batch_sp.sh ├── run_single_audio.sh └── run_single_poor.sh └── weights └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT 2 | TencentHunyuanVideo-Avatar Release Date: May 28, 2025 3 | THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW. 4 | By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. 5 | 1. DEFINITIONS. 6 | a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A. 7 | b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein. 8 | c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent. 9 | d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means. 10 | e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use. 11 | f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement. 12 | g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives. 13 | h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service. 14 | i. “Tencent,” “We” or “Us” shall mean THL A29 Limited. 15 | j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, TencentHunyuanVideo-Avatar released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar]. 16 | k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof. 17 | l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea. 18 | m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You. 19 | n. “including” shall mean including but not limited to. 20 | 2. GRANT OF RIGHTS. 21 | We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy. 22 | 3. DISTRIBUTION. 23 | You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions: 24 | a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement; 25 | b. You must cause any modified files to carry prominent notices stating that You changed the files; 26 | c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and 27 | d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.” 28 | You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You. 29 | 4. ADDITIONAL COMMERCIAL TERMS. 30 | If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights. 31 | 5. RULES OF USE. 32 | a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b). 33 | b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof). 34 | c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement. 35 | 6. INTELLECTUAL PROPERTY. 36 | a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You. 37 | b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent. 38 | c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works. 39 | d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses. 40 | 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY. 41 | a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto. 42 | b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT. 43 | c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 44 | 8. SURVIVAL AND TERMINATION. 45 | a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. 46 | b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement. 47 | 9. GOVERNING LAW AND JURISDICTION. 48 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. 49 | b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute. 50 |   51 | EXHIBIT A 52 | ACCEPTABLE USE POLICY 53 | 54 | Tencent reserves the right to update this Acceptable Use Policy from time to time. 55 | Last modified: November 5, 2024 56 | 57 | Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives: 58 | 1. Outside the Territory; 59 | 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation; 60 | 3. To harm Yourself or others; 61 | 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others; 62 | 5. To override or circumvent the safety guardrails and safeguards We have put in place; 63 | 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; 64 | 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections; 65 | 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement; 66 | 9. To intentionally defame, disparage or otherwise harass others; 67 | 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems; 68 | 11. To generate or disseminate personal identifiable information with the purpose of harming others; 69 | 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated; 70 | 13. To impersonate another individual without consent, authorization, or legal right; 71 | 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance); 72 | 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions; 73 | 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism; 74 | 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics; 75 | 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; 76 | 19. For military purposes; 77 | 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices. 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | 5 |

6 | 7 | # **HunyuanVideo-Avatar** 🌅 8 | 9 |
10 |   11 |   12 | 13 |
14 | 15 |
16 |   17 |
18 | 19 |
20 |   21 |
22 | 23 | ![image](assets/material/teaser.png) 24 | 25 | > [**HunyuanVideo-Avatar: High-Fidelity Audio-Driven Human Animation for Multiple Characters**](https://arxiv.org/pdf/2505.20156) 26 | 27 | ## 🔥🔥🔥 News!! 28 | * Jun 06, 2025: 🔥 HunyuanVideo-Avatar supports **Single GPU** with only **10GB VRAM**, with **TeaCache** included, **HUGE THANKS** to [Wan2GP](https://github.com/deepbeepmeep/Wan2GP) 29 | * May 28, 2025: 🔥 HunyuanVideo-Avatar is available in Cloud-Native-Build (CNB) [HunyuanVideo-Avatar](https://cnb.cool/tencent/hunyuan/HunyuanVideo-Avatar). 30 | * May 28, 2025: 👋 We release the inference code and model weights of HunyuanVideo-Avatar. [Download](weights/README.md). 31 | 32 | 33 | ## 📑 Open-source Plan 34 | 35 | - HunyuanVideo-Avatar 36 | - [x] Inference 37 | - [x] Checkpoints 38 | - [ ] ComfyUI 39 | 40 | ## Contents 41 | - [**HunyuanVideo-Avatar** 🌅](#HunyuanVideo-Avatar-) 42 | - [🔥🔥🔥 News!!](#-news) 43 | - [📑 Open-source Plan](#-open-source-plan) 44 | - [Contents](#contents) 45 | - [**Abstract**](#abstract) 46 | - [**HunyuanVideo-Avatar Overall Architecture**](#HunyuanVideo-Avatar-overall-architecture) 47 | - [🎉 **HunyuanVideo-Avatar Key Features**](#-HunyuanVideo-Avatar-key-features) 48 | - [**Multimodal Video customization**](#multimodal-video-customization) 49 | - [**Various Applications**](#various-applications) 50 | - [📈 Comparisons](#-comparisons) 51 | - [📜 Requirements](#-requirements) 52 | - [🛠️ Dependencies and Installation](#️-dependencies-and-installation) 53 | - [Installation Guide for Linux](#installation-guide-for-linux) 54 | - [🧱 Download Pretrained Models](#-download-pretrained-models) 55 | - [🚀 Parallel Inference on Multiple GPUs](#-parallel-inference-on-multiple-gpus) 56 | - [🔑 Single-gpu Inference](#-single-gpu-inference) 57 | - [Run with very low VRAM](#run-with-very-low-vram) 58 | - [Run a Gradio Server](#run-a-gradio-server) 59 | - [🔗 BibTeX](#-bibtex) 60 | - [Acknowledgements](#acknowledgements) 61 | --- 62 | 63 | ## **Abstract** 64 | 65 | Recent years have witnessed significant progress in audio-driven human animation. However, critical challenges remain in (i) generating highly dynamic videos while preserving character consistency, (ii) achieving precise emotion alignment between characters and audio, and (iii) enabling multi-character audio-driven animation. To address these challenges, we propose HunyuanVideo-Avatar, a multimodal diffusion transformer (MM-DiT)-based model capable of simultaneously generating dynamic, emotion-controllable, and multi-character dialogue videos. Concretely, HunyuanVideo-Avatar introduces three key innovations: (i) A character image injection module is designed to replace the conventional addition-based character conditioning scheme, eliminating the inherent condition mismatch between training and inference. This ensures the dynamic motion and strong character consistency; (ii) An Audio Emotion Module (AEM) is introduced to extract and transfer the emotional cues from an emotion reference image to the target generated video, enabling fine-grained and accurate emotion style control; (iii) A Face-Aware Audio Adapter (FAA) is proposed to isolate the audio-driven character with latent-level face mask, enabling independent audio injection via cross-attention for multi-character scenarios. These innovations empower HunyuanVideo-Avatar to surpass state-of-the-art methods on benchmark datasets and a newly proposed wild dataset, generating realistic avatars in dynamic, immersive scenarios. The source code and model weights will be released publicly. 66 | 67 | ## **HunyuanVideo-Avatar Overall Architecture** 68 | 69 | ![image](assets/material/method.png) 70 | 71 | We propose **HunyuanVideo-Avatar**, a multi-modal diffusion transformer(MM-DiT)-based model capable of generating **dynamic**, **emotion-controllable**, and **multi-character dialogue** videos. 72 | 73 | ## 🎉 **HunyuanVideo-Avatar Key Features** 74 | 75 | ![image](assets/material/demo.png) 76 | 77 | ### **High-Dynamic and Emotion-Controllable Video Generation** 78 | 79 | HunyuanVideo-Avatar supports animating any input **avatar images** to **high-dynamic** and **emotion-controllable** videos with simple **audio conditions**. Specifically, it takes as input **multi-style** avatar images at **arbitrary scales and resolutions**. The system supports multi-style avatars encompassing photorealistic, cartoon, 3D-rendered, and anthropomorphic characters. Multi-scale generation spanning portrait, upper-body and full-body. It generates videos with high-dynamic foreground and background, achieving superior realistic and naturalness. In addition, the system supports controlling facial emotions of the characters conditioned on input audio. 80 | 81 | ### **Various Applications** 82 | 83 | HunyuanVideo-Avatar supports various downstream tasks and applications. For instance, the system generates talking avatar videos, which could be applied to e-commerce, online streaming, social media video production, etc. In addition, its multi-character animation feature enlarges the application such as video content creation, editing, etc. 84 | 85 | ## 📜 Requirements 86 | 87 | * An NVIDIA GPU with CUDA support is required. 88 | * The model is tested on a machine with 8GPUs. 89 | * **Minimum**: The minimum GPU memory required is 24GB for 704px768px129f but very slow. 90 | * **Recommended**: We recommend using a GPU with 96GB of memory for better generation quality. 91 | * **Tips**: If OOM occurs when using GPU with 80GB of memory, try to reduce the image resolution. 92 | * Tested operating system: Linux 93 | 94 | 95 | ## 🛠️ Dependencies and Installation 96 | 97 | Begin by cloning the repository: 98 | ```shell 99 | git clone https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar.git 100 | cd HunyuanVideo-Avatar 101 | ``` 102 | 103 | ### Installation Guide for Linux 104 | 105 | We recommend CUDA versions 12.4 or 11.8 for the manual installation. 106 | 107 | Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html). 108 | 109 | ```shell 110 | # 1. Create conda environment 111 | conda create -n HunyuanVideo-Avatar python==3.10.9 112 | 113 | # 2. Activate the environment 114 | conda activate HunyuanVideo-Avatar 115 | 116 | # 3. Install PyTorch and other dependencies using conda 117 | # For CUDA 11.8 118 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia 119 | # For CUDA 12.4 120 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.4 -c pytorch -c nvidia 121 | 122 | # 4. Install pip dependencies 123 | python -m pip install -r requirements.txt 124 | # 5. Install flash attention v2 for acceleration (requires CUDA 11.8 or above) 125 | python -m pip install ninja 126 | python -m pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3 127 | ``` 128 | 129 | In case of running into float point exception(core dump) on the specific GPU type, you may try the following solutions: 130 | 131 | ```shell 132 | # Option 1: Making sure you have installed CUDA 12.4, CUBLAS>=12.4.5.8, and CUDNN>=9.00 (or simply using our CUDA 12 docker image). 133 | pip install nvidia-cublas-cu12==12.4.5.8 134 | export LD_LIBRARY_PATH=/opt/conda/lib/python3.8/site-packages/nvidia/cublas/lib/ 135 | 136 | # Option 2: Forcing to explicitly use the CUDA 11.8 compiled version of Pytorch and all the other packages 137 | pip uninstall -r requirements.txt # uninstall all packages 138 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu118 139 | pip install -r requirements.txt 140 | pip install ninja 141 | pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3 142 | ``` 143 | 144 | Additionally, you can also use HunyuanVideo Docker image. Use the following command to pull and run the docker image. 145 | 146 | ```shell 147 | # For CUDA 12.4 (updated to avoid float point exception) 148 | docker pull hunyuanvideo/hunyuanvideo:cuda_12 149 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12 150 | pip install gradio==3.39.0 diffusers==0.33.0 transformers==4.41.2 151 | 152 | # For CUDA 11.8 153 | docker pull hunyuanvideo/hunyuanvideo:cuda_11 154 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11 155 | pip install gradio==3.39.0 diffusers==0.33.0 transformers==4.41.2 156 | ``` 157 | 158 | 159 | ## 🧱 Download Pretrained Models 160 | 161 | The details of download pretrained models are shown [here](weights/README.md). 162 | 163 | ## 🚀 Parallel Inference on Multiple GPUs 164 | 165 | For example, to generate a video with 8 GPUs, you can use the following command: 166 | 167 | ```bash 168 | cd HunyuanVideo-Avatar 169 | 170 | JOBS_DIR=$(dirname $(dirname "$0")) 171 | export PYTHONPATH=./ 172 | export MODEL_BASE="./weights" 173 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt 174 | 175 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \ 176 | --input 'assets/test.csv' \ 177 | --ckpt ${checkpoint_path} \ 178 | --sample-n-frames 129 \ 179 | --seed 128 \ 180 | --image-size 704 \ 181 | --cfg-scale 7.5 \ 182 | --infer-steps 50 \ 183 | --use-deepcache 1 \ 184 | --flow-shift-eval-video 5.0 \ 185 | --save-path ${OUTPUT_BASEPATH} 186 | ``` 187 | 188 | ## 🔑 Single-gpu Inference 189 | 190 | For example, to generate a video with 1 GPU, you can use the following command: 191 | 192 | ```bash 193 | cd HunyuanVideo-Avatar 194 | 195 | JOBS_DIR=$(dirname $(dirname "$0")) 196 | export PYTHONPATH=./ 197 | 198 | export MODEL_BASE=./weights 199 | OUTPUT_BASEPATH=./results-single 200 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt 201 | 202 | export DISABLE_SP=1 203 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \ 204 | --input 'assets/test.csv' \ 205 | --ckpt ${checkpoint_path} \ 206 | --sample-n-frames 129 \ 207 | --seed 128 \ 208 | --image-size 704 \ 209 | --cfg-scale 7.5 \ 210 | --infer-steps 50 \ 211 | --use-deepcache 1 \ 212 | --flow-shift-eval-video 5.0 \ 213 | --save-path ${OUTPUT_BASEPATH} \ 214 | --use-fp8 \ 215 | --infer-min 216 | ``` 217 | 218 | ### Run with very low VRAM 219 | 220 | ```bash 221 | cd HunyuanVideo-Avatar 222 | 223 | JOBS_DIR=$(dirname $(dirname "$0")) 224 | export PYTHONPATH=./ 225 | 226 | export MODEL_BASE=./weights 227 | OUTPUT_BASEPATH=./results-poor 228 | 229 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt 230 | 231 | export CPU_OFFLOAD=1 232 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \ 233 | --input 'assets/test.csv' \ 234 | --ckpt ${checkpoint_path} \ 235 | --sample-n-frames 129 \ 236 | --seed 128 \ 237 | --image-size 704 \ 238 | --cfg-scale 7.5 \ 239 | --infer-steps 50 \ 240 | --use-deepcache 1 \ 241 | --flow-shift-eval-video 5.0 \ 242 | --save-path ${OUTPUT_BASEPATH} \ 243 | --use-fp8 \ 244 | --cpu-offload \ 245 | --infer-min 246 | ``` 247 | 248 | ### Run with 10GB VRAM GPU (TeaCache supported) 249 | 250 | Thanks to [Wan2GP](https://github.com/deepbeepmeep/Wan2GP), HunyuanVideo-Avatar now supports single GPU mode with even lower VRAM (10GB) without quality degradation. Check out this [great repo](https://github.com/deepbeepmeep/Wan2GP/tree/main/hyvideo). 251 | 252 | 253 | ## Run a Gradio Server 254 | ```bash 255 | cd HunyuanVideo-Avatar 256 | 257 | bash ./scripts/run_gradio.sh 258 | 259 | ``` 260 | 261 | ## 🔗 BibTeX 262 | 263 | If you find [HunyuanVideo-Avatar](https://arxiv.org/pdf/2505.20156) useful for your research and applications, please cite using this BibTeX: 264 | 265 | ```BibTeX 266 | @misc{hu2025HunyuanVideo-Avatar, 267 | title={HunyuanVideo-Avatar: High-Fidelity Audio-Driven Human Animation for Multiple Characters}, 268 | author={Yi Chen and Sen Liang and Zixiang Zhou and Ziyao Huang and Yifeng Ma and Junshu Tang and Qin Lin and Yuan Zhou and Qinglin Lu}, 269 | year={2025}, 270 | eprint={2505.20156}, 271 | archivePrefix={arXiv}, 272 | primaryClass={cs.CV}, 273 | url={https://arxiv.org/pdf/2505.20156}, 274 | } 275 | ``` 276 | 277 | ## Acknowledgements 278 | 279 | We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration. 280 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/README_zh.md -------------------------------------------------------------------------------- /assets/audio/2.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/2.WAV -------------------------------------------------------------------------------- /assets/audio/3.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/3.WAV -------------------------------------------------------------------------------- /assets/audio/4.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/audio/4.WAV -------------------------------------------------------------------------------- /assets/image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/1.png -------------------------------------------------------------------------------- /assets/image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/2.png -------------------------------------------------------------------------------- /assets/image/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/3.png -------------------------------------------------------------------------------- /assets/image/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/4.png -------------------------------------------------------------------------------- /assets/image/src1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src1.png -------------------------------------------------------------------------------- /assets/image/src2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src2.png -------------------------------------------------------------------------------- /assets/image/src3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src3.png -------------------------------------------------------------------------------- /assets/image/src4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/image/src4.png -------------------------------------------------------------------------------- /assets/material/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/demo.png -------------------------------------------------------------------------------- /assets/material/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/logo.png -------------------------------------------------------------------------------- /assets/material/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/method.png -------------------------------------------------------------------------------- /assets/material/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/assets/material/teaser.png -------------------------------------------------------------------------------- /assets/test.csv: -------------------------------------------------------------------------------- 1 | videoid,image,audio,prompt,fps 2 | 8,assets/image/1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forested area.,25 3 | 9,assets/image/2.png,assets/audio/2.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25 4 | 10,assets/image/3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25 5 | 11,assets/image/4.png,assets/audio/2.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25 6 | 12,assets/image/src1.png,assets/audio/2.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25 7 | 13,assets/image/src2.png,assets/audio/2.WAV,A person in a green jacket stands in a forest at dusk.,25 8 | 14,assets/image/src3.png,assets/audio/2.WAV,A person playing guitar by a campfire in a forest.,25 9 | 15,assets/image/src4.png,assets/audio/2.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25 10 | 16,assets/image/1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forested area.,25 11 | 17,assets/image/2.png,assets/audio/3.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25 12 | 18,assets/image/3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25 13 | 19,assets/image/4.png,assets/audio/3.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25 14 | 20,assets/image/src1.png,assets/audio/3.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25 15 | 21,assets/image/src2.png,assets/audio/3.WAV,A person in a green jacket stands in a forest at dusk.,25 16 | 22,assets/image/src3.png,assets/audio/3.WAV,A person playing guitar by a campfire in a forest.,25 17 | 23,assets/image/src4.png,assets/audio/3.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25 18 | 24,assets/image/1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forested area.,25 19 | 25,assets/image/2.png,assets/audio/4.WAV,"A person with long blonde hair wearing a green jacket, standing in a forested area during twilight.",25 20 | 26,assets/image/3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25 21 | 27,assets/image/4.png,assets/audio/4.WAV,"A person wearing a green jacket stands in a forested area, with sunlight filtering through the trees.",25 22 | 28,assets/image/src1.png,assets/audio/4.WAV,A person sits cross-legged by a campfire in a forest at dusk.,25 23 | 29,assets/image/src2.png,assets/audio/4.WAV,A person in a green jacket stands in a forest at dusk.,25 24 | 30,assets/image/src3.png,assets/audio/4.WAV,A person playing guitar by a campfire in a forest.,25 25 | 31,assets/image/src4.png,assets/audio/4.WAV,"A person in a green jacket stands in a forest, backlit by sunlight.",25 26 | -------------------------------------------------------------------------------- /hymm_gradio/flask_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import warnings 5 | import threading 6 | import traceback 7 | import uvicorn 8 | from fastapi import FastAPI, Body 9 | from pathlib import Path 10 | from datetime import datetime 11 | import torch.distributed as dist 12 | from hymm_gradio.tool_for_end2end import * 13 | from hymm_sp.config import parse_args 14 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler 15 | 16 | from hymm_sp.modules.parallel_states import ( 17 | initialize_distributed, 18 | nccl_info, 19 | ) 20 | 21 | from transformers import WhisperModel 22 | from transformers import AutoFeatureExtractor 23 | from hymm_sp.data_kits.face_align import AlignImage 24 | 25 | 26 | warnings.filterwarnings("ignore") 27 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') 28 | app = FastAPI() 29 | rlock = threading.RLock() 30 | 31 | 32 | 33 | @app.api_route('/predict2', methods=['GET', 'POST']) 34 | def predict(data=Body(...)): 35 | is_acquire = False 36 | error_info = "" 37 | try: 38 | is_acquire = rlock.acquire(blocking=False) 39 | if is_acquire: 40 | res = predict_wrap(data) 41 | return res 42 | except Exception as e: 43 | error_info = traceback.format_exc() 44 | print(error_info) 45 | finally: 46 | if is_acquire: 47 | rlock.release() 48 | return {"errCode": -1, "info": "broken"} 49 | 50 | def predict_wrap(input_dict={}): 51 | if nccl_info.sp_size > 1: 52 | device = torch.device(f"cuda:{torch.distributed.get_rank()}") 53 | rank = local_rank = torch.distributed.get_rank() 54 | print(f"sp_size={nccl_info.sp_size}, rank {rank} local_rank {local_rank}") 55 | try: 56 | print(f"----- rank = {rank}") 57 | if rank == 0: 58 | input_dict = process_input_dict(input_dict) 59 | 60 | print('------- start to predict -------') 61 | # Parse input arguments 62 | image_path = input_dict["image_path"] 63 | driving_audio_path = input_dict["audio_path"] 64 | 65 | prompt = input_dict["prompt"] 66 | 67 | save_fps = input_dict.get("save_fps", 25) 68 | 69 | 70 | ret_dict = None 71 | if image_path is None or driving_audio_path is None: 72 | ret_dict = { 73 | "errCode": -3, 74 | "content": [ 75 | { 76 | "buffer": None 77 | }, 78 | ], 79 | "info": "input content is not valid", 80 | } 81 | 82 | print(f"errCode: -3, input content is not valid!") 83 | return ret_dict 84 | 85 | # Preprocess input batch 86 | torch.cuda.synchronize() 87 | 88 | a = datetime.now() 89 | 90 | try: 91 | model_kwargs_tmp = data_preprocess_server( 92 | args, image_path, driving_audio_path, prompt, feature_extractor 93 | ) 94 | except: 95 | ret_dict = { 96 | "errCode": -2, 97 | "content": [ 98 | { 99 | "buffer": None 100 | }, 101 | ], 102 | "info": "failed to preprocess input data" 103 | } 104 | print(f"errCode: -2, preprocess failed!") 105 | return ret_dict 106 | 107 | text_prompt = model_kwargs_tmp["text_prompt"] 108 | audio_path = model_kwargs_tmp["audio_path"] 109 | image_path = model_kwargs_tmp["image_path"] 110 | fps = model_kwargs_tmp["fps"] 111 | audio_prompts = model_kwargs_tmp["audio_prompts"] 112 | audio_len = model_kwargs_tmp["audio_len"] 113 | motion_bucket_id_exps = model_kwargs_tmp["motion_bucket_id_exps"] 114 | motion_bucket_id_heads = model_kwargs_tmp["motion_bucket_id_heads"] 115 | pixel_value_ref = model_kwargs_tmp["pixel_value_ref"] 116 | pixel_value_ref_llava = model_kwargs_tmp["pixel_value_ref_llava"] 117 | 118 | 119 | 120 | torch.cuda.synchronize() 121 | b = datetime.now() 122 | preprocess_time = (b - a).total_seconds() 123 | print("="*100) 124 | print("preprocess time :", preprocess_time) 125 | print("="*100) 126 | 127 | else: 128 | text_prompt = None 129 | audio_path = None 130 | image_path = None 131 | fps = None 132 | audio_prompts = None 133 | audio_len = None 134 | motion_bucket_id_exps = None 135 | motion_bucket_id_heads = None 136 | pixel_value_ref = None 137 | pixel_value_ref_llava = None 138 | 139 | except: 140 | traceback.print_exc() 141 | if rank == 0: 142 | ret_dict = { 143 | "errCode": -1, # Failed to generate video 144 | "content":[ 145 | { 146 | "buffer": None 147 | } 148 | ], 149 | "info": "failed to preprocess", 150 | } 151 | return ret_dict 152 | 153 | try: 154 | broadcast_params = [ 155 | text_prompt, 156 | audio_path, 157 | image_path, 158 | fps, 159 | audio_prompts, 160 | audio_len, 161 | motion_bucket_id_exps, 162 | motion_bucket_id_heads, 163 | pixel_value_ref, 164 | pixel_value_ref_llava, 165 | ] 166 | dist.broadcast_object_list(broadcast_params, src=0) 167 | outputs = generate_image_parallel(*broadcast_params) 168 | 169 | if rank == 0: 170 | samples = outputs["samples"] 171 | sample = samples[0].unsqueeze(0) 172 | 173 | sample = sample[:, :, :audio_len[0]] 174 | 175 | video = sample[0].permute(1, 2, 3, 0).clamp(0, 1).numpy() 176 | video = (video * 255.).astype(np.uint8) 177 | 178 | output_dict = { 179 | "err_code": 0, 180 | "err_msg": "succeed", 181 | "video": video, 182 | "audio": input_dict.get("audio_path", None), 183 | "save_fps": save_fps, 184 | } 185 | 186 | ret_dict = process_output_dict(output_dict) 187 | return ret_dict 188 | 189 | except: 190 | traceback.print_exc() 191 | if rank == 0: 192 | ret_dict = { 193 | "errCode": -1, # Failed to generate video 194 | "content":[ 195 | { 196 | "buffer": None 197 | } 198 | ], 199 | "info": "failed to generate video", 200 | } 201 | return ret_dict 202 | 203 | return None 204 | 205 | def generate_image_parallel(text_prompt, 206 | audio_path, 207 | image_path, 208 | fps, 209 | audio_prompts, 210 | audio_len, 211 | motion_bucket_id_exps, 212 | motion_bucket_id_heads, 213 | pixel_value_ref, 214 | pixel_value_ref_llava 215 | ): 216 | if nccl_info.sp_size > 1: 217 | device = torch.device(f"cuda:{torch.distributed.get_rank()}") 218 | 219 | batch = { 220 | "text_prompt": text_prompt, 221 | "audio_path": audio_path, 222 | "image_path": image_path, 223 | "fps": fps, 224 | "audio_prompts": audio_prompts, 225 | "audio_len": audio_len, 226 | "motion_bucket_id_exps": motion_bucket_id_exps, 227 | "motion_bucket_id_heads": motion_bucket_id_heads, 228 | "pixel_value_ref": pixel_value_ref, 229 | "pixel_value_ref_llava": pixel_value_ref_llava 230 | } 231 | 232 | samples = hunyuan_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) 233 | return samples 234 | 235 | def worker_loop(): 236 | while True: 237 | predict_wrap() 238 | 239 | 240 | if __name__ == "__main__": 241 | audio_args = parse_args() 242 | initialize_distributed(audio_args.seed) 243 | hunyuan_sampler = HunyuanVideoSampler.from_pretrained( 244 | audio_args.ckpt, args=audio_args) 245 | args = hunyuan_sampler.args 246 | 247 | rank = local_rank = 0 248 | device = torch.device("cuda") 249 | if nccl_info.sp_size > 1: 250 | device = torch.device(f"cuda:{torch.distributed.get_rank()}") 251 | rank = local_rank = torch.distributed.get_rank() 252 | 253 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") 254 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) 255 | wav2vec.requires_grad_(False) 256 | 257 | 258 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' 259 | det_path = os.path.join(BASE_DIR, 'detface.pt') 260 | align_instance = AlignImage("cuda", det_path=det_path) 261 | 262 | 263 | 264 | if rank == 0: 265 | uvicorn.run(app, host="0.0.0.0", port=80) 266 | else: 267 | worker_loop() 268 | 269 | -------------------------------------------------------------------------------- /hymm_gradio/gradio_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import json 5 | import datetime 6 | import requests 7 | import gradio as gr 8 | from tool_for_end2end import * 9 | 10 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" 11 | DATADIR = './temp' 12 | _HEADER_ = ''' 13 |
14 |

Tencent HunyuanVideo-Avatar Demo

15 |
16 | 17 | ''' 18 | # flask url 19 | URL = "http://127.0.0.1:80/predict2" 20 | 21 | def post_and_get(audio_input, id_image, prompt): 22 | now = datetime.datetime.now().isoformat() 23 | imgdir = os.path.join(DATADIR, 'reference') 24 | videodir = os.path.join(DATADIR, 'video') 25 | imgfile = os.path.join(imgdir, now + '.png') 26 | output_video_path = os.path.join(videodir, now + '.mp4') 27 | 28 | 29 | os.makedirs(imgdir, exist_ok=True) 30 | os.makedirs(videodir, exist_ok=True) 31 | cv2.imwrite(imgfile, id_image[:,:,::-1]) 32 | 33 | proxies = { 34 | "http": None, 35 | "https": None, 36 | } 37 | 38 | files = { 39 | "image_buffer": encode_image_to_base64(imgfile), 40 | "audio_buffer": encode_wav_to_base64(audio_input), 41 | "text": prompt, 42 | "save_fps": 25, 43 | } 44 | r = requests.get(URL, data = json.dumps(files), proxies=proxies) 45 | ret_dict = json.loads(r.text) 46 | print(ret_dict["info"]) 47 | save_video_base64_to_local( 48 | video_path=None, 49 | base64_buffer=ret_dict["content"][0]["buffer"], 50 | output_video_path=output_video_path) 51 | 52 | 53 | return output_video_path 54 | 55 | def create_demo(): 56 | 57 | with gr.Blocks() as demo: 58 | gr.Markdown(_HEADER_) 59 | with gr.Tab('语音数字人驱动'): 60 | with gr.Row(): 61 | with gr.Column(scale=1): 62 | with gr.Group(): 63 | prompt = gr.Textbox(label="Prompt", value="a man is speaking.") 64 | 65 | audio_input = gr.Audio(sources=["upload"], 66 | type="filepath", 67 | label="Upload Audio", 68 | elem_classes="media-upload", 69 | scale=1) 70 | id_image = gr.Image(label="Input reference image", height=480) 71 | 72 | with gr.Column(scale=2): 73 | with gr.Group(): 74 | output_image = gr.Video(label="Generated Video") 75 | 76 | 77 | with gr.Column(scale=1): 78 | generate_btn = gr.Button("Generate") 79 | 80 | generate_btn.click(fn=post_and_get, 81 | inputs=[audio_input, id_image, prompt], 82 | outputs=[output_image], 83 | ) 84 | 85 | return demo 86 | 87 | if __name__ == "__main__": 88 | allowed_paths = ['/'] 89 | demo = create_demo() 90 | demo.launch(server_name='0.0.0.0', server_port=8080, share=True, allowed_paths=allowed_paths) 91 | -------------------------------------------------------------------------------- /hymm_gradio/tool_for_end2end.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import math 4 | import uuid 5 | import base64 6 | import imageio 7 | import torch 8 | import torchvision 9 | from PIL import Image 10 | import numpy as np 11 | from copy import deepcopy 12 | from einops import rearrange 13 | import torchvision.transforms as transforms 14 | from torchvision.transforms import ToPILImage 15 | from hymm_sp.data_kits.audio_dataset import get_audio_feature 16 | 17 | TEMP_DIR = "./temp" 18 | if not os.path.exists(TEMP_DIR): 19 | os.makedirs(TEMP_DIR, exist_ok=True) 20 | 21 | 22 | def data_preprocess_server(args, image_path, audio_path, prompts, feature_extractor): 23 | llava_transform = transforms.Compose( 24 | [ 25 | transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), 26 | transforms.ToTensor(), 27 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), 28 | ] 29 | ) 30 | 31 | """ 生成prompt """ 32 | if prompts is None: 33 | prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed." 34 | else: 35 | prompts = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + prompts 36 | 37 | fps = 25 38 | 39 | img_size = args.image_size 40 | ref_image = Image.open(image_path).convert('RGB') 41 | 42 | # Resize reference image 43 | w, h = ref_image.size 44 | scale = img_size / min(w, h) 45 | new_w = round(w * scale / 64) * 64 46 | new_h = round(h * scale / 64) * 64 47 | 48 | if img_size == 704: 49 | img_size_long = 1216 50 | if new_w * new_h > img_size * img_size_long: 51 | scale = math.sqrt(img_size * img_size_long / w / h) 52 | new_w = round(w * scale / 64) * 64 53 | new_h = round(h * scale / 64) * 64 54 | 55 | ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS) 56 | 57 | ref_image = np.array(ref_image) 58 | ref_image = torch.from_numpy(ref_image) 59 | 60 | audio_input, audio_len = get_audio_feature(feature_extractor, audio_path) 61 | audio_prompts = audio_input[0] 62 | 63 | motion_bucket_id_heads = np.array([25] * 4) 64 | motion_bucket_id_exps = np.array([30] * 4) 65 | motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads) 66 | motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps) 67 | fps = torch.from_numpy(np.array(fps)) 68 | 69 | to_pil = ToPILImage() 70 | pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w) 71 | 72 | pixel_value_ref_llava = [llava_transform(to_pil(image)) for image in pixel_value_ref] 73 | pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0) 74 | 75 | batch = { 76 | "text_prompt": [prompts], 77 | "audio_path": [audio_path], 78 | "image_path": [image_path], 79 | "fps": fps.unsqueeze(0).to(dtype=torch.float16), 80 | "audio_prompts": audio_prompts.unsqueeze(0).to(dtype=torch.float16), 81 | "audio_len": [audio_len], 82 | "motion_bucket_id_exps": motion_bucket_id_exps.unsqueeze(0), 83 | "motion_bucket_id_heads": motion_bucket_id_heads.unsqueeze(0), 84 | "pixel_value_ref": pixel_value_ref.unsqueeze(0).to(dtype=torch.float16), 85 | "pixel_value_ref_llava": pixel_value_ref_llava.unsqueeze(0).to(dtype=torch.float16) 86 | } 87 | 88 | return batch 89 | 90 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): 91 | videos = rearrange(videos, "b c t h w -> t b c h w") 92 | outputs = [] 93 | for x in videos: 94 | x = torchvision.utils.make_grid(x, nrow=n_rows) 95 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 96 | if rescale: 97 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 98 | x = torch.clamp(x,0,1) 99 | x = (x * 255).numpy().astype(np.uint8) 100 | outputs.append(x) 101 | 102 | os.makedirs(os.path.dirname(path), exist_ok=True) 103 | imageio.mimsave(path, outputs, fps=fps, quality=quality) 104 | 105 | def encode_image_to_base64(image_path): 106 | try: 107 | with open(image_path, 'rb') as image_file: 108 | image_data = image_file.read() 109 | encoded_data = base64.b64encode(image_data).decode('utf-8') 110 | print(f"Image file '{image_path}' has been successfully encoded to Base64.") 111 | return encoded_data 112 | 113 | except Exception as e: 114 | print(f"Error encoding image: {e}") 115 | return None 116 | 117 | def encode_video_to_base64(video_path): 118 | try: 119 | with open(video_path, 'rb') as video_file: 120 | video_data = video_file.read() 121 | encoded_data = base64.b64encode(video_data).decode('utf-8') 122 | print(f"Video file '{video_path}' has been successfully encoded to Base64.") 123 | return encoded_data 124 | 125 | except Exception as e: 126 | print(f"Error encoding video: {e}") 127 | return None 128 | 129 | def encode_wav_to_base64(wav_path): 130 | try: 131 | with open(wav_path, 'rb') as audio_file: 132 | audio_data = audio_file.read() 133 | encoded_data = base64.b64encode(audio_data).decode('utf-8') 134 | print(f"Audio file '{wav_path}' has been successfully encoded to Base64.") 135 | return encoded_data 136 | 137 | except Exception as e: 138 | print(f"Error encoding audio: {e}") 139 | return None 140 | 141 | def encode_pkl_to_base64(pkl_path): 142 | try: 143 | with open(pkl_path, 'rb') as pkl_file: 144 | pkl_data = pkl_file.read() 145 | 146 | encoded_data = base64.b64encode(pkl_data).decode('utf-8') 147 | 148 | print(f"Pickle file '{pkl_path}' has been successfully encoded to Base64.") 149 | return encoded_data 150 | 151 | except Exception as e: 152 | print(f"Error encoding pickle: {e}") 153 | return None 154 | 155 | def decode_base64_to_image(base64_buffer_str): 156 | try: 157 | image_data = base64.b64decode(base64_buffer_str) 158 | image = Image.open(io.BytesIO(image_data)) 159 | image_array = np.array(image) 160 | print(f"Image Base64 string has beed succesfully decoded to image.") 161 | return image_array 162 | except Exception as e: 163 | print(f"Error encdecodingoding image: {e}") 164 | return None 165 | 166 | def decode_base64_to_video(base64_buffer_str): 167 | try: 168 | video_data = base64.b64decode(base64_buffer_str) 169 | video_bytes = io.BytesIO(video_data) 170 | video_bytes.seek(0) 171 | video_reader = imageio.get_reader(video_bytes, 'ffmpeg') 172 | video_frames = [frame for frame in video_reader] 173 | return video_frames 174 | except Exception as e: 175 | print(f"Error decoding video: {e}") 176 | return None 177 | 178 | 179 | def save_video_base64_to_local(video_path=None, base64_buffer=None, output_video_path=None): 180 | if video_path is not None and base64_buffer is None: 181 | video_buffer_base64 = encode_video_to_base64(video_path) 182 | elif video_path is None and base64_buffer is not None: 183 | video_buffer_base64 = deepcopy(base64_buffer) 184 | else: 185 | print("Please pass either 'video_path' or 'base64_buffer'") 186 | return None 187 | 188 | if video_buffer_base64 is not None: 189 | video_data = base64.b64decode(video_buffer_base64) 190 | if output_video_path is None: 191 | uuid_string = str(uuid.uuid4()) 192 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' 193 | else: 194 | temp_video_path = output_video_path 195 | with open(temp_video_path, 'wb') as video_file: 196 | video_file.write(video_data) 197 | return temp_video_path 198 | else: 199 | return None 200 | 201 | def save_audio_base64_to_local(audio_path=None, base64_buffer=None): 202 | if audio_path is not None and base64_buffer is None: 203 | audio_buffer_base64 = encode_wav_to_base64(audio_path) 204 | elif audio_path is None and base64_buffer is not None: 205 | audio_buffer_base64 = deepcopy(base64_buffer) 206 | else: 207 | print("Please pass either 'audio_path' or 'base64_buffer'") 208 | return None 209 | 210 | if audio_buffer_base64 is not None: 211 | audio_data = base64.b64decode(audio_buffer_base64) 212 | uuid_string = str(uuid.uuid4()) 213 | temp_audio_path = f'{TEMP_DIR}/{uuid_string}.wav' 214 | with open(temp_audio_path, 'wb') as audio_file: 215 | audio_file.write(audio_data) 216 | return temp_audio_path 217 | else: 218 | return None 219 | 220 | def save_pkl_base64_to_local(pkl_path=None, base64_buffer=None): 221 | if pkl_path is not None and base64_buffer is None: 222 | pkl_buffer_base64 = encode_pkl_to_base64(pkl_path) 223 | elif pkl_path is None and base64_buffer is not None: 224 | pkl_buffer_base64 = deepcopy(base64_buffer) 225 | else: 226 | print("Please pass either 'pkl_path' or 'base64_buffer'") 227 | return None 228 | 229 | if pkl_buffer_base64 is not None: 230 | pkl_data = base64.b64decode(pkl_buffer_base64) 231 | uuid_string = str(uuid.uuid4()) 232 | temp_pkl_path = f'{TEMP_DIR}/{uuid_string}.pkl' 233 | with open(temp_pkl_path, 'wb') as pkl_file: 234 | pkl_file.write(pkl_data) 235 | return temp_pkl_path 236 | else: 237 | return None 238 | 239 | def remove_temp_fles(input_dict): 240 | for key, val in input_dict.items(): 241 | if "_path" in key and val is not None and os.path.exists(val): 242 | os.remove(val) 243 | print(f"Remove temporary {key} from {val}") 244 | 245 | def process_output_dict(output_dict): 246 | 247 | uuid_string = str(uuid.uuid4()) 248 | temp_video_path = f'{TEMP_DIR}/{uuid_string}.mp4' 249 | imageio.mimsave(temp_video_path, output_dict["video"], fps=output_dict.get("save_fps", 25)) 250 | 251 | # Add audio 252 | if output_dict["audio"] is not None and os.path.exists(output_dict["audio"]): 253 | output_path = temp_video_path 254 | audio_path = output_dict["audio"] 255 | save_path = temp_video_path.replace(".mp4", "_audio.mp4") 256 | print('='*100) 257 | print(f"output_path = {output_path}\n audio_path = {audio_path}\n save_path = {save_path}") 258 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{save_path}' -y -loglevel quiet; rm '{output_path}'") 259 | else: 260 | save_path = temp_video_path 261 | 262 | video_base64_buffer = encode_video_to_base64(save_path) 263 | 264 | encoded_output_dict = { 265 | "errCode": output_dict["err_code"], 266 | "content": [ 267 | { 268 | "buffer": video_base64_buffer 269 | }, 270 | ], 271 | "info":output_dict["err_msg"], 272 | } 273 | 274 | 275 | 276 | return encoded_output_dict 277 | 278 | 279 | def save_image_base64_to_local(image_path=None, base64_buffer=None): 280 | # Encode image to base64 buffer 281 | if image_path is not None and base64_buffer is None: 282 | image_buffer_base64 = encode_image_to_base64(image_path) 283 | elif image_path is None and base64_buffer is not None: 284 | image_buffer_base64 = deepcopy(base64_buffer) 285 | else: 286 | print("Please pass either 'image_path' or 'base64_buffer'") 287 | return None 288 | 289 | # Decode base64 buffer and save to local disk 290 | if image_buffer_base64 is not None: 291 | image_data = base64.b64decode(image_buffer_base64) 292 | uuid_string = str(uuid.uuid4()) 293 | temp_image_path = f'{TEMP_DIR}/{uuid_string}.png' 294 | with open(temp_image_path, 'wb') as image_file: 295 | image_file.write(image_data) 296 | return temp_image_path 297 | else: 298 | return None 299 | 300 | def process_input_dict(input_dict): 301 | 302 | decoded_input_dict = {} 303 | 304 | decoded_input_dict["save_fps"] = input_dict.get("save_fps", 25) 305 | 306 | image_base64_buffer = input_dict.get("image_buffer", None) 307 | if image_base64_buffer is not None: 308 | decoded_input_dict["image_path"] = save_image_base64_to_local( 309 | image_path=None, 310 | base64_buffer=image_base64_buffer) 311 | else: 312 | decoded_input_dict["image_path"] = None 313 | 314 | audio_base64_buffer = input_dict.get("audio_buffer", None) 315 | if audio_base64_buffer is not None: 316 | decoded_input_dict["audio_path"] = save_audio_base64_to_local( 317 | audio_path=None, 318 | base64_buffer=audio_base64_buffer) 319 | else: 320 | decoded_input_dict["audio_path"] = None 321 | 322 | decoded_input_dict["prompt"] = input_dict.get("text", None) 323 | 324 | return decoded_input_dict -------------------------------------------------------------------------------- /hymm_sp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__init__.py -------------------------------------------------------------------------------- /hymm_sp/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/__pycache__/constants.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/constants.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/__pycache__/helpers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/helpers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/__pycache__/inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/inference.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/__pycache__/sample_inference_audio.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from hymm_sp.constants import * 3 | import re 4 | import collections.abc 5 | 6 | def as_tuple(x): 7 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 8 | return tuple(x) 9 | if x is None or isinstance(x, (int, float, str)): 10 | return (x,) 11 | else: 12 | raise ValueError(f"Unknown type {type(x)}") 13 | 14 | def parse_args(namespace=None): 15 | parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script") 16 | parser = add_extra_args(parser) 17 | args = parser.parse_args(namespace=namespace) 18 | args = sanity_check_args(args) 19 | return args 20 | 21 | def add_extra_args(parser: argparse.ArgumentParser): 22 | parser = add_network_args(parser) 23 | parser = add_extra_models_args(parser) 24 | parser = add_denoise_schedule_args(parser) 25 | parser = add_evaluation_args(parser) 26 | return parser 27 | 28 | def add_network_args(parser: argparse.ArgumentParser): 29 | group = parser.add_argument_group(title="Network") 30 | group.add_argument("--model", type=str, default="HYVideo-T/2", 31 | help="Model architecture to use. It it also used to determine the experiment directory.") 32 | group.add_argument("--latent-channels", type=str, default=None, 33 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, " 34 | "it still needs to match the latent channels of the VAE model.") 35 | group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.") 36 | return parser 37 | 38 | def add_extra_models_args(parser: argparse.ArgumentParser): 39 | group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)") 40 | 41 | # VAE 42 | group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.") 43 | group.add_argument("--vae-precision", type=str, default="fp16", 44 | help="Precision mode for the VAE model.") 45 | group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.") 46 | group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH), 47 | help="Name of the text encoder model.") 48 | group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS, 49 | help="Precision mode for the text encoder model.") 50 | group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.") 51 | group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.") 52 | group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH), 53 | help="Name of the tokenizer model.") 54 | group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"], 55 | help="Inference mode for the text encoder model. It should match the text encoder type. T5 and " 56 | "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.") 57 | group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE, 58 | help="Video prompt template for the decoder-only text encoder model.") 59 | group.add_argument("--hidden-state-skip-layer", type=int, default=2, 60 | help="Skip layer for hidden states.") 61 | group.add_argument("--apply-final-norm", action="store_true", 62 | help="Apply final normalization to the used text encoder hidden states.") 63 | 64 | # - CLIP 65 | group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH), 66 | help="Name of the second text encoder model.") 67 | group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS, 68 | help="Precision mode for the second text encoder model.") 69 | group.add_argument("--text-states-dim-2", type=int, default=768, 70 | help="Dimension of the second text encoder hidden states.") 71 | group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH), 72 | help="Name of the second tokenizer model.") 73 | group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.") 74 | group.set_defaults(use_attention_mask=True) 75 | group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION, 76 | help="A projection layer for bridging the text encoder hidden states and the diffusion model " 77 | "conditions.") 78 | return parser 79 | 80 | 81 | def add_denoise_schedule_args(parser: argparse.ArgumentParser): 82 | group = parser.add_argument_group(title="Denoise schedule") 83 | group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.") 84 | group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.") 85 | group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.") 86 | group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching." 87 | "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)") 88 | group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.") 89 | return parser 90 | 91 | def add_evaluation_args(parser: argparse.ArgumentParser): 92 | group = parser.add_argument_group(title="Validation Loss Evaluation") 93 | parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS, 94 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.") 95 | parser.add_argument("--reproduce", action="store_true", 96 | help="Enable reproducibility by setting random seeds and deterministic algorithms.") 97 | parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.") 98 | parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"], 99 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.") 100 | parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.") 101 | parser.add_argument("--infer-min", action="store_true", help="infer 5s.") 102 | group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.") 103 | group.add_argument("--video-size", type=int, nargs='+', default=512, 104 | help="Video size for training. If a single value is provided, it will be used for both width " 105 | "and height. If two values are provided, they will be used for width and height " 106 | "respectively.") 107 | group.add_argument("--sample-n-frames", type=int, default=1, 108 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1") 109 | group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.") 110 | group.add_argument("--val-disable-autocast", action="store_true", 111 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.") 112 | group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.") 113 | group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.") 114 | group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.") 115 | group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.") 116 | group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.") 117 | group.add_argument("--image-size", type=int, default=704) 118 | group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.") 119 | group.add_argument("--image-path", type=str, default="", help="") 120 | group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.") 121 | group.add_argument("--input", type=str, default=None, help="test data.") 122 | group.add_argument("--item-name", type=str, default=None, help="") 123 | group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.") 124 | group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.") 125 | group.add_argument("--use-deepcache", type=int, default=1) 126 | return parser 127 | 128 | def sanity_check_args(args): 129 | # VAE channels 130 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+" 131 | if not re.match(vae_pattern, args.vae): 132 | raise ValueError( 133 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'." 134 | ) 135 | vae_channels = int(args.vae.split("-")[1][:-1]) 136 | if args.latent_channels is None: 137 | args.latent_channels = vae_channels 138 | if vae_channels != args.latent_channels: 139 | raise ValueError( 140 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})." 141 | ) 142 | return args 143 | -------------------------------------------------------------------------------- /hymm_sp/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | __all__ = [ 5 | "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE", 6 | "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH", 7 | "TEXT_PROJECTION", 8 | ] 9 | 10 | # =================== Constant Values ===================== 11 | 12 | PRECISION_TO_TYPE = { 13 | 'fp32': torch.float32, 14 | 'fp16': torch.float16, 15 | 'bf16': torch.bfloat16, 16 | } 17 | 18 | PROMPT_TEMPLATE_ENCODE_VIDEO = ( 19 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " 20 | "1. The main content and theme of the video." 21 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." 22 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." 23 | "4. background environment, light, style and atmosphere." 24 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>" 25 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" 26 | ) 27 | 28 | PROMPT_TEMPLATE = { 29 | "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95}, 30 | } 31 | 32 | # ======================= Model ====================== 33 | PRECISIONS = {"fp32", "fp16", "bf16"} 34 | 35 | # =================== Model Path ===================== 36 | MODEL_BASE = os.getenv("MODEL_BASE") 37 | MODEL_BASE=f"{MODEL_BASE}/ckpts" 38 | 39 | # 3D VAE 40 | VAE_PATH = { 41 | "884-16c-hy0801": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae", 42 | } 43 | 44 | # Text Encoder 45 | TEXT_ENCODER_PATH = { 46 | "clipL": f"{MODEL_BASE}/text_encoder_2", 47 | "llava-llama-3-8b": f"{MODEL_BASE}/llava_llama_image", 48 | } 49 | 50 | # Tokenizer 51 | TOKENIZER_PATH = { 52 | "clipL": f"{MODEL_BASE}/text_encoder_2", 53 | "llava-llama-3-8b":f"{MODEL_BASE}/llava_llama_image", 54 | } 55 | 56 | TEXT_PROJECTION = { 57 | "linear", # Default, an nn.Linear() layer 58 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT 59 | } 60 | -------------------------------------------------------------------------------- /hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/audio_dataset.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/audio_preprocessor.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/data_tools.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/__pycache__/ffmpeg_utils.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/audio_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import math 4 | import json 5 | import torch 6 | import random 7 | import librosa 8 | import traceback 9 | import torchvision 10 | import numpy as np 11 | import pandas as pd 12 | from PIL import Image 13 | from einops import rearrange 14 | from torch.utils.data import Dataset 15 | from decord import VideoReader, cpu 16 | from transformers import CLIPImageProcessor 17 | import torchvision.transforms as transforms 18 | from torchvision.transforms import ToPILImage 19 | 20 | 21 | 22 | def get_audio_feature(feature_extractor, audio_path): 23 | audio_input, sampling_rate = librosa.load(audio_path, sr=16000) 24 | assert sampling_rate == 16000 25 | 26 | audio_features = [] 27 | window = 750*640 28 | for i in range(0, len(audio_input), window): 29 | audio_feature = feature_extractor(audio_input[i:i+window], 30 | sampling_rate=sampling_rate, 31 | return_tensors="pt", 32 | ).input_features 33 | audio_features.append(audio_feature) 34 | 35 | audio_features = torch.cat(audio_features, dim=-1) 36 | return audio_features, len(audio_input) // 640 37 | 38 | 39 | class VideoAudioTextLoaderVal(Dataset): 40 | def __init__( 41 | self, 42 | image_size: int, 43 | meta_file: str, 44 | **kwargs, 45 | ): 46 | super().__init__() 47 | self.meta_file = meta_file 48 | self.image_size = image_size 49 | self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder 50 | self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder 51 | self.feature_extractor = kwargs.get("feature_extractor", None) 52 | self.meta_files = [] 53 | 54 | csv_data = pd.read_csv(meta_file) 55 | for idx in range(len(csv_data)): 56 | self.meta_files.append( 57 | { 58 | "videoid": str(csv_data["videoid"][idx]), 59 | "image_path": str(csv_data["image"][idx]), 60 | "audio_path": str(csv_data["audio"][idx]), 61 | "prompt": str(csv_data["prompt"][idx]), 62 | "fps": float(csv_data["fps"][idx]) 63 | } 64 | ) 65 | 66 | self.llava_transform = transforms.Compose( 67 | [ 68 | transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), 69 | transforms.ToTensor(), 70 | transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), 71 | ] 72 | ) 73 | self.clip_image_processor = CLIPImageProcessor() 74 | 75 | self.device = torch.device("cuda") 76 | self.weight_dtype = torch.float16 77 | 78 | 79 | def __len__(self): 80 | return len(self.meta_files) 81 | 82 | @staticmethod 83 | def get_text_tokens(text_encoder, description, dtype_encode="video"): 84 | text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode) 85 | text_ids = text_inputs["input_ids"].squeeze(0) 86 | text_mask = text_inputs["attention_mask"].squeeze(0) 87 | return text_ids, text_mask 88 | 89 | def get_batch_data(self, idx): 90 | meta_file = self.meta_files[idx] 91 | videoid = meta_file["videoid"] 92 | image_path = meta_file["image_path"] 93 | audio_path = meta_file["audio_path"] 94 | prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"] 95 | fps = meta_file["fps"] 96 | 97 | img_size = self.image_size 98 | ref_image = Image.open(image_path).convert('RGB') 99 | 100 | # Resize reference image 101 | w, h = ref_image.size 102 | scale = img_size / min(w, h) 103 | new_w = round(w * scale / 64) * 64 104 | new_h = round(h * scale / 64) * 64 105 | 106 | if img_size == 704: 107 | img_size_long = 1216 108 | if new_w * new_h > img_size * img_size_long: 109 | import math 110 | scale = math.sqrt(img_size * img_size_long / w / h) 111 | new_w = round(w * scale / 64) * 64 112 | new_h = round(h * scale / 64) * 64 113 | 114 | ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS) 115 | 116 | ref_image = np.array(ref_image) 117 | ref_image = torch.from_numpy(ref_image) 118 | 119 | audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path) 120 | audio_prompts = audio_input[0] 121 | 122 | motion_bucket_id_heads = np.array([25] * 4) 123 | motion_bucket_id_exps = np.array([30] * 4) 124 | motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads) 125 | motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps) 126 | fps = torch.from_numpy(np.array(fps)) 127 | 128 | to_pil = ToPILImage() 129 | pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w) 130 | 131 | pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref] 132 | pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0) 133 | pixel_value_ref_clip = self.clip_image_processor( 134 | images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)), 135 | return_tensors="pt" 136 | ).pixel_values[0] 137 | pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0) 138 | 139 | # Encode text prompts 140 | 141 | text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt) 142 | text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt) 143 | 144 | # Output batch 145 | batch = { 146 | "text_prompt": prompt, # 147 | "videoid": videoid, 148 | "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255) 149 | "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围 150 | "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围 151 | "audio_prompts": audio_prompts.to(dtype=torch.float16), 152 | "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype), 153 | "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype), 154 | "fps": fps.to(dtype=torch.float16), 155 | "text_ids": text_ids.clone(), # 对应llava_text_encoder 156 | "text_mask": text_mask.clone(), # 对应llava_text_encoder 157 | "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder 158 | "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder 159 | "audio_len": audio_len, 160 | "image_path": image_path, 161 | "audio_path": audio_path, 162 | } 163 | return batch 164 | 165 | def __getitem__(self, idx): 166 | return self.get_batch_data(idx) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /hymm_sp/data_kits/audio_preprocessor.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import json 5 | import time 6 | import decord 7 | import einops 8 | import librosa 9 | import torch 10 | import random 11 | import argparse 12 | import traceback 13 | import numpy as np 14 | from tqdm import tqdm 15 | from PIL import Image 16 | from einops import rearrange 17 | 18 | 19 | 20 | def get_facemask(ref_image, align_instance, area=1.25): 21 | # ref_image: (b f c h w) 22 | bsz, f, c, h, w = ref_image.shape 23 | images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) 24 | face_masks = [] 25 | for image in images: 26 | image_pil = Image.fromarray(image).convert("RGB") 27 | _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) 28 | try: 29 | bboxSrc = bboxes_list[0] 30 | except: 31 | bboxSrc = [0, 0, w, h] 32 | x1, y1, ww, hh = bboxSrc 33 | x2, y2 = x1 + ww, y1 + hh 34 | ww, hh = (x2-x1) * area, (y2-y1) * area 35 | center = [(x2+x1)//2, (y2+y1)//2] 36 | x1 = max(center[0] - ww//2, 0) 37 | y1 = max(center[1] - hh//2, 0) 38 | x2 = min(center[0] + ww//2, w) 39 | y2 = min(center[1] + hh//2, h) 40 | 41 | face_mask = np.zeros_like(np.array(image_pil)) 42 | face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 43 | face_masks.append(torch.from_numpy(face_mask[...,:1])) 44 | face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c) 45 | face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) 46 | face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) 47 | return face_masks 48 | 49 | 50 | def encode_audio(wav2vec, audio_feats, fps, num_frames=129): 51 | if fps == 25: 52 | start_ts = [0] 53 | step_ts = [1] 54 | elif fps == 12.5: 55 | start_ts = [0] 56 | step_ts = [2] 57 | num_frames = min(num_frames, 400) 58 | audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states 59 | audio_feats = torch.stack(audio_feats, dim=2) 60 | audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) 61 | 62 | audio_prompts = [] 63 | for bb in range(1): 64 | audio_feats_list = [] 65 | for f in range(num_frames): 66 | cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 67 | audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] 68 | audio_feats_list.append(audio_clip) 69 | audio_feats_list = torch.stack(audio_feats_list, 1) 70 | audio_prompts.append(audio_feats_list) 71 | audio_prompts = torch.cat(audio_prompts) 72 | return audio_prompts -------------------------------------------------------------------------------- /hymm_sp/data_kits/data_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import imageio 6 | import torchvision 7 | from einops import rearrange 8 | 9 | 10 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): 11 | videos = rearrange(videos, "b c t h w -> t b c h w") 12 | outputs = [] 13 | for x in videos: 14 | x = torchvision.utils.make_grid(x, nrow=n_rows) 15 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 16 | if rescale: 17 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 18 | x = torch.clamp(x,0,1) 19 | x = (x * 255).numpy().astype(np.uint8) 20 | outputs.append(x) 21 | 22 | os.makedirs(os.path.dirname(path), exist_ok=True) 23 | imageio.mimsave(path, outputs, fps=fps, quality=quality) 24 | 25 | def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): 26 | crop_h, crop_w = crop_img.shape[:2] 27 | target_w, target_h = size 28 | scale_h, scale_w = target_h / crop_h, target_w / crop_w 29 | if scale_w > scale_h: 30 | resize_h = int(target_h*resize_ratio) 31 | resize_w = int(crop_w / crop_h * resize_h) 32 | else: 33 | resize_w = int(target_w*resize_ratio) 34 | resize_h = int(crop_h / crop_w * resize_w) 35 | crop_img = cv2.resize(crop_img, (resize_w, resize_h)) 36 | pad_left = (target_w - resize_w) // 2 37 | pad_top = (target_h - resize_h) // 2 38 | pad_right = target_w - resize_w - pad_left 39 | pad_bottom = target_h - resize_h - pad_top 40 | crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) 41 | return crop_img -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/__init__.py: -------------------------------------------------------------------------------- 1 | from .align import AlignImage -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/align.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/data_kits/face_align/__pycache__/detface.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from .detface import DetFace 5 | 6 | class AlignImage(object): 7 | def __init__(self, device='cuda', det_path=''): 8 | self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device) 9 | 10 | @torch.no_grad() 11 | def __call__(self, im, maxface=False): 12 | bboxes, kpss, scores = self.facedet.detect(im) 13 | face_num = bboxes.shape[0] 14 | 15 | five_pts_list = [] 16 | scores_list = [] 17 | bboxes_list = [] 18 | for i in range(face_num): 19 | five_pts_list.append(kpss[i].reshape(5,2)) 20 | scores_list.append(scores[i]) 21 | bboxes_list.append(bboxes[i]) 22 | 23 | if maxface and face_num>1: 24 | max_idx = 0 25 | max_area = (bboxes[0, 2])*(bboxes[0, 3]) 26 | for i in range(1, face_num): 27 | area = (bboxes[i,2])*(bboxes[i,3]) 28 | if area>max_area: 29 | max_idx = i 30 | five_pts_list = [five_pts_list[max_idx]] 31 | scores_list = [scores_list[max_idx]] 32 | bboxes_list = [bboxes_list[max_idx]] 33 | 34 | return five_pts_list, scores_list, bboxes_list -------------------------------------------------------------------------------- /hymm_sp/data_kits/face_align/detface.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | 8 | 9 | def xyxy2xywh(x): 10 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right 11 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 12 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center 13 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center 14 | y[:, 2] = x[:, 2] - x[:, 0] # width 15 | y[:, 3] = x[:, 3] - x[:, 1] # height 16 | return y 17 | 18 | 19 | def xywh2xyxy(x): 20 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 21 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 22 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 23 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 24 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 25 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 26 | return y 27 | 28 | 29 | def box_iou(box1, box2): 30 | # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py 31 | """ 32 | Return intersection-over-union (Jaccard index) of boxes. 33 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 34 | Arguments: 35 | box1 (Tensor[N, 4]) 36 | box2 (Tensor[M, 4]) 37 | Returns: 38 | iou (Tensor[N, M]): the NxM matrix containing the pairwise 39 | IoU values for every element in boxes1 and boxes2 40 | """ 41 | 42 | def box_area(box): 43 | # box = 4xn 44 | return (box[2] - box[0]) * (box[3] - box[1]) 45 | 46 | area1 = box_area(box1.T) 47 | area2 = box_area(box2.T) 48 | 49 | # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) 50 | inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - 51 | torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) 52 | # iou = inter / (area1 + area2 - inter) 53 | return inter / (area1[:, None] + area2 - inter) 54 | 55 | 56 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): 57 | # Rescale coords (xyxy) from img1_shape to img0_shape 58 | if ratio_pad is None: # calculate from img0_shape 59 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 60 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 61 | else: 62 | gain = ratio_pad[0][0] 63 | pad = ratio_pad[1] 64 | 65 | coords[:, [0, 2]] -= pad[0] # x padding 66 | coords[:, [1, 3]] -= pad[1] # y padding 67 | coords[:, :4] /= gain 68 | clip_coords(coords, img0_shape) 69 | return coords 70 | 71 | 72 | def clip_coords(boxes, img_shape): 73 | # Clip bounding xyxy bounding boxes to image shape (height, width) 74 | boxes[:, 0].clamp_(0, img_shape[1]) # x1 75 | boxes[:, 1].clamp_(0, img_shape[0]) # y1 76 | boxes[:, 2].clamp_(0, img_shape[1]) # x2 77 | boxes[:, 3].clamp_(0, img_shape[0]) # y2 78 | 79 | 80 | def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): 81 | # Rescale coords (xyxy) from img1_shape to img0_shape 82 | if ratio_pad is None: # calculate from img0_shape 83 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 84 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 85 | else: 86 | gain = ratio_pad[0][0] 87 | pad = ratio_pad[1] 88 | 89 | coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding 90 | coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding 91 | coords[:, :10] /= gain 92 | #clip_coords(coords, img0_shape) 93 | coords[:, 0].clamp_(0, img0_shape[1]) # x1 94 | coords[:, 1].clamp_(0, img0_shape[0]) # y1 95 | coords[:, 2].clamp_(0, img0_shape[1]) # x2 96 | coords[:, 3].clamp_(0, img0_shape[0]) # y2 97 | coords[:, 4].clamp_(0, img0_shape[1]) # x3 98 | coords[:, 5].clamp_(0, img0_shape[0]) # y3 99 | coords[:, 6].clamp_(0, img0_shape[1]) # x4 100 | coords[:, 7].clamp_(0, img0_shape[0]) # y4 101 | coords[:, 8].clamp_(0, img0_shape[1]) # x5 102 | coords[:, 9].clamp_(0, img0_shape[0]) # y5 103 | return coords 104 | 105 | 106 | def show_results(img, xywh, conf, landmarks, class_num): 107 | h,w,c = img.shape 108 | tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness 109 | x1 = int(xywh[0] * w - 0.5 * xywh[2] * w) 110 | y1 = int(xywh[1] * h - 0.5 * xywh[3] * h) 111 | x2 = int(xywh[0] * w + 0.5 * xywh[2] * w) 112 | y2 = int(xywh[1] * h + 0.5 * xywh[3] * h) 113 | cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA) 114 | 115 | clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)] 116 | 117 | for i in range(5): 118 | point_x = int(landmarks[2 * i] * w) 119 | point_y = int(landmarks[2 * i + 1] * h) 120 | cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1) 121 | 122 | tf = max(tl - 1, 1) # font thickness 123 | label = str(conf)[:5] 124 | cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 125 | return img 126 | 127 | 128 | def make_divisible(x, divisor): 129 | # Returns x evenly divisible by divisor 130 | return (x // divisor) * divisor 131 | 132 | 133 | def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()): 134 | """Performs Non-Maximum Suppression (NMS) on inference results 135 | Returns: 136 | detections with shape: nx6 (x1, y1, x2, y2, conf, cls) 137 | """ 138 | 139 | nc = prediction.shape[2] - 15 # number of classes 140 | xc = prediction[..., 4] > conf_thres # candidates 141 | 142 | # Settings 143 | min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height 144 | # time_limit = 10.0 # seconds to quit after 145 | redundant = True # require redundant detections 146 | multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) 147 | merge = False # use merge-NMS 148 | 149 | # t = time.time() 150 | output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] 151 | for xi, x in enumerate(prediction): # image index, image inference 152 | # Apply constraints 153 | # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height 154 | x = x[xc[xi]] # confidence 155 | 156 | # Cat apriori labels if autolabelling 157 | if labels and len(labels[xi]): 158 | l = labels[xi] 159 | v = torch.zeros((len(l), nc + 15), device=x.device) 160 | v[:, :4] = l[:, 1:5] # box 161 | v[:, 4] = 1.0 # conf 162 | v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls 163 | x = torch.cat((x, v), 0) 164 | 165 | # If none remain process next image 166 | if not x.shape[0]: 167 | continue 168 | 169 | # Compute conf 170 | x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf 171 | 172 | # Box (center x, center y, width, height) to (x1, y1, x2, y2) 173 | box = xywh2xyxy(x[:, :4]) 174 | 175 | # Detections matrix nx6 (xyxy, conf, landmarks, cls) 176 | if multi_label: 177 | i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T 178 | x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1) 179 | else: # best class only 180 | conf, j = x[:, 15:].max(1, keepdim=True) 181 | x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] 182 | 183 | # Filter by class 184 | if classes is not None: 185 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 186 | 187 | # If none remain process next image 188 | n = x.shape[0] # number of boxes 189 | if not n: 190 | continue 191 | 192 | # Batched NMS 193 | c = x[:, 15:16] * (0 if agnostic else max_wh) # classes 194 | boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores 195 | i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS 196 | #if i.shape[0] > max_det: # limit detections 197 | # i = i[:max_det] 198 | if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) 199 | # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) 200 | iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix 201 | weights = iou * scores[None] # box weights 202 | x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes 203 | if redundant: 204 | i = i[iou.sum(1) > 1] # require redundancy 205 | 206 | output[xi] = x[i] 207 | # if (time.time() - t) > time_limit: 208 | # break # time limit exceeded 209 | 210 | return output 211 | 212 | 213 | class DetFace(): 214 | def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'): 215 | assert os.path.exists(pt_path) 216 | 217 | self.inpSize = 416 218 | self.conf_thres = confThreshold 219 | self.iou_thres = nmsThreshold 220 | self.test_device = torch.device(device if torch.cuda.is_available() else "cpu") 221 | self.model = torch.jit.load(pt_path).to(self.test_device) 222 | self.last_w = 416 223 | self.last_h = 416 224 | self.grids = None 225 | 226 | @torch.no_grad() 227 | def detect(self, srcimg): 228 | # t0=time.time() 229 | 230 | h0, w0 = srcimg.shape[:2] # orig hw 231 | r = self.inpSize / min(h0, w0) # resize image to img_size 232 | h1 = int(h0*r+31)//32*32 233 | w1 = int(w0*r+31)//32*32 234 | 235 | img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR) 236 | 237 | # Convert 238 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB 239 | 240 | # Run inference 241 | img = torch.from_numpy(img).to(self.test_device).permute(2,0,1) 242 | img = img.float()/255 # uint8 to fp16/32 0-1 243 | if img.ndimension() == 3: 244 | img = img.unsqueeze(0) 245 | 246 | # Inference 247 | if h1 != self.last_h or w1 != self.last_w or self.grids is None: 248 | grids = [] 249 | for scale in [8,16,32]: 250 | ny = h1//scale 251 | nx = w1//scale 252 | yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) 253 | grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float() 254 | grids.append(grid.to(self.test_device)) 255 | self.grids = grids 256 | self.last_w = w1 257 | self.last_h = h1 258 | 259 | pred = self.model(img, self.grids).cpu() 260 | 261 | # Apply NMS 262 | det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0] 263 | # Process detections 264 | # det = pred[0] 265 | bboxes = np.zeros((det.shape[0], 4)) 266 | kpss = np.zeros((det.shape[0], 5, 2)) 267 | scores = np.zeros((det.shape[0])) 268 | # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh 269 | # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks 270 | det = det.cpu().numpy() 271 | 272 | for j in range(det.shape[0]): 273 | # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy() 274 | bboxes[j, 0] = det[j, 0] * w0/w1 275 | bboxes[j, 1] = det[j, 1] * h0/h1 276 | bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0] 277 | bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1] 278 | scores[j] = det[j, 4] 279 | # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy() 280 | kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]]) 281 | # class_num = det[j, 15].cpu().numpy() 282 | # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num) 283 | return bboxes, kpss, scores 284 | -------------------------------------------------------------------------------- /hymm_sp/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipelines import HunyuanVideoAudioPipeline 2 | from .schedulers import FlowMatchDiscreteScheduler 3 | 4 | 5 | def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None, 6 | device=None, progress_bar_config=None): 7 | """ Load the denoising scheduler for inference. """ 8 | if scheduler is None: 9 | scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift_eval_video, reverse=args.flow_reverse, solver=args.flow_solver, ) 10 | 11 | # Only enable progress bar for rank 0 12 | progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0} 13 | 14 | pipeline = HunyuanVideoAudioPipeline(vae=vae, 15 | text_encoder=text_encoder, 16 | text_encoder_2=text_encoder_2, 17 | transformer=model, 18 | scheduler=scheduler, 19 | # safety_checker=None, 20 | # feature_extractor=None, 21 | # requires_safety_checker=False, 22 | progress_bar_config=progress_bar_config, 23 | args=args, 24 | ) 25 | if args.cpu_offload: # avoid oom 26 | pass 27 | else: 28 | pipeline = pipeline.to(device) 29 | 30 | return pipeline -------------------------------------------------------------------------------- /hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/diffusion/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline 2 | -------------------------------------------------------------------------------- /hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/pipelines/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/pipelines/__pycache__/pipeline_hunyuan_video_audio.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/diffusion/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler -------------------------------------------------------------------------------- /hymm_sp/diffusion/schedulers/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/schedulers/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/diffusion/schedulers/__pycache__/scheduling_flow_match_discrete.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Modified from diffusers==0.29.2 17 | # 18 | # ============================================================================== 19 | 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | 25 | from diffusers.configuration_utils import ConfigMixin, register_to_config 26 | from diffusers.utils import BaseOutput, logging 27 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 28 | 29 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 31 | 32 | 33 | @dataclass 34 | class FlowMatchDiscreteSchedulerOutput(BaseOutput): 35 | """ 36 | Output class for the scheduler's `step` function output. 37 | 38 | Args: 39 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 40 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 41 | denoising loop. 42 | """ 43 | 44 | prev_sample: torch.FloatTensor 45 | 46 | 47 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): 48 | """ 49 | Euler scheduler. 50 | 51 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 52 | methods the library implements for all schedulers such as loading and saving. 53 | 54 | Args: 55 | num_train_timesteps (`int`, defaults to 1000): 56 | The number of diffusion steps to train the model. 57 | timestep_spacing (`str`, defaults to `"linspace"`): 58 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 59 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 60 | shift (`float`, defaults to 1.0): 61 | The shift value for the timestep schedule. 62 | reverse (`bool`, defaults to `True`): 63 | Whether to reverse the timestep schedule. 64 | """ 65 | 66 | _compatibles = [] 67 | order = 1 68 | 69 | @register_to_config 70 | def __init__( 71 | self, 72 | num_train_timesteps: int = 1000, 73 | shift: float = 1.0, 74 | reverse: bool = True, 75 | solver: str = "euler", 76 | n_tokens: Optional[int] = None, 77 | ): 78 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1) 79 | 80 | if not reverse: 81 | sigmas = sigmas.flip(0) 82 | 83 | self.sigmas = sigmas 84 | # the value fed to model 85 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) 86 | 87 | self._step_index = None 88 | self._begin_index = None 89 | 90 | self.supported_solver = ["euler"] 91 | if solver not in self.supported_solver: 92 | raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}") 93 | 94 | @property 95 | def step_index(self): 96 | """ 97 | The index counter for current timestep. It will increase 1 after each scheduler step. 98 | """ 99 | return self._step_index 100 | 101 | @property 102 | def begin_index(self): 103 | """ 104 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 105 | """ 106 | return self._begin_index 107 | 108 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 109 | def set_begin_index(self, begin_index: int = 0): 110 | """ 111 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 112 | 113 | Args: 114 | begin_index (`int`): 115 | The begin index for the scheduler. 116 | """ 117 | self._begin_index = begin_index 118 | 119 | def _sigma_to_t(self, sigma): 120 | return sigma * self.config.num_train_timesteps 121 | 122 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, 123 | n_tokens: int = None): 124 | """ 125 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 126 | 127 | Args: 128 | num_inference_steps (`int`): 129 | The number of diffusion steps used when generating samples with a pre-trained model. 130 | device (`str` or `torch.device`, *optional*): 131 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 132 | n_tokens (`int`, *optional*): 133 | Number of tokens in the input sequence. 134 | """ 135 | self.num_inference_steps = num_inference_steps 136 | 137 | sigmas = torch.linspace(1, 0, num_inference_steps + 1) 138 | sigmas = self.sd3_time_shift(sigmas) 139 | 140 | if not self.config.reverse: 141 | sigmas = 1 - sigmas 142 | 143 | self.sigmas = sigmas 144 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device) 145 | 146 | # Reset step index 147 | self._step_index = None 148 | 149 | def index_for_timestep(self, timestep, schedule_timesteps=None): 150 | if schedule_timesteps is None: 151 | schedule_timesteps = self.timesteps 152 | 153 | indices = (schedule_timesteps == timestep).nonzero() 154 | 155 | # The sigma index that is taken for the **very** first `step` 156 | # is always the second index (or the last index if there is only 1) 157 | # This way we can ensure we don't accidentally skip a sigma in 158 | # case we start in the middle of the denoising schedule (e.g. for image-to-image) 159 | pos = 1 if len(indices) > 1 else 0 160 | 161 | return indices[pos].item() 162 | 163 | def _init_step_index(self, timestep): 164 | if self.begin_index is None: 165 | if isinstance(timestep, torch.Tensor): 166 | timestep = timestep.to(self.timesteps.device) 167 | self._step_index = self.index_for_timestep(timestep) 168 | else: 169 | self._step_index = self._begin_index 170 | 171 | def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: 172 | return sample 173 | 174 | def sd3_time_shift(self, t: torch.Tensor): 175 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) 176 | 177 | def step( 178 | self, 179 | model_output: torch.FloatTensor, 180 | timestep: Union[float, torch.FloatTensor], 181 | sample: torch.FloatTensor, 182 | return_dict: bool = True, 183 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: 184 | """ 185 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 186 | process from the learned model outputs (most often the predicted noise). 187 | 188 | Args: 189 | model_output (`torch.FloatTensor`): 190 | The direct output from learned diffusion model. 191 | timestep (`float`): 192 | The current discrete timestep in the diffusion chain. 193 | sample (`torch.FloatTensor`): 194 | A current instance of a sample created by the diffusion process. 195 | return_dict (`bool`): 196 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 197 | tuple. 198 | 199 | Returns: 200 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 201 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 202 | returned, otherwise a tuple is returned where the first element is the sample tensor. 203 | """ 204 | 205 | if ( 206 | isinstance(timestep, int) 207 | or isinstance(timestep, torch.IntTensor) 208 | or isinstance(timestep, torch.LongTensor) 209 | ): 210 | raise ValueError( 211 | ( 212 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 213 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 214 | " one of the `scheduler.timesteps` as a timestep." 215 | ), 216 | ) 217 | 218 | if self.step_index is None: 219 | self._init_step_index(timestep) 220 | 221 | # Upcast to avoid precision issues when computing prev_sample 222 | sample = sample.to(torch.float32) 223 | 224 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 225 | 226 | if self.config.solver == "euler": 227 | prev_sample = sample + model_output.float() * dt 228 | else: 229 | raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}") 230 | 231 | # upon completion increase step index by one 232 | self._step_index += 1 233 | 234 | if not return_dict: 235 | return (prev_sample,) 236 | 237 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) 238 | 239 | def __len__(self): 240 | return self.config.num_train_timesteps 241 | -------------------------------------------------------------------------------- /hymm_sp/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, List 3 | from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd 4 | 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 12 | x = tuple(x) 13 | if len(x) == 1: 14 | x = tuple(repeat(x[0], n)) 15 | return x 16 | return tuple(repeat(x, n)) 17 | return parse 18 | 19 | to_1tuple = _ntuple(1) 20 | to_2tuple = _ntuple(2) 21 | to_3tuple = _ntuple(3) 22 | to_4tuple = _ntuple(4) 23 | 24 | def get_rope_freq_from_size(latents_size, ndim, target_ndim, args, 25 | rope_theta_rescale_factor: Union[float, List[float]]=1.0, 26 | rope_interpolation_factor: Union[float, List[float]]=1.0, 27 | concat_dict={}): 28 | 29 | if isinstance(args.patch_size, int): 30 | assert all(s % args.patch_size == 0 for s in latents_size), \ 31 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \ 32 | f"but got {latents_size}." 33 | rope_sizes = [s // args.patch_size for s in latents_size] 34 | elif isinstance(args.patch_size, list): 35 | assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ 36 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({args.patch_size}), " \ 37 | f"but got {latents_size}." 38 | rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)] 39 | 40 | if len(rope_sizes) != target_ndim: 41 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis 42 | head_dim = args.hidden_size // args.num_heads 43 | rope_dim_list = args.rope_dim_list 44 | if rope_dim_list is None: 45 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] 46 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" 47 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 48 | rope_sizes, 49 | theta=args.rope_theta, 50 | use_real=True, 51 | theta_rescale_factor=rope_theta_rescale_factor, 52 | interpolation_factor=rope_interpolation_factor, 53 | concat_dict=concat_dict) 54 | return freqs_cos, freqs_sin 55 | 56 | def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False, 57 | theta_rescale_factor: Union[float, List[float]]=1.0, 58 | interpolation_factor: Union[float, List[float]]=1.0, 59 | concat_dict={} 60 | ): 61 | 62 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 63 | if len(concat_dict)<1: 64 | pass 65 | else: 66 | if concat_dict['mode']=='timecat': 67 | bias = grid[:,:1].clone() 68 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 69 | grid = torch.cat([bias, grid], dim=1) 70 | 71 | elif concat_dict['mode']=='timecat-w': 72 | bias = grid[:,:1].clone() 73 | bias[0] = concat_dict['bias']*torch.ones_like(bias[0]) 74 | bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178 75 | grid = torch.cat([bias, grid], dim=1) 76 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 77 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 78 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 79 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 80 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 81 | 82 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 83 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 84 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 85 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 86 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" 87 | 88 | # use 1/ndim of dimensions to encode grid_axis 89 | embs = [] 90 | for i in range(len(rope_dim_list)): 91 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, 92 | theta_rescale_factor=theta_rescale_factor[i], 93 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] 94 | 95 | embs.append(emb) 96 | 97 | if use_real: 98 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 99 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 100 | return cos, sin 101 | else: 102 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 103 | return emb -------------------------------------------------------------------------------- /hymm_sp/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from loguru import logger 4 | from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE 5 | from hymm_sp.vae import load_vae 6 | from hymm_sp.modules import load_model 7 | from hymm_sp.text_encoder import TextEncoder 8 | import torch.distributed 9 | from hymm_sp.modules.parallel_states import ( 10 | nccl_info, 11 | ) 12 | from hymm_sp.modules.fp8_optimization import convert_fp8_linear 13 | 14 | 15 | class Inference(object): 16 | def __init__(self, 17 | args, 18 | vae, 19 | vae_kwargs, 20 | text_encoder, 21 | model, 22 | text_encoder_2=None, 23 | pipeline=None, 24 | cpu_offload=False, 25 | device=None, 26 | logger=None): 27 | self.vae = vae 28 | self.vae_kwargs = vae_kwargs 29 | 30 | self.text_encoder = text_encoder 31 | self.text_encoder_2 = text_encoder_2 32 | 33 | self.model = model 34 | self.pipeline = pipeline 35 | self.cpu_offload = cpu_offload 36 | 37 | self.args = args 38 | self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" 39 | if nccl_info.sp_size > 1: 40 | self.device = torch.device(f"cuda:{torch.distributed.get_rank()}") 41 | 42 | self.logger = logger 43 | 44 | @classmethod 45 | def from_pretrained(cls, 46 | pretrained_model_path, 47 | args, 48 | device=None, 49 | **kwargs): 50 | """ 51 | Initialize the Inference pipeline. 52 | 53 | Args: 54 | pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. 55 | device (int): The device for inference. Default is 0. 56 | logger (logging.Logger): The logger for the inference pipeline. Default is None. 57 | """ 58 | # ======================================================================== 59 | logger.info(f"Got text-to-video model root path: {pretrained_model_path}") 60 | 61 | # ======================== Get the args path ============================= 62 | 63 | # Set device and disable gradient 64 | if device is None: 65 | device = "cuda" if torch.cuda.is_available() else "cpu" 66 | torch.set_grad_enabled(False) 67 | logger.info("Building model...") 68 | factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]} 69 | in_channels = args.latent_channels 70 | out_channels = args.latent_channels 71 | print("="*25, f"build model", "="*25) 72 | model = load_model( 73 | args, 74 | in_channels=in_channels, 75 | out_channels=out_channels, 76 | factor_kwargs=factor_kwargs 77 | ) 78 | if args.use_fp8: 79 | convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision]) 80 | if args.cpu_offload: 81 | print(f'='*20, f'load transformer to cpu') 82 | model = model.to('cpu') 83 | torch.cuda.empty_cache() 84 | else: 85 | model = model.to(device) 86 | model = Inference.load_state_dict(args, model, pretrained_model_path) 87 | model.eval() 88 | 89 | # ============================= Build extra models ======================== 90 | # VAE 91 | print("="*25, f"load vae", "="*25) 92 | vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device) 93 | vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio} 94 | 95 | # Text encoder 96 | if args.prompt_template_video is not None: 97 | crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) 98 | else: 99 | crop_start = 0 100 | max_length = args.text_len + crop_start 101 | 102 | # prompt_template_video 103 | prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None 104 | print("="*25, f"load llava", "="*25) 105 | text_encoder = TextEncoder(text_encoder_type = args.text_encoder, 106 | max_length = max_length, 107 | text_encoder_precision = args.text_encoder_precision, 108 | tokenizer_type = args.tokenizer, 109 | use_attention_mask = args.use_attention_mask, 110 | prompt_template_video = prompt_template_video, 111 | hidden_state_skip_layer = args.hidden_state_skip_layer, 112 | apply_final_norm = args.apply_final_norm, 113 | reproduce = args.reproduce, 114 | logger = logger, 115 | device = 'cpu' if args.cpu_offload else device , 116 | ) 117 | text_encoder_2 = None 118 | if args.text_encoder_2 is not None: 119 | text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2, 120 | max_length=args.text_len_2, 121 | text_encoder_precision=args.text_encoder_precision_2, 122 | tokenizer_type=args.tokenizer_2, 123 | use_attention_mask=args.use_attention_mask, 124 | reproduce=args.reproduce, 125 | logger=logger, 126 | device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu' 127 | ) 128 | 129 | return cls(args=args, 130 | vae=vae, 131 | vae_kwargs=vae_kwargs, 132 | text_encoder=text_encoder, 133 | model=model, 134 | text_encoder_2=text_encoder_2, 135 | device=device, 136 | logger=logger) 137 | 138 | @staticmethod 139 | def load_state_dict(args, model, ckpt_path): 140 | load_key = args.load_key 141 | ckpt_path = Path(ckpt_path) 142 | if ckpt_path.is_dir(): 143 | ckpt_path = next(ckpt_path.glob("*_model_states.pt")) 144 | state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage) 145 | if load_key in state_dict: 146 | state_dict = state_dict[load_key] 147 | elif load_key == ".": 148 | pass 149 | else: 150 | raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}") 151 | model.load_state_dict(state_dict, strict=False) 152 | return model 153 | 154 | def get_exp_dir_and_ckpt_id(self): 155 | if self.ckpt is None: 156 | raise ValueError("The checkpoint path is not provided.") 157 | 158 | ckpt = Path(self.ckpt) 159 | if ckpt.parents[1].name == "checkpoints": 160 | # It should be a standard checkpoint path. We use the parent directory as the default save directory. 161 | exp_dir = ckpt.parents[2] 162 | else: 163 | raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. " 164 | f"It seems that the checkpoint path is not standard. Please explicitly provide the " 165 | f"save path by --save-path.") 166 | return exp_dir, ckpt.parent.name 167 | 168 | @staticmethod 169 | def parse_size(size): 170 | if isinstance(size, int): 171 | size = [size] 172 | if not isinstance(size, (list, tuple)): 173 | raise ValueError(f"Size must be an integer or (height, width), got {size}.") 174 | if len(size) == 1: 175 | size = [size[0], size[0]] 176 | if len(size) != 2: 177 | raise ValueError(f"Size must be an integer or (height, width), got {size}.") 178 | return size 179 | -------------------------------------------------------------------------------- /hymm_sp/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_audio import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG 2 | 3 | def load_model(args, in_channels, out_channels, factor_kwargs): 4 | model = HYVideoDiffusionTransformer( 5 | args, 6 | in_channels=in_channels, 7 | out_channels=out_channels, 8 | **HUNYUAN_VIDEO_CONFIG[args.model], 9 | **factor_kwargs, 10 | ) 11 | return model 12 | -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/activation_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/activation_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/attn_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/attn_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/audio_adapters.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/audio_adapters.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/embed_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/embed_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/fp8_optimization.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/fp8_optimization.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/mlp_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/mlp_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/models_audio.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/models_audio.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/modulate_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/modulate_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/norm_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/norm_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/parallel_states.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/parallel_states.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/posemb_layers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/posemb_layers.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/__pycache__/token_refiner.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/modules/__pycache__/token_refiner.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/modules/activation_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def get_activation_layer(act_type): 5 | """get activation layer 6 | 7 | Args: 8 | act_type (str): the activation type 9 | 10 | Returns: 11 | torch.nn.functional: the activation layer 12 | """ 13 | if act_type == "gelu": 14 | return lambda: nn.GELU() 15 | elif act_type == "gelu_tanh": 16 | # Approximate `tanh` requires torch >= 1.13 17 | return lambda: nn.GELU(approximate="tanh") 18 | elif act_type == "relu": 19 | return nn.ReLU 20 | elif act_type == "silu": 21 | return nn.SiLU 22 | else: 23 | raise ValueError(f"Unknown activation type: {act_type}") -------------------------------------------------------------------------------- /hymm_sp/modules/audio_adapters.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the implementation of an Audio Projection Model, which is designed for 3 | audio processing tasks. The model takes audio embeddings as input and outputs context tokens 4 | that can be used for various downstream applications, such as audio analysis or synthesis. 5 | 6 | The AudioProjModel class is based on the ModelMixin class from the diffusers library, which 7 | provides a foundation for building custom models. This implementation includes multiple linear 8 | layers with ReLU activation functions and a LayerNorm for normalization. 9 | 10 | Key Features: 11 | - Audio embedding input with flexible sequence length and block structure. 12 | - Multiple linear layers for feature transformation. 13 | - ReLU activation for non-linear transformation. 14 | - LayerNorm for stabilizing and speeding up training. 15 | - Rearrangement of input embeddings to match the model's expected input shape. 16 | - Customizable number of blocks, channels, and context tokens for adaptability. 17 | 18 | The module is structured to be easily integrated into larger systems or used as a standalone 19 | component for audio feature extraction and processing. 20 | 21 | Classes: 22 | - AudioProjModel: A class representing the audio projection model with configurable parameters. 23 | 24 | Functions: 25 | - (none) 26 | 27 | Dependencies: 28 | - torch: For tensor operations and neural network components. 29 | - diffusers: For the ModelMixin base class. 30 | - einops: For tensor rearrangement operations. 31 | 32 | """ 33 | 34 | import torch 35 | from diffusers import ModelMixin 36 | from einops import rearrange 37 | 38 | import math 39 | import torch.nn as nn 40 | from .parallel_states import ( 41 | initialize_sequence_parallel_state, 42 | nccl_info, 43 | get_sequence_parallel_state, 44 | parallel_attention, 45 | all_gather, 46 | all_to_all_4D, 47 | ) 48 | 49 | class AudioProjNet2(ModelMixin): 50 | """Audio Projection Model 51 | 52 | This class defines an audio projection model that takes audio embeddings as input 53 | and produces context tokens as output. The model is based on the ModelMixin class 54 | and consists of multiple linear layers and activation functions. It can be used 55 | for various audio processing tasks. 56 | 57 | Attributes: 58 | seq_len (int): The length of the audio sequence. 59 | blocks (int): The number of blocks in the audio projection model. 60 | channels (int): The number of channels in the audio projection model. 61 | intermediate_dim (int): The intermediate dimension of the model. 62 | context_tokens (int): The number of context tokens in the output. 63 | output_dim (int): The output dimension of the context tokens. 64 | 65 | Methods: 66 | __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): 67 | Initializes the AudioProjModel with the given parameters. 68 | forward(self, audio_embeds): 69 | Defines the forward pass for the AudioProjModel. 70 | Parameters: 71 | audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). 72 | Returns: 73 | context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). 74 | 75 | """ 76 | 77 | def __init__( 78 | self, 79 | seq_len=5, 80 | blocks=12, # add a new parameter blocks 81 | channels=768, # add a new parameter channels 82 | intermediate_dim=512, 83 | output_dim=768, 84 | context_tokens=4, 85 | ): 86 | super().__init__() 87 | 88 | self.seq_len = seq_len 89 | self.blocks = blocks 90 | self.channels = channels 91 | self.input_dim = ( 92 | seq_len * blocks * channels 93 | ) 94 | self.intermediate_dim = intermediate_dim 95 | self.context_tokens = context_tokens 96 | self.output_dim = output_dim 97 | 98 | # define multiple linear layers 99 | self.proj1 = nn.Linear(self.input_dim, intermediate_dim) 100 | self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) 101 | self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) 102 | 103 | self.norm = nn.LayerNorm(output_dim) 104 | 105 | 106 | def forward(self, audio_embeds): 107 | 108 | video_length = audio_embeds.shape[1] 109 | audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") 110 | batch_size, window_size, blocks, channels = audio_embeds.shape 111 | audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) 112 | 113 | audio_embeds = torch.relu(self.proj1(audio_embeds)) 114 | audio_embeds = torch.relu(self.proj2(audio_embeds)) 115 | 116 | context_tokens = self.proj3(audio_embeds).reshape( 117 | batch_size, self.context_tokens, self.output_dim 118 | ) 119 | context_tokens = self.norm(context_tokens) 120 | out_all = rearrange( 121 | context_tokens, "(bz f) m c -> bz f m c", f=video_length 122 | ) 123 | 124 | return out_all 125 | 126 | 127 | def reshape_tensor(x, heads): 128 | bs, length, width = x.shape 129 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 130 | x = x.view(bs, length, heads, -1) 131 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 132 | x = x.transpose(1, 2) 133 | # (bs, n_heads, length, dim_per_head) 134 | x = x.reshape(bs, heads, length, -1) 135 | return x 136 | 137 | 138 | class PerceiverAttentionCA(nn.Module): 139 | def __init__(self, *, dim=3072, dim_head=1024, heads=33): 140 | super().__init__() 141 | self.scale = dim_head ** -0.5 142 | self.dim_head = dim_head 143 | self.heads = heads 144 | inner_dim = dim_head #* heads 145 | 146 | self.norm1 = nn.LayerNorm(dim) 147 | self.norm2 = nn.LayerNorm(dim) 148 | 149 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 150 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 151 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 152 | 153 | import torch.nn.init as init 154 | init.zeros_(self.to_out.weight) 155 | if self.to_out.bias is not None: 156 | init.zeros_(self.to_out.bias) 157 | 158 | def forward(self, x, latents): 159 | """ 160 | Args: 161 | x (torch.Tensor): image features 162 | shape (b, t, aa, D) 163 | latent (torch.Tensor): latent features 164 | shape (b, t, hw, D) 165 | """ 166 | x = self.norm1(x) 167 | latents = self.norm2(latents) 168 | # print("latents shape: ", latents.shape) 169 | # print("x shape: ", x.shape) 170 | q = self.to_q(latents) 171 | k, v = self.to_kv(x).chunk(2, dim=-1) 172 | 173 | 174 | # attention 175 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 176 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 177 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 178 | out = weight @ v 179 | 180 | # out = out.permute(0, 2, 1, 3) 181 | return self.to_out(out) 182 | #def forward(self, x, latents): 183 | # """ 184 | # Args: 185 | # x (torch.Tensor): image features 186 | # shape (b, t, aa, D) 187 | # latent (torch.Tensor): latent features 188 | # shape (b, t, hw, D) 189 | # """ 190 | # if get_sequence_parallel_state(): 191 | # sp_size = nccl_info.sp_size 192 | # sp_rank = nccl_info.rank_within_group 193 | # print("rank:", latents.shape, sp_size, sp_rank) 194 | # latents = torch.chunk(latents, sp_size, dim=1)[sp_rank] 195 | 196 | # x = self.norm1(x) 197 | # latents = self.norm2(latents) 198 | # # print("latents shape: ", latents.shape) 199 | # # print("x shape: ", x.shape) 200 | # q = self.to_q(latents) 201 | # k, v = self.to_kv(x).chunk(2, dim=-1) 202 | 203 | # # print("q, k, v: ", q.shape, k.shape, v.shape) 204 | 205 | # # attention 206 | # #scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 207 | # #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 208 | # #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 209 | # #out = weight @ v 210 | # def shrink_head(encoder_state, dim): 211 | # local_heads = encoder_state.shape[dim] // nccl_info.sp_size 212 | # return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) 213 | 214 | # if get_sequence_parallel_state(): 215 | # # batch_size, seq_len, attn_heads, head_dim 216 | # q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128] 217 | # k = shrink_head(k ,dim=2) 218 | # v = shrink_head(v ,dim=2) 219 | # qkv = torch.stack([query, key, value], dim=2) 220 | # attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None) 221 | # # out = out.permute(0, 2, 1, 3) 222 | # #b, s, a, d = attn.shape 223 | # #attn = attn.reshape(b, s, -1) 224 | # 225 | # out = self.to_out(attn) 226 | # if get_sequence_parallel_state(): 227 | # out = all_gather(out, dim=1) 228 | # return out 229 | -------------------------------------------------------------------------------- /hymm_sp/modules/embed_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from hymm_sp.helpers import to_2tuple 5 | 6 | 7 | class PatchEmbed(nn.Module): 8 | """ 2D Image to Patch Embedding 9 | 10 | Image to Patch Embedding using Conv2d 11 | 12 | A convolution based approach to patchifying a 2D image w/ embedding projection. 13 | 14 | Based on the impl in https://github.com/google-research/vision_transformer 15 | 16 | Hacked together by / Copyright 2020 Ross Wightman 17 | 18 | Remove the _assert function in forward function to be compatible with multi-resolution images. 19 | """ 20 | def __init__( 21 | self, 22 | patch_size=16, 23 | in_chans=3, 24 | embed_dim=768, 25 | norm_layer=None, 26 | flatten=True, 27 | bias=True, 28 | dtype=None, 29 | device=None 30 | ): 31 | factory_kwargs = {'dtype': dtype, 'device': device} 32 | super().__init__() 33 | patch_size = to_2tuple(patch_size) 34 | self.patch_size = patch_size 35 | self.flatten = flatten 36 | 37 | self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, 38 | **factory_kwargs) 39 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) 40 | if bias: 41 | nn.init.zeros_(self.proj.bias) 42 | 43 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 44 | 45 | def forward(self, x): 46 | x = self.proj(x) 47 | shape = x.shape 48 | if self.flatten: 49 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 50 | x = self.norm(x) 51 | return x, shape 52 | 53 | 54 | class TextProjection(nn.Module): 55 | """ 56 | Projects text embeddings. Also handles dropout for classifier-free guidance. 57 | 58 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py 59 | """ 60 | 61 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): 62 | factory_kwargs = {'dtype': dtype, 'device': device} 63 | super().__init__() 64 | self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs) 65 | self.act_1 = act_layer() 66 | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs) 67 | 68 | def forward(self, caption): 69 | hidden_states = self.linear_1(caption) 70 | hidden_states = self.act_1(hidden_states) 71 | hidden_states = self.linear_2(hidden_states) 72 | return hidden_states 73 | 74 | 75 | def timestep_embedding(t, dim, max_period=10000): 76 | """ 77 | Create sinusoidal timestep embeddings. 78 | 79 | Args: 80 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. 81 | dim (int): the dimension of the output. 82 | max_period (int): controls the minimum frequency of the embeddings. 83 | 84 | Returns: 85 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. 86 | 87 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 88 | """ 89 | half = dim // 2 90 | freqs = torch.exp( 91 | -math.log(max_period) 92 | * torch.arange(start=0, end=half, dtype=torch.float32) 93 | / half 94 | ).to(device=t.device) 95 | args = t[:, None].float() * freqs[None] 96 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 97 | if dim % 2: 98 | embedding = torch.cat( 99 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 100 | ) 101 | return embedding 102 | 103 | 104 | class TimestepEmbedder(nn.Module): 105 | """ 106 | Embeds scalar timesteps into vector representations. 107 | """ 108 | def __init__(self, 109 | hidden_size, 110 | act_layer, 111 | frequency_embedding_size=256, 112 | max_period=10000, 113 | out_size=None, 114 | dtype=None, 115 | device=None 116 | ): 117 | factory_kwargs = {'dtype': dtype, 'device': device} 118 | super().__init__() 119 | self.frequency_embedding_size = frequency_embedding_size 120 | self.max_period = max_period 121 | if out_size is None: 122 | out_size = hidden_size 123 | 124 | self.mlp = nn.Sequential( 125 | nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs), 126 | act_layer(), 127 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), 128 | ) 129 | nn.init.normal_(self.mlp[0].weight, std=0.02) 130 | nn.init.normal_(self.mlp[2].weight, std=0.02) 131 | 132 | def forward(self, t): 133 | t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) 134 | t_emb = self.mlp(t_freq) 135 | return t_emb -------------------------------------------------------------------------------- /hymm_sp/modules/fp8_optimization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): 8 | _bits = torch.tensor(bits) 9 | _mantissa_bit = torch.tensor(mantissa_bit) 10 | _sign_bits = torch.tensor(sign_bits) 11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) 12 | E = _bits - _sign_bits - M 13 | bias = 2 ** (E - 1) - 1 14 | mantissa = 1 15 | for i in range(mantissa_bit - 1): 16 | mantissa += 1 / (2 ** (i+1)) 17 | maxval = mantissa * 2 ** (2**E - 1 - bias) 18 | return maxval 19 | 20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): 21 | """ 22 | Default is E4M3. 23 | """ 24 | bits = torch.tensor(bits) 25 | mantissa_bit = torch.tensor(mantissa_bit) 26 | sign_bits = torch.tensor(sign_bits) 27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) 28 | E = bits - sign_bits - M 29 | bias = 2 ** (E - 1) - 1 30 | mantissa = 1 31 | for i in range(mantissa_bit - 1): 32 | mantissa += 1 / (2 ** (i+1)) 33 | maxval = mantissa * 2 ** (2**E - 1 - bias) 34 | minval = - maxval 35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) 36 | input_clamp = torch.min(torch.max(x, minval), maxval) 37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) 38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) 39 | # dequant 40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales 41 | return qdq_out, log_scales 42 | 43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): 44 | for i in range(len(x.shape) - 1): 45 | scale = scale.unsqueeze(-1) 46 | new_x = x / scale 47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) 48 | return quant_dequant_x, scale, log_scales 49 | 50 | def fp8_activation_dequant(qdq_out, scale, dtype): 51 | qdq_out = qdq_out.type(dtype) 52 | quant_dequant_x = qdq_out * scale.to(dtype) 53 | return quant_dequant_x 54 | 55 | def fp8_linear_forward(cls, original_dtype, input): 56 | weight_dtype = cls.weight.dtype 57 | ##### 58 | if cls.weight.dtype != torch.float8_e4m3fn: 59 | maxval = get_fp_maxval() 60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval 61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) 62 | linear_weight = linear_weight.to(torch.float8_e4m3fn) 63 | weight_dtype = linear_weight.dtype 64 | else: 65 | scale = cls.fp8_scale.to(cls.weight.device) 66 | linear_weight = cls.weight 67 | ##### 68 | 69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: 70 | if True or len(input.shape) == 3: 71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype) 72 | if cls.bias != None: 73 | output = F.linear(input, cls_dequant, cls.bias) 74 | else: 75 | output = F.linear(input, cls_dequant) 76 | return output 77 | else: 78 | return cls.original_forward(input.to(original_dtype)) 79 | else: 80 | return cls.original_forward(input) 81 | 82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}): 83 | setattr(module, "fp8_matmul_enabled", True) 84 | 85 | # loading fp8 mapping file 86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt') 87 | if os.path.exists(fp8_map_path): 88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)['module'] 89 | else: 90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") 91 | 92 | fp8_layers = [] 93 | for key, layer in module.named_modules(): 94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): 95 | fp8_layers.append(key) 96 | original_forward = layer.forward 97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) 98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype)) 99 | setattr(layer, "original_forward", original_forward) 100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) -------------------------------------------------------------------------------- /hymm_sp/modules/mlp_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from timm library: 2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 3 | 4 | from functools import partial 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .modulate_layers import modulate 10 | from hymm_sp.helpers import to_2tuple 11 | 12 | 13 | class MLP(nn.Module): 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 15 | """ 16 | def __init__(self, 17 | in_channels, 18 | hidden_channels=None, 19 | out_features=None, 20 | act_layer=nn.GELU, 21 | norm_layer=None, 22 | bias=True, 23 | drop=0., 24 | use_conv=False, 25 | device=None, 26 | dtype=None 27 | ): 28 | factory_kwargs = {'device': device, 'dtype': dtype} 29 | super().__init__() 30 | out_features = out_features or in_channels 31 | hidden_channels = hidden_channels or in_channels 32 | bias = to_2tuple(bias) 33 | drop_probs = to_2tuple(drop) 34 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 35 | 36 | self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs) 37 | self.act = act_layer() 38 | self.drop1 = nn.Dropout(drop_probs[0]) 39 | self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity() 40 | self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs) 41 | self.drop2 = nn.Dropout(drop_probs[1]) 42 | 43 | def forward(self, x): 44 | x = self.fc1(x) 45 | x = self.act(x) 46 | x = self.drop1(x) 47 | x = self.norm(x) 48 | x = self.fc2(x) 49 | x = self.drop2(x) 50 | return x 51 | 52 | 53 | class MLPEmbedder(nn.Module): 54 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" 55 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): 56 | factory_kwargs = {'device': device, 'dtype': dtype} 57 | super().__init__() 58 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) 59 | self.silu = nn.SiLU() 60 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) 61 | 62 | def forward(self, x: torch.Tensor) -> torch.Tensor: 63 | return self.out_layer(self.silu(self.in_layer(x))) 64 | 65 | 66 | class FinalLayer(nn.Module): 67 | """The final layer of DiT.""" 68 | 69 | def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None): 70 | factory_kwargs = {'device': device, 'dtype': dtype} 71 | super().__init__() 72 | 73 | # Just use LayerNorm for the final layer 74 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) 75 | if isinstance(patch_size, int): 76 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs) 77 | else: 78 | self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=True) 79 | nn.init.zeros_(self.linear.weight) 80 | nn.init.zeros_(self.linear.bias) 81 | 82 | # Here we don't distinguish between the modulate types. Just use the simple one. 83 | self.adaLN_modulation = nn.Sequential( 84 | act_layer(), 85 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 86 | ) 87 | # Zero-initialize the modulation 88 | nn.init.zeros_(self.adaLN_modulation[1].weight) 89 | nn.init.zeros_(self.adaLN_modulation[1].bias) 90 | 91 | def forward(self, x, c): 92 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 93 | x = modulate(self.norm_final(x), shift=shift, scale=scale) 94 | x = self.linear(x) 95 | return x 96 | -------------------------------------------------------------------------------- /hymm_sp/modules/modulate_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ModulateDiT(nn.Module): 8 | """Modulation layer for DiT.""" 9 | def __init__( 10 | self, 11 | hidden_size: int, 12 | factor: int, 13 | act_layer: Callable, 14 | dtype=None, 15 | device=None, 16 | ): 17 | factory_kwargs = {"dtype": dtype, "device": device} 18 | super().__init__() 19 | self.act = act_layer() 20 | self.linear = nn.Linear( 21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs 22 | ) 23 | # Zero-initialize the modulation 24 | nn.init.zeros_(self.linear.weight) 25 | nn.init.zeros_(self.linear.bias) 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | return self.linear(self.act(x)) 29 | 30 | 31 | def modulate(x, shift=None, scale=None): 32 | """modulate by shift and scale 33 | 34 | Args: 35 | x (torch.Tensor): input tensor. 36 | shift (torch.Tensor, optional): shift tensor. Defaults to None. 37 | scale (torch.Tensor, optional): scale tensor. Defaults to None. 38 | 39 | Returns: 40 | torch.Tensor: the output tensor after modulate. 41 | """ 42 | if scale is None and shift is None: 43 | return x 44 | elif shift is None: 45 | return x * (1 + scale.unsqueeze(1)) 46 | elif scale is None: 47 | return x + shift.unsqueeze(1) 48 | else: 49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 50 | 51 | 52 | def apply_gate(x, gate=None, tanh=False): 53 | """AI is creating summary for apply_gate 54 | 55 | Args: 56 | x (torch.Tensor): input tensor. 57 | gate (torch.Tensor, optional): gate tensor. Defaults to None. 58 | tanh (bool, optional): whether to use tanh function. Defaults to False. 59 | 60 | Returns: 61 | torch.Tensor: the output tensor after apply gate. 62 | """ 63 | if gate is None: 64 | return x 65 | if tanh: 66 | return x * gate.unsqueeze(1).tanh() 67 | else: 68 | return x * gate.unsqueeze(1) 69 | 70 | 71 | def ckpt_wrapper(module): 72 | def ckpt_forward(*inputs): 73 | outputs = module(*inputs) 74 | return outputs 75 | 76 | return ckpt_forward -------------------------------------------------------------------------------- /hymm_sp/modules/norm_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__( 7 | self, 8 | dim: int, 9 | elementwise_affine=True, 10 | eps: float = 1e-6, 11 | device=None, 12 | dtype=None, 13 | ): 14 | """ 15 | Initialize the RMSNorm normalization layer. 16 | 17 | Args: 18 | dim (int): The dimension of the input tensor. 19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 20 | 21 | Attributes: 22 | eps (float): A small value added to the denominator for numerical stability. 23 | weight (nn.Parameter): Learnable scaling parameter. 24 | 25 | """ 26 | factory_kwargs = {"device": device, "dtype": dtype} 27 | super().__init__() 28 | self.eps = eps 29 | if elementwise_affine: 30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) 31 | 32 | def _norm(self, x): 33 | """ 34 | Apply the RMSNorm normalization to the input tensor. 35 | 36 | Args: 37 | x (torch.Tensor): The input tensor. 38 | 39 | Returns: 40 | torch.Tensor: The normalized tensor. 41 | 42 | """ 43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 44 | 45 | def forward(self, x): 46 | """ 47 | Forward pass through the RMSNorm layer. 48 | 49 | Args: 50 | x (torch.Tensor): The input tensor. 51 | 52 | Returns: 53 | torch.Tensor: The output tensor after applying RMSNorm. 54 | 55 | """ 56 | output = self._norm(x.float()).type_as(x) 57 | if hasattr(self, "weight"): 58 | output = output * self.weight 59 | return output 60 | 61 | 62 | def get_norm_layer(norm_layer): 63 | """ 64 | Get the normalization layer. 65 | 66 | Args: 67 | norm_layer (str): The type of normalization layer. 68 | 69 | Returns: 70 | norm_layer (nn.Module): The normalization layer. 71 | """ 72 | if norm_layer == "layer": 73 | return nn.LayerNorm 74 | elif norm_layer == "rms": 75 | return RMSNorm 76 | else: 77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") -------------------------------------------------------------------------------- /hymm_sp/modules/parallel_states.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | import torch.distributed as dist 5 | from typing import Any, Tuple 6 | from torch import Tensor 7 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 8 | 9 | 10 | class COMM_INFO: 11 | def __init__(self): 12 | self.group = None 13 | self.sp_size = 1 14 | self.global_rank = 0 15 | self.rank_within_group = 0 16 | self.group_id = 0 17 | 18 | 19 | nccl_info = COMM_INFO() 20 | _SEQUENCE_PARALLEL_STATE = False 21 | 22 | 23 | def get_cu_seqlens(text_mask, img_len): 24 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len 25 | 26 | Args: 27 | text_mask (torch.Tensor): the mask of text 28 | img_len (int): the length of image 29 | 30 | Returns: 31 | torch.Tensor: the calculated cu_seqlens for flash attention 32 | """ 33 | batch_size = text_mask.shape[0] 34 | text_len = text_mask.sum(dim=1) 35 | max_len = text_mask.shape[1] + img_len 36 | 37 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") 38 | 39 | for i in range(batch_size): 40 | s = text_len[i] + img_len 41 | s1 = i * max_len + s 42 | s2 = (i + 1) * max_len 43 | cu_seqlens[2 * i + 1] = s1 44 | cu_seqlens[2 * i + 2] = s2 45 | 46 | return cu_seqlens 47 | 48 | def initialize_sequence_parallel_state(sequence_parallel_size): 49 | global _SEQUENCE_PARALLEL_STATE 50 | if sequence_parallel_size > 1: 51 | _SEQUENCE_PARALLEL_STATE = True 52 | initialize_sequence_parallel_group(sequence_parallel_size) 53 | else: 54 | nccl_info.sp_size = 1 55 | nccl_info.global_rank = int(os.getenv("RANK", "0")) 56 | nccl_info.rank_within_group = 0 57 | nccl_info.group_id = int(os.getenv("RANK", "0")) 58 | 59 | def get_sequence_parallel_state(): 60 | return _SEQUENCE_PARALLEL_STATE 61 | 62 | def initialize_sequence_parallel_group(sequence_parallel_size): 63 | """Initialize the sequence parallel group.""" 64 | rank = int(os.getenv("RANK", "0")) 65 | world_size = int(os.getenv("WORLD_SIZE", "1")) 66 | assert ( 67 | world_size % sequence_parallel_size == 0 68 | ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format( 69 | world_size, sequence_parallel_size) 70 | nccl_info.sp_size = sequence_parallel_size 71 | nccl_info.global_rank = rank 72 | num_sequence_parallel_groups: int = world_size // sequence_parallel_size 73 | for i in range(num_sequence_parallel_groups): 74 | ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) 75 | group = dist.new_group(ranks) 76 | if rank in ranks: 77 | nccl_info.group = group 78 | nccl_info.rank_within_group = rank - i * sequence_parallel_size 79 | nccl_info.group_id = i 80 | 81 | def initialize_distributed(seed): 82 | local_rank = int(os.getenv("RANK", 0)) 83 | world_size = int(os.getenv("WORLD_SIZE", 1)) 84 | torch.cuda.set_device(local_rank) 85 | dist.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=2**31-1), world_size=world_size, rank=local_rank) 86 | torch.manual_seed(seed) 87 | torch.cuda.manual_seed_all(seed) 88 | initialize_sequence_parallel_state(world_size) 89 | 90 | def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor: 91 | """ 92 | all-to-all for QKV 93 | 94 | Args: 95 | input (torch.tensor): a tensor sharded along dim scatter dim 96 | scatter_idx (int): default 1 97 | gather_idx (int): default 2 98 | group : torch process group 99 | 100 | Returns: 101 | torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) 102 | """ 103 | assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" 104 | 105 | seq_world_size = dist.get_world_size(group) 106 | if scatter_idx == 2 and gather_idx == 1: 107 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs) 108 | bs, shard_seqlen, hc, hs = input.shape 109 | seqlen = shard_seqlen * seq_world_size 110 | shard_hc = hc // seq_world_size 111 | 112 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 113 | # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs) 114 | input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous()) 115 | 116 | output = torch.empty_like(input_t) 117 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 118 | # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head 119 | if seq_world_size > 1: 120 | dist.all_to_all_single(output, input_t, group=group) 121 | torch.cuda.synchronize() 122 | else: 123 | output = input_t 124 | # if scattering the seq-dim, transpose the heads back to the original dimension 125 | output = output.reshape(seqlen, bs, shard_hc, hs) 126 | 127 | # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs) 128 | output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) 129 | 130 | return output 131 | 132 | elif scatter_idx == 1 and gather_idx == 2: 133 | # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs) 134 | bs, seqlen, shard_hc, hs = input.shape 135 | hc = shard_hc * seq_world_size 136 | shard_seqlen = seqlen // seq_world_size 137 | seq_world_size = dist.get_world_size(group) 138 | 139 | # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! 140 | # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs) 141 | input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, 142 | hs).transpose(0, 143 | 3).transpose(0, 144 | 1).contiguous().reshape(seq_world_size, shard_hc, 145 | shard_seqlen, bs, hs)) 146 | 147 | output = torch.empty_like(input_t) 148 | # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single 149 | # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head 150 | if seq_world_size > 1: 151 | dist.all_to_all_single(output, input_t, group=group) 152 | torch.cuda.synchronize() 153 | else: 154 | output = input_t 155 | 156 | # if scattering the seq-dim, transpose the heads back to the original dimension 157 | output = output.reshape(hc, shard_seqlen, bs, hs) 158 | 159 | # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs) 160 | output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) 161 | 162 | return output 163 | else: 164 | raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") 165 | 166 | 167 | class SeqAllToAll4D(torch.autograd.Function): 168 | @staticmethod 169 | def forward( 170 | ctx: Any, 171 | group: dist.ProcessGroup, 172 | input: Tensor, 173 | scatter_idx: int, 174 | gather_idx: int, 175 | ) -> Tensor: 176 | ctx.group = group 177 | ctx.scatter_idx = scatter_idx 178 | ctx.gather_idx = gather_idx 179 | 180 | return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) 181 | 182 | @staticmethod 183 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: 184 | return ( 185 | None, 186 | SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), 187 | None, 188 | None, 189 | ) 190 | 191 | 192 | def all_to_all_4D( 193 | input_: torch.Tensor, 194 | scatter_dim: int = 2, 195 | gather_dim: int = 1, 196 | ): 197 | return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim) 198 | 199 | 200 | def _all_to_all( 201 | input_: torch.Tensor, 202 | world_size: int, 203 | group: dist.ProcessGroup, 204 | scatter_dim: int, 205 | gather_dim: int, 206 | ): 207 | input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] 208 | output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] 209 | dist.all_to_all(output_list, input_list, group=group) 210 | return torch.cat(output_list, dim=gather_dim).contiguous() 211 | 212 | 213 | class _AllToAll(torch.autograd.Function): 214 | """All-to-all communication. 215 | 216 | Args: 217 | input_: input matrix 218 | process_group: communication group 219 | scatter_dim: scatter dimension 220 | gather_dim: gather dimension 221 | """ 222 | 223 | @staticmethod 224 | def forward(ctx, input_, process_group, scatter_dim, gather_dim): 225 | ctx.process_group = process_group 226 | ctx.scatter_dim = scatter_dim 227 | ctx.gather_dim = gather_dim 228 | ctx.world_size = dist.get_world_size(process_group) 229 | output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) 230 | return output 231 | 232 | @staticmethod 233 | def backward(ctx, grad_output): 234 | grad_output = _all_to_all( 235 | grad_output, 236 | ctx.world_size, 237 | ctx.process_group, 238 | ctx.gather_dim, 239 | ctx.scatter_dim, 240 | ) 241 | return ( 242 | grad_output, 243 | None, 244 | None, 245 | None, 246 | ) 247 | 248 | def all_to_all( 249 | input_: torch.Tensor, 250 | scatter_dim: int = 2, 251 | gather_dim: int = 1, 252 | ): 253 | return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) 254 | 255 | 256 | class _AllGather(torch.autograd.Function): 257 | """All-gather communication with autograd support. 258 | 259 | Args: 260 | input_: input tensor 261 | dim: dimension along which to concatenate 262 | """ 263 | 264 | @staticmethod 265 | def forward(ctx, input_, dim): 266 | ctx.dim = dim 267 | world_size = nccl_info.sp_size 268 | group = nccl_info.group 269 | input_size = list(input_.size()) 270 | 271 | ctx.input_size = input_size[dim] 272 | 273 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] 274 | input_ = input_.contiguous() 275 | dist.all_gather(tensor_list, input_, group=group) 276 | 277 | output = torch.cat(tensor_list, dim=dim) 278 | return output 279 | 280 | @staticmethod 281 | def backward(ctx, grad_output): 282 | world_size = nccl_info.sp_size 283 | rank = nccl_info.rank_within_group 284 | dim = ctx.dim 285 | input_size = ctx.input_size 286 | 287 | sizes = [input_size] * world_size 288 | 289 | grad_input_list = torch.split(grad_output, sizes, dim=dim) 290 | grad_input = grad_input_list[rank] 291 | 292 | return grad_input, None 293 | 294 | 295 | def all_gather(input_: torch.Tensor, dim: int = 1): 296 | """Performs an all-gather operation on the input tensor along the specified dimension. 297 | 298 | Args: 299 | input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. 300 | dim (int, optional): Dimension along which to concatenate. Defaults to 1. 301 | 302 | Returns: 303 | torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. 304 | """ 305 | return _AllGather.apply(input_, dim) 306 | 307 | def parallel_attention(q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv,): 308 | """ 309 | img_q_len,img_kv_len: 32256 310 | text_mask: 2x256 311 | query: [2, 32256, 24, 128]) 312 | encoder_query: [2, 256, 24, 128] 313 | """ 314 | query, encoder_query = q 315 | key, encoder_key = k 316 | value, encoder_value = v 317 | rank = torch.distributed.get_rank() 318 | if get_sequence_parallel_state(): 319 | query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128] 320 | key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) 321 | value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) 322 | def shrink_head(encoder_state, dim): 323 | local_heads = encoder_state.shape[dim] // nccl_info.sp_size 324 | return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) 325 | encoder_query = shrink_head(encoder_query, dim=2) 326 | encoder_key = shrink_head(encoder_key, dim=2) 327 | encoder_value = shrink_head(encoder_value, dim=2) 328 | 329 | sequence_length = query.size(1) # 32256 330 | encoder_sequence_length = encoder_query.size(1) # 256 331 | 332 | query = torch.cat([query, encoder_query], dim=1) 333 | key = torch.cat([key, encoder_key], dim=1) 334 | value = torch.cat([value, encoder_value], dim=1) 335 | bsz = query.shape[0] 336 | head = query.shape[-2] 337 | head_dim = query.shape[-1] 338 | query, key, value = [ 339 | x.view(x.shape[0] * x.shape[1], *x.shape[2:]) 340 | for x in [query, key, value] 341 | ] 342 | hidden_states = flash_attn_varlen_func( 343 | query, 344 | key, 345 | value, 346 | cu_seqlens_q, 347 | cu_seqlens_kv, 348 | max_seqlen_q, 349 | max_seqlen_kv, 350 | ) 351 | # B, S, 3, H, D 352 | hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous() 353 | 354 | hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), 355 | dim=1) 356 | if get_sequence_parallel_state(): 357 | hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) 358 | encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous() 359 | hidden_states = hidden_states.to(query.dtype) 360 | encoder_hidden_states = encoder_hidden_states.to(query.dtype) 361 | 362 | attn = torch.cat([hidden_states, encoder_hidden_states], dim=1) 363 | 364 | b, s, _, _= attn.shape 365 | attn = attn.reshape(b, s, -1) 366 | return attn, None -------------------------------------------------------------------------------- /hymm_sp/modules/posemb_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple, List 3 | 4 | 5 | def _to_tuple(x, dim=2): 6 | if isinstance(x, int): 7 | return (x,) * dim 8 | elif len(x) == dim: 9 | return x 10 | else: 11 | raise ValueError(f"Expected length {dim} or int, but got {x}") 12 | 13 | 14 | def get_meshgrid_nd(start, *args, dim=2): 15 | """ 16 | Get n-D meshgrid with start, stop and num. 17 | 18 | Args: 19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, 20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num 21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in 22 | n-tuples. 23 | *args: See above. 24 | dim (int): Dimension of the meshgrid. Defaults to 2. 25 | 26 | Returns: 27 | grid (np.ndarray): [dim, ...] 28 | """ 29 | if len(args) == 0: 30 | # start is grid_size 31 | num = _to_tuple(start, dim=dim) 32 | start = (0,) * dim 33 | stop = num 34 | elif len(args) == 1: 35 | # start is start, args[0] is stop, step is 1 36 | start = _to_tuple(start, dim=dim) 37 | stop = _to_tuple(args[0], dim=dim) 38 | num = [stop[i] - start[i] for i in range(dim)] 39 | elif len(args) == 2: 40 | # start is start, args[0] is stop, args[1] is num 41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 44 | else: 45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") 46 | 47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) 48 | axis_grid = [] 49 | for i in range(dim): 50 | a, b, n = start[i], stop[i], num[i] 51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] 52 | axis_grid.append(g) 53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] 54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D] 55 | 56 | return grid 57 | 58 | 59 | ################################################################################# 60 | # Rotary Positional Embedding Functions # 61 | ################################################################################# 62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 63 | 64 | def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000., use_real=False, 65 | theta_rescale_factor: Union[float, List[float]]=1.0, 66 | interpolation_factor: Union[float, List[float]]=1.0): 67 | """ 68 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. 69 | 70 | Args: 71 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. 72 | sum(rope_dim_list) should equal to head_dim of attention layer. 73 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, 74 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. 75 | *args: See above. 76 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0. 77 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. 78 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real 79 | part and an imaginary part separately. 80 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. 81 | 82 | Returns: 83 | pos_embed (torch.Tensor): [HW, D/2] 84 | """ 85 | 86 | grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] 87 | 88 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): 89 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) 90 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: 91 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) 92 | assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" 93 | 94 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): 95 | interpolation_factor = [interpolation_factor] * len(rope_dim_list) 96 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: 97 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) 98 | assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" 99 | 100 | # use 1/ndim of dimensions to encode grid_axis 101 | embs = [] 102 | for i in range(len(rope_dim_list)): 103 | emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, 104 | theta_rescale_factor=theta_rescale_factor[i], 105 | interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]] 106 | embs.append(emb) 107 | 108 | if use_real: 109 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) 110 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) 111 | return cos, sin 112 | else: 113 | emb = torch.cat(embs, dim=1) # (WHD, D/2) 114 | return emb 115 | 116 | 117 | def get_1d_rotary_pos_embed(dim: int, 118 | pos: Union[torch.FloatTensor, int], 119 | theta: float = 10000.0, 120 | use_real: bool = False, 121 | theta_rescale_factor: float = 1.0, 122 | interpolation_factor: float = 1.0, 123 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 124 | """ 125 | Precompute the frequency tensor for complex exponential (cis) with given dimensions. 126 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) 127 | 128 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim' 129 | and the end index 'end'. The 'theta' parameter scales the frequencies. 130 | The returned tensor contains complex values in complex64 data type. 131 | 132 | Args: 133 | dim (int): Dimension of the frequency tensor. 134 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar 135 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 136 | use_real (bool, optional): If True, return real part and imaginary part separately. 137 | Otherwise, return complex numbers. 138 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. 139 | 140 | Returns: 141 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] 142 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] 143 | """ 144 | if isinstance(pos, int): 145 | pos = torch.arange(pos).float() 146 | 147 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning 148 | # has some connection to NTK literature 149 | if theta_rescale_factor != 1.0: 150 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 151 | 152 | freqs = 1.0 / ( 153 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) 154 | ) # [D/2] 155 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] 156 | if use_real: 157 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] 158 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] 159 | return freqs_cos, freqs_sin 160 | else: 161 | freqs_cis = torch.polar( 162 | torch.ones_like(freqs), freqs 163 | ) # complex64 # [S, D/2] 164 | return freqs_cis 165 | -------------------------------------------------------------------------------- /hymm_sp/modules/token_refiner.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from einops import rearrange 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .activation_layers import get_activation_layer 8 | from .attn_layers import attention 9 | from .norm_layers import get_norm_layer 10 | from .embed_layers import TimestepEmbedder, TextProjection 11 | from .attn_layers import attention 12 | from .mlp_layers import MLP 13 | from .modulate_layers import apply_gate 14 | 15 | 16 | class IndividualTokenRefinerBlock(nn.Module): 17 | def __init__( 18 | self, 19 | hidden_size, 20 | num_heads, 21 | mlp_ratio: str = 4.0, 22 | mlp_drop_rate: float = 0.0, 23 | act_type: str = "silu", 24 | qk_norm: bool = False, 25 | qk_norm_type: str = "layer", 26 | qkv_bias: bool = True, 27 | dtype: Optional[torch.dtype] = None, 28 | device: Optional[torch.device] = None, 29 | ): 30 | factory_kwargs = {'device': device, 'dtype': dtype} 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = hidden_size // num_heads 34 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 35 | 36 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) 37 | self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs) 38 | qk_norm_layer = get_norm_layer(qk_norm_type) 39 | self.self_attn_q_norm = ( 40 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 41 | if qk_norm 42 | else nn.Identity() 43 | ) 44 | self.self_attn_k_norm = ( 45 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) 46 | if qk_norm 47 | else nn.Identity() 48 | ) 49 | self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) 50 | 51 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs) 52 | act_layer = get_activation_layer(act_type) 53 | self.mlp = MLP( 54 | in_channels=hidden_size, 55 | hidden_channels=mlp_hidden_dim, 56 | act_layer=act_layer, 57 | drop=mlp_drop_rate, 58 | **factory_kwargs, 59 | ) 60 | 61 | self.adaLN_modulation = nn.Sequential( 62 | act_layer(), 63 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs) 64 | ) 65 | # Zero-initialize the modulation 66 | nn.init.zeros_(self.adaLN_modulation[1].weight) 67 | nn.init.zeros_(self.adaLN_modulation[1].bias) 68 | 69 | def forward( 70 | self, 71 | x: torch.Tensor, 72 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations 73 | attn_mask: torch.Tensor = None, 74 | ): 75 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) 76 | 77 | norm_x = self.norm1(x) 78 | qkv = self.self_attn_qkv(norm_x) 79 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) 80 | # Apply QK-Norm if needed 81 | q = self.self_attn_q_norm(q).to(v) 82 | k = self.self_attn_k_norm(k).to(v) 83 | 84 | # Self-Attention 85 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) 86 | 87 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa) 88 | 89 | # FFN Layer 90 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) 91 | 92 | return x 93 | 94 | 95 | class IndividualTokenRefiner(nn.Module): 96 | def __init__( 97 | self, 98 | hidden_size, 99 | num_heads, 100 | depth, 101 | mlp_ratio: float = 4.0, 102 | mlp_drop_rate: float = 0.0, 103 | act_type: str = "silu", 104 | qk_norm: bool = False, 105 | qk_norm_type: str = "layer", 106 | qkv_bias: bool = True, 107 | dtype: Optional[torch.dtype] = None, 108 | device: Optional[torch.device] = None, 109 | ): 110 | factory_kwargs = {'device': device, 'dtype': dtype} 111 | super().__init__() 112 | self.blocks = nn.ModuleList([ 113 | IndividualTokenRefinerBlock( 114 | hidden_size=hidden_size, 115 | num_heads=num_heads, 116 | mlp_ratio=mlp_ratio, 117 | mlp_drop_rate=mlp_drop_rate, 118 | act_type=act_type, 119 | qk_norm=qk_norm, 120 | qk_norm_type=qk_norm_type, 121 | qkv_bias=qkv_bias, 122 | **factory_kwargs, 123 | ) for _ in range(depth) 124 | ]) 125 | 126 | def forward( 127 | self, 128 | x: torch.Tensor, 129 | c: torch.LongTensor, 130 | mask: Optional[torch.Tensor] = None, 131 | ): 132 | self_attn_mask = None 133 | if mask is not None: 134 | batch_size = mask.shape[0] 135 | seq_len = mask.shape[1] 136 | mask = mask.to(x.device) 137 | # batch_size x 1 x seq_len x seq_len 138 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) 139 | # batch_size x 1 x seq_len x seq_len 140 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) 141 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads 142 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() 143 | # avoids self-attention weight being NaN for padding tokens 144 | self_attn_mask[:, :, :, 0] = True 145 | 146 | for block in self.blocks: 147 | x = block(x, c, self_attn_mask) 148 | return x 149 | 150 | 151 | class SingleTokenRefiner(nn.Module): 152 | def __init__( 153 | self, 154 | in_channels, 155 | hidden_size, 156 | num_heads, 157 | depth, 158 | mlp_ratio: float = 4.0, 159 | mlp_drop_rate: float = 0.0, 160 | act_type: str = "silu", 161 | qk_norm: bool = False, 162 | qk_norm_type: str = "layer", 163 | qkv_bias: bool = True, 164 | dtype: Optional[torch.dtype] = None, 165 | device: Optional[torch.device] = None, 166 | ): 167 | factory_kwargs = {'device': device, 'dtype': dtype} 168 | super().__init__() 169 | 170 | self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs) 171 | 172 | act_layer = get_activation_layer(act_type) 173 | # Build timestep embedding layer 174 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) 175 | # Build context embedding layer 176 | self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs) 177 | 178 | self.individual_token_refiner = IndividualTokenRefiner( 179 | hidden_size=hidden_size, 180 | num_heads=num_heads, 181 | depth=depth, 182 | mlp_ratio=mlp_ratio, 183 | mlp_drop_rate=mlp_drop_rate, 184 | act_type=act_type, 185 | qk_norm=qk_norm, 186 | qk_norm_type=qk_norm_type, 187 | qkv_bias=qkv_bias, 188 | **factory_kwargs 189 | ) 190 | 191 | def forward( 192 | self, 193 | x: torch.Tensor, 194 | t: torch.LongTensor, 195 | mask: Optional[torch.LongTensor] = None, 196 | ): 197 | timestep_aware_representations = self.t_embedder(t) 198 | 199 | if mask is None: 200 | context_aware_representations = x.mean(dim=1) 201 | else: 202 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] 203 | context_aware_representations = ( 204 | (x * mask_float).sum(dim=1) / mask_float.sum(dim=1) 205 | ) 206 | context_aware_representations = self.c_embedder(context_aware_representations) 207 | c = timestep_aware_representations + context_aware_representations 208 | 209 | x = self.input_embedder(x) 210 | 211 | x = self.individual_token_refiner(x, c, mask) 212 | 213 | return x -------------------------------------------------------------------------------- /hymm_sp/sample_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | from loguru import logger 6 | from einops import rearrange 7 | import imageio 8 | import torch.distributed 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data import DataLoader 11 | from hymm_sp.config import parse_args 12 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler 13 | from hymm_sp.data_kits.audio_dataset import VideoAudioTextLoaderVal 14 | from hymm_sp.data_kits.data_tools import save_videos_grid 15 | from hymm_sp.data_kits.face_align import AlignImage 16 | from hymm_sp.modules.parallel_states import ( 17 | initialize_distributed, 18 | nccl_info, 19 | ) 20 | 21 | from transformers import WhisperModel 22 | from transformers import AutoFeatureExtractor 23 | 24 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') 25 | 26 | 27 | def main(): 28 | args = parse_args() 29 | models_root_path = Path(args.ckpt) 30 | print("*"*20) 31 | initialize_distributed(args.seed) 32 | if not models_root_path.exists(): 33 | raise ValueError(f"`models_root` not exists: {models_root_path}") 34 | print("+"*20) 35 | # Create save folder to save the samples 36 | save_path = args.save_path 37 | if not os.path.exists(args.save_path): 38 | os.makedirs(save_path, exist_ok=True) 39 | 40 | # Load models 41 | rank = 0 42 | vae_dtype = torch.float16 43 | device = torch.device("cuda") 44 | if nccl_info.sp_size > 1: 45 | device = torch.device(f"cuda:{torch.distributed.get_rank()}") 46 | rank = torch.distributed.get_rank() 47 | 48 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device) 49 | # Get the updated args 50 | args = hunyuan_video_sampler.args 51 | 52 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) 53 | wav2vec.requires_grad_(False) 54 | 55 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' 56 | det_path = os.path.join(BASE_DIR, 'detface.pt') 57 | align_instance = AlignImage("cuda", det_path=det_path) 58 | 59 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") 60 | 61 | kwargs = { 62 | "text_encoder": hunyuan_video_sampler.text_encoder, 63 | "text_encoder_2": hunyuan_video_sampler.text_encoder_2, 64 | "feature_extractor": feature_extractor, 65 | } 66 | video_dataset = VideoAudioTextLoaderVal( 67 | image_size=args.image_size, 68 | meta_file=args.input, 69 | **kwargs, 70 | ) 71 | 72 | sampler = DistributedSampler(video_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False) 73 | json_loader = DataLoader(video_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False) 74 | 75 | for batch_index, batch in enumerate(json_loader, start=1): 76 | 77 | fps = batch["fps"] 78 | videoid = batch['videoid'][0] 79 | audio_path = str(batch["audio_path"][0]) 80 | save_path = args.save_path 81 | output_path = f"{save_path}/{videoid}.mp4" 82 | output_audio_path = f"{save_path}/{videoid}_audio.mp4" 83 | 84 | samples = hunyuan_video_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) 85 | 86 | sample = samples['samples'][0].unsqueeze(0) # denoised latent, (bs, 16, t//4, h//8, w//8) 87 | sample = sample[:, :, :batch["audio_len"][0]] 88 | 89 | video = rearrange(sample[0], "c f h w -> f h w c") 90 | video = (video * 255.).data.cpu().numpy().astype(np.uint8) # (f h w c) 91 | 92 | torch.cuda.empty_cache() 93 | 94 | final_frames = [] 95 | for frame in video: 96 | final_frames.append(frame) 97 | final_frames = np.stack(final_frames, axis=0) 98 | 99 | if rank == 0: 100 | imageio.mimsave(output_path, final_frames, fps=fps.item()) 101 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{output_audio_path}' -y -loglevel quiet; rm '{output_path}'") 102 | 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /hymm_sp/sample_gpu_poor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | from loguru import logger 5 | import imageio 6 | import torch 7 | from einops import rearrange 8 | import torch.distributed 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data import DataLoader 11 | from hymm_sp.config import parse_args 12 | from hymm_sp.sample_inference_audio import HunyuanVideoSampler 13 | from hymm_sp.data_kits.audio_dataset import VideoAudioTextLoaderVal 14 | from hymm_sp.data_kits.face_align import AlignImage 15 | 16 | from transformers import WhisperModel 17 | from transformers import AutoFeatureExtractor 18 | 19 | MODEL_OUTPUT_PATH = os.environ.get('MODEL_BASE') 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | models_root_path = Path(args.ckpt) 25 | 26 | if not models_root_path.exists(): 27 | raise ValueError(f"`models_root` not exists: {models_root_path}") 28 | 29 | # Create save folder to save the samples 30 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' 31 | if not os.path.exists(args.save_path): 32 | os.makedirs(save_path, exist_ok=True) 33 | 34 | # Load models 35 | rank = 0 36 | vae_dtype = torch.float16 37 | device = torch.device("cuda") 38 | 39 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(args.ckpt, args=args, device=device) 40 | # Get the updated args 41 | args = hunyuan_video_sampler.args 42 | if args.cpu_offload: 43 | from diffusers.hooks import apply_group_offloading 44 | onload_device = torch.device("cuda") 45 | apply_group_offloading(hunyuan_video_sampler.pipeline.transformer, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=1) 46 | 47 | wav2vec = WhisperModel.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/").to(device=device, dtype=torch.float32) 48 | wav2vec.requires_grad_(False) 49 | 50 | BASE_DIR = f'{MODEL_OUTPUT_PATH}/ckpts/det_align/' 51 | det_path = os.path.join(BASE_DIR, 'detface.pt') 52 | align_instance = AlignImage("cuda", det_path=det_path) 53 | 54 | feature_extractor = AutoFeatureExtractor.from_pretrained(f"{MODEL_OUTPUT_PATH}/ckpts/whisper-tiny/") 55 | 56 | kwargs = { 57 | "text_encoder": hunyuan_video_sampler.text_encoder, 58 | "text_encoder_2": hunyuan_video_sampler.text_encoder_2, 59 | "feature_extractor": feature_extractor, 60 | } 61 | video_dataset = VideoAudioTextLoaderVal( 62 | image_size=args.image_size, 63 | meta_file=args.input, 64 | **kwargs, 65 | ) 66 | 67 | sampler = DistributedSampler(video_dataset, num_replicas=1, rank=0, shuffle=False, drop_last=False) 68 | json_loader = DataLoader(video_dataset, batch_size=1, shuffle=False, sampler=sampler, drop_last=False) 69 | 70 | for batch_index, batch in enumerate(json_loader, start=1): 71 | 72 | fps = batch["fps"] 73 | videoid = batch['videoid'][0] 74 | audio_path = str(batch["audio_path"][0]) 75 | save_path = args.save_path 76 | output_path = f"{save_path}/{videoid}.mp4" 77 | output_audio_path = f"{save_path}/{videoid}_audio.mp4" 78 | 79 | if args.infer_min: 80 | batch["audio_len"][0] = 129 81 | 82 | samples = hunyuan_video_sampler.predict(args, batch, wav2vec, feature_extractor, align_instance) 83 | 84 | sample = samples['samples'][0].unsqueeze(0) # denoised latent, (bs, 16, t//4, h//8, w//8) 85 | sample = sample[:, :, :batch["audio_len"][0]] 86 | 87 | video = rearrange(sample[0], "c f h w -> f h w c") 88 | video = (video * 255.).data.cpu().numpy().astype(np.uint8) # (f h w c) 89 | 90 | torch.cuda.empty_cache() 91 | 92 | final_frames = [] 93 | for frame in video: 94 | final_frames.append(frame) 95 | final_frames = np.stack(final_frames, axis=0) 96 | 97 | if rank == 0: 98 | imageio.mimsave(output_path, final_frames, fps=fps.item()) 99 | os.system(f"ffmpeg -i '{output_path}' -i '{audio_path}' -shortest '{output_audio_path}' -y -loglevel quiet; rm '{output_path}'") 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /hymm_sp/sample_inference_audio.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import torch 4 | import random 5 | from loguru import logger 6 | from einops import rearrange 7 | from hymm_sp.diffusion import load_diffusion_pipeline 8 | from hymm_sp.helpers import get_nd_rotary_pos_embed_new 9 | from hymm_sp.inference import Inference 10 | from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler 11 | from hymm_sp.data_kits.audio_preprocessor import encode_audio, get_facemask 12 | 13 | def align_to(value, alignment): 14 | return int(math.ceil(value / alignment) * alignment) 15 | 16 | class HunyuanVideoSampler(Inference): 17 | def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, 18 | device=0, logger=None): 19 | super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, 20 | pipeline=pipeline, device=device, logger=logger) 21 | 22 | self.args = args 23 | self.pipeline = load_diffusion_pipeline( 24 | args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model, 25 | device=self.device) 26 | print('load hunyuan model successful... ') 27 | 28 | def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}): 29 | target_ndim = 3 30 | ndim = 5 - 2 31 | if '884' in self.args.vae: 32 | latents_size = [(video_length-1)//4+1 , height//8, width//8] 33 | else: 34 | latents_size = [video_length , height//8, width//8] 35 | 36 | if isinstance(self.model.patch_size, int): 37 | assert all(s % self.model.patch_size == 0 for s in latents_size), \ 38 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ 39 | f"but got {latents_size}." 40 | rope_sizes = [s // self.model.patch_size for s in latents_size] 41 | elif isinstance(self.model.patch_size, list): 42 | assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ 43 | f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ 44 | f"but got {latents_size}." 45 | rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] 46 | 47 | if len(rope_sizes) != target_ndim: 48 | rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis 49 | head_dim = self.model.hidden_size // self.model.num_heads 50 | rope_dim_list = self.model.rope_dim_list 51 | if rope_dim_list is None: 52 | rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] 53 | assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" 54 | freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 55 | rope_sizes, 56 | theta=self.args.rope_theta, 57 | use_real=True, 58 | theta_rescale_factor=1, 59 | concat_dict=concat_dict) 60 | return freqs_cos, freqs_sin 61 | 62 | @torch.no_grad() 63 | def predict(self, 64 | args, batch, wav2vec, feature_extractor, align_instance, 65 | **kwargs): 66 | """ 67 | Predict the image from the given text. 68 | 69 | Args: 70 | prompt (str or List[str]): The input text. 71 | kwargs: 72 | size (int): The (height, width) of the output image/video. Default is (256, 256). 73 | video_length (int): The frame number of the output video. Default is 1. 74 | seed (int or List[str]): The random seed for the generation. Default is a random integer. 75 | negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. 76 | infer_steps (int): The number of inference steps. Default is 100. 77 | guidance_scale (float): The guidance scale for the generation. Default is 6.0. 78 | num_videos_per_prompt (int): The number of videos per prompt. Default is 1. 79 | verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1. 80 | output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. 81 | Default is 'pil'. 82 | """ 83 | 84 | out_dict = dict() 85 | 86 | prompt = batch['text_prompt'][0] 87 | image_path = str(batch["image_path"][0]) 88 | audio_path = str(batch["audio_path"][0]) 89 | neg_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" 90 | # videoid = batch['videoid'][0] 91 | fps = batch["fps"].to(self.device) 92 | audio_prompts = batch["audio_prompts"].to(self.device) 93 | weight_dtype = audio_prompts.dtype 94 | 95 | audio_prompts = [encode_audio(wav2vec, audio_feat.to(dtype=wav2vec.dtype), fps.item(), num_frames=batch["audio_len"][0]) for audio_feat in audio_prompts] 96 | audio_prompts = torch.cat(audio_prompts, dim=0).to(device=self.device, dtype=weight_dtype) 97 | if audio_prompts.shape[1] <= 129: 98 | audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1) 99 | else: 100 | audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) 101 | 102 | wav2vec.to("cpu") 103 | torch.cuda.empty_cache() 104 | 105 | uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) 106 | motion_exp = batch["motion_bucket_id_exps"].to(self.device) 107 | motion_pose = batch["motion_bucket_id_heads"].to(self.device) 108 | 109 | pixel_value_ref = batch['pixel_value_ref'].to(self.device) # (b f c h w) 取值范围[0,255] 110 | face_masks = get_facemask(pixel_value_ref.clone(), align_instance, area=3.0) 111 | 112 | pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) 113 | uncond_pixel_value_ref = torch.zeros_like(pixel_value_ref) 114 | pixel_value_ref = pixel_value_ref / 127.5 - 1. 115 | uncond_pixel_value_ref = uncond_pixel_value_ref * 2 - 1 116 | 117 | pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") 118 | uncond_uncond_pixel_value_ref = rearrange(uncond_pixel_value_ref, "b f c h w -> b c f h w") 119 | 120 | pixel_value_llava = batch["pixel_value_ref_llava"].to(self.device) 121 | pixel_value_llava = rearrange(pixel_value_llava, "b f c h w -> (b f) c h w") 122 | uncond_pixel_value_llava = pixel_value_llava.clone() 123 | 124 | # ========== Encode reference latents ========== 125 | vae_dtype = self.vae.dtype 126 | with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): 127 | 128 | if args.cpu_offload: 129 | self.vae.to('cuda') 130 | 131 | self.vae.enable_tiling() 132 | ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample() 133 | uncond_ref_latents = self.vae.encode(uncond_uncond_pixel_value_ref).latent_dist.sample() 134 | self.vae.disable_tiling() 135 | if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: 136 | ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) 137 | uncond_ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) 138 | else: 139 | ref_latents.mul_(self.vae.config.scaling_factor) 140 | uncond_ref_latents.mul_(self.vae.config.scaling_factor) 141 | 142 | if args.cpu_offload: 143 | self.vae.to('cpu') 144 | torch.cuda.empty_cache() 145 | 146 | face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), 147 | (ref_latents.shape[-2], 148 | ref_latents.shape[-1]), 149 | mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) 150 | 151 | 152 | size = (batch['pixel_value_ref'].shape[-2], batch['pixel_value_ref'].shape[-1]) 153 | target_length = 129 154 | target_height = align_to(size[0], 16) 155 | target_width = align_to(size[1], 16) 156 | concat_dict = {'mode': 'timecat', 'bias': -1} 157 | # concat_dict = {} 158 | freqs_cos, freqs_sin = self.get_rotary_pos_embed( 159 | target_length, 160 | target_height, 161 | target_width, 162 | concat_dict) 163 | n_tokens = freqs_cos.shape[0] 164 | 165 | generator = torch.Generator(device=self.device).manual_seed(args.seed) 166 | 167 | debug_str = f""" 168 | prompt: {prompt} 169 | image_path: {image_path} 170 | audio_path: {audio_path} 171 | negative_prompt: {neg_prompt} 172 | seed: {args.seed} 173 | fps: {fps.item()} 174 | infer_steps: {args.infer_steps} 175 | target_height: {target_height} 176 | target_width: {target_width} 177 | target_length: {target_length} 178 | guidance_scale: {args.cfg_scale} 179 | """ 180 | self.logger.info(debug_str) 181 | pipeline_kwargs = { 182 | "cpu_offload": args.cpu_offload 183 | } 184 | start_time = time.time() 185 | samples = self.pipeline(prompt=prompt, 186 | height=target_height, 187 | width=target_width, 188 | frame=target_length, 189 | num_inference_steps=args.infer_steps, 190 | guidance_scale=args.cfg_scale, # cfg scale 191 | 192 | negative_prompt=neg_prompt, 193 | num_images_per_prompt=args.num_images, 194 | generator=generator, 195 | prompt_embeds=None, 196 | 197 | ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] 198 | uncond_ref_latents=uncond_ref_latents, 199 | pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] 200 | uncond_pixel_value_llava=uncond_pixel_value_llava, 201 | face_masks=face_masks, # [b f h w] 202 | audio_prompts=audio_prompts, 203 | uncond_audio_prompts=uncond_audio_prompts, 204 | motion_exp=motion_exp, 205 | motion_pose=motion_pose, 206 | fps=fps, 207 | 208 | num_videos_per_prompt=1, 209 | attention_mask=None, 210 | negative_prompt_embeds=None, 211 | negative_attention_mask=None, 212 | output_type="pil", 213 | freqs_cis=(freqs_cos, freqs_sin), 214 | n_tokens=n_tokens, 215 | data_type='video', 216 | is_progress_bar=True, 217 | vae_ver=self.args.vae, 218 | enable_tiling=self.args.vae_tiling, 219 | **pipeline_kwargs 220 | )[0] 221 | if samples is None: 222 | return None 223 | out_dict['samples'] = samples 224 | gen_time = time.time() - start_time 225 | logger.info(f"Success, time: {gen_time}") 226 | 227 | wav2vec.to(self.device) 228 | 229 | return out_dict 230 | 231 | -------------------------------------------------------------------------------- /hymm_sp/text_encoder/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/text_encoder/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/vae/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D 4 | from ..constants import VAE_PATH, PRECISION_TO_TYPE 5 | 6 | def load_vae(vae_type, 7 | vae_precision=None, 8 | sample_size=None, 9 | vae_path=None, 10 | logger=None, 11 | device=None 12 | ): 13 | if vae_path is None: 14 | vae_path = VAE_PATH[vae_type] 15 | vae_compress_spec, _, _ = vae_type.split("-") 16 | length = len(vae_compress_spec) 17 | if length == 3: 18 | if logger is not None: 19 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}") 20 | config = AutoencoderKLCausal3D.load_config(vae_path) 21 | if sample_size: 22 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size) 23 | else: 24 | vae = AutoencoderKLCausal3D.from_config(config) 25 | ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device) 26 | if "state_dict" in ckpt: 27 | ckpt = ckpt["state_dict"] 28 | # vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")} 29 | vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items()} 30 | vae.load_state_dict(vae_ckpt) 31 | 32 | spatial_compression_ratio = vae.config.spatial_compression_ratio 33 | time_compression_ratio = vae.config.time_compression_ratio 34 | else: 35 | raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.") 36 | 37 | if vae_precision is not None: 38 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision]) 39 | 40 | vae.requires_grad_(False) 41 | 42 | if logger is not None: 43 | logger.info(f"VAE to dtype: {vae.dtype}") 44 | 45 | if device is not None: 46 | vae = vae.to(device) 47 | 48 | # Set vae to eval mode, even though it's dropout rate is 0. 49 | vae.eval() 50 | 51 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio 52 | -------------------------------------------------------------------------------- /hymm_sp/vae/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/autoencoder_kl_causal_3d.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/unet_causal_3d_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /hymm_sp/vae/__pycache__/vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo-Avatar/f2cfee30d10b2ba9a67045382300f9d560dde9b8/hymm_sp/vae/__pycache__/vae.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.9.0.80 2 | diffusers==0.33.0 3 | transformers==4.45.1 4 | accelerate==1.1.1 5 | pandas==2.0.3 6 | numpy==1.24.4 7 | einops==0.7.0 8 | tqdm==4.66.2 9 | loguru==0.7.2 10 | imageio==2.34.0 11 | imageio-ffmpeg==0.5.1 12 | safetensors==0.4.3 13 | gradio==4.42.0 14 | fastapi==0.115.12 15 | uvicorn==0.34.2 16 | decord==0.6.0 17 | librosa==0.11.0 18 | scikit-video==1.1.11 19 | ffmpeg -------------------------------------------------------------------------------- /scripts/run_gradio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | JOBS_DIR=$(dirname $(dirname "$0")) 3 | export PYTHONPATH=./ 4 | 5 | export MODEL_BASE=./weights 6 | 7 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt 8 | 9 | 10 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_gradio/flask_audio.py \ 11 | --input 'assets/test.csv' \ 12 | --ckpt ${checkpoint_path} \ 13 | --sample-n-frames 129 \ 14 | --seed 128 \ 15 | --image-size 704 \ 16 | --cfg-scale 7.5 \ 17 | --infer-steps 50 \ 18 | --use-deepcache 1 \ 19 | --flow-shift-eval-video 5.0 & 20 | 21 | 22 | python3 hymm_gradio/gradio_audio.py 23 | -------------------------------------------------------------------------------- /scripts/run_sample_batch_sp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | JOBS_DIR=$(dirname $(dirname "$0")) 3 | export PYTHONPATH=./ 4 | 5 | export MODEL_BASE=./weights 6 | OUTPUT_BASEPATH=./results 7 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt 8 | 9 | 10 | torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \ 11 | --input 'assets/test.csv' \ 12 | --ckpt ${checkpoint_path} \ 13 | --sample-n-frames 129 \ 14 | --seed 128 \ 15 | --image-size 704 \ 16 | --cfg-scale 7.5 \ 17 | --infer-steps 50 \ 18 | --use-deepcache 1 \ 19 | --flow-shift-eval-video 5.0 \ 20 | --save-path ${OUTPUT_BASEPATH} 21 | -------------------------------------------------------------------------------- /scripts/run_single_audio.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | JOBS_DIR=$(dirname $(dirname "$0")) 3 | export PYTHONPATH=./ 4 | 5 | export MODEL_BASE=./weights 6 | OUTPUT_BASEPATH=./results-single 7 | 8 | # checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt 9 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt 10 | 11 | 12 | export DISABLE_SP=1 13 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \ 14 | --input 'assets/test.csv' \ 15 | --ckpt ${checkpoint_path} \ 16 | --sample-n-frames 129 \ 17 | --seed 128 \ 18 | --image-size 704 \ 19 | --cfg-scale 7.5 \ 20 | --infer-steps 50 \ 21 | --use-deepcache 1 \ 22 | --flow-shift-eval-video 5.0 \ 23 | --save-path ${OUTPUT_BASEPATH} \ 24 | --use-fp8 \ 25 | --infer-min 26 | -------------------------------------------------------------------------------- /scripts/run_single_poor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | JOBS_DIR=$(dirname $(dirname "$0")) 3 | export PYTHONPATH=./ 4 | 5 | export MODEL_BASE=./weights 6 | OUTPUT_BASEPATH=./results-poor 7 | 8 | checkpoint_path=${MODEL_BASE}/ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt 9 | 10 | export CPU_OFFLOAD=1 11 | CUDA_VISIBLE_DEVICES=0 python3 hymm_sp/sample_gpu_poor.py \ 12 | --input 'assets/test.csv' \ 13 | --ckpt ${checkpoint_path} \ 14 | --sample-n-frames 129 \ 15 | --seed 128 \ 16 | --image-size 704 \ 17 | --cfg-scale 7.5 \ 18 | --infer-steps 50 \ 19 | --use-deepcache 1 \ 20 | --flow-shift-eval-video 5.0 \ 21 | --save-path ${OUTPUT_BASEPATH} \ 22 | --use-fp8 \ 23 | --cpu-offload \ 24 | --infer-min 25 | -------------------------------------------------------------------------------- /weights/README.md: -------------------------------------------------------------------------------- 1 | # Download Pretrained Models 2 | 3 | All models are stored in `HunyuanVideo-Avatar/weights` by default, and the file structure is as follows 4 | ```shell 5 | HunyuanVideo-Avatar 6 | ├──weights 7 | │ ├──ckpts 8 | │ │ ├──README.md 9 | │ │ ├──hunyuan-video-t2v-720p 10 | │ │ │ ├──transformers 11 | │ │ │ │ ├──mp_rank_00_model_states.pt 12 | │ │ │ │ ├──mp_rank_00_model_states_fp8.pt 13 | │ │ │ │ ├──mp_rank_00_model_states_fp8_map.pt 14 | │ │ │ ├──vae 15 | │ │ │ │ ├──pytorch_model.pt 16 | │ │ │ │ ├──config.json 17 | │ │ ├──llava_llama_image 18 | │ │ │ ├──model-00001-of-00004.safatensors 19 | │ │ │ ├──model-00002-of-00004.safatensors 20 | │ │ │ ├──model-00003-of-00004.safatensors 21 | │ │ │ ├──model-00004-of-00004.safatensors 22 | │ │ │ ├──... 23 | │ │ ├──text_encoder_2 24 | │ │ ├──whisper-tiny 25 | │ │ ├──det_align 26 | │ │ ├──... 27 | ``` 28 | 29 | ## Download HunyuanVideo-Avatar model 30 | To download the HunyuanCustom model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).) 31 | 32 | ```shell 33 | python -m pip install "huggingface_hub[cli]" 34 | ``` 35 | 36 | Then download the model using the following commands: 37 | 38 | ```shell 39 | # Switch to the directory named 'HunyuanVideo-Avatar/weights' 40 | cd HunyuanVideo-Avatar/weights 41 | # Use the huggingface-cli tool to download HunyuanVideo-Avatar model in HunyuanVideo-Avatar/weights dir. 42 | # The download time may vary from 10 minutes to 1 hour depending on network conditions. 43 | huggingface-cli download tencent/HunyuanVideo-Avatar --local-dir ./ 44 | ``` 45 | --------------------------------------------------------------------------------