├── .gitignore
├── LICENSE.txt
├── Notice
├── README.md
├── README_zh.md
├── assets
├── 3dvae.png
├── PenguinVideoBenchmark.csv
├── WECHAT.md
├── backbone.png
├── hunyuanvideo.pdf
├── logo.png
├── overall.png
├── text_encoder.png
├── video_poster.png
└── wechat.jpg
├── ckpts
└── README.md
├── gradio_server.py
├── hyvideo
├── __init__.py
├── config.py
├── constants.py
├── diffusion
│ ├── __init__.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ └── pipeline_hunyuan_video.py
│ └── schedulers
│ │ ├── __init__.py
│ │ └── scheduling_flow_match_discrete.py
├── inference.py
├── modules
│ ├── __init__.py
│ ├── activation_layers.py
│ ├── attenion.py
│ ├── embed_layers.py
│ ├── fp8_optimization.py
│ ├── mlp_layers.py
│ ├── models.py
│ ├── modulate_layers.py
│ ├── norm_layers.py
│ ├── posemb_layers.py
│ └── token_refiner.py
├── prompt_rewrite.py
├── text_encoder
│ └── __init__.py
├── utils
│ ├── __init__.py
│ ├── data_utils.py
│ ├── file_utils.py
│ ├── helpers.py
│ └── preprocess_text_encoder_tokenizer_utils.py
└── vae
│ ├── __init__.py
│ ├── autoencoder_kl_causal_3d.py
│ ├── unet_causal_3d_blocks.py
│ └── vae.py
├── requirements.txt
├── sample_video.py
├── scripts
├── run_sample_video.sh
├── run_sample_video_fp8.sh
└── run_sample_video_multigpu.sh
├── tests
└── test_attention.py
└── utils
└── collect_env.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /ckpts/**/
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2 | Tencent HunyuanVideo Release Date: December 3, 2024
3 | THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4 | By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5 | 1. DEFINITIONS.
6 | a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7 | b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8 | c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9 | d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10 | e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11 | f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12 | g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13 | h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14 | i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
15 | j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
16 | k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17 | l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18 | m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19 | n. “including” shall mean including but not limited to.
20 | 2. GRANT OF RIGHTS.
21 | We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22 | 3. DISTRIBUTION.
23 | You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24 | a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25 | b. You must cause any modified files to carry prominent notices stating that You changed the files;
26 | c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27 | d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28 | You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29 | 4. ADDITIONAL COMMERCIAL TERMS.
30 | If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31 | 5. RULES OF USE.
32 | a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33 | b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34 | c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35 | 6. INTELLECTUAL PROPERTY.
36 | a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37 | b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38 | c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39 | d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40 | 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41 | a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42 | b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43 | c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44 | 8. SURVIVAL AND TERMINATION.
45 | a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46 | b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47 | 9. GOVERNING LAW AND JURISDICTION.
48 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49 | b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50 |
51 | EXHIBIT A
52 | ACCEPTABLE USE POLICY
53 |
54 | Tencent reserves the right to update this Acceptable Use Policy from time to time.
55 | Last modified: November 5, 2024
56 |
57 | Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58 | 1. Outside the Territory;
59 | 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60 | 3. To harm Yourself or others;
61 | 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62 | 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63 | 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64 | 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65 | 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66 | 9. To intentionally defame, disparage or otherwise harass others;
67 | 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68 | 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69 | 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70 | 13. To impersonate another individual without consent, authorization, or legal right;
71 | 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72 | 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73 | 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74 | 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75 | 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76 | 19. For military purposes;
77 | 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
78 |
--------------------------------------------------------------------------------
/Notice:
--------------------------------------------------------------------------------
1 | Usage and Legal Notices:
2 |
3 | Tencent is pleased to support the open source community by making Tencent HunyuanVideo available.
4 |
5 | Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
6 |
7 | Tencent HunyuanVideo is licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT except for the third-party components listed below. Tencent HunyuanVideo does not impose any additional limitations beyond what is outlined in the repsective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
8 |
9 | For avoidance of doubts, Tencent HunyuanVideo means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing may be made publicly available by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
10 |
11 |
12 | Other dependencies and licenses:
13 |
14 |
15 | Open Source Model Licensed under the Apache License Version 2.0:
16 | The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
17 | --------------------------------------------------------------------
18 | 1. diffusers
19 | Copyright (c) diffusers original author and authors
20 | Please note this software has been modified by Tencent in this distribution.
21 |
22 | 2. transformers
23 | Copyright (c) transformers original author and authors
24 |
25 | 3. safetensors
26 | Copyright (c) safetensors original author and authors
27 |
28 | 4. flux
29 | Copyright (c) flux original author and authors
30 |
31 |
32 | Terms of the Apache License Version 2.0:
33 | --------------------------------------------------------------------
34 | Apache License
35 |
36 | Version 2.0, January 2004
37 |
38 | http://www.apache.org/licenses/
39 |
40 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
41 | 1. Definitions.
42 |
43 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
44 |
45 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
46 |
47 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
48 |
49 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
50 |
51 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
52 |
53 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
54 |
55 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
56 |
57 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
58 |
59 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
60 |
61 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
62 |
63 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
64 |
65 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
66 |
67 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
68 |
69 | You must give any other recipients of the Work or Derivative Works a copy of this License; and
70 |
71 | You must cause any modified files to carry prominent notices stating that You changed the files; and
72 |
73 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
74 |
75 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
76 |
77 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
78 |
79 | 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
80 |
81 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
82 |
83 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
84 |
85 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
86 |
87 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
88 |
89 | END OF TERMS AND CONDITIONS
90 |
91 |
92 |
93 | Open Source Software Licensed under the BSD 2-Clause License:
94 | --------------------------------------------------------------------
95 | 1. imageio
96 | Copyright (c) 2014-2022, imageio developers
97 | All rights reserved.
98 |
99 |
100 | Terms of the BSD 2-Clause License:
101 | --------------------------------------------------------------------
102 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
103 |
104 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
105 |
106 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
107 |
108 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
109 |
110 |
111 |
112 | Open Source Software Licensed under the BSD 3-Clause License:
113 | --------------------------------------------------------------------
114 | 1. torchvision
115 | Copyright (c) Soumith Chintala 2016,
116 | All rights reserved.
117 |
118 | 2. flash-attn
119 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
120 | All rights reserved.
121 |
122 |
123 | Terms of the BSD 3-Clause License:
124 | --------------------------------------------------------------------
125 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
126 |
127 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
128 |
129 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
130 |
131 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
132 |
133 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
134 |
135 |
136 |
137 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
138 | --------------------------------------------------------------------
139 | 1. torch
140 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
141 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
142 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
143 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
144 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
145 | Copyright (c) 2011-2013 NYU (Clement Farabet)
146 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
147 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
148 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
149 |
150 |
151 | A copy of the BSD 3-Clause is included in this file.
152 |
153 | For the license of other third party components, please refer to the following URL:
154 | https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
155 |
156 |
157 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
158 | --------------------------------------------------------------------
159 | 1. pandas
160 | Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team
161 | All rights reserved.
162 |
163 | Copyright (c) 2011-2023, Open source contributors.
164 |
165 |
166 | A copy of the BSD 3-Clause is included in this file.
167 |
168 | For the license of other third party components, please refer to the following URL:
169 | https://github.com/pandas-dev/pandas/tree/v2.0.3/LICENSES
170 |
171 |
172 | Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
173 | --------------------------------------------------------------------
174 | 1. numpy
175 | Copyright (c) 2005-2022, NumPy Developers.
176 | All rights reserved.
177 |
178 |
179 | A copy of the BSD 3-Clause is included in this file.
180 |
181 | For the license of other third party components, please refer to the following URL:
182 | https://github.com/numpy/numpy/blob/v1.24.4/LICENSES_bundled.txt
183 |
184 |
185 | Open Source Software Licensed under the MIT License:
186 | --------------------------------------------------------------------
187 | 1. einops
188 | Copyright (c) 2018 Alex Rogozhnikov
189 |
190 | 2. loguru
191 | Copyright (c) 2017
192 |
193 |
194 | Terms of the MIT License:
195 | --------------------------------------------------------------------
196 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
197 |
198 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
199 |
200 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
201 |
202 |
203 |
204 | Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
205 | --------------------------------------------------------------------
206 | 1. tqdm
207 | Copyright (c) 2013 noamraph
208 |
209 |
210 | A copy of the MIT is included in this file.
211 |
212 | For the license of other third party components, please refer to the following URL:
213 | https://github.com/tqdm/tqdm/blob/v4.66.2/LICENCE
214 |
215 |
216 |
217 | Open Source Model Licensed under the MIT License:
218 | --------------------------------------------------------------------
219 | 1. clip-large
220 | Copyright (c) 2021 OpenAI
221 |
222 |
223 | A copy of the MIT is included in this file.
224 |
225 |
226 | --------------------------------------------------------------------
227 | We may also use other third-party components:
228 |
229 | 1. llava-llama3
230 |
231 | Copyright (c) llava-llama3 original author and authors
232 |
233 | URL: https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers#model
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [English](./README.md)
4 |
5 |
6 |
7 |
8 |
9 | # HunyuanVideo: A Systematic Framework For Large Video Generation Model
10 |
11 |
16 |
17 |

18 |
&color=red)
19 |
20 |
21 |

22 |

23 |

24 |
25 |
26 | [](https://replicate.com/zsxkib/hunyuan-video)
27 |
28 |
29 |
30 |
31 | 👋 加入我们的 WeChat 和 Discord
32 |
33 |
34 |
35 |
36 | -----
37 |
38 | 本仓库包含了 HunyuanVideo 项目的 PyTorch 模型定义、预训练权重和推理/采样代码。参考我们的项目页面 [project page](https://aivideo.hunyuan.tencent.com) 查看更多内容。
39 |
40 | > [**HunyuanVideo: A Systematic Framework For Large Video Generation Model**](https://arxiv.org/abs/2412.03603)
41 |
42 |
43 |
44 | ## 🔥🔥🔥 更新!!
45 |
46 | * 2025年05月28日: 💃 开源 [HunyuanVideo-Avatar](https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar), 腾讯混元语音数字人模型。
47 | * 2025年05月09日: 🙆 开源 [HunyuanCustom](https://github.com/Tencent-Hunyuan/HunyuanCustom), 腾讯混元一致性视频生成模型。
48 | * 2025年03月06日: 🌅 开源 [HunyuanVideo-I2V](https://github.com/Tencent-Hunyuan/HunyuanVideo-I2V), 支持高质量图生视频。
49 | * 2025年01月13日: 📈 开源 Penguin Video [基准测试集](https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/main/assets/PenguinVideoBenchmark.csv) 。
50 | * 2024年12月18日: 🏃♂️ 开源 HunyuanVideo [FP8 模型权重](https://huggingface.co/tencent/HunyuanVideo/blob/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt),节省更多 GPU 显存。
51 | * 2024年12月17日: 🤗 HunyuanVideo已经集成到[Diffusers](https://huggingface.co/docs/diffusers/main/api/pipelines/hunyuan_video)中。
52 | * 2024年12月03日: 🚀 开源 HunyuanVideo 多卡并行推理代码,由[xDiT](https://github.com/xdit-project/xDiT)提供。
53 | * 2024年12月03日: 👋 开源 HunyuanVideo 文生视频的推理代码和模型权重。
54 |
55 |
56 |
57 | ## 🎥 作品展示
58 |
59 |
60 |
61 |
62 |
63 |
64 | ## 🧩 社区贡献
65 |
66 | 如果您的项目中有开发或使用 HunyuanVideo,欢迎告知我们。
67 |
68 | - ComfyUI (支持FP8推理、V2V和IP2V生成): [ComfyUI-HunyuanVideoWrapper](https://github.com/kijai/ComfyUI-HunyuanVideoWrapper) by [Kijai](https://github.com/kijai)
69 |
70 | - ComfyUI-Native (ComfyUI官方原生支持): [ComfyUI-HunyuanVideo](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) by [ComfyUI Official](https://github.com/comfyanonymous/ComfyUI)
71 |
72 | - FastVideo (一致性蒸馏模型、滑动块注意力): [FastVideo](https://github.com/hao-ai-lab/FastVideo) and [Sliding Tile Attention](https://hao-ai-lab.github.io/blogs/sta/) by [Hao AI Lab](https://hao-ai-lab.github.io/)
73 |
74 | - HunyuanVideo-gguf (GGUF、量化): [HunyuanVideo-gguf](https://huggingface.co/city96/HunyuanVideo-gguf) by [city96](https://huggingface.co/city96)
75 |
76 | - Enhance-A-Video (生成更高质量的视频): [Enhance-A-Video](https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video) by [NUS-HPC-AI-Lab](https://ai.comp.nus.edu.sg/)
77 |
78 | - TeaCache (基于缓存的加速采样): [TeaCache](https://github.com/LiewFeng/TeaCache) by [Feng Liu](https://github.com/LiewFeng)
79 |
80 | - HunyuanVideoGP (针对低性能GPU的版本): [HunyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP) by [DeepBeepMeep](https://github.com/deepbeepmeep)
81 |
82 | - RIFLEx (视频时序外拓): [RIFLEx](https://riflex-video.github.io/) by [Tsinghua University](https://riflex-video.github.io/)
83 | - HunyuanVideo Keyframe Control Lora (视频关键帧控制LoRA): [hunyuan-video-keyframe-control-lora](https://github.com/dashtoon/hunyuan-video-keyframe-control-lora) by [dashtoon](https://github.com/dashtoon)
84 | - Sparse-VideoGen (基于高像素级保真度的视频加速生成): [Sparse-VideoGen](https://github.com/svg-project/Sparse-VideoGen) by [University of California, Berkeley](https://svg-project.github.io/)
85 | - FramePack (将输入帧上下文打包到下一帧预测模型中用于视频生成): [FramePack](https://github.com/lllyasviel/FramePack) by [Lvmin Zhang](https://github.com/lllyasviel)
86 | - Jenga (加速采样): [Jenga](https://github.com/dvlab-research/Jenga) by [DV Lab](https://github.com/dvlab-research)
87 |
88 |
89 |
90 |
91 | ## 📑 开源计划
92 |
93 | - HunyuanVideo (文生视频模型)
94 | - [x] 推理代码
95 | - [x] 模型权重
96 | - [x] 多GPU序列并行推理(GPU 越多,推理速度越快)
97 | - [x] Web Demo (Gradio)
98 | - [x] Diffusers
99 | - [x] FP8 量化版本
100 | - [x] Penguin Video 基准测试集
101 | - [x] ComfyUI
102 | - [HunyuanVideo (图生视频模型)](https://github.com/Tencent-Hunyuan/HunyuanVideo-I2V)
103 | - [x] 推理代码
104 | - [x] 模型权重
105 |
106 |
107 |
108 | ## 目录
109 |
110 | - [HunyuanVideo: A Systematic Framework For Large Video Generation Model](#hunyuanvideo-a-systematic-framework-for-large-video-generation-model)
111 | - [🎥 作品展示](#-作品展示)
112 | - [🔥🔥🔥 更新!!](#-更新)
113 | - [🧩 社区贡献](#-社区贡献)
114 | - [📑 开源计划](#-开源计划)
115 | - [目录](#目录)
116 | - [**摘要**](#摘要)
117 | - [**HunyuanVideo 的架构**](#hunyuanvideo-的架构)
118 | - [🎉 **亮点**](#-亮点)
119 | - [**统一的图视频生成架构**](#统一的图视频生成架构)
120 | - [**MLLM 文本编码器**](#mllm-文本编码器)
121 | - [**3D VAE**](#3d-vae)
122 | - [**Prompt 改写**](#prompt-改写)
123 | - [📈 能力评估](#-能力评估)
124 | - [📜 运行配置](#-运行配置)
125 | - [🛠️ 安装和依赖](#️-安装和依赖)
126 | - [Linux 安装指引](#linux-安装指引)
127 | - [🧱 下载预训练模型](#-下载预训练模型)
128 | - [🔑 单卡推理](#-单卡推理)
129 | - [使用命令行](#使用命令行)
130 | - [运行gradio服务](#运行gradio服务)
131 | - [更多配置](#更多配置)
132 | - [🚀 使用 xDiT 实现多卡并行推理](#-使用-xdit-实现多卡并行推理)
133 | - [使用命令行](#使用命令行-1)
134 | - [🚀 FP8 Inference](#---fp8-inference)
135 | - [Using Command Line](#using-command-line)
136 | - [🔗 BibTeX](#-bibtex)
137 | - [致谢](#致谢)
138 | - [Star 趋势](#star-趋势)
139 | ---
140 |
141 |
142 |
143 | ## **摘要**
144 |
145 | HunyuanVideo 是一个全新的开源视频生成大模型,具有与领先的闭源模型相媲美甚至更优的视频生成表现。为了训练 HunyuanVideo,我们采用了一个全面的框架,集成了数据整理、图像-视频联合模型训练和高效的基础设施以支持大规模模型训练和推理。此外,通过有效的模型架构和数据集扩展策略,我们成功地训练了一个拥有超过 130 亿参数的视频生成模型,使其成为最大的开源视频生成模型之一。
146 |
147 | 我们在模型结构的设计上做了大量的实验以确保其能拥有高质量的视觉效果、多样的运动、文本-视频对齐和生成稳定性。根据专业人员的评估结果,HunyuanVideo 在综合指标上优于以往的最先进模型,包括 Runway Gen-3、Luma 1.6 和 3 个中文社区表现最好的视频生成模型。**通过开源基础模型和应用模型的代码和权重,我们旨在弥合闭源和开源视频基础模型之间的差距,帮助社区中的每个人都能够尝试自己的想法,促进更加动态和活跃的视频生成生态。**
148 |
149 |
150 |
151 | ## **HunyuanVideo 的架构**
152 |
153 | HunyuanVideo 是一个隐空间模型,训练时它采用了 3D VAE 压缩时间维度和空间维度的特征。文本提示通过一个大语言模型编码后作为条件输入模型,引导模型通过对高斯噪声的多步去噪,输出一个视频的隐空间表示。最后,推理时通过 3D VAE 解码器将隐空间表示解码为视频。
154 |
155 |
156 |
157 |
158 |
159 | ## 🎉 **亮点**
160 |
161 | ### **统一的图视频生成架构**
162 |
163 | HunyuanVideo 采用了 Transformer 和 Full Attention 的设计用于视频生成。具体来说,我们使用了一个“双流到单流”的混合模型设计用于视频生成。在双流阶段,视频和文本 token 通过并行的 Transformer Block 独立处理,使得每个模态可以学习适合自己的调制机制而不会相互干扰。在单流阶段,我们将视频和文本 token 连接起来并将它们输入到后续的 Transformer Block 中进行有效的多模态信息融合。这种设计捕捉了视觉和语义信息之间的复杂交互,增强了整体模型性能。
164 |
165 |
166 |
167 |
168 | ### **MLLM 文本编码器**
169 | 过去的视频生成模型通常使用预训练的 CLIP 和 T5-XXL 作为文本编码器,其中 CLIP 使用 Transformer Encoder,T5 使用 Encoder-Decoder 结构。HunyuanVideo 使用了一个预训练的 Multimodal Large Language Model (MLLM) 作为文本编码器,它具有以下优势:
170 | * 与 T5 相比,MLLM 基于图文数据指令微调后在特征空间中具有更好的图像-文本对齐能力,这减轻了扩散模型中的图文对齐的难度;
171 | * 与 CLIP 相比,MLLM 在图像的细节描述和复杂推理方面表现出更强的能力;
172 | * MLLM 可以通过遵循系统指令实现零样本生成,帮助文本特征更多地关注关键信息。
173 |
174 | 由于 MLLM 是基于 Causal Attention 的,而 T5-XXL 使用了 Bidirectional Attention 为扩散模型提供更好的文本引导。因此,我们引入了一个额外的 token 优化器来增强文本特征。
175 |
176 |
177 |
178 |
179 | ### **3D VAE**
180 | 我们的 VAE 采用了 CausalConv3D 作为 HunyuanVideo 的编码器和解码器,用于压缩视频的时间维度和空间维度,其中时间维度压缩 4 倍,空间维度压缩 8 倍,压缩为 16 channels。这样可以显著减少后续 Transformer 模型的 token 数量,使我们能够在原始分辨率和帧率下训练视频生成模型。
181 |
182 |
183 |
184 |
185 | ### **Prompt 改写**
186 | 为了解决用户输入文本提示的多样性和不一致性的困难,我们微调了 [Hunyuan-Large model](https://github.com/Tencent/Tencent-Hunyuan-Large) 模型作为我们的 prompt 改写模型,将用户输入的提示词改写为更适合模型偏好的写法。
187 |
188 | 我们提供了两个改写模式:正常模式和导演模式。两种模式的提示词见[这里](hyvideo/prompt_rewrite.py)。正常模式旨在增强视频生成模型对用户意图的理解,从而更准确地解释提供的指令。导演模式增强了诸如构图、光照和摄像机移动等方面的描述,倾向于生成视觉质量更高的视频。注意,这种增强有时可能会导致一些语义细节的丢失。
189 |
190 | Prompt 改写模型可以直接使用 [Hunyuan-Large](https://github.com/Tencent/Tencent-Hunyuan-Large) 部署和推理. 我们开源了 prompt 改写模型的权重,见[这里](https://huggingface.co/Tencent/HunyuanVideo-PromptRewrite).
191 |
192 |
193 |
194 | ## 📈 能力评估
195 |
196 | 为了评估 HunyuanVideo 的能力,我们选择了四个闭源视频生成模型作为对比。我们总共使用了 1,533 个 prompt,每个 prompt 通过一次推理生成了相同数量的视频样本。为了公平比较,我们只进行了一次推理以避免任何挑选。在与其他方法比较时,我们保持了所有选择模型的默认设置,并确保了视频分辨率的一致性。视频根据三个标准进行评估:文本对齐、运动质量和视觉质量。在 60 多名专业评估人员评估后,HunyuanVideo 在综合指标上表现最好,特别是在运动质量方面表现较为突出。
197 |
198 |
199 |
200 |
201 |
202 | 模型 | 是否开源 | 时长 | 文本对齐 | 运动质量 | 视觉质量 | 综合评价 | 排序 |
203 |
204 |
205 |
206 |
207 | HunyuanVideo (Ours) | ✔ | 5s | 61.8% | 66.5% | 95.7% | 41.3% | 1 |
208 |
209 |
210 | 国内模型 A (API) | ✘ | 5s | 62.6% | 61.7% | 95.6% | 37.7% | 2 |
211 |
212 |
213 | 国内模型 B (Web) | ✘ | 5s | 60.1% | 62.9% | 97.7% | 37.5% | 3 |
214 |
215 |
216 | GEN-3 alpha (Web) | ✘ | 6s | 47.7% | 54.7% | 97.5% | 27.4% | 4 |
217 |
218 |
219 | Luma1.6 (API) | ✘ | 5s | 57.6% | 44.2% | 94.1% | 24.8% | 5 |
220 |
221 |
222 |
223 |
224 |
225 |
226 | ## 📜 运行配置
227 |
228 | 下表列出了运行 HunyuanVideo 模型使用文本生成视频的推荐配置(batch size = 1):
229 |
230 | | 模型 | 分辨率
(height/width/frame) | 峰值显存 |
231 | |:--------------:|:--------------------------------:|:----------------:|
232 | | HunyuanVideo | 720px1280px129f | 60G |
233 | | HunyuanVideo | 544px960px129f | 45G |
234 |
235 | * 本项目适用于使用 NVIDIA GPU 和支持 CUDA 的设备
236 | * 模型在单张 80G GPU 上测试
237 | * 运行 720px1280px129f 的最小显存要求是 60GB,544px960px129f 的最小显存要求是 45GB。
238 | * 测试操作系统:Linux
239 |
240 |
241 |
242 | ## 🛠️ 安装和依赖
243 |
244 | 首先克隆 git 仓库:
245 | ```shell
246 | git clone https://github.com/Tencent-Hunyuan/HunyuanVideo
247 | cd HunyuanVideo
248 | ```
249 |
250 | ### Linux 安装指引
251 |
252 | 我们推荐使用 CUDA 12.4 或 11.8 的版本。
253 |
254 | Conda 的安装指南可以参考[这里](https://docs.anaconda.com/free/miniconda/index.html)。
255 |
256 | ```shell
257 | # 1. Create conda environment
258 | conda create -n HunyuanVideo python==3.10.9
259 |
260 | # 2. Activate the environment
261 | conda activate HunyuanVideo
262 |
263 | # 3. Install PyTorch and other dependencies using conda
264 | # For CUDA 11.8
265 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=11.8 -c pytorch -c nvidia
266 | # For CUDA 12.4
267 | conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.4 -c pytorch -c nvidia
268 |
269 | # 4. Install pip dependencies
270 | python -m pip install -r requirements.txt
271 |
272 | # 5. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
273 | python -m pip install ninja
274 | python -m pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3
275 |
276 | # 6. Install xDiT for parallel inference (It is recommended to use torch 2.4.0 and flash-attn 2.6.3)
277 | python -m pip install xfuser==0.4.0
278 | ```
279 |
280 | 如果在特定 GPU 型号上遭遇 float point exception(core dump) 问题,可尝试以下方案修复:
281 |
282 | ```shell
283 | #选项1:确保已正确安装 CUDA 12.4, CUBLAS>=12.4.5.8, 和 CUDNN>=9.00 (或直接使用我们提供的CUDA12镜像)
284 | pip install nvidia-cublas-cu12==12.4.5.8
285 | export LD_LIBRARY_PATH=/opt/conda/lib/python3.8/site-packages/nvidia/cublas/lib/
286 |
287 | #选项2:强制显式使用 CUDA11.8 编译的 Pytorch 版本以及其他所有软件包
288 | pip uninstall -r requirements.txt # 确保卸载所有依赖包
289 | pip uninstall -y xfuser
290 | pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu118
291 | pip install -r requirements.txt
292 | pip install ninja
293 | pip install git+https://github.com/Dao-AILab/flash-attention.git@v2.6.3
294 | pip install xfuser==0.4.0
295 | ```
296 |
297 | 另外,我们提供了一个预构建的 Docker 镜像,可以使用如下命令进行拉取和运行。
298 | ```shell
299 | # 用于 CUDA 12.4 (已更新避免 float point exception)
300 | docker pull hunyuanvideo/hunyuanvideo:cuda_12
301 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
302 |
303 | # 用于 CUDA 11.8
304 | docker pull hunyuanvideo/hunyuanvideo:cuda_11
305 | docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11
306 | ```
307 |
308 | ## 🧱 下载预训练模型
309 |
310 | 下载预训练模型参考[这里](ckpts/README.md)。
311 |
312 |
313 |
314 | ## 🔑 单卡推理
315 |
316 | 我们在下表中列出了支持的高度/宽度/帧数设置。
317 |
318 | | 分辨率 | h/w=9:16 | h/w=16:9 | h/w=4:3 | h/w=3:4 | h/w=1:1 |
319 | |:---------------------:|:----------------------------:|:---------------:|:---------------:|:---------------:|:---------------:|
320 | | 540p | 544px960px129f | 960px544px129f | 624px832px129f | 832px624px129f | 720px720px129f |
321 | | 720p (推荐) | 720px1280px129f | 1280px720px129f | 1104px832px129f | 832px1104px129f | 960px960px129f |
322 |
323 | ### 使用命令行
324 |
325 | ```bash
326 | cd HunyuanVideo
327 |
328 | python3 sample_video.py \
329 | --video-size 720 1280 \
330 | --video-length 129 \
331 | --infer-steps 50 \
332 | --prompt "A cat walks on the grass, realistic style." \
333 | --flow-reverse \
334 | --use-cpu-offload \
335 | --save-path ./results
336 | ```
337 |
338 | ### 运行gradio服务
339 | ```bash
340 | python3 gradio_server.py --flow-reverse
341 |
342 | # set SERVER_NAME and SERVER_PORT manually
343 | # SERVER_NAME=0.0.0.0 SERVER_PORT=8081 python3 gradio_server.py --flow-reverse
344 | ```
345 |
346 | ### 更多配置
347 |
348 | 下面列出了更多关键配置项:
349 |
350 | | 参数 | 默认值 | 描述 |
351 | |:----------------------:|:---------:|:-----------------------------------------:|
352 | | `--prompt` | None | 用于生成视频的 prompt |
353 | | `--video-size` | 720 1280 | 生成视频的高度和宽度 |
354 | | `--video-length` | 129 | 生成视频的帧数 |
355 | | `--infer-steps` | 50 | 生成时采样的步数 |
356 | | `--embedded-cfg-scale` | 6.0 | 文本的控制强度 |
357 | | `--flow-shift` | 7.0 | 推理时 timestep 的 shift 系数,值越大,高噪区域采样步数越多 |
358 | | `--flow-reverse` | False | If reverse, learning/sampling from t=1 -> t=0 |
359 | | `--neg-prompt` | None | 负向词 |
360 | | `--seed` | 0 | 随机种子 |
361 | | `--use-cpu-offload` | False | 启用 CPU offload,可以节省显存 |
362 | | `--save-path` | ./results | 保存路径 |
363 |
364 |
365 |
366 | ## 🚀 使用 xDiT 实现多卡并行推理
367 |
368 | [xDiT](https://github.com/xdit-project/xDiT) 是一个针对多 GPU 集群的扩展推理引擎,用于扩展 Transformers(DiTs)。
369 | 它成功为各种 DiT 模型(包括 mochi-1、CogVideoX、Flux.1、SD3 等)提供了低延迟的并行推理解决方案。该存储库采用了 [Unified Sequence Parallelism (USP)](https://arxiv.org/abs/2405.07719) API 用于混元视频模型的并行推理。
370 |
371 | ### 使用命令行
372 |
373 | 例如,可用如下命令使用8张GPU卡完成推理
374 |
375 | ```bash
376 | cd HunyuanVideo
377 |
378 | torchrun --nproc_per_node=8 sample_video_parallel.py \
379 | --video-size 1280 720 \
380 | --video-length 129 \
381 | --infer-steps 50 \
382 | --prompt "A cat walks on the grass, realistic style." \
383 | --flow-reverse \
384 | --seed 42 \
385 | --ulysses_degree 8 \
386 | --ring_degree 1 \
387 | --save-path ./results
388 | ```
389 |
390 | 可以配置`--ulysses-degree`和`--ring-degree`来控制并行配置,可选参数如下。
391 |
392 |
393 | 支持的并行配置 (点击查看详情)
394 |
395 | | --video-size | --video-length | --ulysses-degree x --ring-degree | --nproc_per_node |
396 | |----------------------|----------------|----------------------------------|------------------|
397 | | 1280 720 或 720 1280 | 129 | 8x1,4x2,2x4,1x8 | 8 |
398 | | 1280 720 或 720 1280 | 129 | 1x5 | 5 |
399 | | 1280 720 或 720 1280 | 129 | 4x1,2x2,1x4 | 4 |
400 | | 1280 720 或 720 1280 | 129 | 3x1,1x3 | 3 |
401 | | 1280 720 或 720 1280 | 129 | 2x1,1x2 | 2 |
402 | | 1104 832 或 832 1104 | 129 | 4x1,2x2,1x4 | 4 |
403 | | 1104 832 或 832 1104 | 129 | 3x1,1x3 | 3 |
404 | | 1104 832 或 832 1104 | 129 | 2x1,1x2 | 2 |
405 | | 960 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
406 | | 960 960 | 129 | 4x1,2x2,1x4 | 4 |
407 | | 960 960 | 129 | 3x1,1x3 | 3 |
408 | | 960 960 | 129 | 1x2,2x1 | 2 |
409 | | 960 544 或 544 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
410 | | 960 544 或 544 960 | 129 | 4x1,2x2,1x4 | 4 |
411 | | 960 544 或 544 960 | 129 | 3x1,1x3 | 3 |
412 | | 960 544 或 544 960 | 129 | 1x2,2x1 | 2 |
413 | | 832 624 或 624 832 | 129 | 4x1,2x2,1x4 | 4 |
414 | | 624 832 或 624 832 | 129 | 3x1,1x3 | 3 |
415 | | 832 624 或 624 832 | 129 | 2x1,1x2 | 2 |
416 | | 720 720 | 129 | 1x5 | 5 |
417 | | 720 720 | 129 | 3x1,1x3 | 3 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 | 在 8xGPU上生成1280x720 (129 帧 50 步)的时耗 (秒) |
426 |
427 |
428 | 1 |
429 | 2 |
430 | 4 |
431 | 8 |
432 |
433 |
434 |
435 |
436 | 1904.08 |
437 | 934.09 (2.04x) |
438 | 514.08 (3.70x) |
439 | 337.58 (5.64x) |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 | ## 🚀 FP8 Inference
449 |
450 | 使用FP8量化后的HunyuanVideo模型能够帮您节省大概10GB显存。 使用前需要从 Huggingface 下载[FP8权重](https://huggingface.co/tencent/HunyuanVideo/blob/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt)和每层量化权重的[scale参数](https://huggingface.co/tencent/HunyuanVideo/blob/main/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8_map.pt).
451 |
452 | ### Using Command Line
453 |
454 | 这里,您必须显示地指定FP8的权重路径。例如,可用如下命令使用FP8模型推理
455 |
456 | ```bash
457 | cd HunyuanVideo
458 |
459 | DIT_CKPT_PATH={PATH_TO_FP8_WEIGHTS}/{WEIGHT_NAME}_fp8.pt
460 |
461 | python3 sample_video.py \
462 | --dit-weight ${DIT_CKPT_PATH} \
463 | --video-size 1280 720 \
464 | --video-length 129 \
465 | --infer-steps 50 \
466 | --prompt "A cat walks on the grass, realistic style." \
467 | --seed 42 \
468 | --embedded-cfg-scale 6.0 \
469 | --flow-shift 7.0 \
470 | --flow-reverse \
471 | --use-cpu-offload \
472 | --use-fp8 \
473 | --save-path ./results
474 | ```
475 |
476 |
477 |
478 | ## 🔗 BibTeX
479 |
480 | 如果您认为 [HunyuanVideo](https://arxiv.org/abs/2412.03603) 给您的研究和应用带来了一些帮助,可以通过下面的方式来引用:
481 |
482 |
483 | ```BibTeX
484 | @article{kong2024hunyuanvideo,
485 | title={Hunyuanvideo: A systematic framework for large video generative models},
486 | author={Kong, Weijie and Tian, Qi and Zhang, Zijian and Min, Rox and Dai, Zuozhuo and Zhou, Jin and Xiong, Jiangfeng and Li, Xin and Wu, Bo and Zhang, Jianwei and others},
487 | journal={arXiv preprint arXiv:2412.03603},
488 | year={2024}
489 | }
490 | ```
491 |
492 |
493 |
494 | ## 致谢
495 |
496 | HunyuanVideo 的开源离不开诸多开源工作,这里我们特别感谢 [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) 的开源工作和探索。另外,我们也感谢腾讯混元多模态团队对 HunyuanVideo 适配多种文本编码器的支持。
497 |
498 |
499 |
500 | ## Star 趋势
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
--------------------------------------------------------------------------------
/assets/3dvae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/3dvae.png
--------------------------------------------------------------------------------
/assets/WECHAT.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
扫码关注混元系列工作,加入「 Hunyuan Video 交流群」
5 |
Scan the QR code to join the "Hunyuan Discussion Group"
6 |
7 |
8 |
--------------------------------------------------------------------------------
/assets/backbone.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/backbone.png
--------------------------------------------------------------------------------
/assets/hunyuanvideo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/hunyuanvideo.pdf
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/logo.png
--------------------------------------------------------------------------------
/assets/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/overall.png
--------------------------------------------------------------------------------
/assets/text_encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/text_encoder.png
--------------------------------------------------------------------------------
/assets/video_poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/video_poster.png
--------------------------------------------------------------------------------
/assets/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/assets/wechat.jpg
--------------------------------------------------------------------------------
/ckpts/README.md:
--------------------------------------------------------------------------------
1 | # Download Pretrained Models
2 |
3 | All models are stored in `HunyuanVideo/ckpts` by default, and the file structure is as follows
4 | ```shell
5 | HunyuanVideo
6 | ├──ckpts
7 | │ ├──README.md
8 | │ ├──hunyuan-video-t2v-720p
9 | │ │ ├──transformers
10 | │ │ │ ├──mp_rank_00_model_states.pt
11 | │ │ │ ├──mp_rank_00_model_states_fp8.pt
12 | │ │ │ ├──mp_rank_00_model_states_fp8_map.pt
13 | ├ │ ├──vae
14 | │ ├──text_encoder
15 | │ ├──text_encoder_2
16 | ├──...
17 | ```
18 |
19 | ## Download HunyuanVideo model
20 | To download the HunyuanVideo model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
21 |
22 | ```shell
23 | python -m pip install "huggingface_hub[cli]"
24 | ```
25 |
26 | Then download the model using the following commands:
27 |
28 | ```shell
29 | # Switch to the directory named 'HunyuanVideo'
30 | cd HunyuanVideo
31 | # Use the huggingface-cli tool to download HunyuanVideo model in HunyuanVideo/ckpts dir.
32 | # The download time may vary from 10 minutes to 1 hour depending on network conditions.
33 | huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
34 | ```
35 |
36 |
37 | 💡Tips for using huggingface-cli (network problem)
38 |
39 | ##### 1. Using HF-Mirror
40 |
41 | If you encounter slow download speeds in China, you can try a mirror to speed up the download process. For example,
42 |
43 | ```shell
44 | HF_ENDPOINT=https://hf-mirror.com huggingface-cli download tencent/HunyuanVideo --local-dir ./ckpts
45 | ```
46 |
47 | ##### 2. Resume Download
48 |
49 | `huggingface-cli` supports resuming downloads. If the download is interrupted, you can just rerun the download
50 | command to resume the download process.
51 |
52 | Note: If an `No such file or directory: 'ckpts/.huggingface/.gitignore.lock'` like error occurs during the download
53 | process, you can ignore the error and rerun the download command.
54 |
55 |
56 |
57 | ---
58 |
59 | ## Download Text Encoder
60 |
61 | HunyuanVideo uses an MLLM model and a CLIP model as text encoder.
62 |
63 | 1. MLLM model (text_encoder folder)
64 |
65 | HunyuanVideo supports different MLLMs (including HunyuanMLLM and open-source MLLM models). At this stage, we have not yet released HunyuanMLLM. We recommend the user in community to use [llava-llama-3-8b](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) provided by [Xtuer](https://huggingface.co/xtuner), which can be downloaded by the following command
66 |
67 | ```shell
68 | cd HunyuanVideo/ckpts
69 | huggingface-cli download xtuner/llava-llama-3-8b-v1_1-transformers --local-dir ./llava-llama-3-8b-v1_1-transformers
70 | ```
71 |
72 | In order to save GPU memory resources for model loading, we separate the language model parts of `llava-llama-3-8b-v1_1-transformers` into `text_encoder`.
73 | ```
74 | cd HunyuanVideo
75 | python hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py --input_dir ckpts/llava-llama-3-8b-v1_1-transformers --output_dir ckpts/text_encoder
76 | ```
77 |
78 | 2. CLIP model (text_encoder_2 folder)
79 |
80 | We use [CLIP](https://huggingface.co/openai/clip-vit-large-patch14) provided by [OpenAI](https://openai.com) as another text encoder, users in the community can download this model by the following command
81 |
82 | ```
83 | cd HunyuanVideo/ckpts
84 | huggingface-cli download openai/clip-vit-large-patch14 --local-dir ./text_encoder_2
85 | ```
86 |
--------------------------------------------------------------------------------
/gradio_server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from pathlib import Path
4 | from loguru import logger
5 | from datetime import datetime
6 | import gradio as gr
7 | import random
8 |
9 | from hyvideo.utils.file_utils import save_videos_grid
10 | from hyvideo.config import parse_args
11 | from hyvideo.inference import HunyuanVideoSampler
12 | from hyvideo.constants import NEGATIVE_PROMPT
13 |
14 | def initialize_model(model_path):
15 | args = parse_args()
16 | models_root_path = Path(model_path)
17 | if not models_root_path.exists():
18 | raise ValueError(f"`models_root` not exists: {models_root_path}")
19 |
20 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
21 | return hunyuan_video_sampler
22 |
23 | def generate_video(
24 | model,
25 | prompt,
26 | resolution,
27 | video_length,
28 | seed,
29 | num_inference_steps,
30 | guidance_scale,
31 | flow_shift,
32 | embedded_guidance_scale
33 | ):
34 | seed = None if seed == -1 else seed
35 | width, height = resolution.split("x")
36 | width, height = int(width), int(height)
37 | negative_prompt = "" # not applicable in the inference
38 |
39 | outputs = model.predict(
40 | prompt=prompt,
41 | height=height,
42 | width=width,
43 | video_length=video_length,
44 | seed=seed,
45 | negative_prompt=negative_prompt,
46 | infer_steps=num_inference_steps,
47 | guidance_scale=guidance_scale,
48 | num_videos_per_prompt=1,
49 | flow_shift=flow_shift,
50 | batch_size=1,
51 | embedded_guidance_scale=embedded_guidance_scale
52 | )
53 |
54 | samples = outputs['samples']
55 | sample = samples[0].unsqueeze(0)
56 |
57 | save_path = os.path.join(os.getcwd(), "gradio_outputs")
58 | os.makedirs(save_path, exist_ok=True)
59 |
60 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
61 | video_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][0]}_{outputs['prompts'][0][:100].replace('/','')}.mp4"
62 | save_videos_grid(sample, video_path, fps=24)
63 | logger.info(f'Sample saved to: {video_path}')
64 |
65 | return video_path
66 |
67 | def create_demo(model_path, save_path):
68 | model = initialize_model(model_path)
69 |
70 | with gr.Blocks() as demo:
71 | gr.Markdown("# Hunyuan Video Generation")
72 |
73 | with gr.Row():
74 | with gr.Column():
75 | prompt = gr.Textbox(label="Prompt", value="A cat walks on the grass, realistic style.")
76 | with gr.Row():
77 | resolution = gr.Dropdown(
78 | choices=[
79 | # 720p
80 | ("1280x720 (16:9, 720p)", "1280x720"),
81 | ("720x1280 (9:16, 720p)", "720x1280"),
82 | ("1104x832 (4:3, 720p)", "1104x832"),
83 | ("832x1104 (3:4, 720p)", "832x1104"),
84 | ("960x960 (1:1, 720p)", "960x960"),
85 | # 540p
86 | ("960x544 (16:9, 540p)", "960x544"),
87 | ("544x960 (9:16, 540p)", "544x960"),
88 | ("832x624 (4:3, 540p)", "832x624"),
89 | ("624x832 (3:4, 540p)", "624x832"),
90 | ("720x720 (1:1, 540p)", "720x720"),
91 | ],
92 | value="1280x720",
93 | label="Resolution"
94 | )
95 | video_length = gr.Dropdown(
96 | label="Video Length",
97 | choices=[
98 | ("2s(65f)", 65),
99 | ("5s(129f)", 129),
100 | ],
101 | value=129,
102 | )
103 | num_inference_steps = gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps")
104 | show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
105 | with gr.Row(visible=False) as advanced_row:
106 | with gr.Column():
107 | seed = gr.Number(value=-1, label="Seed (-1 for random)")
108 | guidance_scale = gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="Guidance Scale")
109 | flow_shift = gr.Slider(0.0, 10.0, value=7.0, step=0.1, label="Flow Shift")
110 | embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale")
111 | show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
112 | generate_btn = gr.Button("Generate")
113 |
114 | with gr.Column():
115 | output = gr.Video(label="Generated Video")
116 |
117 | generate_btn.click(
118 | fn=lambda *inputs: generate_video(model, *inputs),
119 | inputs=[
120 | prompt,
121 | resolution,
122 | video_length,
123 | seed,
124 | num_inference_steps,
125 | guidance_scale,
126 | flow_shift,
127 | embedded_guidance_scale
128 | ],
129 | outputs=output
130 | )
131 |
132 | return demo
133 |
134 | if __name__ == "__main__":
135 | os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
136 | server_name = os.getenv("SERVER_NAME", "0.0.0.0")
137 | server_port = int(os.getenv("SERVER_PORT", "8081"))
138 | args = parse_args()
139 | print(args)
140 | demo = create_demo(args.model_base, args.save_path)
141 | demo.launch(server_name=server_name, server_port=server_port)
--------------------------------------------------------------------------------
/hyvideo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/hyvideo/__init__.py
--------------------------------------------------------------------------------
/hyvideo/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .constants import *
3 | import re
4 | from .modules.models import HUNYUAN_VIDEO_CONFIG
5 |
6 |
7 | def parse_args(namespace=None):
8 | parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
9 |
10 | parser = add_network_args(parser)
11 | parser = add_extra_models_args(parser)
12 | parser = add_denoise_schedule_args(parser)
13 | parser = add_inference_args(parser)
14 | parser = add_parallel_args(parser)
15 |
16 | args = parser.parse_args(namespace=namespace)
17 | args = sanity_check_args(args)
18 |
19 | return args
20 |
21 |
22 | def add_network_args(parser: argparse.ArgumentParser):
23 | group = parser.add_argument_group(title="HunyuanVideo network args")
24 |
25 | # Main model
26 | group.add_argument(
27 | "--model",
28 | type=str,
29 | choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
30 | default="HYVideo-T/2-cfgdistill",
31 | )
32 | group.add_argument(
33 | "--latent-channels",
34 | type=str,
35 | default=16,
36 | help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
37 | "it still needs to match the latent channels of the VAE model.",
38 | )
39 | group.add_argument(
40 | "--precision",
41 | type=str,
42 | default="bf16",
43 | choices=PRECISIONS,
44 | help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
45 | )
46 |
47 | # RoPE
48 | group.add_argument(
49 | "--rope-theta", type=int, default=256, help="Theta used in RoPE."
50 | )
51 | return parser
52 |
53 |
54 | def add_extra_models_args(parser: argparse.ArgumentParser):
55 | group = parser.add_argument_group(
56 | title="Extra models args, including vae, text encoders and tokenizers)"
57 | )
58 |
59 | # - VAE
60 | group.add_argument(
61 | "--vae",
62 | type=str,
63 | default="884-16c-hy",
64 | choices=list(VAE_PATH),
65 | help="Name of the VAE model.",
66 | )
67 | group.add_argument(
68 | "--vae-precision",
69 | type=str,
70 | default="fp16",
71 | choices=PRECISIONS,
72 | help="Precision mode for the VAE model.",
73 | )
74 | group.add_argument(
75 | "--vae-tiling",
76 | action="store_true",
77 | help="Enable tiling for the VAE model to save GPU memory.",
78 | )
79 | group.set_defaults(vae_tiling=True)
80 |
81 | group.add_argument(
82 | "--text-encoder",
83 | type=str,
84 | default="llm",
85 | choices=list(TEXT_ENCODER_PATH),
86 | help="Name of the text encoder model.",
87 | )
88 | group.add_argument(
89 | "--text-encoder-precision",
90 | type=str,
91 | default="fp16",
92 | choices=PRECISIONS,
93 | help="Precision mode for the text encoder model.",
94 | )
95 | group.add_argument(
96 | "--text-states-dim",
97 | type=int,
98 | default=4096,
99 | help="Dimension of the text encoder hidden states.",
100 | )
101 | group.add_argument(
102 | "--text-len", type=int, default=256, help="Maximum length of the text input."
103 | )
104 | group.add_argument(
105 | "--tokenizer",
106 | type=str,
107 | default="llm",
108 | choices=list(TOKENIZER_PATH),
109 | help="Name of the tokenizer model.",
110 | )
111 | group.add_argument(
112 | "--prompt-template",
113 | type=str,
114 | default="dit-llm-encode",
115 | choices=PROMPT_TEMPLATE,
116 | help="Image prompt template for the decoder-only text encoder model.",
117 | )
118 | group.add_argument(
119 | "--prompt-template-video",
120 | type=str,
121 | default="dit-llm-encode-video",
122 | choices=PROMPT_TEMPLATE,
123 | help="Video prompt template for the decoder-only text encoder model.",
124 | )
125 | group.add_argument(
126 | "--hidden-state-skip-layer",
127 | type=int,
128 | default=2,
129 | help="Skip layer for hidden states.",
130 | )
131 | group.add_argument(
132 | "--apply-final-norm",
133 | action="store_true",
134 | help="Apply final normalization to the used text encoder hidden states.",
135 | )
136 |
137 | # - CLIP
138 | group.add_argument(
139 | "--text-encoder-2",
140 | type=str,
141 | default="clipL",
142 | choices=list(TEXT_ENCODER_PATH),
143 | help="Name of the second text encoder model.",
144 | )
145 | group.add_argument(
146 | "--text-encoder-precision-2",
147 | type=str,
148 | default="fp16",
149 | choices=PRECISIONS,
150 | help="Precision mode for the second text encoder model.",
151 | )
152 | group.add_argument(
153 | "--text-states-dim-2",
154 | type=int,
155 | default=768,
156 | help="Dimension of the second text encoder hidden states.",
157 | )
158 | group.add_argument(
159 | "--tokenizer-2",
160 | type=str,
161 | default="clipL",
162 | choices=list(TOKENIZER_PATH),
163 | help="Name of the second tokenizer model.",
164 | )
165 | group.add_argument(
166 | "--text-len-2",
167 | type=int,
168 | default=77,
169 | help="Maximum length of the second text input.",
170 | )
171 |
172 | return parser
173 |
174 |
175 | def add_denoise_schedule_args(parser: argparse.ArgumentParser):
176 | group = parser.add_argument_group(title="Denoise schedule args")
177 |
178 | group.add_argument(
179 | "--denoise-type",
180 | type=str,
181 | default="flow",
182 | help="Denoise type for noised inputs.",
183 | )
184 |
185 | # Flow Matching
186 | group.add_argument(
187 | "--flow-shift",
188 | type=float,
189 | default=7.0,
190 | help="Shift factor for flow matching schedulers.",
191 | )
192 | group.add_argument(
193 | "--flow-reverse",
194 | action="store_true",
195 | help="If reverse, learning/sampling from t=1 -> t=0.",
196 | )
197 | group.add_argument(
198 | "--flow-solver",
199 | type=str,
200 | default="euler",
201 | help="Solver for flow matching.",
202 | )
203 | group.add_argument(
204 | "--use-linear-quadratic-schedule",
205 | action="store_true",
206 | help="Use linear quadratic schedule for flow matching."
207 | "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
208 | )
209 | group.add_argument(
210 | "--linear-schedule-end",
211 | type=int,
212 | default=25,
213 | help="End step for linear quadratic schedule for flow matching.",
214 | )
215 |
216 | return parser
217 |
218 |
219 | def add_inference_args(parser: argparse.ArgumentParser):
220 | group = parser.add_argument_group(title="Inference args")
221 |
222 | # ======================== Model loads ========================
223 | group.add_argument(
224 | "--model-base",
225 | type=str,
226 | default="ckpts",
227 | help="Root path of all the models, including t2v models and extra models.",
228 | )
229 | group.add_argument(
230 | "--dit-weight",
231 | type=str,
232 | default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
233 | help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
234 | "1. If it is a file, load the model directly."
235 | "2. If it is a directory, search the model in the directory. Support two types of models: "
236 | "1) named `pytorch_model_*.pt`"
237 | "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
238 | )
239 | group.add_argument(
240 | "--model-resolution",
241 | type=str,
242 | default="540p",
243 | choices=["540p", "720p"],
244 | help="Root path of all the models, including t2v models and extra models.",
245 | )
246 | group.add_argument(
247 | "--load-key",
248 | type=str,
249 | default="module",
250 | help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
251 | )
252 | group.add_argument(
253 | "--use-cpu-offload",
254 | action="store_true",
255 | help="Use CPU offload for the model load.",
256 | )
257 |
258 | # ======================== Inference general setting ========================
259 | group.add_argument(
260 | "--batch-size",
261 | type=int,
262 | default=1,
263 | help="Batch size for inference and evaluation.",
264 | )
265 | group.add_argument(
266 | "--infer-steps",
267 | type=int,
268 | default=50,
269 | help="Number of denoising steps for inference.",
270 | )
271 | group.add_argument(
272 | "--disable-autocast",
273 | action="store_true",
274 | help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
275 | )
276 | group.add_argument(
277 | "--save-path",
278 | type=str,
279 | default="./results",
280 | help="Path to save the generated samples.",
281 | )
282 | group.add_argument(
283 | "--save-path-suffix",
284 | type=str,
285 | default="",
286 | help="Suffix for the directory of saved samples.",
287 | )
288 | group.add_argument(
289 | "--name-suffix",
290 | type=str,
291 | default="",
292 | help="Suffix for the names of saved samples.",
293 | )
294 | group.add_argument(
295 | "--num-videos",
296 | type=int,
297 | default=1,
298 | help="Number of videos to generate for each prompt.",
299 | )
300 | # ---sample size---
301 | group.add_argument(
302 | "--video-size",
303 | type=int,
304 | nargs="+",
305 | default=(720, 1280),
306 | help="Video size for training. If a single value is provided, it will be used for both height "
307 | "and width. If two values are provided, they will be used for height and width "
308 | "respectively.",
309 | )
310 | group.add_argument(
311 | "--video-length",
312 | type=int,
313 | default=129,
314 | help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
315 | )
316 | # --- prompt ---
317 | group.add_argument(
318 | "--prompt",
319 | type=str,
320 | default=None,
321 | help="Prompt for sampling during evaluation.",
322 | )
323 | group.add_argument(
324 | "--seed-type",
325 | type=str,
326 | default="auto",
327 | choices=["file", "random", "fixed", "auto"],
328 | help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
329 | "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
330 | "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
331 | "fixed `seed` value.",
332 | )
333 | group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
334 |
335 | # Classifier-Free Guidance
336 | group.add_argument(
337 | "--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
338 | )
339 | group.add_argument(
340 | "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
341 | )
342 | group.add_argument(
343 | "--embedded-cfg-scale",
344 | type=float,
345 | default=6.0,
346 | help="Embeded classifier free guidance scale.",
347 | )
348 |
349 | group.add_argument(
350 | "--use-fp8",
351 | action="store_true",
352 | help="Enable use fp8 for inference acceleration."
353 | )
354 |
355 | group.add_argument(
356 | "--reproduce",
357 | action="store_true",
358 | help="Enable reproducibility by setting random seeds and deterministic algorithms.",
359 | )
360 |
361 | return parser
362 |
363 |
364 | def add_parallel_args(parser: argparse.ArgumentParser):
365 | group = parser.add_argument_group(title="Parallel args")
366 |
367 | # ======================== Model loads ========================
368 | group.add_argument(
369 | "--ulysses-degree",
370 | type=int,
371 | default=1,
372 | help="Ulysses degree.",
373 | )
374 | group.add_argument(
375 | "--ring-degree",
376 | type=int,
377 | default=1,
378 | help="Ulysses degree.",
379 | )
380 |
381 | return parser
382 |
383 |
384 | def sanity_check_args(args):
385 | # VAE channels
386 | vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
387 | if not re.match(vae_pattern, args.vae):
388 | raise ValueError(
389 | f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
390 | )
391 | vae_channels = int(args.vae.split("-")[1][:-1])
392 | if args.latent_channels is None:
393 | args.latent_channels = vae_channels
394 | if vae_channels != args.latent_channels:
395 | raise ValueError(
396 | f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
397 | )
398 | return args
399 |
--------------------------------------------------------------------------------
/hyvideo/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | __all__ = [
5 | "C_SCALE",
6 | "PROMPT_TEMPLATE",
7 | "MODEL_BASE",
8 | "PRECISIONS",
9 | "NORMALIZATION_TYPE",
10 | "ACTIVATION_TYPE",
11 | "VAE_PATH",
12 | "TEXT_ENCODER_PATH",
13 | "TOKENIZER_PATH",
14 | "TEXT_PROJECTION",
15 | "DATA_TYPE",
16 | "NEGATIVE_PROMPT",
17 | ]
18 |
19 | PRECISION_TO_TYPE = {
20 | 'fp32': torch.float32,
21 | 'fp16': torch.float16,
22 | 'bf16': torch.bfloat16,
23 | }
24 |
25 | # =================== Constant Values =====================
26 | # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
27 | # overflow error when tensorboard logging values.
28 | C_SCALE = 1_000_000_000_000_000
29 |
30 | # When using decoder-only models, we must provide a prompt template to instruct the text encoder
31 | # on how to generate the text.
32 | # --------------------------------------------------------------------
33 | PROMPT_TEMPLATE_ENCODE = (
34 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
35 | "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
36 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
37 | )
38 | PROMPT_TEMPLATE_ENCODE_VIDEO = (
39 | "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
40 | "1. The main content and theme of the video."
41 | "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
42 | "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
43 | "4. background environment, light, style and atmosphere."
44 | "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
45 | "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
46 | )
47 |
48 | NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
49 |
50 | PROMPT_TEMPLATE = {
51 | "dit-llm-encode": {
52 | "template": PROMPT_TEMPLATE_ENCODE,
53 | "crop_start": 36,
54 | },
55 | "dit-llm-encode-video": {
56 | "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
57 | "crop_start": 95,
58 | },
59 | }
60 |
61 | # ======================= Model ======================
62 | PRECISIONS = {"fp32", "fp16", "bf16"}
63 | NORMALIZATION_TYPE = {"layer", "rms"}
64 | ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
65 |
66 | # =================== Model Path =====================
67 | MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
68 |
69 | # =================== Data =======================
70 | DATA_TYPE = {"image", "video", "image_video"}
71 |
72 | # 3D VAE
73 | VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
74 |
75 | # Text Encoder
76 | TEXT_ENCODER_PATH = {
77 | "clipL": f"{MODEL_BASE}/text_encoder_2",
78 | "llm": f"{MODEL_BASE}/text_encoder",
79 | }
80 |
81 | # Tokenizer
82 | TOKENIZER_PATH = {
83 | "clipL": f"{MODEL_BASE}/text_encoder_2",
84 | "llm": f"{MODEL_BASE}/text_encoder",
85 | }
86 |
87 | TEXT_PROJECTION = {
88 | "linear", # Default, an nn.Linear() layer
89 | "single_refiner", # Single TokenRefiner. Refer to LI-DiT
90 | }
91 |
--------------------------------------------------------------------------------
/hyvideo/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import HunyuanVideoPipeline
2 | from .schedulers import FlowMatchDiscreteScheduler
3 |
--------------------------------------------------------------------------------
/hyvideo/diffusion/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_hunyuan_video import HunyuanVideoPipeline
2 |
--------------------------------------------------------------------------------
/hyvideo/diffusion/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
2 |
--------------------------------------------------------------------------------
/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | #
16 | # Modified from diffusers==0.29.2
17 | #
18 | # ==============================================================================
19 |
20 | from dataclasses import dataclass
21 | from typing import Optional, Tuple, Union
22 |
23 | import numpy as np
24 | import torch
25 |
26 | from diffusers.configuration_utils import ConfigMixin, register_to_config
27 | from diffusers.utils import BaseOutput, logging
28 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
29 |
30 |
31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32 |
33 |
34 | @dataclass
35 | class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36 | """
37 | Output class for the scheduler's `step` function output.
38 |
39 | Args:
40 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42 | denoising loop.
43 | """
44 |
45 | prev_sample: torch.FloatTensor
46 |
47 |
48 | class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49 | """
50 | Euler scheduler.
51 |
52 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53 | methods the library implements for all schedulers such as loading and saving.
54 |
55 | Args:
56 | num_train_timesteps (`int`, defaults to 1000):
57 | The number of diffusion steps to train the model.
58 | timestep_spacing (`str`, defaults to `"linspace"`):
59 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61 | shift (`float`, defaults to 1.0):
62 | The shift value for the timestep schedule.
63 | reverse (`bool`, defaults to `True`):
64 | Whether to reverse the timestep schedule.
65 | """
66 |
67 | _compatibles = []
68 | order = 1
69 |
70 | @register_to_config
71 | def __init__(
72 | self,
73 | num_train_timesteps: int = 1000,
74 | shift: float = 1.0,
75 | reverse: bool = True,
76 | solver: str = "euler",
77 | n_tokens: Optional[int] = None,
78 | ):
79 | sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80 |
81 | if not reverse:
82 | sigmas = sigmas.flip(0)
83 |
84 | self.sigmas = sigmas
85 | # the value fed to model
86 | self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87 |
88 | self._step_index = None
89 | self._begin_index = None
90 |
91 | self.supported_solver = ["euler"]
92 | if solver not in self.supported_solver:
93 | raise ValueError(
94 | f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95 | )
96 |
97 | @property
98 | def step_index(self):
99 | """
100 | The index counter for current timestep. It will increase 1 after each scheduler step.
101 | """
102 | return self._step_index
103 |
104 | @property
105 | def begin_index(self):
106 | """
107 | The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108 | """
109 | return self._begin_index
110 |
111 | # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112 | def set_begin_index(self, begin_index: int = 0):
113 | """
114 | Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115 |
116 | Args:
117 | begin_index (`int`):
118 | The begin index for the scheduler.
119 | """
120 | self._begin_index = begin_index
121 |
122 | def _sigma_to_t(self, sigma):
123 | return sigma * self.config.num_train_timesteps
124 |
125 | def set_timesteps(
126 | self,
127 | num_inference_steps: int,
128 | device: Union[str, torch.device] = None,
129 | n_tokens: int = None,
130 | ):
131 | """
132 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133 |
134 | Args:
135 | num_inference_steps (`int`):
136 | The number of diffusion steps used when generating samples with a pre-trained model.
137 | device (`str` or `torch.device`, *optional*):
138 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139 | n_tokens (`int`, *optional*):
140 | Number of tokens in the input sequence.
141 | """
142 | self.num_inference_steps = num_inference_steps
143 |
144 | sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145 | sigmas = self.sd3_time_shift(sigmas)
146 |
147 | if not self.config.reverse:
148 | sigmas = 1 - sigmas
149 |
150 | self.sigmas = sigmas
151 | self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152 | dtype=torch.float32, device=device
153 | )
154 |
155 | # Reset step index
156 | self._step_index = None
157 |
158 | def index_for_timestep(self, timestep, schedule_timesteps=None):
159 | if schedule_timesteps is None:
160 | schedule_timesteps = self.timesteps
161 |
162 | indices = (schedule_timesteps == timestep).nonzero()
163 |
164 | # The sigma index that is taken for the **very** first `step`
165 | # is always the second index (or the last index if there is only 1)
166 | # This way we can ensure we don't accidentally skip a sigma in
167 | # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168 | pos = 1 if len(indices) > 1 else 0
169 |
170 | return indices[pos].item()
171 |
172 | def _init_step_index(self, timestep):
173 | if self.begin_index is None:
174 | if isinstance(timestep, torch.Tensor):
175 | timestep = timestep.to(self.timesteps.device)
176 | self._step_index = self.index_for_timestep(timestep)
177 | else:
178 | self._step_index = self._begin_index
179 |
180 | def scale_model_input(
181 | self, sample: torch.Tensor, timestep: Optional[int] = None
182 | ) -> torch.Tensor:
183 | return sample
184 |
185 | def sd3_time_shift(self, t: torch.Tensor):
186 | return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187 |
188 | def step(
189 | self,
190 | model_output: torch.FloatTensor,
191 | timestep: Union[float, torch.FloatTensor],
192 | sample: torch.FloatTensor,
193 | return_dict: bool = True,
194 | ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195 | """
196 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197 | process from the learned model outputs (most often the predicted noise).
198 |
199 | Args:
200 | model_output (`torch.FloatTensor`):
201 | The direct output from learned diffusion model.
202 | timestep (`float`):
203 | The current discrete timestep in the diffusion chain.
204 | sample (`torch.FloatTensor`):
205 | A current instance of a sample created by the diffusion process.
206 | generator (`torch.Generator`, *optional*):
207 | A random number generator.
208 | n_tokens (`int`, *optional*):
209 | Number of tokens in the input sequence.
210 | return_dict (`bool`):
211 | Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212 | tuple.
213 |
214 | Returns:
215 | [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216 | If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217 | returned, otherwise a tuple is returned where the first element is the sample tensor.
218 | """
219 |
220 | if (
221 | isinstance(timestep, int)
222 | or isinstance(timestep, torch.IntTensor)
223 | or isinstance(timestep, torch.LongTensor)
224 | ):
225 | raise ValueError(
226 | (
227 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229 | " one of the `scheduler.timesteps` as a timestep."
230 | ),
231 | )
232 |
233 | if self.step_index is None:
234 | self._init_step_index(timestep)
235 |
236 | # Upcast to avoid precision issues when computing prev_sample
237 | sample = sample.to(torch.float32)
238 |
239 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240 |
241 | if self.config.solver == "euler":
242 | prev_sample = sample + model_output.to(torch.float32) * dt
243 | else:
244 | raise ValueError(
245 | f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246 | )
247 |
248 | # upon completion increase step index by one
249 | self._step_index += 1
250 |
251 | if not return_dict:
252 | return (prev_sample,)
253 |
254 | return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255 |
256 | def __len__(self):
257 | return self.config.num_train_timesteps
258 |
--------------------------------------------------------------------------------
/hyvideo/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2 |
3 |
4 | def load_model(args, in_channels, out_channels, factor_kwargs):
5 | """load hunyuan video model
6 |
7 | Args:
8 | args (dict): model args
9 | in_channels (int): input channels number
10 | out_channels (int): output channels number
11 | factor_kwargs (dict): factor kwargs
12 |
13 | Returns:
14 | model (nn.Module): The hunyuan video model
15 | """
16 | if args.model in HUNYUAN_VIDEO_CONFIG.keys():
17 | model = HYVideoDiffusionTransformer(
18 | args,
19 | in_channels=in_channels,
20 | out_channels=out_channels,
21 | **HUNYUAN_VIDEO_CONFIG[args.model],
22 | **factor_kwargs,
23 | )
24 | return model
25 | else:
26 | raise NotImplementedError()
27 |
--------------------------------------------------------------------------------
/hyvideo/modules/activation_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def get_activation_layer(act_type):
5 | """get activation layer
6 |
7 | Args:
8 | act_type (str): the activation type
9 |
10 | Returns:
11 | torch.nn.functional: the activation layer
12 | """
13 | if act_type == "gelu":
14 | return lambda: nn.GELU()
15 | elif act_type == "gelu_tanh":
16 | # Approximate `tanh` requires torch >= 1.13
17 | return lambda: nn.GELU(approximate="tanh")
18 | elif act_type == "relu":
19 | return nn.ReLU
20 | elif act_type == "silu":
21 | return nn.SiLU
22 | else:
23 | raise ValueError(f"Unknown activation type: {act_type}")
24 |
--------------------------------------------------------------------------------
/hyvideo/modules/attenion.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | try:
9 | import flash_attn
10 | from flash_attn.flash_attn_interface import _flash_attn_forward
11 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
12 | except ImportError:
13 | flash_attn = None
14 | flash_attn_varlen_func = None
15 | _flash_attn_forward = None
16 |
17 |
18 | MEMORY_LAYOUT = {
19 | "flash": (
20 | lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
21 | lambda x: x,
22 | ),
23 | "torch": (
24 | lambda x: x.transpose(1, 2),
25 | lambda x: x.transpose(1, 2),
26 | ),
27 | "vanilla": (
28 | lambda x: x.transpose(1, 2),
29 | lambda x: x.transpose(1, 2),
30 | ),
31 | }
32 |
33 |
34 | def get_cu_seqlens(text_mask, img_len):
35 | """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
36 |
37 | Args:
38 | text_mask (torch.Tensor): the mask of text
39 | img_len (int): the length of image
40 |
41 | Returns:
42 | torch.Tensor: the calculated cu_seqlens for flash attention
43 | """
44 | batch_size = text_mask.shape[0]
45 | text_len = text_mask.sum(dim=1)
46 | max_len = text_mask.shape[1] + img_len
47 |
48 | cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
49 |
50 | for i in range(batch_size):
51 | s = text_len[i] + img_len
52 | s1 = i * max_len + s
53 | s2 = (i + 1) * max_len
54 | cu_seqlens[2 * i + 1] = s1
55 | cu_seqlens[2 * i + 2] = s2
56 |
57 | return cu_seqlens
58 |
59 |
60 | def attention(
61 | q,
62 | k,
63 | v,
64 | mode="flash",
65 | drop_rate=0,
66 | attn_mask=None,
67 | causal=False,
68 | cu_seqlens_q=None,
69 | cu_seqlens_kv=None,
70 | max_seqlen_q=None,
71 | max_seqlen_kv=None,
72 | batch_size=1,
73 | ):
74 | """
75 | Perform QKV self attention.
76 |
77 | Args:
78 | q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
79 | k (torch.Tensor): Key tensor with shape [b, s1, a, d]
80 | v (torch.Tensor): Value tensor with shape [b, s1, a, d]
81 | mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
82 | drop_rate (float): Dropout rate in attention map. (default: 0)
83 | attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
84 | (default: None)
85 | causal (bool): Whether to use causal attention. (default: False)
86 | cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
87 | used to index into q.
88 | cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
89 | used to index into kv.
90 | max_seqlen_q (int): The maximum sequence length in the batch of q.
91 | max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
92 |
93 | Returns:
94 | torch.Tensor: Output tensor after self attention with shape [b, s, ad]
95 | """
96 | pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
97 | q = pre_attn_layout(q)
98 | k = pre_attn_layout(k)
99 | v = pre_attn_layout(v)
100 |
101 | if mode == "torch":
102 | if attn_mask is not None and attn_mask.dtype != torch.bool:
103 | attn_mask = attn_mask.to(q.dtype)
104 | x = F.scaled_dot_product_attention(
105 | q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
106 | )
107 | elif mode == "flash":
108 | x = flash_attn_varlen_func(
109 | q,
110 | k,
111 | v,
112 | cu_seqlens_q,
113 | cu_seqlens_kv,
114 | max_seqlen_q,
115 | max_seqlen_kv,
116 | )
117 | # x with shape [(bxs), a, d]
118 | x = x.view(
119 | batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
120 | ) # reshape x to [b, s, a, d]
121 | elif mode == "vanilla":
122 | scale_factor = 1 / math.sqrt(q.size(-1))
123 |
124 | b, a, s, _ = q.shape
125 | s1 = k.size(2)
126 | attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
127 | if causal:
128 | # Only applied to self attention
129 | assert (
130 | attn_mask is None
131 | ), "Causal mask and attn_mask cannot be used together"
132 | temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
133 | diagonal=0
134 | )
135 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
136 | attn_bias.to(q.dtype)
137 |
138 | if attn_mask is not None:
139 | if attn_mask.dtype == torch.bool:
140 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
141 | else:
142 | attn_bias += attn_mask
143 |
144 | # TODO: Maybe force q and k to be float32 to avoid numerical overflow
145 | attn = (q @ k.transpose(-2, -1)) * scale_factor
146 | attn += attn_bias
147 | attn = attn.softmax(dim=-1)
148 | attn = torch.dropout(attn, p=drop_rate, train=True)
149 | x = attn @ v
150 | else:
151 | raise NotImplementedError(f"Unsupported attention mode: {mode}")
152 |
153 | x = post_attn_layout(x)
154 | b, s, a, d = x.shape
155 | out = x.reshape(b, s, -1)
156 | return out
157 |
158 |
159 | def parallel_attention(
160 | hybrid_seq_parallel_attn,
161 | q,
162 | k,
163 | v,
164 | img_q_len,
165 | img_kv_len,
166 | cu_seqlens_q,
167 | cu_seqlens_kv
168 | ):
169 | attn1 = hybrid_seq_parallel_attn(
170 | None,
171 | q[:, :img_q_len, :, :],
172 | k[:, :img_kv_len, :, :],
173 | v[:, :img_kv_len, :, :],
174 | dropout_p=0.0,
175 | causal=False,
176 | joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
177 | joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
178 | joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
179 | joint_strategy="rear",
180 | )
181 | if flash_attn.__version__ >= '2.7.0':
182 | attn2, *_ = _flash_attn_forward(
183 | q[:,cu_seqlens_q[1]:],
184 | k[:,cu_seqlens_kv[1]:],
185 | v[:,cu_seqlens_kv[1]:],
186 | dropout_p=0.0,
187 | softmax_scale=q.shape[-1] ** (-0.5),
188 | causal=False,
189 | window_size_left=-1,
190 | window_size_right=-1,
191 | softcap=0.0,
192 | alibi_slopes=None,
193 | return_softmax=False,
194 | )
195 | else:
196 | attn2, *_ = _flash_attn_forward(
197 | q[:,cu_seqlens_q[1]:],
198 | k[:,cu_seqlens_kv[1]:],
199 | v[:,cu_seqlens_kv[1]:],
200 | dropout_p=0.0,
201 | softmax_scale=q.shape[-1] ** (-0.5),
202 | causal=False,
203 | window_size=(-1, -1),
204 | softcap=0.0,
205 | alibi_slopes=None,
206 | return_softmax=False,
207 | )
208 | attn = torch.cat([attn1, attn2], dim=1)
209 | b, s, a, d = attn.shape
210 | attn = attn.reshape(b, s, -1)
211 |
212 | return attn
213 |
--------------------------------------------------------------------------------
/hyvideo/modules/embed_layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange, repeat
5 |
6 | from ..utils.helpers import to_2tuple
7 |
8 |
9 | class PatchEmbed(nn.Module):
10 | """2D Image to Patch Embedding
11 |
12 | Image to Patch Embedding using Conv2d
13 |
14 | A convolution based approach to patchifying a 2D image w/ embedding projection.
15 |
16 | Based on the impl in https://github.com/google-research/vision_transformer
17 |
18 | Hacked together by / Copyright 2020 Ross Wightman
19 |
20 | Remove the _assert function in forward function to be compatible with multi-resolution images.
21 | """
22 |
23 | def __init__(
24 | self,
25 | patch_size=16,
26 | in_chans=3,
27 | embed_dim=768,
28 | norm_layer=None,
29 | flatten=True,
30 | bias=True,
31 | dtype=None,
32 | device=None,
33 | ):
34 | factory_kwargs = {"dtype": dtype, "device": device}
35 | super().__init__()
36 | patch_size = to_2tuple(patch_size)
37 | self.patch_size = patch_size
38 | self.flatten = flatten
39 |
40 | self.proj = nn.Conv3d(
41 | in_chans,
42 | embed_dim,
43 | kernel_size=patch_size,
44 | stride=patch_size,
45 | bias=bias,
46 | **factory_kwargs
47 | )
48 | nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
49 | if bias:
50 | nn.init.zeros_(self.proj.bias)
51 |
52 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
53 |
54 | def forward(self, x):
55 | x = self.proj(x)
56 | if self.flatten:
57 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
58 | x = self.norm(x)
59 | return x
60 |
61 |
62 | class TextProjection(nn.Module):
63 | """
64 | Projects text embeddings. Also handles dropout for classifier-free guidance.
65 |
66 | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
67 | """
68 |
69 | def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
70 | factory_kwargs = {"dtype": dtype, "device": device}
71 | super().__init__()
72 | self.linear_1 = nn.Linear(
73 | in_features=in_channels,
74 | out_features=hidden_size,
75 | bias=True,
76 | **factory_kwargs
77 | )
78 | self.act_1 = act_layer()
79 | self.linear_2 = nn.Linear(
80 | in_features=hidden_size,
81 | out_features=hidden_size,
82 | bias=True,
83 | **factory_kwargs
84 | )
85 |
86 | def forward(self, caption):
87 | hidden_states = self.linear_1(caption)
88 | hidden_states = self.act_1(hidden_states)
89 | hidden_states = self.linear_2(hidden_states)
90 | return hidden_states
91 |
92 |
93 | def timestep_embedding(t, dim, max_period=10000):
94 | """
95 | Create sinusoidal timestep embeddings.
96 |
97 | Args:
98 | t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
99 | dim (int): the dimension of the output.
100 | max_period (int): controls the minimum frequency of the embeddings.
101 |
102 | Returns:
103 | embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
104 |
105 | .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
106 | """
107 | half = dim // 2
108 | freqs = torch.exp(
109 | -math.log(max_period)
110 | * torch.arange(start=0, end=half, dtype=torch.float32)
111 | / half
112 | ).to(device=t.device)
113 | args = t[:, None].float() * freqs[None]
114 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
115 | if dim % 2:
116 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
117 | return embedding
118 |
119 |
120 | class TimestepEmbedder(nn.Module):
121 | """
122 | Embeds scalar timesteps into vector representations.
123 | """
124 |
125 | def __init__(
126 | self,
127 | hidden_size,
128 | act_layer,
129 | frequency_embedding_size=256,
130 | max_period=10000,
131 | out_size=None,
132 | dtype=None,
133 | device=None,
134 | ):
135 | factory_kwargs = {"dtype": dtype, "device": device}
136 | super().__init__()
137 | self.frequency_embedding_size = frequency_embedding_size
138 | self.max_period = max_period
139 | if out_size is None:
140 | out_size = hidden_size
141 |
142 | self.mlp = nn.Sequential(
143 | nn.Linear(
144 | frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
145 | ),
146 | act_layer(),
147 | nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
148 | )
149 | nn.init.normal_(self.mlp[0].weight, std=0.02)
150 | nn.init.normal_(self.mlp[2].weight, std=0.02)
151 |
152 | def forward(self, t):
153 | t_freq = timestep_embedding(
154 | t, self.frequency_embedding_size, self.max_period
155 | ).type(self.mlp[0].weight.dtype)
156 | t_emb = self.mlp(t_freq)
157 | return t_emb
158 |
--------------------------------------------------------------------------------
/hyvideo/modules/fp8_optimization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8 | _bits = torch.tensor(bits)
9 | _mantissa_bit = torch.tensor(mantissa_bit)
10 | _sign_bits = torch.tensor(sign_bits)
11 | M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12 | E = _bits - _sign_bits - M
13 | bias = 2 ** (E - 1) - 1
14 | mantissa = 1
15 | for i in range(mantissa_bit - 1):
16 | mantissa += 1 / (2 ** (i+1))
17 | maxval = mantissa * 2 ** (2**E - 1 - bias)
18 | return maxval
19 |
20 | def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21 | """
22 | Default is E4M3.
23 | """
24 | bits = torch.tensor(bits)
25 | mantissa_bit = torch.tensor(mantissa_bit)
26 | sign_bits = torch.tensor(sign_bits)
27 | M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28 | E = bits - sign_bits - M
29 | bias = 2 ** (E - 1) - 1
30 | mantissa = 1
31 | for i in range(mantissa_bit - 1):
32 | mantissa += 1 / (2 ** (i+1))
33 | maxval = mantissa * 2 ** (2**E - 1 - bias)
34 | minval = - maxval
35 | minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36 | input_clamp = torch.min(torch.max(x, minval), maxval)
37 | log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38 | log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39 | # dequant
40 | qdq_out = torch.round(input_clamp / log_scales) * log_scales
41 | return qdq_out, log_scales
42 |
43 | def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44 | for i in range(len(x.shape) - 1):
45 | scale = scale.unsqueeze(-1)
46 | new_x = x / scale
47 | quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48 | return quant_dequant_x, scale, log_scales
49 |
50 | def fp8_activation_dequant(qdq_out, scale, dtype):
51 | qdq_out = qdq_out.type(dtype)
52 | quant_dequant_x = qdq_out * scale.to(dtype)
53 | return quant_dequant_x
54 |
55 | def fp8_linear_forward(cls, original_dtype, input):
56 | weight_dtype = cls.weight.dtype
57 | #####
58 | if cls.weight.dtype != torch.float8_e4m3fn:
59 | maxval = get_fp_maxval()
60 | scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61 | linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62 | linear_weight = linear_weight.to(torch.float8_e4m3fn)
63 | weight_dtype = linear_weight.dtype
64 | else:
65 | scale = cls.fp8_scale.to(cls.weight.device)
66 | linear_weight = cls.weight
67 | #####
68 |
69 | if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70 | if True or len(input.shape) == 3:
71 | cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72 | if cls.bias != None:
73 | output = F.linear(input, cls_dequant, cls.bias)
74 | else:
75 | output = F.linear(input, cls_dequant)
76 | return output
77 | else:
78 | return cls.original_forward(input.to(original_dtype))
79 | else:
80 | return cls.original_forward(input)
81 |
82 | def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83 | setattr(module, "fp8_matmul_enabled", True)
84 |
85 | # loading fp8 mapping file
86 | fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87 | if os.path.exists(fp8_map_path):
88 | fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89 | else:
90 | raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91 |
92 | fp8_layers = []
93 | for key, layer in module.named_modules():
94 | if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95 | fp8_layers.append(key)
96 | original_forward = layer.forward
97 | layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98 | setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99 | setattr(layer, "original_forward", original_forward)
100 | setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101 |
102 |
103 |
--------------------------------------------------------------------------------
/hyvideo/modules/mlp_layers.py:
--------------------------------------------------------------------------------
1 | # Modified from timm library:
2 | # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3 |
4 | from functools import partial
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .modulate_layers import modulate
10 | from ..utils.helpers import to_2tuple
11 |
12 |
13 | class MLP(nn.Module):
14 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15 |
16 | def __init__(
17 | self,
18 | in_channels,
19 | hidden_channels=None,
20 | out_features=None,
21 | act_layer=nn.GELU,
22 | norm_layer=None,
23 | bias=True,
24 | drop=0.0,
25 | use_conv=False,
26 | device=None,
27 | dtype=None,
28 | ):
29 | factory_kwargs = {"device": device, "dtype": dtype}
30 | super().__init__()
31 | out_features = out_features or in_channels
32 | hidden_channels = hidden_channels or in_channels
33 | bias = to_2tuple(bias)
34 | drop_probs = to_2tuple(drop)
35 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36 |
37 | self.fc1 = linear_layer(
38 | in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39 | )
40 | self.act = act_layer()
41 | self.drop1 = nn.Dropout(drop_probs[0])
42 | self.norm = (
43 | norm_layer(hidden_channels, **factory_kwargs)
44 | if norm_layer is not None
45 | else nn.Identity()
46 | )
47 | self.fc2 = linear_layer(
48 | hidden_channels, out_features, bias=bias[1], **factory_kwargs
49 | )
50 | self.drop2 = nn.Dropout(drop_probs[1])
51 |
52 | def forward(self, x):
53 | x = self.fc1(x)
54 | x = self.act(x)
55 | x = self.drop1(x)
56 | x = self.norm(x)
57 | x = self.fc2(x)
58 | x = self.drop2(x)
59 | return x
60 |
61 |
62 | #
63 | class MLPEmbedder(nn.Module):
64 | """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65 | def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66 | factory_kwargs = {"device": device, "dtype": dtype}
67 | super().__init__()
68 | self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69 | self.silu = nn.SiLU()
70 | self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71 |
72 | def forward(self, x: torch.Tensor) -> torch.Tensor:
73 | return self.out_layer(self.silu(self.in_layer(x)))
74 |
75 |
76 | class FinalLayer(nn.Module):
77 | """The final layer of DiT."""
78 |
79 | def __init__(
80 | self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81 | ):
82 | factory_kwargs = {"device": device, "dtype": dtype}
83 | super().__init__()
84 |
85 | # Just use LayerNorm for the final layer
86 | self.norm_final = nn.LayerNorm(
87 | hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88 | )
89 | if isinstance(patch_size, int):
90 | self.linear = nn.Linear(
91 | hidden_size,
92 | patch_size * patch_size * out_channels,
93 | bias=True,
94 | **factory_kwargs
95 | )
96 | else:
97 | self.linear = nn.Linear(
98 | hidden_size,
99 | patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100 | bias=True,
101 | )
102 | nn.init.zeros_(self.linear.weight)
103 | nn.init.zeros_(self.linear.bias)
104 |
105 | # Here we don't distinguish between the modulate types. Just use the simple one.
106 | self.adaLN_modulation = nn.Sequential(
107 | act_layer(),
108 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109 | )
110 | # Zero-initialize the modulation
111 | nn.init.zeros_(self.adaLN_modulation[1].weight)
112 | nn.init.zeros_(self.adaLN_modulation[1].bias)
113 |
114 | def forward(self, x, c):
115 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116 | x = modulate(self.norm_final(x), shift=shift, scale=scale)
117 | x = self.linear(x)
118 | return x
119 |
--------------------------------------------------------------------------------
/hyvideo/modules/modulate_layers.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ModulateDiT(nn.Module):
8 | """Modulation layer for DiT."""
9 | def __init__(
10 | self,
11 | hidden_size: int,
12 | factor: int,
13 | act_layer: Callable,
14 | dtype=None,
15 | device=None,
16 | ):
17 | factory_kwargs = {"dtype": dtype, "device": device}
18 | super().__init__()
19 | self.act = act_layer()
20 | self.linear = nn.Linear(
21 | hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22 | )
23 | # Zero-initialize the modulation
24 | nn.init.zeros_(self.linear.weight)
25 | nn.init.zeros_(self.linear.bias)
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | return self.linear(self.act(x))
29 |
30 |
31 | def modulate(x, shift=None, scale=None):
32 | """modulate by shift and scale
33 |
34 | Args:
35 | x (torch.Tensor): input tensor.
36 | shift (torch.Tensor, optional): shift tensor. Defaults to None.
37 | scale (torch.Tensor, optional): scale tensor. Defaults to None.
38 |
39 | Returns:
40 | torch.Tensor: the output tensor after modulate.
41 | """
42 | if scale is None and shift is None:
43 | return x
44 | elif shift is None:
45 | return x * (1 + scale.unsqueeze(1))
46 | elif scale is None:
47 | return x + shift.unsqueeze(1)
48 | else:
49 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50 |
51 |
52 | def apply_gate(x, gate=None, tanh=False):
53 | """AI is creating summary for apply_gate
54 |
55 | Args:
56 | x (torch.Tensor): input tensor.
57 | gate (torch.Tensor, optional): gate tensor. Defaults to None.
58 | tanh (bool, optional): whether to use tanh function. Defaults to False.
59 |
60 | Returns:
61 | torch.Tensor: the output tensor after apply gate.
62 | """
63 | if gate is None:
64 | return x
65 | if tanh:
66 | return x * gate.unsqueeze(1).tanh()
67 | else:
68 | return x * gate.unsqueeze(1)
69 |
70 |
71 | def ckpt_wrapper(module):
72 | def ckpt_forward(*inputs):
73 | outputs = module(*inputs)
74 | return outputs
75 |
76 | return ckpt_forward
77 |
--------------------------------------------------------------------------------
/hyvideo/modules/norm_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSNorm(nn.Module):
6 | def __init__(
7 | self,
8 | dim: int,
9 | elementwise_affine=True,
10 | eps: float = 1e-6,
11 | device=None,
12 | dtype=None,
13 | ):
14 | """
15 | Initialize the RMSNorm normalization layer.
16 |
17 | Args:
18 | dim (int): The dimension of the input tensor.
19 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20 |
21 | Attributes:
22 | eps (float): A small value added to the denominator for numerical stability.
23 | weight (nn.Parameter): Learnable scaling parameter.
24 |
25 | """
26 | factory_kwargs = {"device": device, "dtype": dtype}
27 | super().__init__()
28 | self.eps = eps
29 | if elementwise_affine:
30 | self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31 |
32 | def _norm(self, x):
33 | """
34 | Apply the RMSNorm normalization to the input tensor.
35 |
36 | Args:
37 | x (torch.Tensor): The input tensor.
38 |
39 | Returns:
40 | torch.Tensor: The normalized tensor.
41 |
42 | """
43 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44 |
45 | def forward(self, x):
46 | """
47 | Forward pass through the RMSNorm layer.
48 |
49 | Args:
50 | x (torch.Tensor): The input tensor.
51 |
52 | Returns:
53 | torch.Tensor: The output tensor after applying RMSNorm.
54 |
55 | """
56 | output = self._norm(x.float()).type_as(x)
57 | if hasattr(self, "weight"):
58 | output = output * self.weight
59 | return output
60 |
61 |
62 | def get_norm_layer(norm_layer):
63 | """
64 | Get the normalization layer.
65 |
66 | Args:
67 | norm_layer (str): The type of normalization layer.
68 |
69 | Returns:
70 | norm_layer (nn.Module): The normalization layer.
71 | """
72 | if norm_layer == "layer":
73 | return nn.LayerNorm
74 | elif norm_layer == "rms":
75 | return RMSNorm
76 | else:
77 | raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
78 |
--------------------------------------------------------------------------------
/hyvideo/modules/posemb_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple, List
3 |
4 |
5 | def _to_tuple(x, dim=2):
6 | if isinstance(x, int):
7 | return (x,) * dim
8 | elif len(x) == dim:
9 | return x
10 | else:
11 | raise ValueError(f"Expected length {dim} or int, but got {x}")
12 |
13 |
14 | def get_meshgrid_nd(start, *args, dim=2):
15 | """
16 | Get n-D meshgrid with start, stop and num.
17 |
18 | Args:
19 | start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20 | step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21 | should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22 | n-tuples.
23 | *args: See above.
24 | dim (int): Dimension of the meshgrid. Defaults to 2.
25 |
26 | Returns:
27 | grid (np.ndarray): [dim, ...]
28 | """
29 | if len(args) == 0:
30 | # start is grid_size
31 | num = _to_tuple(start, dim=dim)
32 | start = (0,) * dim
33 | stop = num
34 | elif len(args) == 1:
35 | # start is start, args[0] is stop, step is 1
36 | start = _to_tuple(start, dim=dim)
37 | stop = _to_tuple(args[0], dim=dim)
38 | num = [stop[i] - start[i] for i in range(dim)]
39 | elif len(args) == 2:
40 | # start is start, args[0] is stop, args[1] is num
41 | start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42 | stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43 | num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44 | else:
45 | raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46 |
47 | # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48 | axis_grid = []
49 | for i in range(dim):
50 | a, b, n = start[i], stop[i], num[i]
51 | g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52 | axis_grid.append(g)
53 | grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54 | grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55 |
56 | return grid
57 |
58 |
59 | #################################################################################
60 | # Rotary Positional Embedding Functions #
61 | #################################################################################
62 | # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63 |
64 |
65 | def reshape_for_broadcast(
66 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67 | x: torch.Tensor,
68 | head_first=False,
69 | ):
70 | """
71 | Reshape frequency tensor for broadcasting it with another tensor.
72 |
73 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74 | for the purpose of broadcasting the frequency tensor during element-wise operations.
75 |
76 | Notes:
77 | When using FlashMHAModified, head_first should be False.
78 | When using Attention, head_first should be True.
79 |
80 | Args:
81 | freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82 | x (torch.Tensor): Target tensor for broadcasting compatibility.
83 | head_first (bool): head dimension first (except batch dim) or not.
84 |
85 | Returns:
86 | torch.Tensor: Reshaped frequency tensor.
87 |
88 | Raises:
89 | AssertionError: If the frequency tensor doesn't match the expected shape.
90 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91 | """
92 | ndim = x.ndim
93 | assert 0 <= 1 < ndim
94 |
95 | if isinstance(freqs_cis, tuple):
96 | # freqs_cis: (cos, sin) in real space
97 | if head_first:
98 | assert freqs_cis[0].shape == (
99 | x.shape[-2],
100 | x.shape[-1],
101 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102 | shape = [
103 | d if i == ndim - 2 or i == ndim - 1 else 1
104 | for i, d in enumerate(x.shape)
105 | ]
106 | else:
107 | assert freqs_cis[0].shape == (
108 | x.shape[1],
109 | x.shape[-1],
110 | ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112 | return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113 | else:
114 | # freqs_cis: values in complex space
115 | if head_first:
116 | assert freqs_cis.shape == (
117 | x.shape[-2],
118 | x.shape[-1],
119 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120 | shape = [
121 | d if i == ndim - 2 or i == ndim - 1 else 1
122 | for i, d in enumerate(x.shape)
123 | ]
124 | else:
125 | assert freqs_cis.shape == (
126 | x.shape[1],
127 | x.shape[-1],
128 | ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130 | return freqs_cis.view(*shape)
131 |
132 |
133 | def rotate_half(x):
134 | x_real, x_imag = (
135 | x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136 | ) # [B, S, H, D//2]
137 | return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138 |
139 |
140 | def apply_rotary_emb(
141 | xq: torch.Tensor,
142 | xk: torch.Tensor,
143 | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144 | head_first: bool = False,
145 | ) -> Tuple[torch.Tensor, torch.Tensor]:
146 | """
147 | Apply rotary embeddings to input tensors using the given frequency tensor.
148 |
149 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152 | returned as real tensors.
153 |
154 | Args:
155 | xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156 | xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157 | freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158 | head_first (bool): head dimension first (except batch dim) or not.
159 |
160 | Returns:
161 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162 |
163 | """
164 | xk_out = None
165 | if isinstance(freqs_cis, tuple):
166 | cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167 | cos, sin = cos.to(xq.device), sin.to(xq.device)
168 | # real * cos - imag * sin
169 | # imag * cos + real * sin
170 | xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171 | xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172 | else:
173 | # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174 | xq_ = torch.view_as_complex(
175 | xq.float().reshape(*xq.shape[:-1], -1, 2)
176 | ) # [B, S, H, D//2]
177 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178 | xq.device
179 | ) # [S, D//2] --> [1, S, 1, D//2]
180 | # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181 | # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183 | xk_ = torch.view_as_complex(
184 | xk.float().reshape(*xk.shape[:-1], -1, 2)
185 | ) # [B, S, H, D//2]
186 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187 |
188 | return xq_out, xk_out
189 |
190 |
191 | def get_nd_rotary_pos_embed(
192 | rope_dim_list,
193 | start,
194 | *args,
195 | theta=10000.0,
196 | use_real=False,
197 | theta_rescale_factor: Union[float, List[float]] = 1.0,
198 | interpolation_factor: Union[float, List[float]] = 1.0,
199 | ):
200 | """
201 | This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202 |
203 | Args:
204 | rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205 | sum(rope_dim_list) should equal to head_dim of attention layer.
206 | start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207 | args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208 | *args: See above.
209 | theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210 | use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211 | Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212 | part and an imaginary part separately.
213 | theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214 |
215 | Returns:
216 | pos_embed (torch.Tensor): [HW, D/2]
217 | """
218 |
219 | grid = get_meshgrid_nd(
220 | start, *args, dim=len(rope_dim_list)
221 | ) # [3, W, H, D] / [2, W, H]
222 |
223 | if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224 | theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225 | elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226 | theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227 | assert len(theta_rescale_factor) == len(
228 | rope_dim_list
229 | ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230 |
231 | if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232 | interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233 | elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234 | interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235 | assert len(interpolation_factor) == len(
236 | rope_dim_list
237 | ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238 |
239 | # use 1/ndim of dimensions to encode grid_axis
240 | embs = []
241 | for i in range(len(rope_dim_list)):
242 | emb = get_1d_rotary_pos_embed(
243 | rope_dim_list[i],
244 | grid[i].reshape(-1),
245 | theta,
246 | use_real=use_real,
247 | theta_rescale_factor=theta_rescale_factor[i],
248 | interpolation_factor=interpolation_factor[i],
249 | ) # 2 x [WHD, rope_dim_list[i]]
250 | embs.append(emb)
251 |
252 | if use_real:
253 | cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254 | sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255 | return cos, sin
256 | else:
257 | emb = torch.cat(embs, dim=1) # (WHD, D/2)
258 | return emb
259 |
260 |
261 | def get_1d_rotary_pos_embed(
262 | dim: int,
263 | pos: Union[torch.FloatTensor, int],
264 | theta: float = 10000.0,
265 | use_real: bool = False,
266 | theta_rescale_factor: float = 1.0,
267 | interpolation_factor: float = 1.0,
268 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269 | """
270 | Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271 | (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272 |
273 | This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274 | and the end index 'end'. The 'theta' parameter scales the frequencies.
275 | The returned tensor contains complex values in complex64 data type.
276 |
277 | Args:
278 | dim (int): Dimension of the frequency tensor.
279 | pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281 | use_real (bool, optional): If True, return real part and imaginary part separately.
282 | Otherwise, return complex numbers.
283 | theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284 |
285 | Returns:
286 | freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287 | freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288 | """
289 | if isinstance(pos, int):
290 | pos = torch.arange(pos).float()
291 |
292 | # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293 | # has some connection to NTK literature
294 | if theta_rescale_factor != 1.0:
295 | theta *= theta_rescale_factor ** (dim / (dim - 2))
296 |
297 | freqs = 1.0 / (
298 | theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299 | ) # [D/2]
300 | # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301 | freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302 | if use_real:
303 | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304 | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305 | return freqs_cos, freqs_sin
306 | else:
307 | freqs_cis = torch.polar(
308 | torch.ones_like(freqs), freqs
309 | ) # complex64 # [S, D/2]
310 | return freqs_cis
311 |
--------------------------------------------------------------------------------
/hyvideo/modules/token_refiner.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from einops import rearrange
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .activation_layers import get_activation_layer
8 | from .attenion import attention
9 | from .norm_layers import get_norm_layer
10 | from .embed_layers import TimestepEmbedder, TextProjection
11 | from .attenion import attention
12 | from .mlp_layers import MLP
13 | from .modulate_layers import modulate, apply_gate
14 |
15 |
16 | class IndividualTokenRefinerBlock(nn.Module):
17 | def __init__(
18 | self,
19 | hidden_size,
20 | heads_num,
21 | mlp_width_ratio: str = 4.0,
22 | mlp_drop_rate: float = 0.0,
23 | act_type: str = "silu",
24 | qk_norm: bool = False,
25 | qk_norm_type: str = "layer",
26 | qkv_bias: bool = True,
27 | dtype: Optional[torch.dtype] = None,
28 | device: Optional[torch.device] = None,
29 | ):
30 | factory_kwargs = {"device": device, "dtype": dtype}
31 | super().__init__()
32 | self.heads_num = heads_num
33 | head_dim = hidden_size // heads_num
34 | mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35 |
36 | self.norm1 = nn.LayerNorm(
37 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
38 | )
39 | self.self_attn_qkv = nn.Linear(
40 | hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
41 | )
42 | qk_norm_layer = get_norm_layer(qk_norm_type)
43 | self.self_attn_q_norm = (
44 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
45 | if qk_norm
46 | else nn.Identity()
47 | )
48 | self.self_attn_k_norm = (
49 | qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
50 | if qk_norm
51 | else nn.Identity()
52 | )
53 | self.self_attn_proj = nn.Linear(
54 | hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
55 | )
56 |
57 | self.norm2 = nn.LayerNorm(
58 | hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
59 | )
60 | act_layer = get_activation_layer(act_type)
61 | self.mlp = MLP(
62 | in_channels=hidden_size,
63 | hidden_channels=mlp_hidden_dim,
64 | act_layer=act_layer,
65 | drop=mlp_drop_rate,
66 | **factory_kwargs,
67 | )
68 |
69 | self.adaLN_modulation = nn.Sequential(
70 | act_layer(),
71 | nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
72 | )
73 | # Zero-initialize the modulation
74 | nn.init.zeros_(self.adaLN_modulation[1].weight)
75 | nn.init.zeros_(self.adaLN_modulation[1].bias)
76 |
77 | def forward(
78 | self,
79 | x: torch.Tensor,
80 | c: torch.Tensor, # timestep_aware_representations + context_aware_representations
81 | attn_mask: torch.Tensor = None,
82 | ):
83 | gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
84 |
85 | norm_x = self.norm1(x)
86 | qkv = self.self_attn_qkv(norm_x)
87 | q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
88 | # Apply QK-Norm if needed
89 | q = self.self_attn_q_norm(q).to(v)
90 | k = self.self_attn_k_norm(k).to(v)
91 |
92 | # Self-Attention
93 | attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
94 |
95 | x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
96 |
97 | # FFN Layer
98 | x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
99 |
100 | return x
101 |
102 |
103 | class IndividualTokenRefiner(nn.Module):
104 | def __init__(
105 | self,
106 | hidden_size,
107 | heads_num,
108 | depth,
109 | mlp_width_ratio: float = 4.0,
110 | mlp_drop_rate: float = 0.0,
111 | act_type: str = "silu",
112 | qk_norm: bool = False,
113 | qk_norm_type: str = "layer",
114 | qkv_bias: bool = True,
115 | dtype: Optional[torch.dtype] = None,
116 | device: Optional[torch.device] = None,
117 | ):
118 | factory_kwargs = {"device": device, "dtype": dtype}
119 | super().__init__()
120 | self.blocks = nn.ModuleList(
121 | [
122 | IndividualTokenRefinerBlock(
123 | hidden_size=hidden_size,
124 | heads_num=heads_num,
125 | mlp_width_ratio=mlp_width_ratio,
126 | mlp_drop_rate=mlp_drop_rate,
127 | act_type=act_type,
128 | qk_norm=qk_norm,
129 | qk_norm_type=qk_norm_type,
130 | qkv_bias=qkv_bias,
131 | **factory_kwargs,
132 | )
133 | for _ in range(depth)
134 | ]
135 | )
136 |
137 | def forward(
138 | self,
139 | x: torch.Tensor,
140 | c: torch.LongTensor,
141 | mask: Optional[torch.Tensor] = None,
142 | ):
143 | self_attn_mask = None
144 | if mask is not None:
145 | batch_size = mask.shape[0]
146 | seq_len = mask.shape[1]
147 | mask = mask.to(x.device)
148 | # batch_size x 1 x seq_len x seq_len
149 | self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
150 | 1, 1, seq_len, 1
151 | )
152 | # batch_size x 1 x seq_len x seq_len
153 | self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
154 | # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
155 | self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
156 | # avoids self-attention weight being NaN for padding tokens
157 | self_attn_mask[:, :, :, 0] = True
158 |
159 | for block in self.blocks:
160 | x = block(x, c, self_attn_mask)
161 | return x
162 |
163 |
164 | class SingleTokenRefiner(nn.Module):
165 | """
166 | A single token refiner block for llm text embedding refine.
167 | """
168 | def __init__(
169 | self,
170 | in_channels,
171 | hidden_size,
172 | heads_num,
173 | depth,
174 | mlp_width_ratio: float = 4.0,
175 | mlp_drop_rate: float = 0.0,
176 | act_type: str = "silu",
177 | qk_norm: bool = False,
178 | qk_norm_type: str = "layer",
179 | qkv_bias: bool = True,
180 | attn_mode: str = "torch",
181 | dtype: Optional[torch.dtype] = None,
182 | device: Optional[torch.device] = None,
183 | ):
184 | factory_kwargs = {"device": device, "dtype": dtype}
185 | super().__init__()
186 | self.attn_mode = attn_mode
187 | assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
188 |
189 | self.input_embedder = nn.Linear(
190 | in_channels, hidden_size, bias=True, **factory_kwargs
191 | )
192 |
193 | act_layer = get_activation_layer(act_type)
194 | # Build timestep embedding layer
195 | self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
196 | # Build context embedding layer
197 | self.c_embedder = TextProjection(
198 | in_channels, hidden_size, act_layer, **factory_kwargs
199 | )
200 |
201 | self.individual_token_refiner = IndividualTokenRefiner(
202 | hidden_size=hidden_size,
203 | heads_num=heads_num,
204 | depth=depth,
205 | mlp_width_ratio=mlp_width_ratio,
206 | mlp_drop_rate=mlp_drop_rate,
207 | act_type=act_type,
208 | qk_norm=qk_norm,
209 | qk_norm_type=qk_norm_type,
210 | qkv_bias=qkv_bias,
211 | **factory_kwargs,
212 | )
213 |
214 | def forward(
215 | self,
216 | x: torch.Tensor,
217 | t: torch.LongTensor,
218 | mask: Optional[torch.LongTensor] = None,
219 | ):
220 | timestep_aware_representations = self.t_embedder(t)
221 |
222 | if mask is None:
223 | context_aware_representations = x.mean(dim=1)
224 | else:
225 | mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
226 | context_aware_representations = (x * mask_float).sum(
227 | dim=1
228 | ) / mask_float.sum(dim=1)
229 | context_aware_representations = self.c_embedder(context_aware_representations)
230 | c = timestep_aware_representations + context_aware_representations
231 |
232 | x = self.input_embedder(x)
233 |
234 | x = self.individual_token_refiner(x, c, mask)
235 |
236 | return x
237 |
--------------------------------------------------------------------------------
/hyvideo/prompt_rewrite.py:
--------------------------------------------------------------------------------
1 | normal_mode_prompt = """Normal mode - Video Recaption Task:
2 |
3 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4 |
5 | 0. Preserve ALL information, including style words and technical terms.
6 |
7 | 1. If the input is in Chinese, translate the entire description to English.
8 |
9 | 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10 |
11 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12 |
13 | 4. Output ALL must be in English.
14 |
15 | Given Input:
16 | input: "{input}"
17 | """
18 |
19 |
20 | master_mode_prompt = """Master mode - Video Recaption Task:
21 |
22 | You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
23 |
24 | 0. Preserve ALL information, including style words and technical terms.
25 |
26 | 1. If the input is in Chinese, translate the entire description to English.
27 |
28 | 2. To generate high-quality visual scenes with aesthetic appeal, it is necessary to carefully depict each visual element to create a unique aesthetic.
29 |
30 | 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
31 |
32 | 4. Output ALL must be in English.
33 |
34 | Given Input:
35 | input: "{input}"
36 | """
37 |
38 | def get_rewrite_prompt(ori_prompt, mode="Normal"):
39 | if mode == "Normal":
40 | prompt = normal_mode_prompt.format(input=ori_prompt)
41 | elif mode == "Master":
42 | prompt = master_mode_prompt.format(input=ori_prompt)
43 | else:
44 | raise Exception("Only supports Normal and Normal", mode)
45 | return prompt
46 |
47 | ori_prompt = "一只小狗在草地上奔跑。"
48 | normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
49 | master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
50 |
51 | # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
52 |
--------------------------------------------------------------------------------
/hyvideo/text_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 | from copy import deepcopy
4 |
5 | import torch
6 | import torch.nn as nn
7 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
8 | from transformers.utils import ModelOutput
9 |
10 | from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
11 | from ..constants import PRECISION_TO_TYPE
12 |
13 |
14 | def use_default(value, default):
15 | return value if value is not None else default
16 |
17 |
18 | def load_text_encoder(
19 | text_encoder_type,
20 | text_encoder_precision=None,
21 | text_encoder_path=None,
22 | logger=None,
23 | device=None,
24 | ):
25 | if text_encoder_path is None:
26 | text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
27 | if logger is not None:
28 | logger.info(
29 | f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}"
30 | )
31 |
32 | if text_encoder_type == "clipL":
33 | text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
34 | text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
35 | elif text_encoder_type == "llm":
36 | text_encoder = AutoModel.from_pretrained(
37 | text_encoder_path, low_cpu_mem_usage=True
38 | )
39 | text_encoder.final_layer_norm = text_encoder.norm
40 | else:
41 | raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
42 | # from_pretrained will ensure that the model is in eval mode.
43 |
44 | if text_encoder_precision is not None:
45 | text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
46 |
47 | text_encoder.requires_grad_(False)
48 |
49 | if logger is not None:
50 | logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
51 |
52 | if device is not None:
53 | text_encoder = text_encoder.to(device)
54 |
55 | return text_encoder, text_encoder_path
56 |
57 |
58 | def load_tokenizer(
59 | tokenizer_type, tokenizer_path=None, padding_side="right", logger=None
60 | ):
61 | if tokenizer_path is None:
62 | tokenizer_path = TOKENIZER_PATH[tokenizer_type]
63 | if logger is not None:
64 | logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
65 |
66 | if tokenizer_type == "clipL":
67 | tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
68 | elif tokenizer_type == "llm":
69 | tokenizer = AutoTokenizer.from_pretrained(
70 | tokenizer_path, padding_side=padding_side
71 | )
72 | else:
73 | raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
74 |
75 | return tokenizer, tokenizer_path
76 |
77 |
78 | @dataclass
79 | class TextEncoderModelOutput(ModelOutput):
80 | """
81 | Base class for model's outputs that also contains a pooling of the last hidden states.
82 |
83 | Args:
84 | hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
85 | Sequence of hidden-states at the output of the last layer of the model.
86 | attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
87 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
88 | hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
89 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
90 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
91 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92 | text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
93 | List of decoded texts.
94 | """
95 |
96 | hidden_state: torch.FloatTensor = None
97 | attention_mask: Optional[torch.LongTensor] = None
98 | hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
99 | text_outputs: Optional[list] = None
100 |
101 |
102 | class TextEncoder(nn.Module):
103 | def __init__(
104 | self,
105 | text_encoder_type: str,
106 | max_length: int,
107 | text_encoder_precision: Optional[str] = None,
108 | text_encoder_path: Optional[str] = None,
109 | tokenizer_type: Optional[str] = None,
110 | tokenizer_path: Optional[str] = None,
111 | output_key: Optional[str] = None,
112 | use_attention_mask: bool = True,
113 | input_max_length: Optional[int] = None,
114 | prompt_template: Optional[dict] = None,
115 | prompt_template_video: Optional[dict] = None,
116 | hidden_state_skip_layer: Optional[int] = None,
117 | apply_final_norm: bool = False,
118 | reproduce: bool = False,
119 | logger=None,
120 | device=None,
121 | ):
122 | super().__init__()
123 | self.text_encoder_type = text_encoder_type
124 | self.max_length = max_length
125 | self.precision = text_encoder_precision
126 | self.model_path = text_encoder_path
127 | self.tokenizer_type = (
128 | tokenizer_type if tokenizer_type is not None else text_encoder_type
129 | )
130 | self.tokenizer_path = (
131 | tokenizer_path if tokenizer_path is not None else text_encoder_path
132 | )
133 | self.use_attention_mask = use_attention_mask
134 | if prompt_template_video is not None:
135 | assert (
136 | use_attention_mask is True
137 | ), "Attention mask is True required when training videos."
138 | self.input_max_length = (
139 | input_max_length if input_max_length is not None else max_length
140 | )
141 | self.prompt_template = prompt_template
142 | self.prompt_template_video = prompt_template_video
143 | self.hidden_state_skip_layer = hidden_state_skip_layer
144 | self.apply_final_norm = apply_final_norm
145 | self.reproduce = reproduce
146 | self.logger = logger
147 |
148 | self.use_template = self.prompt_template is not None
149 | if self.use_template:
150 | assert (
151 | isinstance(self.prompt_template, dict)
152 | and "template" in self.prompt_template
153 | ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
154 | assert "{}" in str(self.prompt_template["template"]), (
155 | "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
156 | f"got {self.prompt_template['template']}"
157 | )
158 |
159 | self.use_video_template = self.prompt_template_video is not None
160 | if self.use_video_template:
161 | if self.prompt_template_video is not None:
162 | assert (
163 | isinstance(self.prompt_template_video, dict)
164 | and "template" in self.prompt_template_video
165 | ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
166 | assert "{}" in str(self.prompt_template_video["template"]), (
167 | "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
168 | f"got {self.prompt_template_video['template']}"
169 | )
170 |
171 | if "t5" in text_encoder_type:
172 | self.output_key = output_key or "last_hidden_state"
173 | elif "clip" in text_encoder_type:
174 | self.output_key = output_key or "pooler_output"
175 | elif "llm" in text_encoder_type or "glm" in text_encoder_type:
176 | self.output_key = output_key or "last_hidden_state"
177 | else:
178 | raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
179 |
180 | self.model, self.model_path = load_text_encoder(
181 | text_encoder_type=self.text_encoder_type,
182 | text_encoder_precision=self.precision,
183 | text_encoder_path=self.model_path,
184 | logger=self.logger,
185 | device=device,
186 | )
187 | self.dtype = self.model.dtype
188 | self.device = self.model.device
189 |
190 | self.tokenizer, self.tokenizer_path = load_tokenizer(
191 | tokenizer_type=self.tokenizer_type,
192 | tokenizer_path=self.tokenizer_path,
193 | padding_side="right",
194 | logger=self.logger,
195 | )
196 |
197 | def __repr__(self):
198 | return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
199 |
200 | @staticmethod
201 | def apply_text_to_template(text, template, prevent_empty_text=True):
202 | """
203 | Apply text to template.
204 |
205 | Args:
206 | text (str): Input text.
207 | template (str or list): Template string or list of chat conversation.
208 | prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
209 | by adding a space. Defaults to True.
210 | """
211 | if isinstance(template, str):
212 | # Will send string to tokenizer. Used for llm
213 | return template.format(text)
214 | else:
215 | raise TypeError(f"Unsupported template type: {type(template)}")
216 |
217 | def text2tokens(self, text, data_type="image"):
218 | """
219 | Tokenize the input text.
220 |
221 | Args:
222 | text (str or list): Input text.
223 | """
224 | tokenize_input_type = "str"
225 | if self.use_template:
226 | if data_type == "image":
227 | prompt_template = self.prompt_template["template"]
228 | elif data_type == "video":
229 | prompt_template = self.prompt_template_video["template"]
230 | else:
231 | raise ValueError(f"Unsupported data type: {data_type}")
232 | if isinstance(text, (list, tuple)):
233 | text = [
234 | self.apply_text_to_template(one_text, prompt_template)
235 | for one_text in text
236 | ]
237 | if isinstance(text[0], list):
238 | tokenize_input_type = "list"
239 | elif isinstance(text, str):
240 | text = self.apply_text_to_template(text, prompt_template)
241 | if isinstance(text, list):
242 | tokenize_input_type = "list"
243 | else:
244 | raise TypeError(f"Unsupported text type: {type(text)}")
245 |
246 | kwargs = dict(
247 | truncation=True,
248 | max_length=self.max_length,
249 | padding="max_length",
250 | return_tensors="pt",
251 | )
252 | if tokenize_input_type == "str":
253 | return self.tokenizer(
254 | text,
255 | return_length=False,
256 | return_overflowing_tokens=False,
257 | return_attention_mask=True,
258 | **kwargs,
259 | )
260 | elif tokenize_input_type == "list":
261 | return self.tokenizer.apply_chat_template(
262 | text,
263 | add_generation_prompt=True,
264 | tokenize=True,
265 | return_dict=True,
266 | **kwargs,
267 | )
268 | else:
269 | raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
270 |
271 | def encode(
272 | self,
273 | batch_encoding,
274 | use_attention_mask=None,
275 | output_hidden_states=False,
276 | do_sample=None,
277 | hidden_state_skip_layer=None,
278 | return_texts=False,
279 | data_type="image",
280 | device=None,
281 | ):
282 | """
283 | Args:
284 | batch_encoding (dict): Batch encoding from tokenizer.
285 | use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
286 | Defaults to None.
287 | output_hidden_states (bool): Whether to output hidden states. If False, return the value of
288 | self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
289 | output_hidden_states will be set True. Defaults to False.
290 | do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
291 | When self.produce is False, do_sample is set to True by default.
292 | hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
293 | If None, self.output_key will be used. Defaults to None.
294 | return_texts (bool): Whether to return the decoded texts. Defaults to False.
295 | """
296 | device = self.model.device if device is None else device
297 | use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
298 | hidden_state_skip_layer = use_default(
299 | hidden_state_skip_layer, self.hidden_state_skip_layer
300 | )
301 | do_sample = use_default(do_sample, not self.reproduce)
302 | attention_mask = (
303 | batch_encoding["attention_mask"].to(device) if use_attention_mask else None
304 | )
305 | outputs = self.model(
306 | input_ids=batch_encoding["input_ids"].to(device),
307 | attention_mask=attention_mask,
308 | output_hidden_states=output_hidden_states
309 | or hidden_state_skip_layer is not None,
310 | )
311 | if hidden_state_skip_layer is not None:
312 | last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
313 | # Real last hidden state already has layer norm applied. So here we only apply it
314 | # for intermediate layers.
315 | if hidden_state_skip_layer > 0 and self.apply_final_norm:
316 | last_hidden_state = self.model.final_layer_norm(last_hidden_state)
317 | else:
318 | last_hidden_state = outputs[self.output_key]
319 |
320 | # Remove hidden states of instruction tokens, only keep prompt tokens.
321 | if self.use_template:
322 | if data_type == "image":
323 | crop_start = self.prompt_template.get("crop_start", -1)
324 | elif data_type == "video":
325 | crop_start = self.prompt_template_video.get("crop_start", -1)
326 | else:
327 | raise ValueError(f"Unsupported data type: {data_type}")
328 | if crop_start > 0:
329 | last_hidden_state = last_hidden_state[:, crop_start:]
330 | attention_mask = (
331 | attention_mask[:, crop_start:] if use_attention_mask else None
332 | )
333 |
334 | if output_hidden_states:
335 | return TextEncoderModelOutput(
336 | last_hidden_state, attention_mask, outputs.hidden_states
337 | )
338 | return TextEncoderModelOutput(last_hidden_state, attention_mask)
339 |
340 | def forward(
341 | self,
342 | text,
343 | use_attention_mask=None,
344 | output_hidden_states=False,
345 | do_sample=False,
346 | hidden_state_skip_layer=None,
347 | return_texts=False,
348 | ):
349 | batch_encoding = self.text2tokens(text)
350 | return self.encode(
351 | batch_encoding,
352 | use_attention_mask=use_attention_mask,
353 | output_hidden_states=output_hidden_states,
354 | do_sample=do_sample,
355 | hidden_state_skip_layer=hidden_state_skip_layer,
356 | return_texts=return_texts,
357 | )
358 |
--------------------------------------------------------------------------------
/hyvideo/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tencent-Hunyuan/HunyuanVideo/757d06f2696fd7a457b888c092f5d57d086f51eb/hyvideo/utils/__init__.py
--------------------------------------------------------------------------------
/hyvideo/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 |
5 | def align_to(value, alignment):
6 | """align hight, width according to alignment
7 |
8 | Args:
9 | value (int): height or width
10 | alignment (int): target alignment factor
11 |
12 | Returns:
13 | int: the aligned value
14 | """
15 | return int(math.ceil(value / alignment) * alignment)
16 |
--------------------------------------------------------------------------------
/hyvideo/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from einops import rearrange
4 |
5 | import torch
6 | import torchvision
7 | import numpy as np
8 | import imageio
9 |
10 | CODE_SUFFIXES = {
11 | ".py", # Python codes
12 | ".sh", # Shell scripts
13 | ".yaml",
14 | ".yml", # Configuration files
15 | }
16 |
17 |
18 | def safe_dir(path):
19 | """
20 | Create a directory (or the parent directory of a file) if it does not exist.
21 |
22 | Args:
23 | path (str or Path): Path to the directory.
24 |
25 | Returns:
26 | path (Path): Path object of the directory.
27 | """
28 | path = Path(path)
29 | path.mkdir(exist_ok=True, parents=True)
30 | return path
31 |
32 |
33 | def safe_file(path):
34 | """
35 | Create the parent directory of a file if it does not exist.
36 |
37 | Args:
38 | path (str or Path): Path to the file.
39 |
40 | Returns:
41 | path (Path): Path object of the file.
42 | """
43 | path = Path(path)
44 | path.parent.mkdir(exist_ok=True, parents=True)
45 | return path
46 |
47 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
48 | """save videos by video tensor
49 | copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
50 |
51 | Args:
52 | videos (torch.Tensor): video tensor predicted by the model
53 | path (str): path to save video
54 | rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
55 | n_rows (int, optional): Defaults to 1.
56 | fps (int, optional): video save fps. Defaults to 8.
57 | """
58 | videos = rearrange(videos, "b c t h w -> t b c h w")
59 | outputs = []
60 | for x in videos:
61 | x = torchvision.utils.make_grid(x, nrow=n_rows)
62 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
63 | if rescale:
64 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
65 | x = torch.clamp(x, 0, 1)
66 | x = (x * 255).numpy().astype(np.uint8)
67 | outputs.append(x)
68 |
69 | os.makedirs(os.path.dirname(path), exist_ok=True)
70 | imageio.mimsave(path, outputs, fps=fps)
71 |
--------------------------------------------------------------------------------
/hyvideo/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 |
3 | from itertools import repeat
4 |
5 |
6 | def _ntuple(n):
7 | def parse(x):
8 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9 | x = tuple(x)
10 | if len(x) == 1:
11 | x = tuple(repeat(x[0], n))
12 | return x
13 | return tuple(repeat(x, n))
14 | return parse
15 |
16 |
17 | to_1tuple = _ntuple(1)
18 | to_2tuple = _ntuple(2)
19 | to_3tuple = _ntuple(3)
20 | to_4tuple = _ntuple(4)
21 |
22 |
23 | def as_tuple(x):
24 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25 | return tuple(x)
26 | if x is None or isinstance(x, (int, float, str)):
27 | return (x,)
28 | else:
29 | raise ValueError(f"Unknown type {type(x)}")
30 |
31 |
32 | def as_list_of_2tuple(x):
33 | x = as_tuple(x)
34 | if len(x) == 1:
35 | x = (x[0], x[0])
36 | assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37 | lst = []
38 | for i in range(0, len(x), 2):
39 | lst.append((x[i], x[i + 1]))
40 | return lst
41 |
--------------------------------------------------------------------------------
/hyvideo/utils/preprocess_text_encoder_tokenizer_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from transformers import (
4 | AutoProcessor,
5 | LlavaForConditionalGeneration,
6 | )
7 |
8 |
9 | def preprocess_text_encoder_tokenizer(args):
10 |
11 | processor = AutoProcessor.from_pretrained(args.input_dir)
12 | model = LlavaForConditionalGeneration.from_pretrained(
13 | args.input_dir,
14 | torch_dtype=torch.float16,
15 | low_cpu_mem_usage=True,
16 | ).to(0)
17 |
18 | model.language_model.save_pretrained(
19 | f"{args.output_dir}"
20 | )
21 | processor.tokenizer.save_pretrained(
22 | f"{args.output_dir}"
23 | )
24 |
25 | if __name__ == "__main__":
26 |
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument(
29 | "--input_dir",
30 | type=str,
31 | required=True,
32 | help="The path to the llava-llama-3-8b-v1_1-transformers.",
33 | )
34 | parser.add_argument(
35 | "--output_dir",
36 | type=str,
37 | default="",
38 | help="The output path of the llava-llama-3-8b-text-encoder-tokenizer."
39 | "if '', the parent dir of output will be the same as input dir.",
40 | )
41 | args = parser.parse_args()
42 |
43 | if len(args.output_dir) == 0:
44 | args.output_dir = "/".join(args.input_dir.split("/")[:-1])
45 |
46 | preprocess_text_encoder_tokenizer(args)
47 |
--------------------------------------------------------------------------------
/hyvideo/vae/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 |
5 | from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
6 | from ..constants import VAE_PATH, PRECISION_TO_TYPE
7 |
8 | def load_vae(vae_type: str="884-16c-hy",
9 | vae_precision: str=None,
10 | sample_size: tuple=None,
11 | vae_path: str=None,
12 | logger=None,
13 | device=None
14 | ):
15 | """the fucntion to load the 3D VAE model
16 |
17 | Args:
18 | vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
19 | vae_precision (str, optional): the precision to load vae. Defaults to None.
20 | sample_size (tuple, optional): the tiling size. Defaults to None.
21 | vae_path (str, optional): the path to vae. Defaults to None.
22 | logger (_type_, optional): logger. Defaults to None.
23 | device (_type_, optional): device to load vae. Defaults to None.
24 | """
25 | if vae_path is None:
26 | vae_path = VAE_PATH[vae_type]
27 |
28 | if logger is not None:
29 | logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
30 | config = AutoencoderKLCausal3D.load_config(vae_path)
31 | if sample_size:
32 | vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
33 | else:
34 | vae = AutoencoderKLCausal3D.from_config(config)
35 |
36 | vae_ckpt = Path(vae_path) / "pytorch_model.pt"
37 | assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
38 |
39 | ckpt = torch.load(vae_ckpt, map_location=vae.device)
40 | if "state_dict" in ckpt:
41 | ckpt = ckpt["state_dict"]
42 | if any(k.startswith("vae.") for k in ckpt.keys()):
43 | ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
44 | vae.load_state_dict(ckpt)
45 |
46 | spatial_compression_ratio = vae.config.spatial_compression_ratio
47 | time_compression_ratio = vae.config.time_compression_ratio
48 |
49 | if vae_precision is not None:
50 | vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
51 |
52 | vae.requires_grad_(False)
53 |
54 | if logger is not None:
55 | logger.info(f"VAE to dtype: {vae.dtype}")
56 |
57 | if device is not None:
58 | vae = vae.to(device)
59 |
60 | vae.eval()
61 |
62 | return vae, vae_path, spatial_compression_ratio, time_compression_ratio
63 |
--------------------------------------------------------------------------------
/hyvideo/vae/vae.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 |
8 | from diffusers.utils import BaseOutput, is_torch_version
9 | from diffusers.utils.torch_utils import randn_tensor
10 | from diffusers.models.attention_processor import SpatialNorm
11 | from .unet_causal_3d_blocks import (
12 | CausalConv3d,
13 | UNetMidBlockCausal3D,
14 | get_down_block3d,
15 | get_up_block3d,
16 | )
17 |
18 |
19 | @dataclass
20 | class DecoderOutput(BaseOutput):
21 | r"""
22 | Output of decoding method.
23 |
24 | Args:
25 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26 | The decoded output sample from the last layer of the model.
27 | """
28 |
29 | sample: torch.FloatTensor
30 |
31 |
32 | class EncoderCausal3D(nn.Module):
33 | r"""
34 | The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
35 | """
36 |
37 | def __init__(
38 | self,
39 | in_channels: int = 3,
40 | out_channels: int = 3,
41 | down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
42 | block_out_channels: Tuple[int, ...] = (64,),
43 | layers_per_block: int = 2,
44 | norm_num_groups: int = 32,
45 | act_fn: str = "silu",
46 | double_z: bool = True,
47 | mid_block_add_attention=True,
48 | time_compression_ratio: int = 4,
49 | spatial_compression_ratio: int = 8,
50 | ):
51 | super().__init__()
52 | self.layers_per_block = layers_per_block
53 |
54 | self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
55 | self.mid_block = None
56 | self.down_blocks = nn.ModuleList([])
57 |
58 | # down
59 | output_channel = block_out_channels[0]
60 | for i, down_block_type in enumerate(down_block_types):
61 | input_channel = output_channel
62 | output_channel = block_out_channels[i]
63 | is_final_block = i == len(block_out_channels) - 1
64 | num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
65 | num_time_downsample_layers = int(np.log2(time_compression_ratio))
66 |
67 | if time_compression_ratio == 4:
68 | add_spatial_downsample = bool(i < num_spatial_downsample_layers)
69 | add_time_downsample = bool(
70 | i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
71 | and not is_final_block
72 | )
73 | else:
74 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
75 |
76 | downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
77 | downsample_stride_T = (2,) if add_time_downsample else (1,)
78 | downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
79 | down_block = get_down_block3d(
80 | down_block_type,
81 | num_layers=self.layers_per_block,
82 | in_channels=input_channel,
83 | out_channels=output_channel,
84 | add_downsample=bool(add_spatial_downsample or add_time_downsample),
85 | downsample_stride=downsample_stride,
86 | resnet_eps=1e-6,
87 | downsample_padding=0,
88 | resnet_act_fn=act_fn,
89 | resnet_groups=norm_num_groups,
90 | attention_head_dim=output_channel,
91 | temb_channels=None,
92 | )
93 | self.down_blocks.append(down_block)
94 |
95 | # mid
96 | self.mid_block = UNetMidBlockCausal3D(
97 | in_channels=block_out_channels[-1],
98 | resnet_eps=1e-6,
99 | resnet_act_fn=act_fn,
100 | output_scale_factor=1,
101 | resnet_time_scale_shift="default",
102 | attention_head_dim=block_out_channels[-1],
103 | resnet_groups=norm_num_groups,
104 | temb_channels=None,
105 | add_attention=mid_block_add_attention,
106 | )
107 |
108 | # out
109 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
110 | self.conv_act = nn.SiLU()
111 |
112 | conv_out_channels = 2 * out_channels if double_z else out_channels
113 | self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
114 |
115 | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
116 | r"""The forward method of the `EncoderCausal3D` class."""
117 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
118 |
119 | sample = self.conv_in(sample)
120 |
121 | # down
122 | for down_block in self.down_blocks:
123 | sample = down_block(sample)
124 |
125 | # middle
126 | sample = self.mid_block(sample)
127 |
128 | # post-process
129 | sample = self.conv_norm_out(sample)
130 | sample = self.conv_act(sample)
131 | sample = self.conv_out(sample)
132 |
133 | return sample
134 |
135 |
136 | class DecoderCausal3D(nn.Module):
137 | r"""
138 | The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
139 | """
140 |
141 | def __init__(
142 | self,
143 | in_channels: int = 3,
144 | out_channels: int = 3,
145 | up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
146 | block_out_channels: Tuple[int, ...] = (64,),
147 | layers_per_block: int = 2,
148 | norm_num_groups: int = 32,
149 | act_fn: str = "silu",
150 | norm_type: str = "group", # group, spatial
151 | mid_block_add_attention=True,
152 | time_compression_ratio: int = 4,
153 | spatial_compression_ratio: int = 8,
154 | ):
155 | super().__init__()
156 | self.layers_per_block = layers_per_block
157 |
158 | self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
159 | self.mid_block = None
160 | self.up_blocks = nn.ModuleList([])
161 |
162 | temb_channels = in_channels if norm_type == "spatial" else None
163 |
164 | # mid
165 | self.mid_block = UNetMidBlockCausal3D(
166 | in_channels=block_out_channels[-1],
167 | resnet_eps=1e-6,
168 | resnet_act_fn=act_fn,
169 | output_scale_factor=1,
170 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
171 | attention_head_dim=block_out_channels[-1],
172 | resnet_groups=norm_num_groups,
173 | temb_channels=temb_channels,
174 | add_attention=mid_block_add_attention,
175 | )
176 |
177 | # up
178 | reversed_block_out_channels = list(reversed(block_out_channels))
179 | output_channel = reversed_block_out_channels[0]
180 | for i, up_block_type in enumerate(up_block_types):
181 | prev_output_channel = output_channel
182 | output_channel = reversed_block_out_channels[i]
183 | is_final_block = i == len(block_out_channels) - 1
184 | num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
185 | num_time_upsample_layers = int(np.log2(time_compression_ratio))
186 |
187 | if time_compression_ratio == 4:
188 | add_spatial_upsample = bool(i < num_spatial_upsample_layers)
189 | add_time_upsample = bool(
190 | i >= len(block_out_channels) - 1 - num_time_upsample_layers
191 | and not is_final_block
192 | )
193 | else:
194 | raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
195 |
196 | upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
197 | upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
198 | upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
199 | up_block = get_up_block3d(
200 | up_block_type,
201 | num_layers=self.layers_per_block + 1,
202 | in_channels=prev_output_channel,
203 | out_channels=output_channel,
204 | prev_output_channel=None,
205 | add_upsample=bool(add_spatial_upsample or add_time_upsample),
206 | upsample_scale_factor=upsample_scale_factor,
207 | resnet_eps=1e-6,
208 | resnet_act_fn=act_fn,
209 | resnet_groups=norm_num_groups,
210 | attention_head_dim=output_channel,
211 | temb_channels=temb_channels,
212 | resnet_time_scale_shift=norm_type,
213 | )
214 | self.up_blocks.append(up_block)
215 | prev_output_channel = output_channel
216 |
217 | # out
218 | if norm_type == "spatial":
219 | self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
220 | else:
221 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
222 | self.conv_act = nn.SiLU()
223 | self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
224 |
225 | self.gradient_checkpointing = False
226 |
227 | def forward(
228 | self,
229 | sample: torch.FloatTensor,
230 | latent_embeds: Optional[torch.FloatTensor] = None,
231 | ) -> torch.FloatTensor:
232 | r"""The forward method of the `DecoderCausal3D` class."""
233 | assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
234 |
235 | sample = self.conv_in(sample)
236 |
237 | upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
238 | if self.training and self.gradient_checkpointing:
239 |
240 | def create_custom_forward(module):
241 | def custom_forward(*inputs):
242 | return module(*inputs)
243 |
244 | return custom_forward
245 |
246 | if is_torch_version(">=", "1.11.0"):
247 | # middle
248 | sample = torch.utils.checkpoint.checkpoint(
249 | create_custom_forward(self.mid_block),
250 | sample,
251 | latent_embeds,
252 | use_reentrant=False,
253 | )
254 | sample = sample.to(upscale_dtype)
255 |
256 | # up
257 | for up_block in self.up_blocks:
258 | sample = torch.utils.checkpoint.checkpoint(
259 | create_custom_forward(up_block),
260 | sample,
261 | latent_embeds,
262 | use_reentrant=False,
263 | )
264 | else:
265 | # middle
266 | sample = torch.utils.checkpoint.checkpoint(
267 | create_custom_forward(self.mid_block), sample, latent_embeds
268 | )
269 | sample = sample.to(upscale_dtype)
270 |
271 | # up
272 | for up_block in self.up_blocks:
273 | sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
274 | else:
275 | # middle
276 | sample = self.mid_block(sample, latent_embeds)
277 | sample = sample.to(upscale_dtype)
278 |
279 | # up
280 | for up_block in self.up_blocks:
281 | sample = up_block(sample, latent_embeds)
282 |
283 | # post-process
284 | if latent_embeds is None:
285 | sample = self.conv_norm_out(sample)
286 | else:
287 | sample = self.conv_norm_out(sample, latent_embeds)
288 | sample = self.conv_act(sample)
289 | sample = self.conv_out(sample)
290 |
291 | return sample
292 |
293 |
294 | class DiagonalGaussianDistribution(object):
295 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
296 | if parameters.ndim == 3:
297 | dim = 2 # (B, L, C)
298 | elif parameters.ndim == 5 or parameters.ndim == 4:
299 | dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
300 | else:
301 | raise NotImplementedError
302 | self.parameters = parameters
303 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
304 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
305 | self.deterministic = deterministic
306 | self.std = torch.exp(0.5 * self.logvar)
307 | self.var = torch.exp(self.logvar)
308 | if self.deterministic:
309 | self.var = self.std = torch.zeros_like(
310 | self.mean, device=self.parameters.device, dtype=self.parameters.dtype
311 | )
312 |
313 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
314 | # make sure sample is on the same device as the parameters and has same dtype
315 | sample = randn_tensor(
316 | self.mean.shape,
317 | generator=generator,
318 | device=self.parameters.device,
319 | dtype=self.parameters.dtype,
320 | )
321 | x = self.mean + self.std * sample
322 | return x
323 |
324 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
325 | if self.deterministic:
326 | return torch.Tensor([0.0])
327 | else:
328 | reduce_dim = list(range(1, self.mean.ndim))
329 | if other is None:
330 | return 0.5 * torch.sum(
331 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
332 | dim=reduce_dim,
333 | )
334 | else:
335 | return 0.5 * torch.sum(
336 | torch.pow(self.mean - other.mean, 2) / other.var
337 | + self.var / other.var
338 | - 1.0
339 | - self.logvar
340 | + other.logvar,
341 | dim=reduce_dim,
342 | )
343 |
344 | def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
345 | if self.deterministic:
346 | return torch.Tensor([0.0])
347 | logtwopi = np.log(2.0 * np.pi)
348 | return 0.5 * torch.sum(
349 | logtwopi + self.logvar +
350 | torch.pow(sample - self.mean, 2) / self.var,
351 | dim=dims,
352 | )
353 |
354 | def mode(self) -> torch.Tensor:
355 | return self.mean
356 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.9.0.80
2 | diffusers==0.31.0
3 | transformers==4.46.3
4 | tokenizers==0.20.3
5 | accelerate==1.1.1
6 | pandas==2.0.3
7 | numpy==1.24.4
8 | einops==0.7.0
9 | tqdm==4.66.2
10 | loguru==0.7.2
11 | imageio==2.34.0
12 | imageio-ffmpeg==0.5.1
13 | safetensors==0.4.3
14 | gradio==5.0.0
15 |
--------------------------------------------------------------------------------
/sample_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from pathlib import Path
4 | from loguru import logger
5 | from datetime import datetime
6 |
7 | from hyvideo.utils.file_utils import save_videos_grid
8 | from hyvideo.config import parse_args
9 | from hyvideo.inference import HunyuanVideoSampler
10 |
11 |
12 | def main():
13 | args = parse_args()
14 | print(args)
15 | models_root_path = Path(args.model_base)
16 | if not models_root_path.exists():
17 | raise ValueError(f"`models_root` not exists: {models_root_path}")
18 |
19 | # Create save folder to save the samples
20 | save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
21 | if not os.path.exists(save_path):
22 | os.makedirs(save_path, exist_ok=True)
23 |
24 | # Load models
25 | hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
26 |
27 | # Get the updated args
28 | args = hunyuan_video_sampler.args
29 |
30 | # Start sampling
31 | # TODO: batch inference check
32 | outputs = hunyuan_video_sampler.predict(
33 | prompt=args.prompt,
34 | height=args.video_size[0],
35 | width=args.video_size[1],
36 | video_length=args.video_length,
37 | seed=args.seed,
38 | negative_prompt=args.neg_prompt,
39 | infer_steps=args.infer_steps,
40 | guidance_scale=args.cfg_scale,
41 | num_videos_per_prompt=args.num_videos,
42 | flow_shift=args.flow_shift,
43 | batch_size=args.batch_size,
44 | embedded_guidance_scale=args.embedded_cfg_scale
45 | )
46 | samples = outputs['samples']
47 |
48 | # Save samples
49 | if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
50 | for i, sample in enumerate(samples):
51 | sample = samples[i].unsqueeze(0)
52 | time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53 | cur_save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54 | save_videos_grid(sample, cur_save_path, fps=24)
55 | logger.info(f'Sample save to: {cur_save_path}')
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/scripts/run_sample_video.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Description: This script demonstrates how to inference a video based on HunyuanVideo model
3 |
4 | python3 sample_video.py \
5 | --video-size 720 1280 \
6 | --video-length 129 \
7 | --infer-steps 50 \
8 | --prompt "A cat walks on the grass, realistic style." \
9 | --seed 42 \
10 | --embedded-cfg-scale 6.0 \
11 | --flow-shift 7.0 \
12 | --flow-reverse \
13 | --use-cpu-offload \
14 | --save-path ./results
15 |
--------------------------------------------------------------------------------
/scripts/run_sample_video_fp8.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Description: This script demonstrates how to inference a video based on HunyuanVideo model
3 | DIT_CKPT_PATH={PATH_TO}/{MODEL_NAME}_model_states_fp8.pt
4 |
5 | python3 sample_video.py \
6 | --dit-weight ${DIT_CKPT_PATH} \
7 | --video-size 720 1280 \
8 | --video-length 129 \
9 | --infer-steps 50 \
10 | --prompt "A cat walks on the grass, realistic style." \
11 | --seed 42 \
12 | --embedded-cfg-scale 6.0 \
13 | --flow-shift 7.0 \
14 | --flow-reverse \
15 | --use-cpu-offload \
16 | --use-fp8 \
17 | --save-path ./results
18 |
--------------------------------------------------------------------------------
/scripts/run_sample_video_multigpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Description: This script demonstrates how to inference a video based on HunyuanVideo model
3 |
4 | # Supported Parallel Configurations
5 | # | --video-size | --video-length | --ulysses-degree x --ring-degree | --nproc_per_node |
6 | # |----------------------|----------------|----------------------------------|------------------|
7 | # | 1280 720 or 720 1280 | 129 | 8x1,4x2,2x4,1x8 | 8 |
8 | # | 1280 720 or 720 1280 | 129 | 1x5 | 5 |
9 | # | 1280 720 or 720 1280 | 129 | 4x1,2x2,1x4 | 4 |
10 | # | 1280 720 or 720 1280 | 129 | 3x1,1x3 | 3 |
11 | # | 1280 720 or 720 1280 | 129 | 2x1,1x2 | 2 |
12 | # | 1104 832 or 832 1104 | 129 | 4x1,2x2,1x4 | 4 |
13 | # | 1104 832 or 832 1104 | 129 | 3x1,1x3 | 3 |
14 | # | 1104 832 or 832 1104 | 129 | 2x1,1x2 | 2 |
15 | # | 960 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
16 | # | 960 960 | 129 | 4x1,2x2,1x4 | 4 |
17 | # | 960 960 | 129 | 3x1,1x3 | 3 |
18 | # | 960 960 | 129 | 1x2,2x1 | 2 |
19 | # | 960 544 or 544 960 | 129 | 6x1,3x2,2x3,1x6 | 6 |
20 | # | 960 544 or 544 960 | 129 | 4x1,2x2,1x4 | 4 |
21 | # | 960 544 or 544 960 | 129 | 3x1,1x3 | 3 |
22 | # | 960 544 or 544 960 | 129 | 1x2,2x1 | 2 |
23 | # | 832 624 or 624 832 | 129 | 4x1,2x2,1x4 | 4 |
24 | # | 624 832 or 624 832 | 129 | 3x1,1x3 | 3 |
25 | # | 832 624 or 624 832 | 129 | 2x1,1x2 | 2 |
26 | # | 720 720 | 129 | 1x5 | 5 |
27 | # | 720 720 | 129 | 3x1,1x3 | 3 |
28 |
29 | export TOKENIZERS_PARALLELISM=false
30 |
31 | export NPROC_PER_NODE=8
32 | export ULYSSES_DEGREE=8
33 | export RING_DEGREE=1
34 |
35 | torchrun --nproc_per_node=$NPROC_PER_NODE sample_video.py \
36 | --video-size 720 1280 \
37 | --video-length 129 \
38 | --infer-steps 50 \
39 | --prompt "A cat walks on the grass, realistic style." \
40 | --seed 42 \
41 | --embedded-cfg-scale 6.0 \
42 | --flow-shift 7.0 \
43 | --flow-reverse \
44 | --ulysses-degree=$ULYSSES_DEGREE \
45 | --ring-degree=$RING_DEGREE \
46 | --save-path ./results
47 |
--------------------------------------------------------------------------------
/tests/test_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | current_dir = os.path.dirname(os.path.abspath(__file__))
5 | project_root = os.path.dirname(current_dir)
6 | sys.path.append(project_root)
7 |
8 | from hyvideo.modules.attenion import attention
9 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention
10 | from xfuser.core.distributed import (
11 | init_distributed_environment,
12 | initialize_model_parallel,
13 | # initialize_runtime_state,
14 | )
15 |
16 | def init_dist(backend="nccl"):
17 | local_rank = int(os.environ["LOCAL_RANK"])
18 | rank = int(os.environ["RANK"])
19 | world_size = int(os.environ["WORLD_SIZE"])
20 |
21 | print(
22 | f"Initializing distributed environment with rank {rank}, world size {world_size}, local rank {local_rank}"
23 | )
24 |
25 | torch.cuda.set_device(local_rank)
26 | init_distributed_environment(rank=rank, world_size=world_size)
27 | # dist.init_process_group(backend=backend)
28 | # construct a hybrid sequence parallel config (ulysses=2, ring = world_size // 2)
29 |
30 | if world_size > 1:
31 | ring_degree = world_size // 2
32 | ulysses_degree = 2
33 | else:
34 | ring_degree = 1
35 | ulysses_degree = 1
36 | initialize_model_parallel(
37 | sequence_parallel_degree=world_size,
38 | ring_degree=ring_degree,
39 | ulysses_degree=ulysses_degree,
40 | )
41 |
42 | return rank, world_size
43 |
44 | def test_mm_double_stream_block_attention(rank, world_size):
45 | device = torch.device(f"cuda:{rank}")
46 | dtype = torch.bfloat16
47 | batch_size = 1
48 | seq_len_img = 118800
49 | seq_len_txt = 256
50 | heads_num = 24
51 | head_dim = 128
52 |
53 | img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
54 | img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
55 | img_v = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
56 | txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
57 | txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
58 | txt_v = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
59 |
60 | with torch.no_grad():
61 | torch.distributed.broadcast(img_q, src=0)
62 | torch.distributed.broadcast(img_k, src=0)
63 | torch.distributed.broadcast(img_v, src=0)
64 | torch.distributed.broadcast(txt_q, src=0)
65 | torch.distributed.broadcast(txt_k, src=0)
66 | torch.distributed.broadcast(txt_v, src=0)
67 | q = torch.cat((img_q, txt_q), dim=1)
68 | k = torch.cat((img_k, txt_k), dim=1)
69 | v = torch.cat((img_v, txt_v), dim=1)
70 |
71 |
72 | cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
73 | cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
74 | max_seqlen_q = 119056
75 | max_seqlen_kv = 119056
76 | mode = "torch" # "torch", "vanilla", "flash"
77 |
78 | original_output = attention(
79 | q,
80 | k,
81 | v,
82 | mode=mode,
83 | cu_seqlens_q=cu_seqlens_q,
84 | cu_seqlens_kv=cu_seqlens_kv,
85 | max_seqlen_q=max_seqlen_q,
86 | max_seqlen_kv=max_seqlen_kv,
87 | batch_size=batch_size
88 | )
89 |
90 | hybrid_seq_parallel_attn = xFuserLongContextAttention()
91 | hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
92 | None,
93 | img_q,
94 | img_k,
95 | img_v,
96 | dropout_p=0.0,
97 | causal=False,
98 | joint_tensor_query=txt_q,
99 | joint_tensor_key=txt_k,
100 | joint_tensor_value=txt_v,
101 | joint_strategy="rear",
102 | )
103 |
104 | b, s, a, d = hybrid_seq_parallel_output.shape
105 | hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
106 |
107 | assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
108 |
109 | torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
110 | print("test_mm_double_stream_block_attention Passed")
111 |
112 | def test_mm_single_stream_block_attention(rank, world_size):
113 | device = torch.device(f"cuda:{rank}")
114 | dtype = torch.bfloat16
115 | txt_len = 256
116 | batch_size = 1
117 | seq_len_img = 118800
118 | seq_len_txt = 256
119 | heads_num = 24
120 | head_dim = 128
121 |
122 | with torch.no_grad():
123 | img_q = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
124 | img_k = torch.randn(batch_size, seq_len_img, heads_num, head_dim, device=device, dtype=dtype)
125 | txt_q = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
126 | txt_k = torch.randn(batch_size, seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
127 | v = torch.randn(batch_size, seq_len_img + seq_len_txt, heads_num, head_dim, device=device, dtype=dtype)
128 |
129 | torch.distributed.broadcast(img_q, src=0)
130 | torch.distributed.broadcast(img_k, src=0)
131 | torch.distributed.broadcast(txt_q, src=0)
132 | torch.distributed.broadcast(txt_k, src=0)
133 | torch.distributed.broadcast(v, src=0)
134 |
135 | q = torch.cat((img_q, txt_q), dim=1)
136 | k = torch.cat((img_k, txt_k), dim=1)
137 |
138 | cu_seqlens_q = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
139 | cu_seqlens_kv = torch.tensor([0, 118811, 119056], device='cuda:0', dtype=torch.int32)
140 | max_seqlen_q = 119056
141 | max_seqlen_kv = 119056
142 | mode = "torch" # "torch", "vanilla", "flash"
143 |
144 | original_output = attention(
145 | q,
146 | k,
147 | v,
148 | mode=mode,
149 | cu_seqlens_q=cu_seqlens_q,
150 | cu_seqlens_kv=cu_seqlens_kv,
151 | max_seqlen_q=max_seqlen_q,
152 | max_seqlen_kv=max_seqlen_kv,
153 | batch_size=batch_size
154 | )
155 |
156 | hybrid_seq_parallel_attn = xFuserLongContextAttention()
157 | hybrid_seq_parallel_output = hybrid_seq_parallel_attn(
158 | None,
159 | q[:, :-txt_len, :, :],
160 | k[:, :-txt_len, :, :],
161 | v[:, :-txt_len, :, :],
162 | dropout_p=0.0,
163 | causal=False,
164 | joint_tensor_query=q[:, -txt_len:, :, :],
165 | joint_tensor_key=k[:, -txt_len:, :, :],
166 | joint_tensor_value=v[:, -txt_len:, :, :],
167 | joint_strategy="rear",
168 | )
169 | b, s, a, d = hybrid_seq_parallel_output.shape
170 | hybrid_seq_parallel_output = hybrid_seq_parallel_output.reshape(b, s, -1)
171 |
172 | assert original_output.shape == hybrid_seq_parallel_output.shape, f"Shape mismatch: {original_output.shape} vs {hybrid_seq_parallel_output.shape}"
173 |
174 | torch.testing.assert_close(original_output, hybrid_seq_parallel_output, rtol=1e-3, atol=1e-3)
175 | print("test_mm_single_stream_block_attention Passed")
176 |
177 | if __name__ == "__main__":
178 | rank, world_size = init_dist()
179 | test_mm_double_stream_block_attention(rank, world_size)
180 | test_mm_single_stream_block_attention(rank, world_size)
181 |
--------------------------------------------------------------------------------
/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """This file holding some environment constant for sharing by other files."""
3 | import os
4 | import os.path as osp
5 | import subprocess
6 | import sys
7 | from collections import OrderedDict, defaultdict
8 |
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def is_rocm_pytorch() -> bool:
14 | """Check whether the PyTorch is compiled on ROCm."""
15 | is_rocm = False
16 | if TORCH_VERSION != 'parrots':
17 | try:
18 | from torch.utils.cpp_extension import ROCM_HOME
19 | is_rocm = True if ((torch.version.hip is not None) and
20 | (ROCM_HOME is not None)) else False
21 | except ImportError:
22 | pass
23 | return is_rocm
24 |
25 | TORCH_VERSION = torch.__version__
26 |
27 | def get_build_config():
28 | """Obtain the build information of PyTorch or Parrots."""
29 | if TORCH_VERSION == 'parrots':
30 | from parrots.config import get_build_info
31 | return get_build_info()
32 | else:
33 | return torch.__config__.show()
34 |
35 | try:
36 | import torch_musa # noqa: F401
37 | IS_MUSA_AVAILABLE = True
38 | except Exception:
39 | IS_MUSA_AVAILABLE = False
40 |
41 | def is_musa_available() -> bool:
42 | return IS_MUSA_AVAILABLE
43 |
44 | def is_cuda_available() -> bool:
45 | """Returns True if cuda devices exist."""
46 | return torch.cuda.is_available()
47 |
48 | def _get_cuda_home():
49 | if TORCH_VERSION == 'parrots':
50 | from parrots.utils.build_extension import CUDA_HOME
51 | else:
52 | if is_rocm_pytorch():
53 | from torch.utils.cpp_extension import ROCM_HOME
54 | CUDA_HOME = ROCM_HOME
55 | else:
56 | from torch.utils.cpp_extension import CUDA_HOME
57 | return CUDA_HOME
58 |
59 |
60 | def _get_musa_home():
61 | return os.environ.get('MUSA_HOME')
62 |
63 |
64 | def collect_env():
65 | """Collect the information of the running environments.
66 |
67 | Returns:
68 | dict: The environment information. The following fields are contained.
69 |
70 | - sys.platform: The variable of ``sys.platform``.
71 | - Python: Python version.
72 | - CUDA available: Bool, indicating if CUDA is available.
73 | - GPU devices: Device type of each GPU.
74 | - CUDA_HOME (optional): The env var ``CUDA_HOME``.
75 | - NVCC (optional): NVCC version.
76 | - GCC: GCC version, "n/a" if GCC is not installed.
77 | - MSVC: Microsoft Virtual C++ Compiler version, Windows only.
78 | - PyTorch: PyTorch version.
79 | - PyTorch compiling details: The output of \
80 | ``torch.__config__.show()``.
81 | - TorchVision (optional): TorchVision version.
82 | - OpenCV (optional): OpenCV version.
83 | """
84 | from distutils import errors
85 |
86 | env_info = OrderedDict()
87 | env_info['sys.platform'] = sys.platform
88 | env_info['Python'] = sys.version.replace('\n', '')
89 |
90 | cuda_available = is_cuda_available()
91 | musa_available = is_musa_available()
92 | env_info['CUDA available'] = cuda_available
93 | env_info['MUSA available'] = musa_available
94 | env_info['numpy_random_seed'] = np.random.get_state()[1][0]
95 |
96 | if cuda_available:
97 | devices = defaultdict(list)
98 | for k in range(torch.cuda.device_count()):
99 | devices[torch.cuda.get_device_name(k)].append(str(k))
100 | for name, device_ids in devices.items():
101 | env_info['GPU ' + ','.join(device_ids)] = name
102 |
103 | CUDA_HOME = _get_cuda_home()
104 | env_info['CUDA_HOME'] = CUDA_HOME
105 |
106 | if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
107 | if CUDA_HOME == '/opt/rocm':
108 | try:
109 | nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc')
110 | nvcc = subprocess.check_output(
111 | f'"{nvcc}" --version', shell=True)
112 | nvcc = nvcc.decode('utf-8').strip()
113 | release = nvcc.rfind('HIP version:')
114 | build = nvcc.rfind('')
115 | nvcc = nvcc[release:build].strip()
116 | except subprocess.SubprocessError:
117 | nvcc = 'Not Available'
118 | else:
119 | try:
120 | nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
121 | nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
122 | nvcc = nvcc.decode('utf-8').strip()
123 | release = nvcc.rfind('Cuda compilation tools')
124 | build = nvcc.rfind('Build ')
125 | nvcc = nvcc[release:build].strip()
126 | except subprocess.SubprocessError:
127 | nvcc = 'Not Available'
128 | env_info['NVCC'] = nvcc
129 | elif musa_available:
130 | devices = defaultdict(list)
131 | for k in range(torch.musa.device_count()):
132 | devices[torch.musa.get_device_name(k)].append(str(k))
133 | for name, device_ids in devices.items():
134 | env_info['GPU ' + ','.join(device_ids)] = name
135 |
136 | MUSA_HOME = _get_musa_home()
137 | env_info['MUSA_HOME'] = MUSA_HOME
138 |
139 | if MUSA_HOME is not None and osp.isdir(MUSA_HOME):
140 | try:
141 | mcc = osp.join(MUSA_HOME, 'bin/mcc')
142 | subprocess.check_output(f'"{mcc}" -v', shell=True)
143 | except subprocess.SubprocessError:
144 | mcc = 'Not Available'
145 | env_info['mcc'] = mcc
146 | try:
147 | # Check C++ Compiler.
148 | # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
149 | # indicating the compiler used, we use this to get the compiler name
150 | import io
151 | import sysconfig
152 | cc = sysconfig.get_config_var('CC')
153 | if cc:
154 | cc = osp.basename(cc.split()[0])
155 | cc_info = subprocess.check_output(f'{cc} --version', shell=True)
156 | env_info['GCC'] = cc_info.decode('utf-8').partition(
157 | '\n')[0].strip()
158 | else:
159 | # on Windows, cl.exe is not in PATH. We need to find the path.
160 | # distutils.ccompiler.new_compiler() returns a msvccompiler
161 | # object and after initialization, path to cl.exe is found.
162 | import locale
163 | import os
164 | from distutils.ccompiler import new_compiler
165 | ccompiler = new_compiler()
166 | ccompiler.initialize()
167 | cc = subprocess.check_output(
168 | f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True)
169 | encoding = os.device_encoding(
170 | sys.stdout.fileno()) or locale.getpreferredencoding()
171 | env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip()
172 | env_info['GCC'] = 'n/a'
173 | except (subprocess.CalledProcessError, errors.DistutilsPlatformError):
174 | env_info['GCC'] = 'n/a'
175 | except io.UnsupportedOperation as e:
176 | # JupyterLab on Windows changes sys.stdout, which has no `fileno` attr
177 | # Refer to: https://github.com/open-mmlab/mmengine/issues/931
178 | # TODO: find a solution to get compiler info in Windows JupyterLab,
179 | # while preserving backward-compatibility in other systems.
180 | env_info['MSVC'] = f'n/a, reason: {str(e)}'
181 |
182 | env_info['PyTorch'] = torch.__version__
183 | env_info['PyTorch compiling details'] = get_build_config()
184 |
185 | try:
186 | import torchvision
187 | env_info['TorchVision'] = torchvision.__version__
188 | except ModuleNotFoundError:
189 | pass
190 |
191 | try:
192 | import cv2
193 | env_info['OpenCV'] = cv2.__version__
194 | except ImportError:
195 | pass
196 |
197 |
198 | return env_info
199 |
200 | if __name__ == '__main__':
201 | for name, val in collect_env().items():
202 | print(f'{name}: {val}')
--------------------------------------------------------------------------------