├── .gitignore ├── LICENSE.txt ├── README.md ├── assets ├── Alegreya-Regular.ttf ├── Caveat-Medium.ttf ├── DancingScript-Bold.ttf ├── arrow-right.svg ├── background.png ├── bootstrap.min.css ├── bootstrap.min.js ├── icon-new.png ├── icon.png ├── images │ ├── id_eval.png │ ├── ip_eval_m_00.png │ └── ip_eval_s.png ├── index.css ├── index.js ├── roboto.woff2 ├── teaser.png └── videos │ ├── .DS_Store │ ├── demo │ ├── .DS_Store │ ├── demo1.mp4 │ ├── demo2.mp4 │ └── demo3.mp4 │ ├── id │ ├── .DS_Store │ ├── 0_ref.mov │ ├── 1_ref.mov │ ├── 2_ref.mov │ ├── 4_ref.mov │ ├── 5_ref.mov │ ├── 6_ref.mov │ ├── 7_ref.mov │ ├── 8_ref.mov │ └── 9_ref.mov │ ├── multi_ip │ ├── .DS_Store │ ├── 0.mov │ ├── 1.mov │ ├── 10.mov │ ├── 2.mov │ ├── 4.mov │ ├── 5.mov │ ├── 7.mov │ ├── 8.mov │ └── 9.mov │ └── single_ip │ ├── .DS_Store │ ├── 0.mov │ ├── 1.mov │ ├── 2.mov │ ├── 3.mov │ ├── 4.mov │ ├── 5.mov │ ├── 6.mov │ ├── 7.mov │ └── 8.mov ├── examples ├── ref1.png ├── ref10.png ├── ref11.png ├── ref2.png ├── ref3.png ├── ref4.png ├── ref5.png ├── ref6.png ├── ref7.png ├── ref8.png ├── ref9.png └── ref_results │ ├── result1.gif │ ├── result2.gif │ ├── result3.gif │ └── result4.gif ├── generate.py ├── infer.sh ├── phantom_wan ├── __init__.py ├── configs │ ├── __init__.py │ ├── shared_config.py │ ├── wan_i2v_14B.py │ ├── wan_s2v_1_3B.py │ ├── wan_t2v_14B.py │ └── wan_t2v_1_3B.py ├── distributed │ ├── __init__.py │ ├── fsdp.py │ └── xdit_context_parallel.py ├── image2video.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── clip.py │ ├── model.py │ ├── t5.py │ ├── tokenizers.py │ ├── vae.py │ └── xlm_roberta.py ├── subject2video.py ├── text2video.py └── utils │ ├── __init__.py │ ├── fm_solvers.py │ ├── fm_solvers_unipc.py │ ├── prompt_extend.py │ ├── qwen_vl_utils.py │ └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | *.py[cod] 3 | *.bmp 4 | *.log 5 | *.zip 6 | *.tar 7 | *.pt 8 | *.pth 9 | *.ckpt 10 | *.safetensors 11 | *.json 12 | *.backup 13 | *.pkl 14 | *.html 15 | *.pdf 16 | *.whl 17 | cache 18 | __pycache__/ 19 | storage/ 20 | samples/ 21 | !.gitignore 22 | !requirements.txt 23 | .DS_Store 24 | *DS_Store 25 | assets/.DS_Store 26 | google/ 27 | Wan2.1-T2V-14B/ 28 | Wan2.1-T2V-1.3B/ 29 | Wan2.1-I2V-14B-480P/ 30 | Wan2.1-I2V-14B-720P/ 31 | Phantom-Wan-1.3B/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 lab-cv 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment 2 | 3 |
4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2502.11079-b31b1b.svg)](https://arxiv.org/abs/2502.11079)  6 | [![project page](https://img.shields.io/badge/Project_page-More_visualizations-green)](https://phantom-video.github.io/Phantom/)  7 | 8 |
9 | 10 | 11 | > [**Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment**](https://arxiv.org/abs/2502.11079)
12 | > [Lijie Liu](https://liulj13.github.io/) * , [Tianxiang Ma](https://tianxiangma.github.io/) * , [Bingchuan Li](https://scholar.google.com/citations?user=ac5Se6QAAAAJ) * †, [Zhuowei Chen](https://scholar.google.com/citations?user=ow1jGJkAAAAJ) * , [Jiawei Liu](https://scholar.google.com/citations?user=X21Fz-EAAAAJ), Gen Li, Siyu Zhou, [Qian He](https://scholar.google.com/citations?user=9rWWCgUAAAAJ), Xinglong Wu 13 | >
* Equal contribution,Project lead 14 | >
Intelligent Creation Team, ByteDance
15 | 16 |

17 | 18 |

19 | 20 | ## 🔥 Latest News! 21 | * Apr 10, 2025: We have updated the full version of the Phantom paper, which now includes more detailed descriptions of the model architecture and dataset pipeline. 22 | * Apr 21, 2025: 👋 Phantom-Wan is coming! We adapted the Phantom framework into the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model. The inference codes and checkpoint have been released. 23 | * Apr 23, 2025: 😊 Thanks to [ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper/tree/dev) for adapting ComfyUI to Phantom-Wan-1.3B. Everyone is welcome to use it ! 24 | 25 | ## 📑 Todo List 26 | - [x] Inference codes and Checkpoint of Phantom-Wan 1.3B 27 | - [ ] Checkpoint of Phantom-Wan 14B 28 | - [ ] Training codes of Phantom-Wan 29 | 30 | ## 📖 Overview 31 | Phantom is a unified video generation framework for single and multi-subject references, built on existing text-to-video and image-to-video architectures. It achieves cross-modal alignment using text-image-video triplet data by redesigning the joint text-image injection model. Additionally, it emphasizes subject consistency in human generation while enhancing ID-preserving video generation. 32 | 33 | ## ⚡️ Quickstart 34 | 35 | ### Installation 36 | Clone the repo: 37 | ```sh 38 | git clone https://github.com/Phantom-video/Phantom.git 39 | cd Phantom 40 | ``` 41 | 42 | Install dependencies: 43 | ```sh 44 | # Ensure torch >= 2.4.0 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | ### Model Download 49 | First you need to download the 1.3B original model of Wan2.1. Download Wan2.1-1.3B using huggingface-cli: 50 | ``` sh 51 | pip install "huggingface_hub[cli]" 52 | huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B 53 | ``` 54 | Then download the Phantom-Wan-1.3B model: 55 | ``` sh 56 | huggingface-cli download bytedance-research/Phantom --local-dir ./Phantom-Wan-1.3B 57 | ``` 58 | 59 | ### Run Subject-to-Video Generation 60 | 61 | - Single-GPU inference 62 | 63 | ``` sh 64 | python generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref1.png,examples/ref2.png" --prompt "暖阳漫过草地,扎着双马尾、头戴绿色蝴蝶结、身穿浅绿色连衣裙的小女孩蹲在盛开的雏菊旁。她身旁一只棕白相间的狗狗吐着舌头,毛茸茸尾巴欢快摇晃。小女孩笑着举起黄红配色、带有蓝色按钮的玩具相机,将和狗狗的欢乐瞬间定格。" --base_seed 42 65 | ``` 66 | 67 | - Multi-GPU inference using FSDP + xDiT USP 68 | 69 | ``` sh 70 | pip install "xfuser>=0.4.1" 71 | torchrun --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref3.png,examples/ref4.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "夕阳下,一位有着小麦色肌肤、留着乌黑长发的女人穿上有着大朵立体花朵装饰、肩袖处带有飘逸纱带的红色纱裙,漫步在金色的海滩上,海风轻拂她的长发,画面唯美动人。" --base_seed 42 72 | ``` 73 | 74 | > 💡Note: 75 | > * Changing `--ref_image` can achieve single reference Subject-to-Video generation or multi-reference Subject-to-Video generation. The number of reference images should be within 4. 76 | > * To achieve the best generation results, we recommend that you describe the visual content of the reference image as accurately as possible when writing `--prompt`. For example, "examples/ref1.png" can be described as "a toy camera in yellow and red with blue buttons". 77 | > * When the generated video is unsatisfactory, the most straightforward solution is to try changing the `--base_seed` and modifying the description in the `--prompt`. 78 | 79 | For more inference examples, please refer to "infer.sh". You will get the following generated results: 80 | 81 | 82 | 83 | 86 | 89 | 90 | 91 | 92 | 96 | 99 | 100 | 101 | 102 | 106 | 109 | 110 | 111 | 112 | 113 | 118 | 121 | 122 | 123 | 124 | 130 | 133 | 134 |
84 | Reference Images 85 | 87 | Generated Videos 88 |
93 | Image 1 94 | Image 2 95 | 97 | GIF 1 98 |
103 | Image 3 104 | Image 4 105 | 107 | GIF 2 108 |
114 | Image 5 115 | Image 6 116 | Image 7 117 | 119 | GIF 3 120 |
125 | Image 8 126 | Image 9 127 | Image 10 128 | Image 11 129 | 131 | GIF 4 132 |
135 | 136 | 137 | 138 | ## 🆚 Comparative Results 139 | - **Identity Preserving Video Generation**. 140 | ![image](./assets/images/id_eval.png) 141 | - **Single Reference Subject-to-Video Generation**. 142 | ![image](./assets/images/ip_eval_s.png) 143 | - **Multi-Reference Subject-to-Video Generation**. 144 | ![image](./assets/images/ip_eval_m_00.png) 145 | 146 | ## Acknowledgements 147 | We would like to express our gratitude to the SEED team for their support. Special thanks to Lu Jiang, Haoyuan Guo, Zhibei Ma, and Sen Wang for their assistance with the model and data. In addition, we are also very grateful to Siying Chen, Qingyang Li, and Wei Han for their help with the evaluation. 148 | 149 | ## BibTeX 150 | ```bibtex 151 | @article{liu2025phantom, 152 | title={Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment}, 153 | author={Liu, Lijie and Ma, Tianxaing and Li, Bingchuan and Chen, Zhuowei and Liu, Jiawei and He, Qian and Wu, Xinglong}, 154 | journal={arXiv preprint arXiv:2502.11079}, 155 | year={2025} 156 | } 157 | ``` 158 | 159 | ## Star History 160 | [![Star History Chart](https://api.star-history.com/svg?repos=Phantom-video/Phantom&type=Date)](https://www.star-history.com/#Phantom-video/Phantom&Date) -------------------------------------------------------------------------------- /assets/Alegreya-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/Alegreya-Regular.ttf -------------------------------------------------------------------------------- /assets/Caveat-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/Caveat-Medium.ttf -------------------------------------------------------------------------------- /assets/DancingScript-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/DancingScript-Bold.ttf -------------------------------------------------------------------------------- /assets/background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/background.png -------------------------------------------------------------------------------- /assets/icon-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/icon-new.png -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/icon.png -------------------------------------------------------------------------------- /assets/images/id_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/id_eval.png -------------------------------------------------------------------------------- /assets/images/ip_eval_m_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/ip_eval_m_00.png -------------------------------------------------------------------------------- /assets/images/ip_eval_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/ip_eval_s.png -------------------------------------------------------------------------------- /assets/index.css: -------------------------------------------------------------------------------- 1 | .nopadding { 2 | padding: 0 !important; 3 | margin: 0 !important; 4 | } 5 | 6 | .flex { 7 | display: flex; 8 | } 9 | 10 | .paper-btn { 11 | position: relative; 12 | text-align: center; 13 | 14 | display: inline-block; 15 | margin: 8px; 16 | padding: 8px 8px; 17 | 18 | border-width: 0; 19 | outline: none; 20 | border-radius: 10px; 21 | 22 | background-color: #3e3e40; 23 | color: white !important; 24 | font-size: 16px; 25 | width: 200px; 26 | font-weight: 350; 27 | } 28 | .paper-btn-parent { 29 | display: flex; 30 | justify-content: center; 31 | margin: 16px 0px; 32 | } 33 | .paper-btn:hover { 34 | opacity: 0.80; 35 | } 36 | 37 | .video { 38 | width: 100%; 39 | height: auto; 40 | } 41 | 42 | .hover-overlay { 43 | position: absolute; 44 | top: 50%; 45 | left: 50%; 46 | transform: translate(-50%, -50%); 47 | width: 80%; 48 | height: auto; 49 | opacity: 1; 50 | display: flex; 51 | flex-direction: column; 52 | justify-content: center; 53 | align-items: center; 54 | font-family: Chalkduster; 55 | font-size: 16px; 56 | text-align: center; 57 | transition: opacity 0.3s ease-in-out; 58 | } 59 | 60 | div.scroll-container { 61 | background-color: #3e3e40; 62 | overflow: auto; 63 | white-space: nowrap; 64 | margin-top: 25px; 65 | margin-bottom: 25px; 66 | padding: 7px 7.5px 2.5px 9px; 67 | } 68 | 69 | .youtube-container { 70 | position: relative; 71 | width: 100%; 72 | padding-bottom: 56.25%; /* 16:9 aspect ratio */ 73 | height: 0; 74 | overflow: hidden; 75 | } 76 | .youtube-container iframe { 77 | position: absolute; 78 | top: 0; 79 | left: 0; 80 | width: 100%; 81 | height: 100%; 82 | } 83 | 84 | .image-set { 85 | width: 80% 86 | } 87 | .image-set img { 88 | cursor: pointer; 89 | border: 0px solid transparent; 90 | opacity: 0.5; 91 | height: min(90px, 6vw); 92 | transition: opacity 0.3s ease; /* Smooth transition */ 93 | } 94 | .image-set img.selected { 95 | opacity: 1.0; 96 | } 97 | .image-set img:hover { 98 | opacity: 1.0; 99 | } 100 | 101 | .show-neighbors { 102 | overflow: hidden; 103 | } 104 | 105 | .show-neighbors .carousel-indicators { 106 | margin-right: 15%; 107 | margin-left: 15%; 108 | } 109 | .show-neighbors .carousel-control-prev, 110 | .show-neighbors .carousel-control-next { 111 | height: 80%; 112 | width: 15%; 113 | z-index: 11; 114 | /* .carousel-caption has z-index 10 */ 115 | } 116 | .show-neighbors .carousel-inner { 117 | width: 200%; 118 | left: -50%; 119 | } 120 | .show-neighbors .carousel-item-next:not(.carousel-item-left), 121 | .show-neighbors .carousel-item-right.active { 122 | -webkit-transform: translate3d(33%, 0, 0); 123 | transform: translate3d(33%, 0, 0); 124 | } 125 | .show-neighbors .carousel-item-prev:not(.carousel-item-right), 126 | .show-neighbors .carousel-item-left.active { 127 | -webkit-transform: translate3d(-33%, 0, 0); 128 | transform: translate3d(-33%, 0, 0); 129 | } 130 | .show-neighbors .item__third { 131 | display: block !important; 132 | float: left; 133 | position: relative; 134 | width: 33.333333333333%; 135 | padding: 0.5%; 136 | text-align: center; 137 | } 138 | 139 | .custom-button { 140 | color: white !important; 141 | text-decoration: none !important; 142 | background-color: #2f2f30; 143 | background-size: cover; 144 | background-repeat: no-repeat; 145 | background-position: center; 146 | border: none; 147 | cursor: pointer; 148 | opacity: 0.7; /* Initial opacity */ 149 | transition: opacity 0.3s ease; /* Smooth transition */ 150 | } 151 | .custom-button:hover { 152 | opacity: 1.0; /* Increased opacity on hover */ 153 | } 154 | 155 | .block-container { 156 | display: flex; 157 | flex-wrap: wrap; 158 | justify-content: center; 159 | gap: 0.5%; 160 | } 161 | 162 | @media screen and (max-width: 1024px) { 163 | .mobile-break { display: block; } 164 | } -------------------------------------------------------------------------------- /assets/index.js: -------------------------------------------------------------------------------- 1 | 2 | $('.carousel-item', '.show-neighbors').each(function(){ 3 | var next = $(this).next(); 4 | if (! next.length) { 5 | next = $(this).siblings(':first'); 6 | } 7 | next.children(':first-child').clone().appendTo($(this)); 8 | }).each(function(){ 9 | var prev = $(this).prev(); 10 | if (! prev.length) { 11 | prev = $(this).siblings(':last'); 12 | } 13 | prev.children(':nth-last-child(2)').clone().prependTo($(this)); 14 | }); 15 | -------------------------------------------------------------------------------- /assets/roboto.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/roboto.woff2 -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/teaser.png -------------------------------------------------------------------------------- /assets/videos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/.DS_Store -------------------------------------------------------------------------------- /assets/videos/demo/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/.DS_Store -------------------------------------------------------------------------------- /assets/videos/demo/demo1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo1.mp4 -------------------------------------------------------------------------------- /assets/videos/demo/demo2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo2.mp4 -------------------------------------------------------------------------------- /assets/videos/demo/demo3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo3.mp4 -------------------------------------------------------------------------------- /assets/videos/id/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/.DS_Store -------------------------------------------------------------------------------- /assets/videos/id/0_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/0_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/1_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/1_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/2_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/2_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/4_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/4_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/5_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/5_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/6_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/6_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/7_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/7_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/8_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/8_ref.mov -------------------------------------------------------------------------------- /assets/videos/id/9_ref.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/9_ref.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/.DS_Store -------------------------------------------------------------------------------- /assets/videos/multi_ip/0.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/0.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/1.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/1.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/10.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/10.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/2.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/2.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/4.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/4.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/5.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/5.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/7.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/7.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/8.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/8.mov -------------------------------------------------------------------------------- /assets/videos/multi_ip/9.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/9.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/.DS_Store -------------------------------------------------------------------------------- /assets/videos/single_ip/0.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/0.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/1.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/1.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/2.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/2.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/3.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/3.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/4.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/4.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/5.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/5.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/6.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/6.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/7.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/7.mov -------------------------------------------------------------------------------- /assets/videos/single_ip/8.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/8.mov -------------------------------------------------------------------------------- /examples/ref1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref1.png -------------------------------------------------------------------------------- /examples/ref10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref10.png -------------------------------------------------------------------------------- /examples/ref11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref11.png -------------------------------------------------------------------------------- /examples/ref2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref2.png -------------------------------------------------------------------------------- /examples/ref3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref3.png -------------------------------------------------------------------------------- /examples/ref4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref4.png -------------------------------------------------------------------------------- /examples/ref5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref5.png -------------------------------------------------------------------------------- /examples/ref6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref6.png -------------------------------------------------------------------------------- /examples/ref7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref7.png -------------------------------------------------------------------------------- /examples/ref8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref8.png -------------------------------------------------------------------------------- /examples/ref9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref9.png -------------------------------------------------------------------------------- /examples/ref_results/result1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result1.gif -------------------------------------------------------------------------------- /examples/ref_results/result2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result2.gif -------------------------------------------------------------------------------- /examples/ref_results/result3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result3.gif -------------------------------------------------------------------------------- /examples/ref_results/result4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result4.gif -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | from datetime import datetime 17 | import logging 18 | import os 19 | import sys 20 | import warnings 21 | 22 | warnings.filterwarnings('ignore') 23 | 24 | import torch, random 25 | import torch.distributed as dist 26 | from PIL import Image, ImageOps 27 | 28 | import phantom_wan 29 | from phantom_wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES 30 | from phantom_wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 31 | from phantom_wan.utils.utils import cache_video, cache_image, str2bool 32 | 33 | EXAMPLE_PROMPT = { 34 | "t2v-1.3B": { 35 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", 36 | }, 37 | "t2v-14B": { 38 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", 39 | }, 40 | "t2i-14B": { 41 | "prompt": "一个朴素端庄的美人", 42 | }, 43 | "i2v-14B": { 44 | "prompt": 45 | "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", 46 | "image": 47 | "examples/i2v_input.JPG", 48 | }, 49 | } 50 | 51 | 52 | def _validate_args(args): 53 | # Basic check 54 | assert args.ckpt_dir is not None, "Please specify the checkpoint directory." 55 | assert args.phantom_ckpt is not None, "Please specify the Phantom-Wan checkpoint." 56 | assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" 57 | 58 | # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. 59 | if args.sample_steps is None: 60 | args.sample_steps = 40 if "i2v" in args.task else 50 61 | 62 | if args.sample_shift is None: 63 | args.sample_shift = 5.0 64 | if "i2v" in args.task and args.size in ["832*480", "480*832"]: 65 | args.sample_shift = 3.0 66 | 67 | # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. 68 | if args.frame_num is None: 69 | args.frame_num = 1 if "t2i" in args.task else 81 70 | 71 | # T2I frame_num check 72 | if "t2i" in args.task: 73 | assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" 74 | 75 | args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( 76 | 0, sys.maxsize) 77 | # Size check 78 | assert args.size in SUPPORTED_SIZES[ 79 | args. 80 | task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" 81 | 82 | 83 | def _parse_args(): 84 | parser = argparse.ArgumentParser( 85 | description="Generate a image or video from a text prompt or image using Wan" 86 | ) 87 | parser.add_argument( 88 | "--task", 89 | type=str, 90 | default="s2v-1.3B", 91 | choices=list(WAN_CONFIGS.keys()), 92 | help="The task to run.") 93 | parser.add_argument( 94 | "--size", 95 | type=str, 96 | default="1280*720", 97 | choices=list(SIZE_CONFIGS.keys()), 98 | help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." 99 | ) 100 | parser.add_argument( 101 | "--frame_num", 102 | type=int, 103 | default=None, 104 | help="How many frames to sample from a image or video. The number should be 4n+1" 105 | ) 106 | parser.add_argument( 107 | "--ckpt_dir", 108 | type=str, 109 | default=None, 110 | help="The path to the checkpoint directory.") 111 | parser.add_argument( 112 | "--phantom_ckpt", 113 | type=str, 114 | default=None, 115 | help="The path to the Phantom-Wan checkpoint.") 116 | parser.add_argument( 117 | "--offload_model", 118 | type=str2bool, 119 | default=None, 120 | help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." 121 | ) 122 | parser.add_argument( 123 | "--ulysses_size", 124 | type=int, 125 | default=1, 126 | help="The size of the ulysses parallelism in DiT.") 127 | parser.add_argument( 128 | "--ring_size", 129 | type=int, 130 | default=1, 131 | help="The size of the ring attention parallelism in DiT.") 132 | parser.add_argument( 133 | "--t5_fsdp", 134 | action="store_true", 135 | default=False, 136 | help="Whether to use FSDP for T5.") 137 | parser.add_argument( 138 | "--t5_cpu", 139 | action="store_true", 140 | default=False, 141 | help="Whether to place T5 model on CPU.") 142 | parser.add_argument( 143 | "--dit_fsdp", 144 | action="store_true", 145 | default=False, 146 | help="Whether to use FSDP for DiT.") 147 | parser.add_argument( 148 | "--save_file", 149 | type=str, 150 | default=None, 151 | help="The file to save the generated image or video to.") 152 | parser.add_argument( 153 | "--prompt", 154 | type=str, 155 | default=None, 156 | help="The prompt to generate the image or video from.") 157 | parser.add_argument( 158 | "--use_prompt_extend", 159 | action="store_true", 160 | default=False, 161 | help="Whether to use prompt extend.") 162 | parser.add_argument( 163 | "--prompt_extend_method", 164 | type=str, 165 | default="local_qwen", 166 | choices=["dashscope", "local_qwen"], 167 | help="The prompt extend method to use.") 168 | parser.add_argument( 169 | "--prompt_extend_model", 170 | type=str, 171 | default=None, 172 | help="The prompt extend model to use.") 173 | parser.add_argument( 174 | "--prompt_extend_target_lang", 175 | type=str, 176 | default="ch", 177 | choices=["ch", "en"], 178 | help="The target language of prompt extend.") 179 | parser.add_argument( 180 | "--base_seed", 181 | type=int, 182 | default=-1, 183 | help="The seed to use for generating the image or video.") 184 | parser.add_argument( 185 | "--image", 186 | type=str, 187 | default=None, 188 | help="The image to generate the video from.") 189 | parser.add_argument( 190 | "--ref_image", 191 | type=str, 192 | default=None, 193 | help="The reference images used by Phantom-Wan.") 194 | parser.add_argument( 195 | "--sample_solver", 196 | type=str, 197 | default='unipc', 198 | choices=['unipc', 'dpm++'], 199 | help="The solver used to sample.") 200 | parser.add_argument( 201 | "--sample_steps", type=int, default=None, help="The sampling steps.") 202 | parser.add_argument( 203 | "--sample_shift", 204 | type=float, 205 | default=None, 206 | help="Sampling shift factor for flow matching schedulers.") 207 | parser.add_argument( 208 | "--sample_guide_scale", 209 | type=float, 210 | default=5.0, 211 | help="Classifier free guidance scale.") 212 | parser.add_argument( 213 | "--sample_guide_scale_img", 214 | type=float, 215 | default=5.0, 216 | help="Classifier free guidance scale for reference images.") 217 | parser.add_argument( 218 | "--sample_guide_scale_text", 219 | type=float, 220 | default=7.5, 221 | help="Classifier free guidance scale for text.") 222 | 223 | args = parser.parse_args() 224 | 225 | _validate_args(args) 226 | 227 | return args 228 | 229 | 230 | def _init_logging(rank): 231 | # logging 232 | if rank == 0: 233 | # set format 234 | logging.basicConfig( 235 | level=logging.INFO, 236 | format="[%(asctime)s] %(levelname)s: %(message)s", 237 | handlers=[logging.StreamHandler(stream=sys.stdout)]) 238 | else: 239 | logging.basicConfig(level=logging.ERROR) 240 | 241 | 242 | def load_ref_images(path, size): 243 | # Load size. 244 | h, w = size[1], size[0] 245 | # Load images. 246 | ref_paths = path.split(",") 247 | ref_images = [] 248 | for image_path in ref_paths: 249 | with Image.open(image_path) as img: 250 | img = img.convert("RGB") 251 | 252 | # Calculate the required size to keep aspect ratio and fill the rest with padding. 253 | img_ratio = img.width / img.height 254 | target_ratio = w / h 255 | 256 | if img_ratio > target_ratio: # Image is wider than target 257 | new_width = w 258 | new_height = int(new_width / img_ratio) 259 | else: # Image is taller than target 260 | new_height = h 261 | new_width = int(new_height * img_ratio) 262 | 263 | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) 264 | 265 | # Create a new image with the target size and place the resized image in the center 266 | delta_w = w - img.size[0] 267 | delta_h = h - img.size[1] 268 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 269 | new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) 270 | ref_images.append(new_img) 271 | 272 | return ref_images 273 | 274 | 275 | def generate(args): 276 | rank = int(os.getenv("RANK", 0)) 277 | world_size = int(os.getenv("WORLD_SIZE", 1)) 278 | local_rank = int(os.getenv("LOCAL_RANK", 0)) 279 | device = local_rank 280 | _init_logging(rank) 281 | 282 | if args.offload_model is None: 283 | args.offload_model = False if world_size > 1 else True 284 | logging.info( 285 | f"offload_model is not specified, set to {args.offload_model}.") 286 | if world_size > 1: 287 | torch.cuda.set_device(local_rank) 288 | dist.init_process_group( 289 | backend="nccl", 290 | init_method="env://", 291 | rank=rank, 292 | world_size=world_size) 293 | else: 294 | assert not ( 295 | args.t5_fsdp or args.dit_fsdp 296 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." 297 | assert not ( 298 | args.ulysses_size > 1 or args.ring_size > 1 299 | ), f"context parallel are not supported in non-distributed environments." 300 | 301 | if args.ulysses_size > 1 or args.ring_size > 1: 302 | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." 303 | from xfuser.core.distributed import (initialize_model_parallel, 304 | init_distributed_environment) 305 | init_distributed_environment( 306 | rank=dist.get_rank(), world_size=dist.get_world_size()) 307 | 308 | initialize_model_parallel( 309 | sequence_parallel_degree=dist.get_world_size(), 310 | ring_degree=args.ring_size, 311 | ulysses_degree=args.ulysses_size, 312 | ) 313 | 314 | if args.use_prompt_extend: 315 | if args.prompt_extend_method == "dashscope": 316 | prompt_expander = DashScopePromptExpander( 317 | model_name=args.prompt_extend_model, is_vl="i2v" in args.task) 318 | elif args.prompt_extend_method == "local_qwen": 319 | prompt_expander = QwenPromptExpander( 320 | model_name=args.prompt_extend_model, 321 | is_vl="i2v" in args.task, 322 | device=rank) 323 | else: 324 | raise NotImplementedError( 325 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 326 | 327 | cfg = WAN_CONFIGS[args.task] 328 | if args.ulysses_size > 1: 329 | assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." 330 | 331 | logging.info(f"Generation job args: {args}") 332 | logging.info(f"Generation model config: {cfg}") 333 | 334 | if dist.is_initialized(): 335 | base_seed = [args.base_seed] if rank == 0 else [None] 336 | dist.broadcast_object_list(base_seed, src=0) 337 | args.base_seed = base_seed[0] 338 | 339 | if "s2v" in args.task: 340 | 341 | ref_images = load_ref_images(args.ref_image, SIZE_CONFIGS[args.size]) 342 | 343 | logging.info("Creating Phantom-Wan pipeline.") 344 | wan_s2v = phantom_wan.Phantom_Wan_S2V( 345 | config=cfg, 346 | checkpoint_dir=args.ckpt_dir, 347 | phantom_ckpt=args.phantom_ckpt, 348 | device_id=device, 349 | rank=rank, 350 | t5_fsdp=args.t5_fsdp, 351 | dit_fsdp=args.dit_fsdp, 352 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), 353 | t5_cpu=args.t5_cpu, 354 | ) 355 | 356 | logging.info( 357 | f"Generating {'image' if 't2i' in args.task else 'video'} ...") 358 | video = wan_s2v.generate( 359 | args.prompt, 360 | ref_images, 361 | size=SIZE_CONFIGS[args.size], 362 | frame_num=args.frame_num, 363 | shift=args.sample_shift, 364 | sample_solver=args.sample_solver, 365 | sampling_steps=args.sample_steps, 366 | guide_scale_img=args.sample_guide_scale_img, 367 | guide_scale_text=args.sample_guide_scale_text, 368 | seed=args.base_seed, 369 | offload_model=args.offload_model) 370 | 371 | elif "t2v" in args.task or "t2i" in args.task: 372 | if args.prompt is None: 373 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] 374 | logging.info(f"Input prompt: {args.prompt}") 375 | if args.use_prompt_extend: 376 | logging.info("Extending prompt ...") 377 | if rank == 0: 378 | prompt_output = prompt_expander( 379 | args.prompt, 380 | tar_lang=args.prompt_extend_target_lang, 381 | seed=args.base_seed) 382 | if prompt_output.status == False: 383 | logging.info( 384 | f"Extending prompt failed: {prompt_output.message}") 385 | logging.info("Falling back to original prompt.") 386 | input_prompt = args.prompt 387 | else: 388 | input_prompt = prompt_output.prompt 389 | input_prompt = [input_prompt] 390 | else: 391 | input_prompt = [None] 392 | if dist.is_initialized(): 393 | dist.broadcast_object_list(input_prompt, src=0) 394 | args.prompt = input_prompt[0] 395 | logging.info(f"Extended prompt: {args.prompt}") 396 | 397 | logging.info("Creating Phantom-Wan pipeline.") 398 | wan_t2v = phantom_wan.WanT2V( 399 | config=cfg, 400 | checkpoint_dir=args.ckpt_dir, 401 | device_id=device, 402 | rank=rank, 403 | t5_fsdp=args.t5_fsdp, 404 | dit_fsdp=args.dit_fsdp, 405 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), 406 | t5_cpu=args.t5_cpu, 407 | ) 408 | 409 | logging.info( 410 | f"Generating {'image' if 't2i' in args.task else 'video'} ...") 411 | video = wan_t2v.generate( 412 | args.prompt, 413 | size=SIZE_CONFIGS[args.size], 414 | frame_num=args.frame_num, 415 | shift=args.sample_shift, 416 | sample_solver=args.sample_solver, 417 | sampling_steps=args.sample_steps, 418 | guide_scale=args.sample_guide_scale, 419 | seed=args.base_seed, 420 | offload_model=args.offload_model) 421 | 422 | else: 423 | if args.prompt is None: 424 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] 425 | if args.image is None: 426 | args.image = EXAMPLE_PROMPT[args.task]["image"] 427 | logging.info(f"Input prompt: {args.prompt}") 428 | logging.info(f"Input image: {args.image}") 429 | 430 | img = Image.open(args.image).convert("RGB") 431 | if args.use_prompt_extend: 432 | logging.info("Extending prompt ...") 433 | if rank == 0: 434 | prompt_output = prompt_expander( 435 | args.prompt, 436 | tar_lang=args.prompt_extend_target_lang, 437 | image=img, 438 | seed=args.base_seed) 439 | if prompt_output.status == False: 440 | logging.info( 441 | f"Extending prompt failed: {prompt_output.message}") 442 | logging.info("Falling back to original prompt.") 443 | input_prompt = args.prompt 444 | else: 445 | input_prompt = prompt_output.prompt 446 | input_prompt = [input_prompt] 447 | else: 448 | input_prompt = [None] 449 | if dist.is_initialized(): 450 | dist.broadcast_object_list(input_prompt, src=0) 451 | args.prompt = input_prompt[0] 452 | logging.info(f"Extended prompt: {args.prompt}") 453 | 454 | logging.info("Creating WanI2V pipeline.") 455 | wan_i2v = phantom_wan.WanI2V( 456 | config=cfg, 457 | checkpoint_dir=args.ckpt_dir, 458 | device_id=device, 459 | rank=rank, 460 | t5_fsdp=args.t5_fsdp, 461 | dit_fsdp=args.dit_fsdp, 462 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1), 463 | t5_cpu=args.t5_cpu, 464 | ) 465 | 466 | logging.info("Generating video ...") 467 | video = wan_i2v.generate( 468 | args.prompt, 469 | img, 470 | max_area=MAX_AREA_CONFIGS[args.size], 471 | frame_num=args.frame_num, 472 | shift=args.sample_shift, 473 | sample_solver=args.sample_solver, 474 | sampling_steps=args.sample_steps, 475 | guide_scale=args.sample_guide_scale, 476 | seed=args.base_seed, 477 | offload_model=args.offload_model) 478 | 479 | if rank == 0: 480 | if args.save_file is None: 481 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") 482 | formatted_prompt = args.prompt.replace(" ", "_").replace("/", 483 | "_")[:50] 484 | suffix = '.png' if "t2i" in args.task else '.mp4' 485 | args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix 486 | 487 | if "t2i" in args.task: 488 | logging.info(f"Saving generated image to {args.save_file}") 489 | cache_image( 490 | tensor=video.squeeze(1)[None], 491 | save_file=args.save_file, 492 | nrow=1, 493 | normalize=True, 494 | value_range=(-1, 1)) 495 | else: 496 | logging.info(f"Saving generated video to {args.save_file}") 497 | cache_video( 498 | tensor=video[None], 499 | save_file=args.save_file, 500 | fps=cfg.sample_fps, 501 | nrow=1, 502 | normalize=True, 503 | value_range=(-1, 1)) 504 | logging.info("Finished.") 505 | 506 | 507 | if __name__ == "__main__": 508 | args = _parse_args() 509 | generate(args) 510 | -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | torchrun \ 2 | --node_rank=0 \ 3 | --nnodes=1 \ 4 | --rdzv_endpoint=127.0.0.1:23468 \ 5 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref1.png,examples/ref2.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "暖阳漫过草地,扎着双马尾、头戴绿色蝴蝶结、身穿浅绿色连衣裙的小女孩蹲在盛开的雏菊旁。她身旁一只棕白相间的狗狗吐着舌头,毛茸茸尾巴欢快摇晃。小女孩笑着举起黄红配色、带有蓝色按钮的玩具相机,将和狗狗的欢乐瞬间定格。" --base_seed 42 6 | 7 | torchrun \ 8 | --node_rank=0 \ 9 | --nnodes=1 \ 10 | --rdzv_endpoint=127.0.0.1:23468 \ 11 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref3.png,examples/ref4.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "夕阳下,一位有着小麦色肌肤、留着乌黑长发的女人穿上有着大朵立体花朵装饰、肩袖处带有飘逸纱带的红色纱裙,漫步在金色的海滩上,海风轻拂她的长发,画面唯美动人。" --base_seed 42 12 | 13 | torchrun \ 14 | --node_rank=0 \ 15 | --nnodes=1 \ 16 | --rdzv_endpoint=127.0.0.1:23468 \ 17 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref5.png,examples/ref6.png,examples/ref7.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "在被冰雪覆盖,周围盛开着粉色花朵,有蝴蝶飞舞,屋内透出暖黄色灯光的梦幻小屋场景下,一位头发灰白、穿着深绿色上衣的老人牵着梳着双丸子头、身着中式传统服饰、外披白色毛绒衣物的小女孩的手,缓缓前行,画面温馨宁静。" --base_seed 42 18 | 19 | torchrun \ 20 | --node_rank=0 \ 21 | --nnodes=1 \ 22 | --rdzv_endpoint=127.0.0.1:23468 \ 23 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref8.png,examples/ref9.png,examples/ref10.png,examples/ref11.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "一位金色长发的女人身穿棕色带波点网纱长袖、胸前系带设计的泳衣,手持一杯有橙色切片和草莓装饰、插着绿色吸管的分层鸡尾酒,坐在有着棕榈树、铺有蓝白条纹毯子和灰色垫子、摆放着躺椅的沙滩上晒日光浴的慢镜头,捕捉她享受阳光的微笑与海浪轻抚沙滩的美景。" --base_seed 42 -------------------------------------------------------------------------------- /phantom_wan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs, distributed, modules 2 | from .image2video import WanI2V 3 | from .text2video import WanT2V 4 | from .subject2video import Phantom_Wan_S2V 5 | -------------------------------------------------------------------------------- /phantom_wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_s2v_1_3B import s2v_1_3B 8 | from .wan_i2v_14B import i2v_14B 9 | from .wan_t2v_1_3B import t2v_1_3B 10 | from .wan_t2v_14B import t2v_14B 11 | 12 | # the config of t2i_14B is the same as t2v_14B 13 | t2i_14B = copy.deepcopy(t2v_14B) 14 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 15 | 16 | WAN_CONFIGS = { 17 | 's2v-1.3B': s2v_1_3B, 18 | 't2v-14B': t2v_14B, 19 | 't2v-1.3B': t2v_1_3B, 20 | 'i2v-14B': i2v_14B, 21 | 't2i-14B': t2i_14B, 22 | } 23 | 24 | SIZE_CONFIGS = { 25 | '720*1280': (720, 1280), 26 | '1280*720': (1280, 720), 27 | '480*832': (480, 832), 28 | '832*480': (832, 480), 29 | '1024*1024': (1024, 1024), 30 | } 31 | 32 | MAX_AREA_CONFIGS = { 33 | '720*1280': 720 * 1280, 34 | '1280*720': 1280 * 720, 35 | '480*832': 480 * 832, 36 | '832*480': 832 * 480, 37 | } 38 | 39 | SUPPORTED_SIZES = { 40 | 's2v-1.3B': ('832*480',), 41 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 42 | 't2v-1.3B': ('480*832', '832*480'), 43 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 44 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 45 | } 46 | -------------------------------------------------------------------------------- /phantom_wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 20 | -------------------------------------------------------------------------------- /phantom_wan/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | 12 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # clip 16 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 17 | i2v_14B.clip_dtype = torch.float16 18 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 19 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 20 | 21 | # vae 22 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 23 | i2v_14B.vae_stride = (4, 8, 8) 24 | 25 | # transformer 26 | i2v_14B.patch_size = (1, 2, 2) 27 | i2v_14B.dim = 5120 28 | i2v_14B.ffn_dim = 13824 29 | i2v_14B.freq_dim = 256 30 | i2v_14B.num_heads = 40 31 | i2v_14B.num_layers = 40 32 | i2v_14B.window_size = (-1, -1) 33 | i2v_14B.qk_norm = True 34 | i2v_14B.cross_attn_norm = True 35 | i2v_14B.eps = 1e-6 36 | -------------------------------------------------------------------------------- /phantom_wan/configs/wan_s2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from easydict import EasyDict 16 | 17 | from .shared_config import wan_shared_cfg 18 | 19 | #------------------------ Wan T2V 1.3B ------------------------# 20 | 21 | s2v_1_3B = EasyDict(__name__='Config: Phantom-Wan S2V 1.3B') 22 | s2v_1_3B.update(wan_shared_cfg) 23 | 24 | # t5 25 | s2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 26 | s2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 27 | 28 | # vae 29 | s2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 30 | s2v_1_3B.vae_stride = (4, 8, 8) 31 | 32 | # transformer 33 | s2v_1_3B.patch_size = (1, 2, 2) 34 | s2v_1_3B.dim = 1536 35 | s2v_1_3B.ffn_dim = 8960 36 | s2v_1_3B.freq_dim = 256 37 | s2v_1_3B.num_heads = 12 38 | s2v_1_3B.num_layers = 30 39 | s2v_1_3B.window_size = (-1, -1) 40 | s2v_1_3B.qk_norm = True 41 | s2v_1_3B.cross_attn_norm = True 42 | s2v_1_3B.eps = 1e-6 43 | -------------------------------------------------------------------------------- /phantom_wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /phantom_wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /phantom_wan/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/phantom_wan/distributed/__init__.py -------------------------------------------------------------------------------- /phantom_wan/distributed/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from functools import partial 3 | 4 | import torch 5 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 6 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 7 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 8 | 9 | 10 | def shard_model( 11 | model, 12 | device_id, 13 | param_dtype=torch.bfloat16, 14 | reduce_dtype=torch.float32, 15 | buffer_dtype=torch.float32, 16 | process_group=None, 17 | sharding_strategy=ShardingStrategy.FULL_SHARD, 18 | sync_module_states=True, 19 | ): 20 | model = FSDP( 21 | module=model, 22 | process_group=process_group, 23 | sharding_strategy=sharding_strategy, 24 | auto_wrap_policy=partial( 25 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), 26 | mixed_precision=MixedPrecision( 27 | param_dtype=param_dtype, 28 | reduce_dtype=reduce_dtype, 29 | buffer_dtype=buffer_dtype), 30 | device_id=device_id, 31 | sync_module_states=sync_module_states) 32 | return model 33 | -------------------------------------------------------------------------------- /phantom_wan/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.cuda.amp as amp 4 | from xfuser.core.distributed import (get_sequence_parallel_rank, 5 | get_sequence_parallel_world_size, 6 | get_sp_group) 7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 8 | 9 | from ..modules.model import sinusoidal_embedding_1d 10 | 11 | 12 | def pad_freqs(original_tensor, target_len): 13 | seq_len, s1, s2 = original_tensor.shape 14 | pad_size = target_len - seq_len 15 | padding_tensor = torch.ones( 16 | pad_size, 17 | s1, 18 | s2, 19 | dtype=original_tensor.dtype, 20 | device=original_tensor.device) 21 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 22 | return padded_tensor 23 | 24 | 25 | @amp.autocast(enabled=False) 26 | def rope_apply(x, grid_sizes, freqs): 27 | """ 28 | x: [B, L, N, C]. 29 | grid_sizes: [B, 3]. 30 | freqs: [M, C // 2]. 31 | """ 32 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 33 | # split freqs 34 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 35 | 36 | # loop over samples 37 | output = [] 38 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 39 | seq_len = f * h * w 40 | 41 | # precompute multipliers 42 | x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape( 43 | s, n, -1, 2)) 44 | freqs_i = torch.cat([ 45 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 46 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 47 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 48 | ], 49 | dim=-1).reshape(seq_len, 1, -1) 50 | 51 | # apply rotary embedding 52 | sp_size = get_sequence_parallel_world_size() 53 | sp_rank = get_sequence_parallel_rank() 54 | freqs_i = pad_freqs(freqs_i, s * sp_size) 55 | s_per_rank = s 56 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * 57 | s_per_rank), :, :] 58 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) 59 | x_i = torch.cat([x_i, x[i, s:]]) 60 | 61 | # append to collection 62 | output.append(x_i) 63 | return torch.stack(output).float() 64 | 65 | 66 | def usp_dit_forward( 67 | self, 68 | x, 69 | t, 70 | context, 71 | seq_len, 72 | clip_fea=None, 73 | y=None, 74 | ): 75 | """ 76 | x: A list of videos each with shape [C, T, H, W]. 77 | t: [B]. 78 | context: A list of text embeddings each with shape [L, C]. 79 | """ 80 | if self.model_type == 'i2v': 81 | assert clip_fea is not None and y is not None 82 | # params 83 | device = self.patch_embedding.weight.device 84 | if self.freqs.device != device: 85 | self.freqs = self.freqs.to(device) 86 | 87 | if y is not None: 88 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 89 | 90 | # embeddings 91 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 92 | grid_sizes = torch.stack( 93 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 94 | x = [u.flatten(2).transpose(1, 2) for u in x] 95 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 96 | assert seq_lens.max() <= seq_len 97 | x = torch.cat([ 98 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 99 | for u in x 100 | ]) 101 | 102 | # time embeddings 103 | with amp.autocast(dtype=torch.float32): 104 | e = self.time_embedding( 105 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 107 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 108 | 109 | # context 110 | context_lens = None 111 | context = self.text_embedding( 112 | torch.stack([ 113 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 114 | for u in context 115 | ])) 116 | 117 | if clip_fea is not None: 118 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 119 | context = torch.concat([context_clip, context], dim=1) 120 | 121 | # arguments 122 | kwargs = dict( 123 | e=e0, 124 | seq_lens=seq_lens, 125 | grid_sizes=grid_sizes, 126 | freqs=self.freqs, 127 | context=context, 128 | context_lens=context_lens) 129 | 130 | # Context Parallel 131 | x = torch.chunk( 132 | x, get_sequence_parallel_world_size(), 133 | dim=1)[get_sequence_parallel_rank()] 134 | 135 | for block in self.blocks: 136 | x = block(x, **kwargs) 137 | 138 | # head 139 | x = self.head(x, e) 140 | 141 | # Context Parallel 142 | x = get_sp_group().all_gather(x, dim=1) 143 | 144 | # unpatchify 145 | x = self.unpatchify(x, grid_sizes) 146 | return [u.float() for u in x] 147 | 148 | 149 | def usp_attn_forward(self, 150 | x, 151 | seq_lens, 152 | grid_sizes, 153 | freqs, 154 | dtype=torch.bfloat16): 155 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 156 | half_dtypes = (torch.float16, torch.bfloat16) 157 | 158 | def half(x): 159 | return x if x.dtype in half_dtypes else x.to(dtype) 160 | 161 | # query, key, value function 162 | def qkv_fn(x): 163 | q = self.norm_q(self.q(x)).view(b, s, n, d) 164 | k = self.norm_k(self.k(x)).view(b, s, n, d) 165 | v = self.v(x).view(b, s, n, d) 166 | return q, k, v 167 | 168 | q, k, v = qkv_fn(x) 169 | q = rope_apply(q, grid_sizes, freqs) 170 | k = rope_apply(k, grid_sizes, freqs) 171 | 172 | # TODO: We should use unpaded q,k,v for attention. 173 | # k_lens = seq_lens // get_sequence_parallel_world_size() 174 | # if k_lens is not None: 175 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) 176 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) 177 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) 178 | 179 | x = xFuserLongContextAttention()( 180 | None, 181 | query=half(q), 182 | key=half(k), 183 | value=half(v), 184 | window_size=self.window_size) 185 | 186 | # TODO: padding after attention. 187 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) 188 | 189 | # output 190 | x = x.flatten(2) 191 | x = self.o(x) 192 | return x -------------------------------------------------------------------------------- /phantom_wan/image2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, 25 | get_sampling_sigmas, retrieve_timesteps) 26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 27 | 28 | 29 | class WanI2V: 30 | 31 | def __init__( 32 | self, 33 | config, 34 | checkpoint_dir, 35 | device_id=0, 36 | rank=0, 37 | t5_fsdp=False, 38 | dit_fsdp=False, 39 | use_usp=False, 40 | t5_cpu=False, 41 | # init_on_cpu=True, 42 | ): 43 | r""" 44 | Initializes the image-to-video generation model components. 45 | 46 | Args: 47 | config (EasyDict): 48 | Object containing model parameters initialized from config.py 49 | checkpoint_dir (`str`): 50 | Path to directory containing model checkpoints 51 | device_id (`int`, *optional*, defaults to 0): 52 | Id of target GPU device 53 | rank (`int`, *optional*, defaults to 0): 54 | Process rank for distributed training 55 | t5_fsdp (`bool`, *optional*, defaults to False): 56 | Enable FSDP sharding for T5 model 57 | dit_fsdp (`bool`, *optional*, defaults to False): 58 | Enable FSDP sharding for DiT model 59 | use_usp (`bool`, *optional*, defaults to False): 60 | Enable distribution strategy of USP. 61 | t5_cpu (`bool`, *optional*, defaults to False): 62 | Whether to place T5 model on CPU. Only works without t5_fsdp. 63 | init_on_cpu (`bool`, *optional*, defaults to True): 64 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 65 | """ 66 | self.device = torch.device(f"cuda:{device_id}") 67 | self.config = config 68 | self.rank = rank 69 | self.use_usp = use_usp 70 | self.t5_cpu = t5_cpu 71 | 72 | self.num_train_timesteps = config.num_train_timesteps 73 | self.param_dtype = config.param_dtype 74 | 75 | shard_fn = partial(shard_model, device_id=device_id) 76 | self.text_encoder = T5EncoderModel( 77 | text_len=config.text_len, 78 | dtype=config.t5_dtype, 79 | device=torch.device('cpu'), 80 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 81 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 82 | shard_fn=shard_fn if t5_fsdp else None, 83 | ) 84 | 85 | self.vae_stride = config.vae_stride 86 | self.patch_size = config.patch_size 87 | self.vae = WanVAE( 88 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 89 | device=self.device) 90 | 91 | self.clip = CLIPModel( 92 | dtype=config.clip_dtype, 93 | device=self.device, 94 | checkpoint_path=os.path.join(checkpoint_dir, 95 | config.clip_checkpoint), 96 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 97 | 98 | logging.info(f"Creating WanModel from {checkpoint_dir}") 99 | self.model = WanModel.from_pretrained(checkpoint_dir, ) 100 | self.model.eval().requires_grad_(False) 101 | 102 | # if t5_fsdp or dit_fsdp or use_usp: 103 | # init_on_cpu = False 104 | 105 | if use_usp: 106 | from xfuser.core.distributed import \ 107 | get_sequence_parallel_world_size 108 | 109 | from .distributed.xdit_context_parallel import (usp_attn_forward, 110 | usp_dit_forward) 111 | for block in self.model.blocks: 112 | block.self_attn.forward = types.MethodType( 113 | usp_attn_forward, block.self_attn) 114 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 115 | self.sp_size = get_sequence_parallel_world_size() 116 | else: 117 | self.sp_size = 1 118 | 119 | if dist.is_initialized(): 120 | dist.barrier() 121 | if dit_fsdp: 122 | self.model = shard_fn(self.model) 123 | else: 124 | # if not init_on_cpu: 125 | self.model.to(self.device) 126 | 127 | self.sample_neg_prompt = config.sample_neg_prompt 128 | 129 | def generate(self, 130 | input_prompt, 131 | img, 132 | max_area=720 * 1280, 133 | frame_num=81, 134 | shift=5.0, 135 | sample_solver='unipc', 136 | sampling_steps=40, 137 | guide_scale=5.0, 138 | n_prompt="", 139 | seed=-1, 140 | offload_model=True): 141 | r""" 142 | Generates video frames from input image and text prompt using diffusion process. 143 | 144 | Args: 145 | input_prompt (`str`): 146 | Text prompt for content generation. 147 | img (PIL.Image.Image): 148 | Input image tensor. Shape: [3, H, W] 149 | max_area (`int`, *optional*, defaults to 720*1280): 150 | Maximum pixel area for latent space calculation. Controls video resolution scaling 151 | frame_num (`int`, *optional*, defaults to 81): 152 | How many frames to sample from a video. The number should be 4n+1 153 | shift (`float`, *optional*, defaults to 5.0): 154 | Noise schedule shift parameter. Affects temporal dynamics 155 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 156 | sample_solver (`str`, *optional*, defaults to 'unipc'): 157 | Solver used to sample the video. 158 | sampling_steps (`int`, *optional*, defaults to 40): 159 | Number of diffusion sampling steps. Higher values improve quality but slow generation 160 | guide_scale (`float`, *optional*, defaults 5.0): 161 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 162 | n_prompt (`str`, *optional*, defaults to ""): 163 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 164 | seed (`int`, *optional*, defaults to -1): 165 | Random seed for noise generation. If -1, use random seed 166 | offload_model (`bool`, *optional*, defaults to True): 167 | If True, offloads models to CPU during generation to save VRAM 168 | 169 | Returns: 170 | torch.Tensor: 171 | Generated video frames tensor. Dimensions: (C, N H, W) where: 172 | - C: Color channels (3 for RGB) 173 | - N: Number of frames (81) 174 | - H: Frame height (from max_area) 175 | - W: Frame width from max_area) 176 | """ 177 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) 178 | 179 | F = frame_num 180 | h, w = img.shape[1:] 181 | aspect_ratio = h / w 182 | lat_h = round( 183 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 184 | self.patch_size[1] * self.patch_size[1]) 185 | lat_w = round( 186 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 187 | self.patch_size[2] * self.patch_size[2]) 188 | h = lat_h * self.vae_stride[1] 189 | w = lat_w * self.vae_stride[2] 190 | 191 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 192 | self.patch_size[1] * self.patch_size[2]) 193 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 194 | 195 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 196 | seed_g = torch.Generator(device=self.device) 197 | seed_g.manual_seed(seed) 198 | noise = torch.randn( 199 | 16, 200 | 21, 201 | lat_h, 202 | lat_w, 203 | dtype=torch.float32, 204 | generator=seed_g, 205 | device=self.device) 206 | 207 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 208 | msk[:, 1:] = 0 209 | msk = torch.concat([ 210 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 211 | ],dim=1) 212 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 213 | msk = msk.transpose(1, 2)[0] 214 | 215 | if n_prompt == "": 216 | n_prompt = self.sample_neg_prompt 217 | 218 | # preprocess 219 | if not self.t5_cpu: 220 | self.text_encoder.model.to(self.device) 221 | context = self.text_encoder([input_prompt], self.device) 222 | context_null = self.text_encoder([n_prompt], self.device) 223 | if offload_model: 224 | self.text_encoder.model.cpu() 225 | else: 226 | context = self.text_encoder([input_prompt], torch.device('cpu')) 227 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 228 | context = [t.to(self.device) for t in context] 229 | context_null = [t.to(self.device) for t in context_null] 230 | 231 | self.clip.model.to(self.device) 232 | clip_context = self.clip.visual([img[:, None, :, :]]) 233 | if offload_model: 234 | self.clip.model.cpu() 235 | 236 | y = self.vae.encode([ 237 | torch.concat([ 238 | torch.nn.functional.interpolate( 239 | img[None].cpu(), size=(h, w), mode='bicubic').transpose( 240 | 0, 1), 241 | torch.zeros(3, 80, h, w) 242 | ], 243 | dim=1).to(self.device) 244 | ])[0] 245 | y = torch.concat([msk, y]) 246 | 247 | @contextmanager 248 | def noop_no_sync(): 249 | yield 250 | 251 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 252 | 253 | # evaluation mode 254 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 255 | 256 | if sample_solver == 'unipc': 257 | sample_scheduler = FlowUniPCMultistepScheduler( 258 | num_train_timesteps=self.num_train_timesteps, 259 | shift=1, 260 | use_dynamic_shifting=False) 261 | sample_scheduler.set_timesteps( 262 | sampling_steps, device=self.device, shift=shift) 263 | timesteps = sample_scheduler.timesteps 264 | elif sample_solver == 'dpm++': 265 | sample_scheduler = FlowDPMSolverMultistepScheduler( 266 | num_train_timesteps=self.num_train_timesteps, 267 | shift=1, 268 | use_dynamic_shifting=False) 269 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 270 | timesteps, _ = retrieve_timesteps( 271 | sample_scheduler, 272 | device=self.device, 273 | sigmas=sampling_sigmas) 274 | else: 275 | raise NotImplementedError("Unsupported solver.") 276 | 277 | # sample videos 278 | latent = noise 279 | 280 | arg_c = { 281 | 'context': [context[0]], 282 | 'clip_fea': clip_context, 283 | 'seq_len': max_seq_len, 284 | 'y': [y], 285 | } 286 | 287 | arg_null = { 288 | 'context': context_null, 289 | 'clip_fea': clip_context, 290 | 'seq_len': max_seq_len, 291 | 'y': [y], 292 | } 293 | 294 | if offload_model: 295 | torch.cuda.empty_cache() 296 | 297 | self.model.to(self.device) 298 | for _, t in enumerate(tqdm(timesteps)): 299 | latent_model_input = [latent.to(self.device)] 300 | timestep = [t] 301 | 302 | timestep = torch.stack(timestep).to(self.device) 303 | 304 | noise_pred_cond = self.model( 305 | latent_model_input, t=timestep, **arg_c)[0].to( 306 | torch.device('cpu') if offload_model else self.device) 307 | if offload_model: 308 | torch.cuda.empty_cache() 309 | noise_pred_uncond = self.model( 310 | latent_model_input, t=timestep, **arg_null)[0].to( 311 | torch.device('cpu') if offload_model else self.device) 312 | if offload_model: 313 | torch.cuda.empty_cache() 314 | noise_pred = noise_pred_uncond + guide_scale * ( 315 | noise_pred_cond - noise_pred_uncond) 316 | 317 | latent = latent.to( 318 | torch.device('cpu') if offload_model else self.device) 319 | 320 | temp_x0 = sample_scheduler.step( 321 | noise_pred.unsqueeze(0), 322 | t, 323 | latent.unsqueeze(0), 324 | return_dict=False, 325 | generator=seed_g)[0] 326 | latent = temp_x0.squeeze(0) 327 | 328 | x0 = [latent.to(self.device)] 329 | del latent_model_input, timestep 330 | 331 | if offload_model: 332 | self.model.cpu() 333 | torch.cuda.empty_cache() 334 | 335 | if self.rank == 0: 336 | videos = self.vae.decode(x0) 337 | 338 | del noise, latent 339 | del sample_scheduler 340 | if offload_model: 341 | gc.collect() 342 | torch.cuda.synchronize() 343 | if dist.is_initialized(): 344 | dist.barrier() 345 | 346 | return videos[0] if self.rank == 0 else None 347 | -------------------------------------------------------------------------------- /phantom_wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import flash_attention 2 | from .model import WanModel 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 4 | from .tokenizers import HuggingfaceTokenizer 5 | from .vae import WanVAE 6 | 7 | __all__ = [ 8 | 'WanVAE', 9 | 'WanModel', 10 | 'T5Model', 11 | 'T5Encoder', 12 | 'T5Decoder', 13 | 'T5EncoderModel', 14 | 'HuggingfaceTokenizer', 15 | 'flash_attention', 16 | ] 17 | -------------------------------------------------------------------------------- /phantom_wan/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | 4 | try: 5 | import flash_attn_interface 6 | FLASH_ATTN_3_AVAILABLE = True 7 | except ModuleNotFoundError: 8 | FLASH_ATTN_3_AVAILABLE = False 9 | 10 | try: 11 | import flash_attn 12 | FLASH_ATTN_2_AVAILABLE = True 13 | except ModuleNotFoundError: 14 | FLASH_ATTN_2_AVAILABLE = False 15 | 16 | import warnings 17 | 18 | __all__ = [ 19 | 'flash_attention', 20 | 'attention', 21 | ] 22 | 23 | 24 | def flash_attention( 25 | q, 26 | k, 27 | v, 28 | q_lens=None, 29 | k_lens=None, 30 | dropout_p=0., 31 | softmax_scale=None, 32 | q_scale=None, 33 | causal=False, 34 | window_size=(-1, -1), 35 | deterministic=False, 36 | dtype=torch.bfloat16, 37 | version=None, 38 | ): 39 | """ 40 | q: [B, Lq, Nq, C1]. 41 | k: [B, Lk, Nk, C1]. 42 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 43 | q_lens: [B]. 44 | k_lens: [B]. 45 | dropout_p: float. Dropout probability. 46 | softmax_scale: float. The scaling of QK^T before applying softmax. 47 | causal: bool. Whether to apply causal attention mask. 48 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 49 | deterministic: bool. If True, slightly slower and uses more memory. 50 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 51 | """ 52 | half_dtypes = (torch.float16, torch.bfloat16) 53 | assert dtype in half_dtypes 54 | assert q.device.type == 'cuda' and q.size(-1) <= 256 55 | 56 | # params 57 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 58 | 59 | def half(x): 60 | return x if x.dtype in half_dtypes else x.to(dtype) 61 | 62 | # preprocess query 63 | if q_lens is None: 64 | q = half(q.flatten(0, 1)) 65 | q_lens = torch.tensor( 66 | [lq] * b, dtype=torch.int32).to( 67 | device=q.device, non_blocking=True) 68 | else: 69 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) 70 | 71 | # preprocess key, value 72 | if k_lens is None: 73 | k = half(k.flatten(0, 1)) 74 | v = half(v.flatten(0, 1)) 75 | k_lens = torch.tensor( 76 | [lk] * b, dtype=torch.int32).to( 77 | device=k.device, non_blocking=True) 78 | else: 79 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) 80 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) 81 | 82 | q = q.to(v.dtype) 83 | k = k.to(v.dtype) 84 | 85 | if q_scale is not None: 86 | q = q * q_scale 87 | 88 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 89 | warnings.warn( 90 | 'Flash attention 3 is not available, use flash attention 2 instead.' 91 | ) 92 | 93 | # apply attention 94 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 95 | # Note: dropout_p, window_size are not supported in FA3 now. 96 | x = flash_attn_interface.flash_attn_varlen_func( 97 | q=q, 98 | k=k, 99 | v=v, 100 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 101 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 102 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 103 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 104 | seqused_q=None, 105 | seqused_k=None, 106 | max_seqlen_q=lq, 107 | max_seqlen_k=lk, 108 | softmax_scale=softmax_scale, 109 | causal=causal, 110 | deterministic=deterministic)[0].unflatten(0, (b, lq)) 111 | else: 112 | assert FLASH_ATTN_2_AVAILABLE 113 | x = flash_attn.flash_attn_varlen_func( 114 | q=q, 115 | k=k, 116 | v=v, 117 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 118 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 119 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 120 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 121 | max_seqlen_q=lq, 122 | max_seqlen_k=lk, 123 | dropout_p=dropout_p, 124 | softmax_scale=softmax_scale, 125 | causal=causal, 126 | window_size=window_size, 127 | deterministic=deterministic).unflatten(0, (b, lq)) 128 | 129 | # output 130 | return x.type(out_dtype) 131 | 132 | 133 | def attention( 134 | q, 135 | k, 136 | v, 137 | q_lens=None, 138 | k_lens=None, 139 | dropout_p=0., 140 | softmax_scale=None, 141 | q_scale=None, 142 | causal=False, 143 | window_size=(-1, -1), 144 | deterministic=False, 145 | dtype=torch.bfloat16, 146 | fa_version=None, 147 | ): 148 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: 149 | return flash_attention( 150 | q=q, 151 | k=k, 152 | v=v, 153 | q_lens=q_lens, 154 | k_lens=k_lens, 155 | dropout_p=dropout_p, 156 | softmax_scale=softmax_scale, 157 | q_scale=q_scale, 158 | causal=causal, 159 | window_size=window_size, 160 | deterministic=deterministic, 161 | dtype=dtype, 162 | version=fa_version, 163 | ) 164 | else: 165 | if q_lens is not None or k_lens is not None: 166 | warnings.warn( 167 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' 168 | ) 169 | attn_mask = None 170 | 171 | q = q.transpose(1, 2).to(dtype) 172 | k = k.transpose(1, 2).to(dtype) 173 | v = v.transpose(1, 2).to(dtype) 174 | 175 | out = torch.nn.functional.scaled_dot_product_attention( 176 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 177 | 178 | out = out.transpose(1, 2).contiguous() 179 | return out 180 | -------------------------------------------------------------------------------- /phantom_wan/modules/clip.py: -------------------------------------------------------------------------------- 1 | # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip'' 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision.transforms as T 10 | 11 | from .attention import flash_attention 12 | from .tokenizers import HuggingfaceTokenizer 13 | from .xlm_roberta import XLMRoberta 14 | 15 | __all__ = [ 16 | 'XLMRobertaCLIP', 17 | 'clip_xlm_roberta_vit_h_14', 18 | 'CLIPModel', 19 | ] 20 | 21 | 22 | def pos_interpolate(pos, seq_len): 23 | if pos.size(1) == seq_len: 24 | return pos 25 | else: 26 | src_grid = int(math.sqrt(pos.size(1))) 27 | tar_grid = int(math.sqrt(seq_len)) 28 | n = pos.size(1) - src_grid * src_grid 29 | return torch.cat([ 30 | pos[:, :n], 31 | F.interpolate( 32 | pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( 33 | 0, 3, 1, 2), 34 | size=(tar_grid, tar_grid), 35 | mode='bicubic', 36 | align_corners=False).flatten(2).transpose(1, 2) 37 | ], 38 | dim=1) 39 | 40 | 41 | class QuickGELU(nn.Module): 42 | 43 | def forward(self, x): 44 | return x * torch.sigmoid(1.702 * x) 45 | 46 | 47 | class LayerNorm(nn.LayerNorm): 48 | 49 | def forward(self, x): 50 | return super().forward(x.float()).type_as(x) 51 | 52 | 53 | class SelfAttention(nn.Module): 54 | 55 | def __init__(self, 56 | dim, 57 | num_heads, 58 | causal=False, 59 | attn_dropout=0.0, 60 | proj_dropout=0.0): 61 | assert dim % num_heads == 0 62 | super().__init__() 63 | self.dim = dim 64 | self.num_heads = num_heads 65 | self.head_dim = dim // num_heads 66 | self.causal = causal 67 | self.attn_dropout = attn_dropout 68 | self.proj_dropout = proj_dropout 69 | 70 | # layers 71 | self.to_qkv = nn.Linear(dim, dim * 3) 72 | self.proj = nn.Linear(dim, dim) 73 | 74 | def forward(self, x): 75 | """ 76 | x: [B, L, C]. 77 | """ 78 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 79 | 80 | # compute query, key, value 81 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) 82 | 83 | # compute attention 84 | p = self.attn_dropout if self.training else 0.0 85 | x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2) 86 | x = x.reshape(b, s, c) 87 | 88 | # output 89 | x = self.proj(x) 90 | x = F.dropout(x, self.proj_dropout, self.training) 91 | return x 92 | 93 | 94 | class SwiGLU(nn.Module): 95 | 96 | def __init__(self, dim, mid_dim): 97 | super().__init__() 98 | self.dim = dim 99 | self.mid_dim = mid_dim 100 | 101 | # layers 102 | self.fc1 = nn.Linear(dim, mid_dim) 103 | self.fc2 = nn.Linear(dim, mid_dim) 104 | self.fc3 = nn.Linear(mid_dim, dim) 105 | 106 | def forward(self, x): 107 | x = F.silu(self.fc1(x)) * self.fc2(x) 108 | x = self.fc3(x) 109 | return x 110 | 111 | 112 | class AttentionBlock(nn.Module): 113 | 114 | def __init__(self, 115 | dim, 116 | mlp_ratio, 117 | num_heads, 118 | post_norm=False, 119 | causal=False, 120 | activation='quick_gelu', 121 | attn_dropout=0.0, 122 | proj_dropout=0.0, 123 | norm_eps=1e-5): 124 | assert activation in ['quick_gelu', 'gelu', 'swi_glu'] 125 | super().__init__() 126 | self.dim = dim 127 | self.mlp_ratio = mlp_ratio 128 | self.num_heads = num_heads 129 | self.post_norm = post_norm 130 | self.causal = causal 131 | self.norm_eps = norm_eps 132 | 133 | # layers 134 | self.norm1 = LayerNorm(dim, eps=norm_eps) 135 | self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, 136 | proj_dropout) 137 | self.norm2 = LayerNorm(dim, eps=norm_eps) 138 | if activation == 'swi_glu': 139 | self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) 140 | else: 141 | self.mlp = nn.Sequential( 142 | nn.Linear(dim, int(dim * mlp_ratio)), 143 | QuickGELU() if activation == 'quick_gelu' else nn.GELU(), 144 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) 145 | 146 | def forward(self, x): 147 | if self.post_norm: 148 | x = x + self.norm1(self.attn(x)) 149 | x = x + self.norm2(self.mlp(x)) 150 | else: 151 | x = x + self.attn(self.norm1(x)) 152 | x = x + self.mlp(self.norm2(x)) 153 | return x 154 | 155 | 156 | class AttentionPool(nn.Module): 157 | 158 | def __init__(self, 159 | dim, 160 | mlp_ratio, 161 | num_heads, 162 | activation='gelu', 163 | proj_dropout=0.0, 164 | norm_eps=1e-5): 165 | assert dim % num_heads == 0 166 | super().__init__() 167 | self.dim = dim 168 | self.mlp_ratio = mlp_ratio 169 | self.num_heads = num_heads 170 | self.head_dim = dim // num_heads 171 | self.proj_dropout = proj_dropout 172 | self.norm_eps = norm_eps 173 | 174 | # layers 175 | gain = 1.0 / math.sqrt(dim) 176 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 177 | self.to_q = nn.Linear(dim, dim) 178 | self.to_kv = nn.Linear(dim, dim * 2) 179 | self.proj = nn.Linear(dim, dim) 180 | self.norm = LayerNorm(dim, eps=norm_eps) 181 | self.mlp = nn.Sequential( 182 | nn.Linear(dim, int(dim * mlp_ratio)), 183 | QuickGELU() if activation == 'quick_gelu' else nn.GELU(), 184 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout)) 185 | 186 | def forward(self, x): 187 | """ 188 | x: [B, L, C]. 189 | """ 190 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 191 | 192 | # compute query, key, value 193 | q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1) 194 | k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) 195 | 196 | # compute attention 197 | x = flash_attention(q, k, v, version=2) 198 | x = x.reshape(b, 1, c) 199 | 200 | # output 201 | x = self.proj(x) 202 | x = F.dropout(x, self.proj_dropout, self.training) 203 | 204 | # mlp 205 | x = x + self.mlp(self.norm(x)) 206 | return x[:, 0] 207 | 208 | 209 | class VisionTransformer(nn.Module): 210 | 211 | def __init__(self, 212 | image_size=224, 213 | patch_size=16, 214 | dim=768, 215 | mlp_ratio=4, 216 | out_dim=512, 217 | num_heads=12, 218 | num_layers=12, 219 | pool_type='token', 220 | pre_norm=True, 221 | post_norm=False, 222 | activation='quick_gelu', 223 | attn_dropout=0.0, 224 | proj_dropout=0.0, 225 | embedding_dropout=0.0, 226 | norm_eps=1e-5): 227 | if image_size % patch_size != 0: 228 | print( 229 | '[WARNING] image_size is not divisible by patch_size', 230 | flush=True) 231 | assert pool_type in ('token', 'token_fc', 'attn_pool') 232 | out_dim = out_dim or dim 233 | super().__init__() 234 | self.image_size = image_size 235 | self.patch_size = patch_size 236 | self.num_patches = (image_size // patch_size)**2 237 | self.dim = dim 238 | self.mlp_ratio = mlp_ratio 239 | self.out_dim = out_dim 240 | self.num_heads = num_heads 241 | self.num_layers = num_layers 242 | self.pool_type = pool_type 243 | self.post_norm = post_norm 244 | self.norm_eps = norm_eps 245 | 246 | # embeddings 247 | gain = 1.0 / math.sqrt(dim) 248 | self.patch_embedding = nn.Conv2d( 249 | 3, 250 | dim, 251 | kernel_size=patch_size, 252 | stride=patch_size, 253 | bias=not pre_norm) 254 | if pool_type in ('token', 'token_fc'): 255 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) 256 | self.pos_embedding = nn.Parameter(gain * torch.randn( 257 | 1, self.num_patches + 258 | (1 if pool_type in ('token', 'token_fc') else 0), dim)) 259 | self.dropout = nn.Dropout(embedding_dropout) 260 | 261 | # transformer 262 | self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None 263 | self.transformer = nn.Sequential(*[ 264 | AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, 265 | activation, attn_dropout, proj_dropout, norm_eps) 266 | for _ in range(num_layers) 267 | ]) 268 | self.post_norm = LayerNorm(dim, eps=norm_eps) 269 | 270 | # head 271 | if pool_type == 'token': 272 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) 273 | elif pool_type == 'token_fc': 274 | self.head = nn.Linear(dim, out_dim) 275 | elif pool_type == 'attn_pool': 276 | self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, 277 | proj_dropout, norm_eps) 278 | 279 | def forward(self, x, interpolation=False, use_31_block=False): 280 | b = x.size(0) 281 | 282 | # embeddings 283 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) 284 | if self.pool_type in ('token', 'token_fc'): 285 | x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) 286 | if interpolation: 287 | e = pos_interpolate(self.pos_embedding, x.size(1)) 288 | else: 289 | e = self.pos_embedding 290 | x = self.dropout(x + e) 291 | if self.pre_norm is not None: 292 | x = self.pre_norm(x) 293 | 294 | # transformer 295 | if use_31_block: 296 | x = self.transformer[:-1](x) 297 | return x 298 | else: 299 | x = self.transformer(x) 300 | return x 301 | 302 | 303 | class XLMRobertaWithHead(XLMRoberta): 304 | 305 | def __init__(self, **kwargs): 306 | self.out_dim = kwargs.pop('out_dim') 307 | super().__init__(**kwargs) 308 | 309 | # head 310 | mid_dim = (self.dim + self.out_dim) // 2 311 | self.head = nn.Sequential( 312 | nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), 313 | nn.Linear(mid_dim, self.out_dim, bias=False)) 314 | 315 | def forward(self, ids): 316 | # xlm-roberta 317 | x = super().forward(ids) 318 | 319 | # average pooling 320 | mask = ids.ne(self.pad_id).unsqueeze(-1).to(x) 321 | x = (x * mask).sum(dim=1) / mask.sum(dim=1) 322 | 323 | # head 324 | x = self.head(x) 325 | return x 326 | 327 | 328 | class XLMRobertaCLIP(nn.Module): 329 | 330 | def __init__(self, 331 | embed_dim=1024, 332 | image_size=224, 333 | patch_size=14, 334 | vision_dim=1280, 335 | vision_mlp_ratio=4, 336 | vision_heads=16, 337 | vision_layers=32, 338 | vision_pool='token', 339 | vision_pre_norm=True, 340 | vision_post_norm=False, 341 | activation='gelu', 342 | vocab_size=250002, 343 | max_text_len=514, 344 | type_size=1, 345 | pad_id=1, 346 | text_dim=1024, 347 | text_heads=16, 348 | text_layers=24, 349 | text_post_norm=True, 350 | text_dropout=0.1, 351 | attn_dropout=0.0, 352 | proj_dropout=0.0, 353 | embedding_dropout=0.0, 354 | norm_eps=1e-5): 355 | super().__init__() 356 | self.embed_dim = embed_dim 357 | self.image_size = image_size 358 | self.patch_size = patch_size 359 | self.vision_dim = vision_dim 360 | self.vision_mlp_ratio = vision_mlp_ratio 361 | self.vision_heads = vision_heads 362 | self.vision_layers = vision_layers 363 | self.vision_pre_norm = vision_pre_norm 364 | self.vision_post_norm = vision_post_norm 365 | self.activation = activation 366 | self.vocab_size = vocab_size 367 | self.max_text_len = max_text_len 368 | self.type_size = type_size 369 | self.pad_id = pad_id 370 | self.text_dim = text_dim 371 | self.text_heads = text_heads 372 | self.text_layers = text_layers 373 | self.text_post_norm = text_post_norm 374 | self.norm_eps = norm_eps 375 | 376 | # models 377 | self.visual = VisionTransformer( 378 | image_size=image_size, 379 | patch_size=patch_size, 380 | dim=vision_dim, 381 | mlp_ratio=vision_mlp_ratio, 382 | out_dim=embed_dim, 383 | num_heads=vision_heads, 384 | num_layers=vision_layers, 385 | pool_type=vision_pool, 386 | pre_norm=vision_pre_norm, 387 | post_norm=vision_post_norm, 388 | activation=activation, 389 | attn_dropout=attn_dropout, 390 | proj_dropout=proj_dropout, 391 | embedding_dropout=embedding_dropout, 392 | norm_eps=norm_eps) 393 | self.textual = XLMRobertaWithHead( 394 | vocab_size=vocab_size, 395 | max_seq_len=max_text_len, 396 | type_size=type_size, 397 | pad_id=pad_id, 398 | dim=text_dim, 399 | out_dim=embed_dim, 400 | num_heads=text_heads, 401 | num_layers=text_layers, 402 | post_norm=text_post_norm, 403 | dropout=text_dropout) 404 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) 405 | 406 | def forward(self, imgs, txt_ids): 407 | """ 408 | imgs: [B, 3, H, W] of torch.float32. 409 | - mean: [0.48145466, 0.4578275, 0.40821073] 410 | - std: [0.26862954, 0.26130258, 0.27577711] 411 | txt_ids: [B, L] of torch.long. 412 | Encoded by data.CLIPTokenizer. 413 | """ 414 | xi = self.visual(imgs) 415 | xt = self.textual(txt_ids) 416 | return xi, xt 417 | 418 | def param_groups(self): 419 | groups = [{ 420 | 'params': [ 421 | p for n, p in self.named_parameters() 422 | if 'norm' in n or n.endswith('bias') 423 | ], 424 | 'weight_decay': 0.0 425 | }, { 426 | 'params': [ 427 | p for n, p in self.named_parameters() 428 | if not ('norm' in n or n.endswith('bias')) 429 | ] 430 | }] 431 | return groups 432 | 433 | 434 | def _clip(pretrained=False, 435 | pretrained_name=None, 436 | model_cls=XLMRobertaCLIP, 437 | return_transforms=False, 438 | return_tokenizer=False, 439 | tokenizer_padding='eos', 440 | dtype=torch.float32, 441 | device='cpu', 442 | **kwargs): 443 | # init a model on device 444 | with torch.device(device): 445 | model = model_cls(**kwargs) 446 | 447 | # set device 448 | model = model.to(dtype=dtype, device=device) 449 | output = (model,) 450 | 451 | # init transforms 452 | if return_transforms: 453 | # mean and std 454 | if 'siglip' in pretrained_name.lower(): 455 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 456 | else: 457 | mean = [0.48145466, 0.4578275, 0.40821073] 458 | std = [0.26862954, 0.26130258, 0.27577711] 459 | 460 | # transforms 461 | transforms = T.Compose([ 462 | T.Resize((model.image_size, model.image_size), 463 | interpolation=T.InterpolationMode.BICUBIC), 464 | T.ToTensor(), 465 | T.Normalize(mean=mean, std=std) 466 | ]) 467 | output += (transforms,) 468 | return output[0] if len(output) == 1 else output 469 | 470 | 471 | def clip_xlm_roberta_vit_h_14( 472 | pretrained=False, 473 | pretrained_name='open-clip-xlm-roberta-large-vit-huge-14', 474 | **kwargs): 475 | cfg = dict( 476 | embed_dim=1024, 477 | image_size=224, 478 | patch_size=14, 479 | vision_dim=1280, 480 | vision_mlp_ratio=4, 481 | vision_heads=16, 482 | vision_layers=32, 483 | vision_pool='token', 484 | activation='gelu', 485 | vocab_size=250002, 486 | max_text_len=514, 487 | type_size=1, 488 | pad_id=1, 489 | text_dim=1024, 490 | text_heads=16, 491 | text_layers=24, 492 | text_post_norm=True, 493 | text_dropout=0.1, 494 | attn_dropout=0.0, 495 | proj_dropout=0.0, 496 | embedding_dropout=0.0) 497 | cfg.update(**kwargs) 498 | return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) 499 | 500 | 501 | class CLIPModel: 502 | 503 | def __init__(self, dtype, device, checkpoint_path, tokenizer_path): 504 | self.dtype = dtype 505 | self.device = device 506 | self.checkpoint_path = checkpoint_path 507 | self.tokenizer_path = tokenizer_path 508 | 509 | # init model 510 | self.model, self.transforms = clip_xlm_roberta_vit_h_14( 511 | pretrained=False, 512 | return_transforms=True, 513 | return_tokenizer=False, 514 | dtype=dtype, 515 | device=device) 516 | self.model = self.model.eval().requires_grad_(False) 517 | logging.info(f'loading {checkpoint_path}') 518 | self.model.load_state_dict( 519 | torch.load(checkpoint_path, map_location='cpu')) 520 | 521 | # init tokenizer 522 | self.tokenizer = HuggingfaceTokenizer( 523 | name=tokenizer_path, 524 | seq_len=self.model.max_text_len - 2, 525 | clean='whitespace') 526 | 527 | def visual(self, videos): 528 | # preprocess 529 | size = (self.model.image_size,) * 2 530 | videos = torch.cat([ 531 | F.interpolate( 532 | u.transpose(0, 1), 533 | size=size, 534 | mode='bicubic', 535 | align_corners=False) for u in videos 536 | ]) 537 | videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) 538 | 539 | # forward 540 | with torch.cuda.amp.autocast(dtype=self.dtype): 541 | out = self.model.visual(videos, use_31_block=True) 542 | return out 543 | -------------------------------------------------------------------------------- /phantom_wan/modules/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.cuda.amp as amp 6 | import torch.nn as nn 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.models.modeling_utils import ModelMixin 9 | 10 | from .attention import flash_attention 11 | 12 | __all__ = ['WanModel'] 13 | 14 | 15 | def sinusoidal_embedding_1d(dim, position): 16 | # preprocess 17 | assert dim % 2 == 0 18 | half = dim // 2 19 | position = position.type(torch.float64) 20 | 21 | # calculation 22 | sinusoid = torch.outer( 23 | position, torch.pow(10000, -torch.arange(half).to(position).div(half))) 24 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) 25 | return x 26 | 27 | 28 | @amp.autocast(enabled=False) 29 | def rope_params(max_seq_len, dim, theta=10000): 30 | assert dim % 2 == 0 31 | freqs = torch.outer( 32 | torch.arange(max_seq_len), 33 | 1.0 / torch.pow(theta, 34 | torch.arange(0, dim, 2).div(dim))) 35 | freqs = torch.polar(torch.ones_like(freqs), freqs) 36 | return freqs 37 | 38 | 39 | @amp.autocast(enabled=False) 40 | def rope_apply(x, grid_sizes, freqs): 41 | n, c = x.size(2), x.size(3) // 2 42 | 43 | # split freqs 44 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 45 | 46 | # loop over samples 47 | output = [] 48 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 49 | seq_len = f * h * w 50 | 51 | # precompute multipliers 52 | x_i = torch.view_as_complex(x[i, :seq_len].reshape( 53 | seq_len, n, -1, 2)) 54 | freqs_i = torch.cat([ 55 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 56 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 57 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 58 | ], 59 | dim=-1).reshape(seq_len, 1, -1) 60 | 61 | # apply rotary embedding 62 | x_i = torch.view_as_real(x_i * freqs_i).flatten(2) 63 | x_i = torch.cat([x_i, x[i, seq_len:]]) 64 | 65 | # append to collection 66 | output.append(x_i) 67 | return torch.stack(output).float() 68 | 69 | 70 | class WanRMSNorm(nn.Module): 71 | 72 | def __init__(self, dim, eps=1e-5): 73 | super().__init__() 74 | self.dim = dim 75 | self.eps = eps 76 | self.weight = nn.Parameter(torch.ones(dim)) 77 | 78 | def forward(self, x): 79 | r""" 80 | Args: 81 | x(Tensor): Shape [B, L, C] 82 | """ 83 | return self._norm(x.float()).type_as(x) * self.weight 84 | 85 | def _norm(self, x): 86 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) 87 | 88 | 89 | class WanLayerNorm(nn.LayerNorm): 90 | 91 | def __init__(self, dim, eps=1e-6, elementwise_affine=False): 92 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) 93 | 94 | def forward(self, x): 95 | r""" 96 | Args: 97 | x(Tensor): Shape [B, L, C] 98 | """ 99 | return super().forward(x.float()).type_as(x) 100 | 101 | 102 | class WanSelfAttention(nn.Module): 103 | 104 | def __init__(self, 105 | dim, 106 | num_heads, 107 | window_size=(-1, -1), 108 | qk_norm=True, 109 | eps=1e-6): 110 | assert dim % num_heads == 0 111 | super().__init__() 112 | self.dim = dim 113 | self.num_heads = num_heads 114 | self.head_dim = dim // num_heads 115 | self.window_size = window_size 116 | self.qk_norm = qk_norm 117 | self.eps = eps 118 | 119 | # layers 120 | self.q = nn.Linear(dim, dim) 121 | self.k = nn.Linear(dim, dim) 122 | self.v = nn.Linear(dim, dim) 123 | self.o = nn.Linear(dim, dim) 124 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 125 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 126 | 127 | def forward(self, x, seq_lens, grid_sizes, freqs): 128 | r""" 129 | Args: 130 | x(Tensor): Shape [B, L, num_heads, C / num_heads] 131 | seq_lens(Tensor): Shape [B] 132 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 133 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 134 | """ 135 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 136 | 137 | # query, key, value function 138 | def qkv_fn(x): 139 | q = self.norm_q(self.q(x)).view(b, s, n, d) 140 | k = self.norm_k(self.k(x)).view(b, s, n, d) 141 | v = self.v(x).view(b, s, n, d) 142 | return q, k, v 143 | 144 | q, k, v = qkv_fn(x) 145 | 146 | x = flash_attention( 147 | q=rope_apply(q, grid_sizes, freqs), 148 | k=rope_apply(k, grid_sizes, freqs), 149 | v=v, 150 | k_lens=seq_lens, 151 | window_size=self.window_size) 152 | 153 | # output 154 | x = x.flatten(2) 155 | x = self.o(x) 156 | return x 157 | 158 | 159 | class WanT2VCrossAttention(WanSelfAttention): 160 | 161 | def forward(self, x, context, context_lens): 162 | r""" 163 | Args: 164 | x(Tensor): Shape [B, L1, C] 165 | context(Tensor): Shape [B, L2, C] 166 | context_lens(Tensor): Shape [B] 167 | """ 168 | b, n, d = x.size(0), self.num_heads, self.head_dim 169 | 170 | # compute query, key, value 171 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 172 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 173 | v = self.v(context).view(b, -1, n, d) 174 | 175 | # compute attention 176 | x = flash_attention(q, k, v, k_lens=context_lens) 177 | 178 | # output 179 | x = x.flatten(2) 180 | x = self.o(x) 181 | return x 182 | 183 | 184 | class WanI2VCrossAttention(WanSelfAttention): 185 | 186 | def __init__(self, 187 | dim, 188 | num_heads, 189 | window_size=(-1, -1), 190 | qk_norm=True, 191 | eps=1e-6): 192 | super().__init__(dim, num_heads, window_size, qk_norm, eps) 193 | 194 | self.k_img = nn.Linear(dim, dim) 195 | self.v_img = nn.Linear(dim, dim) 196 | # self.alpha = nn.Parameter(torch.zeros((1, ))) 197 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() 198 | 199 | def forward(self, x, context, context_lens): 200 | r""" 201 | Args: 202 | x(Tensor): Shape [B, L1, C] 203 | context(Tensor): Shape [B, L2, C] 204 | context_lens(Tensor): Shape [B] 205 | """ 206 | context_img = context[:, :257] 207 | context = context[:, 257:] 208 | b, n, d = x.size(0), self.num_heads, self.head_dim 209 | 210 | # compute query, key, value 211 | q = self.norm_q(self.q(x)).view(b, -1, n, d) 212 | k = self.norm_k(self.k(context)).view(b, -1, n, d) 213 | v = self.v(context).view(b, -1, n, d) 214 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) 215 | v_img = self.v_img(context_img).view(b, -1, n, d) 216 | img_x = flash_attention(q, k_img, v_img, k_lens=None) 217 | # compute attention 218 | x = flash_attention(q, k, v, k_lens=context_lens) 219 | 220 | # output 221 | x = x.flatten(2) 222 | img_x = img_x.flatten(2) 223 | x = x + img_x 224 | x = self.o(x) 225 | return x 226 | 227 | 228 | WAN_CROSSATTENTION_CLASSES = { 229 | 't2v_cross_attn': WanT2VCrossAttention, 230 | 'i2v_cross_attn': WanI2VCrossAttention, 231 | } 232 | 233 | 234 | class WanAttentionBlock(nn.Module): 235 | 236 | def __init__(self, 237 | cross_attn_type, 238 | dim, 239 | ffn_dim, 240 | num_heads, 241 | window_size=(-1, -1), 242 | qk_norm=True, 243 | cross_attn_norm=False, 244 | eps=1e-6): 245 | super().__init__() 246 | self.dim = dim 247 | self.ffn_dim = ffn_dim 248 | self.num_heads = num_heads 249 | self.window_size = window_size 250 | self.qk_norm = qk_norm 251 | self.cross_attn_norm = cross_attn_norm 252 | self.eps = eps 253 | 254 | # layers 255 | self.norm1 = WanLayerNorm(dim, eps) 256 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, 257 | eps) 258 | self.norm3 = WanLayerNorm( 259 | dim, eps, 260 | elementwise_affine=True) if cross_attn_norm else nn.Identity() 261 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, 262 | num_heads, 263 | (-1, -1), 264 | qk_norm, 265 | eps) 266 | self.norm2 = WanLayerNorm(dim, eps) 267 | self.ffn = nn.Sequential( 268 | nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), 269 | nn.Linear(ffn_dim, dim)) 270 | 271 | # modulation 272 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 273 | 274 | def forward( 275 | self, 276 | x, 277 | e, 278 | seq_lens, 279 | grid_sizes, 280 | freqs, 281 | context, 282 | context_lens, 283 | ): 284 | r""" 285 | Args: 286 | x(Tensor): Shape [B, L, C] 287 | e(Tensor): Shape [B, 6, C] 288 | seq_lens(Tensor): Shape [B], length of each sequence in batch 289 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) 290 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] 291 | """ 292 | assert e.dtype == torch.float32 293 | with amp.autocast(dtype=torch.float32): 294 | e = (self.modulation + e).chunk(6, dim=1) 295 | assert e[0].dtype == torch.float32 296 | 297 | # self-attention 298 | y = self.self_attn( 299 | self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, 300 | freqs) 301 | with amp.autocast(dtype=torch.float32): 302 | x = x + y * e[2] 303 | 304 | # cross-attention & ffn function 305 | def cross_attn_ffn(x, context, context_lens, e): 306 | x = x + self.cross_attn(self.norm3(x), context, context_lens) 307 | y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3]) 308 | with amp.autocast(dtype=torch.float32): 309 | x = x + y * e[5] 310 | return x 311 | 312 | x = cross_attn_ffn(x, context, context_lens, e) 313 | return x 314 | 315 | 316 | class Head(nn.Module): 317 | 318 | def __init__(self, dim, out_dim, patch_size, eps=1e-6): 319 | super().__init__() 320 | self.dim = dim 321 | self.out_dim = out_dim 322 | self.patch_size = patch_size 323 | self.eps = eps 324 | 325 | # layers 326 | out_dim = math.prod(patch_size) * out_dim 327 | self.norm = WanLayerNorm(dim, eps) 328 | self.head = nn.Linear(dim, out_dim) 329 | 330 | # modulation 331 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) 332 | 333 | def forward(self, x, e): 334 | r""" 335 | Args: 336 | x(Tensor): Shape [B, L1, C] 337 | e(Tensor): Shape [B, C] 338 | """ 339 | assert e.dtype == torch.float32 340 | with amp.autocast(dtype=torch.float32): 341 | e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) 342 | x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) 343 | return x 344 | 345 | 346 | class MLPProj(torch.nn.Module): 347 | 348 | def __init__(self, in_dim, out_dim): 349 | super().__init__() 350 | 351 | self.proj = torch.nn.Sequential( 352 | torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), 353 | torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), 354 | torch.nn.LayerNorm(out_dim)) 355 | 356 | def forward(self, image_embeds): 357 | clip_extra_context_tokens = self.proj(image_embeds) 358 | return clip_extra_context_tokens 359 | 360 | 361 | class WanModel(ModelMixin, ConfigMixin): 362 | r""" 363 | Wan diffusion backbone supporting both text-to-video and image-to-video. 364 | """ 365 | 366 | ignore_for_config = [ 367 | 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size' 368 | ] 369 | _no_split_modules = ['WanAttentionBlock'] 370 | 371 | @register_to_config 372 | def __init__(self, 373 | model_type='t2v', 374 | patch_size=(1, 2, 2), 375 | text_len=512, 376 | in_dim=16, 377 | dim=5120, 378 | ffn_dim=13824, 379 | freq_dim=256, 380 | text_dim=4096, 381 | out_dim=16, 382 | num_heads=40, 383 | num_layers=40, 384 | window_size=(-1, -1), 385 | qk_norm=True, 386 | cross_attn_norm=True, 387 | eps=1e-6): 388 | r""" 389 | Initialize the diffusion model backbone. 390 | 391 | Args: 392 | model_type (`str`, *optional*, defaults to 't2v'): 393 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) 394 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 395 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) 396 | text_len (`int`, *optional*, defaults to 512): 397 | Fixed length for text embeddings 398 | in_dim (`int`, *optional*, defaults to 16): 399 | Input video channels (C_in) 400 | dim (`int`, *optional*, defaults to 2048): 401 | Hidden dimension of the transformer 402 | ffn_dim (`int`, *optional*, defaults to 8192): 403 | Intermediate dimension in feed-forward network 404 | freq_dim (`int`, *optional*, defaults to 256): 405 | Dimension for sinusoidal time embeddings 406 | text_dim (`int`, *optional*, defaults to 4096): 407 | Input dimension for text embeddings 408 | out_dim (`int`, *optional*, defaults to 16): 409 | Output video channels (C_out) 410 | num_heads (`int`, *optional*, defaults to 16): 411 | Number of attention heads 412 | num_layers (`int`, *optional*, defaults to 32): 413 | Number of transformer blocks 414 | window_size (`tuple`, *optional*, defaults to (-1, -1)): 415 | Window size for local attention (-1 indicates global attention) 416 | qk_norm (`bool`, *optional*, defaults to True): 417 | Enable query/key normalization 418 | cross_attn_norm (`bool`, *optional*, defaults to False): 419 | Enable cross-attention normalization 420 | eps (`float`, *optional*, defaults to 1e-6): 421 | Epsilon value for normalization layers 422 | """ 423 | 424 | super().__init__() 425 | 426 | assert model_type in ['t2v', 'i2v'] 427 | self.model_type = model_type 428 | 429 | self.patch_size = patch_size 430 | self.text_len = text_len 431 | self.in_dim = in_dim 432 | self.dim = dim 433 | self.ffn_dim = ffn_dim 434 | self.freq_dim = freq_dim 435 | self.text_dim = text_dim 436 | self.out_dim = out_dim 437 | self.num_heads = num_heads 438 | self.num_layers = num_layers 439 | self.window_size = window_size 440 | self.qk_norm = qk_norm 441 | self.cross_attn_norm = cross_attn_norm 442 | self.eps = eps 443 | 444 | # embeddings 445 | self.patch_embedding = nn.Conv3d( 446 | in_dim, dim, kernel_size=patch_size, stride=patch_size) 447 | self.text_embedding = nn.Sequential( 448 | nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), 449 | nn.Linear(dim, dim)) 450 | 451 | self.time_embedding = nn.Sequential( 452 | nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) 453 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) 454 | 455 | # blocks 456 | cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' 457 | self.blocks = nn.ModuleList([ 458 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, 459 | window_size, qk_norm, cross_attn_norm, eps) 460 | for _ in range(num_layers) 461 | ]) 462 | 463 | # head 464 | self.head = Head(dim, out_dim, patch_size, eps) 465 | 466 | # buffers (don't use register_buffer otherwise dtype will be changed in to()) 467 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 468 | d = dim // num_heads 469 | self.freqs = torch.cat([ 470 | rope_params(1024, d - 4 * (d // 6)), 471 | rope_params(1024, 2 * (d // 6)), 472 | rope_params(1024, 2 * (d // 6)) 473 | ], 474 | dim=1) 475 | 476 | if model_type == 'i2v': 477 | self.img_emb = MLPProj(1280, dim) 478 | 479 | # initialize weights 480 | self.init_weights() 481 | 482 | 483 | def forward( 484 | self, 485 | x, 486 | t, 487 | context, 488 | seq_len, 489 | clip_fea=None, 490 | y=None, 491 | ): 492 | r""" 493 | Forward pass through the diffusion model 494 | 495 | Args: 496 | x (List[Tensor]): 497 | List of input video tensors, each with shape [C_in, F, H, W] 498 | t (Tensor): 499 | Diffusion timesteps tensor of shape [B] 500 | context (List[Tensor]): 501 | List of text embeddings each with shape [L, C] 502 | seq_len (`int`): 503 | Maximum sequence length for positional encoding 504 | clip_fea (Tensor, *optional*): 505 | CLIP image features for image-to-video mode 506 | y (List[Tensor], *optional*): 507 | Conditional video inputs for image-to-video mode, same shape as x 508 | 509 | Returns: 510 | List[Tensor]: 511 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 512 | """ 513 | if self.model_type == 'i2v': 514 | assert clip_fea is not None and y is not None 515 | # params 516 | device = self.patch_embedding.weight.device 517 | if self.freqs.device != device: 518 | self.freqs = self.freqs.to(device) 519 | 520 | if y is not None: 521 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 522 | 523 | # embeddings 524 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 525 | grid_sizes = torch.stack( 526 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 527 | x = [u.flatten(2).transpose(1, 2) for u in x] 528 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 529 | assert seq_lens.max() <= seq_len 530 | x = torch.cat([ 531 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 532 | dim=1) for u in x 533 | ]) 534 | 535 | # time embeddings 536 | with amp.autocast(dtype=torch.float32): 537 | e = self.time_embedding( 538 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 539 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 540 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 541 | 542 | # context 543 | context_lens = None 544 | context = self.text_embedding( 545 | torch.stack([ 546 | torch.cat( 547 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 548 | for u in context 549 | ])) 550 | 551 | if clip_fea is not None: 552 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 553 | context = torch.concat([context_clip, context], dim=1) 554 | 555 | # arguments 556 | kwargs = dict( 557 | e=e0, 558 | seq_lens=seq_lens, 559 | grid_sizes=grid_sizes, 560 | freqs=self.freqs, 561 | context=context, 562 | context_lens=context_lens) 563 | 564 | for block in self.blocks: 565 | x = block(x, **kwargs) 566 | 567 | # head 568 | x = self.head(x, e) 569 | 570 | # unpatchify 571 | x = self.unpatchify(x, grid_sizes) 572 | return [u.float() for u in x] 573 | 574 | def unpatchify(self, x, grid_sizes): 575 | r""" 576 | Reconstruct video tensors from patch embeddings. 577 | 578 | Args: 579 | x (List[Tensor]): 580 | List of patchified features, each with shape [L, C_out * prod(patch_size)] 581 | grid_sizes (Tensor): 582 | Original spatial-temporal grid dimensions before patching, 583 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) 584 | 585 | Returns: 586 | List[Tensor]: 587 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] 588 | """ 589 | 590 | c = self.out_dim 591 | out = [] 592 | for u, v in zip(x, grid_sizes.tolist()): 593 | u = u[:math.prod(v)].view(*v, *self.patch_size, c) 594 | u = torch.einsum('fhwpqrc->cfphqwr', u) 595 | u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) 596 | out.append(u) 597 | return out 598 | 599 | def init_weights(self): 600 | r""" 601 | Initialize model parameters using Xavier initialization. 602 | """ 603 | 604 | # basic init 605 | for m in self.modules(): 606 | if isinstance(m, nn.Linear): 607 | nn.init.xavier_uniform_(m.weight) 608 | if m.bias is not None: 609 | nn.init.zeros_(m.bias) 610 | 611 | # init embeddings 612 | nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) 613 | for m in self.text_embedding.modules(): 614 | if isinstance(m, nn.Linear): 615 | nn.init.normal_(m.weight, std=.02) 616 | for m in self.time_embedding.modules(): 617 | if isinstance(m, nn.Linear): 618 | nn.init.normal_(m.weight, std=.02) 619 | 620 | # init output layer 621 | nn.init.zeros_(self.head.head.weight) 622 | -------------------------------------------------------------------------------- /phantom_wan/modules/t5.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.t5.modeling_t5 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .tokenizers import HuggingfaceTokenizer 11 | 12 | __all__ = [ 13 | 'T5Model', 14 | 'T5Encoder', 15 | 'T5Decoder', 16 | 'T5EncoderModel', 17 | ] 18 | 19 | 20 | def fp16_clamp(x): 21 | if x.dtype == torch.float16 and torch.isinf(x).any(): 22 | clamp = torch.finfo(x.dtype).max - 1000 23 | x = torch.clamp(x, min=-clamp, max=clamp) 24 | return x 25 | 26 | 27 | def init_weights(m): 28 | if isinstance(m, T5LayerNorm): 29 | nn.init.ones_(m.weight) 30 | elif isinstance(m, T5Model): 31 | nn.init.normal_(m.token_embedding.weight, std=1.0) 32 | elif isinstance(m, T5FeedForward): 33 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) 34 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) 35 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) 36 | elif isinstance(m, T5Attention): 37 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) 38 | nn.init.normal_(m.k.weight, std=m.dim**-0.5) 39 | nn.init.normal_(m.v.weight, std=m.dim**-0.5) 40 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) 41 | elif isinstance(m, T5RelativeEmbedding): 42 | nn.init.normal_( 43 | m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) 44 | 45 | 46 | class GELU(nn.Module): 47 | 48 | def forward(self, x): 49 | return 0.5 * x * (1.0 + torch.tanh( 50 | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 51 | 52 | 53 | class T5LayerNorm(nn.Module): 54 | 55 | def __init__(self, dim, eps=1e-6): 56 | super(T5LayerNorm, self).__init__() 57 | self.dim = dim 58 | self.eps = eps 59 | self.weight = nn.Parameter(torch.ones(dim)) 60 | 61 | def forward(self, x): 62 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + 63 | self.eps) 64 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 65 | x = x.type_as(self.weight) 66 | return self.weight * x 67 | 68 | 69 | class T5Attention(nn.Module): 70 | 71 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1): 72 | assert dim_attn % num_heads == 0 73 | super(T5Attention, self).__init__() 74 | self.dim = dim 75 | self.dim_attn = dim_attn 76 | self.num_heads = num_heads 77 | self.head_dim = dim_attn // num_heads 78 | 79 | # layers 80 | self.q = nn.Linear(dim, dim_attn, bias=False) 81 | self.k = nn.Linear(dim, dim_attn, bias=False) 82 | self.v = nn.Linear(dim, dim_attn, bias=False) 83 | self.o = nn.Linear(dim_attn, dim, bias=False) 84 | self.dropout = nn.Dropout(dropout) 85 | 86 | def forward(self, x, context=None, mask=None, pos_bias=None): 87 | """ 88 | x: [B, L1, C]. 89 | context: [B, L2, C] or None. 90 | mask: [B, L2] or [B, L1, L2] or None. 91 | """ 92 | # check inputs 93 | context = x if context is None else context 94 | b, n, c = x.size(0), self.num_heads, self.head_dim 95 | 96 | # compute query, key, value 97 | q = self.q(x).view(b, -1, n, c) 98 | k = self.k(context).view(b, -1, n, c) 99 | v = self.v(context).view(b, -1, n, c) 100 | 101 | # attention bias 102 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) 103 | if pos_bias is not None: 104 | attn_bias += pos_bias 105 | if mask is not None: 106 | assert mask.ndim in [2, 3] 107 | mask = mask.view(b, 1, 1, 108 | -1) if mask.ndim == 2 else mask.unsqueeze(1) 109 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) 110 | 111 | # compute attention (T5 does not use scaling) 112 | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias 113 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 114 | x = torch.einsum('bnij,bjnc->binc', attn, v) 115 | 116 | # output 117 | x = x.reshape(b, -1, n * c) 118 | x = self.o(x) 119 | x = self.dropout(x) 120 | return x 121 | 122 | 123 | class T5FeedForward(nn.Module): 124 | 125 | def __init__(self, dim, dim_ffn, dropout=0.1): 126 | super(T5FeedForward, self).__init__() 127 | self.dim = dim 128 | self.dim_ffn = dim_ffn 129 | 130 | # layers 131 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) 132 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False) 133 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False) 134 | self.dropout = nn.Dropout(dropout) 135 | 136 | def forward(self, x): 137 | x = self.fc1(x) * self.gate(x) 138 | x = self.dropout(x) 139 | x = self.fc2(x) 140 | x = self.dropout(x) 141 | return x 142 | 143 | 144 | class T5SelfAttention(nn.Module): 145 | 146 | def __init__(self, 147 | dim, 148 | dim_attn, 149 | dim_ffn, 150 | num_heads, 151 | num_buckets, 152 | shared_pos=True, 153 | dropout=0.1): 154 | super(T5SelfAttention, self).__init__() 155 | self.dim = dim 156 | self.dim_attn = dim_attn 157 | self.dim_ffn = dim_ffn 158 | self.num_heads = num_heads 159 | self.num_buckets = num_buckets 160 | self.shared_pos = shared_pos 161 | 162 | # layers 163 | self.norm1 = T5LayerNorm(dim) 164 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout) 165 | self.norm2 = T5LayerNorm(dim) 166 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 167 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 168 | num_buckets, num_heads, bidirectional=True) 169 | 170 | def forward(self, x, mask=None, pos_bias=None): 171 | e = pos_bias if self.shared_pos else self.pos_embedding( 172 | x.size(1), x.size(1)) 173 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) 174 | x = fp16_clamp(x + self.ffn(self.norm2(x))) 175 | return x 176 | 177 | 178 | class T5CrossAttention(nn.Module): 179 | 180 | def __init__(self, 181 | dim, 182 | dim_attn, 183 | dim_ffn, 184 | num_heads, 185 | num_buckets, 186 | shared_pos=True, 187 | dropout=0.1): 188 | super(T5CrossAttention, self).__init__() 189 | self.dim = dim 190 | self.dim_attn = dim_attn 191 | self.dim_ffn = dim_ffn 192 | self.num_heads = num_heads 193 | self.num_buckets = num_buckets 194 | self.shared_pos = shared_pos 195 | 196 | # layers 197 | self.norm1 = T5LayerNorm(dim) 198 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) 199 | self.norm2 = T5LayerNorm(dim) 200 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) 201 | self.norm3 = T5LayerNorm(dim) 202 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 203 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 204 | num_buckets, num_heads, bidirectional=False) 205 | 206 | def forward(self, 207 | x, 208 | mask=None, 209 | encoder_states=None, 210 | encoder_mask=None, 211 | pos_bias=None): 212 | e = pos_bias if self.shared_pos else self.pos_embedding( 213 | x.size(1), x.size(1)) 214 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) 215 | x = fp16_clamp(x + self.cross_attn( 216 | self.norm2(x), context=encoder_states, mask=encoder_mask)) 217 | x = fp16_clamp(x + self.ffn(self.norm3(x))) 218 | return x 219 | 220 | 221 | class T5RelativeEmbedding(nn.Module): 222 | 223 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): 224 | super(T5RelativeEmbedding, self).__init__() 225 | self.num_buckets = num_buckets 226 | self.num_heads = num_heads 227 | self.bidirectional = bidirectional 228 | self.max_dist = max_dist 229 | 230 | # layers 231 | self.embedding = nn.Embedding(num_buckets, num_heads) 232 | 233 | def forward(self, lq, lk): 234 | device = self.embedding.weight.device 235 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ 236 | # torch.arange(lq).unsqueeze(1).to(device) 237 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ 238 | torch.arange(lq, device=device).unsqueeze(1) 239 | rel_pos = self._relative_position_bucket(rel_pos) 240 | rel_pos_embeds = self.embedding(rel_pos) 241 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( 242 | 0) # [1, N, Lq, Lk] 243 | return rel_pos_embeds.contiguous() 244 | 245 | def _relative_position_bucket(self, rel_pos): 246 | # preprocess 247 | if self.bidirectional: 248 | num_buckets = self.num_buckets // 2 249 | rel_buckets = (rel_pos > 0).long() * num_buckets 250 | rel_pos = torch.abs(rel_pos) 251 | else: 252 | num_buckets = self.num_buckets 253 | rel_buckets = 0 254 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) 255 | 256 | # embeddings for small and large positions 257 | max_exact = num_buckets // 2 258 | rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / 259 | math.log(self.max_dist / max_exact) * 260 | (num_buckets - max_exact)).long() 261 | rel_pos_large = torch.min( 262 | rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) 263 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) 264 | return rel_buckets 265 | 266 | 267 | class T5Encoder(nn.Module): 268 | 269 | def __init__(self, 270 | vocab, 271 | dim, 272 | dim_attn, 273 | dim_ffn, 274 | num_heads, 275 | num_layers, 276 | num_buckets, 277 | shared_pos=True, 278 | dropout=0.1): 279 | super(T5Encoder, self).__init__() 280 | self.dim = dim 281 | self.dim_attn = dim_attn 282 | self.dim_ffn = dim_ffn 283 | self.num_heads = num_heads 284 | self.num_layers = num_layers 285 | self.num_buckets = num_buckets 286 | self.shared_pos = shared_pos 287 | 288 | # layers 289 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 290 | else nn.Embedding(vocab, dim) 291 | self.pos_embedding = T5RelativeEmbedding( 292 | num_buckets, num_heads, bidirectional=True) if shared_pos else None 293 | self.dropout = nn.Dropout(dropout) 294 | self.blocks = nn.ModuleList([ 295 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 296 | shared_pos, dropout) for _ in range(num_layers) 297 | ]) 298 | self.norm = T5LayerNorm(dim) 299 | 300 | # initialize weights 301 | self.apply(init_weights) 302 | 303 | def forward(self, ids, mask=None): 304 | x = self.token_embedding(ids) 305 | x = self.dropout(x) 306 | e = self.pos_embedding(x.size(1), 307 | x.size(1)) if self.shared_pos else None 308 | for block in self.blocks: 309 | x = block(x, mask, pos_bias=e) 310 | x = self.norm(x) 311 | x = self.dropout(x) 312 | return x 313 | 314 | 315 | class T5Decoder(nn.Module): 316 | 317 | def __init__(self, 318 | vocab, 319 | dim, 320 | dim_attn, 321 | dim_ffn, 322 | num_heads, 323 | num_layers, 324 | num_buckets, 325 | shared_pos=True, 326 | dropout=0.1): 327 | super(T5Decoder, self).__init__() 328 | self.dim = dim 329 | self.dim_attn = dim_attn 330 | self.dim_ffn = dim_ffn 331 | self.num_heads = num_heads 332 | self.num_layers = num_layers 333 | self.num_buckets = num_buckets 334 | self.shared_pos = shared_pos 335 | 336 | # layers 337 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 338 | else nn.Embedding(vocab, dim) 339 | self.pos_embedding = T5RelativeEmbedding( 340 | num_buckets, num_heads, bidirectional=False) if shared_pos else None 341 | self.dropout = nn.Dropout(dropout) 342 | self.blocks = nn.ModuleList([ 343 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 344 | shared_pos, dropout) for _ in range(num_layers) 345 | ]) 346 | self.norm = T5LayerNorm(dim) 347 | 348 | # initialize weights 349 | self.apply(init_weights) 350 | 351 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): 352 | b, s = ids.size() 353 | 354 | # causal mask 355 | if mask is None: 356 | mask = torch.tril(torch.ones(1, s, s).to(ids.device)) 357 | elif mask.ndim == 2: 358 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) 359 | 360 | # layers 361 | x = self.token_embedding(ids) 362 | x = self.dropout(x) 363 | e = self.pos_embedding(x.size(1), 364 | x.size(1)) if self.shared_pos else None 365 | for block in self.blocks: 366 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) 367 | x = self.norm(x) 368 | x = self.dropout(x) 369 | return x 370 | 371 | 372 | class T5Model(nn.Module): 373 | 374 | def __init__(self, 375 | vocab_size, 376 | dim, 377 | dim_attn, 378 | dim_ffn, 379 | num_heads, 380 | encoder_layers, 381 | decoder_layers, 382 | num_buckets, 383 | shared_pos=True, 384 | dropout=0.1): 385 | super(T5Model, self).__init__() 386 | self.vocab_size = vocab_size 387 | self.dim = dim 388 | self.dim_attn = dim_attn 389 | self.dim_ffn = dim_ffn 390 | self.num_heads = num_heads 391 | self.encoder_layers = encoder_layers 392 | self.decoder_layers = decoder_layers 393 | self.num_buckets = num_buckets 394 | 395 | # layers 396 | self.token_embedding = nn.Embedding(vocab_size, dim) 397 | self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, 398 | num_heads, encoder_layers, num_buckets, 399 | shared_pos, dropout) 400 | self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, 401 | num_heads, decoder_layers, num_buckets, 402 | shared_pos, dropout) 403 | self.head = nn.Linear(dim, vocab_size, bias=False) 404 | 405 | # initialize weights 406 | self.apply(init_weights) 407 | 408 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): 409 | x = self.encoder(encoder_ids, encoder_mask) 410 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) 411 | x = self.head(x) 412 | return x 413 | 414 | 415 | def _t5(name, 416 | encoder_only=False, 417 | decoder_only=False, 418 | return_tokenizer=False, 419 | tokenizer_kwargs={}, 420 | dtype=torch.float32, 421 | device='cpu', 422 | **kwargs): 423 | # sanity check 424 | assert not (encoder_only and decoder_only) 425 | 426 | # params 427 | if encoder_only: 428 | model_cls = T5Encoder 429 | kwargs['vocab'] = kwargs.pop('vocab_size') 430 | kwargs['num_layers'] = kwargs.pop('encoder_layers') 431 | _ = kwargs.pop('decoder_layers') 432 | elif decoder_only: 433 | model_cls = T5Decoder 434 | kwargs['vocab'] = kwargs.pop('vocab_size') 435 | kwargs['num_layers'] = kwargs.pop('decoder_layers') 436 | _ = kwargs.pop('encoder_layers') 437 | else: 438 | model_cls = T5Model 439 | 440 | # init model 441 | with torch.device(device): 442 | model = model_cls(**kwargs) 443 | 444 | # set device 445 | model = model.to(dtype=dtype, device=device) 446 | 447 | # init tokenizer 448 | if return_tokenizer: 449 | from .tokenizers import HuggingfaceTokenizer 450 | tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) 451 | return model, tokenizer 452 | else: 453 | return model 454 | 455 | 456 | def umt5_xxl(**kwargs): 457 | cfg = dict( 458 | vocab_size=256384, 459 | dim=4096, 460 | dim_attn=4096, 461 | dim_ffn=10240, 462 | num_heads=64, 463 | encoder_layers=24, 464 | decoder_layers=24, 465 | num_buckets=32, 466 | shared_pos=False, 467 | dropout=0.1) 468 | cfg.update(**kwargs) 469 | return _t5('umt5-xxl', **cfg) 470 | 471 | 472 | class T5EncoderModel: 473 | 474 | def __init__( 475 | self, 476 | text_len, 477 | dtype=torch.bfloat16, 478 | device=torch.cuda.current_device(), 479 | checkpoint_path=None, 480 | tokenizer_path=None, 481 | shard_fn=None, 482 | ): 483 | self.text_len = text_len 484 | self.dtype = dtype 485 | self.device = device 486 | self.checkpoint_path = checkpoint_path 487 | self.tokenizer_path = tokenizer_path 488 | 489 | # init model 490 | model = umt5_xxl( 491 | encoder_only=True, 492 | return_tokenizer=False, 493 | dtype=dtype, 494 | device=device).eval().requires_grad_(False) 495 | logging.info(f'loading {checkpoint_path}') 496 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) 497 | self.model = model 498 | if shard_fn is not None: 499 | self.model = shard_fn(self.model, sync_module_states=False) 500 | else: 501 | self.model.to(self.device) 502 | # init tokenizer 503 | self.tokenizer = HuggingfaceTokenizer( 504 | name=tokenizer_path, seq_len=text_len, clean='whitespace') 505 | 506 | def __call__(self, texts, device): 507 | ids, mask = self.tokenizer( 508 | texts, return_mask=True, add_special_tokens=True) 509 | ids = ids.to(device) 510 | mask = mask.to(device) 511 | seq_lens = mask.gt(0).sum(dim=1).long() 512 | context = self.model(ids, mask) 513 | return [u[:v] for u, v in zip(context, seq_lens)] 514 | -------------------------------------------------------------------------------- /phantom_wan/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /phantom_wan/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['XLMRoberta', 'xlm_roberta_large'] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | 12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 13 | assert dim % num_heads == 0 14 | super().__init__() 15 | self.dim = dim 16 | self.num_heads = num_heads 17 | self.head_dim = dim // num_heads 18 | self.eps = eps 19 | 20 | # layers 21 | self.q = nn.Linear(dim, dim) 22 | self.k = nn.Linear(dim, dim) 23 | self.v = nn.Linear(dim, dim) 24 | self.o = nn.Linear(dim, dim) 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, x, mask): 28 | """ 29 | x: [B, L, C]. 30 | """ 31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 32 | 33 | # compute query, key, value 34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 37 | 38 | # compute attention 39 | p = self.dropout.p if self.training else 0.0 40 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 42 | 43 | # output 44 | x = self.o(x) 45 | x = self.dropout(x) 46 | return x 47 | 48 | 49 | class AttentionBlock(nn.Module): 50 | 51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 52 | super().__init__() 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | self.post_norm = post_norm 56 | self.eps = eps 57 | 58 | # layers 59 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 60 | self.norm1 = nn.LayerNorm(dim, eps=eps) 61 | self.ffn = nn.Sequential( 62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), 63 | nn.Dropout(dropout)) 64 | self.norm2 = nn.LayerNorm(dim, eps=eps) 65 | 66 | def forward(self, x, mask): 67 | if self.post_norm: 68 | x = self.norm1(x + self.attn(x, mask)) 69 | x = self.norm2(x + self.ffn(x)) 70 | else: 71 | x = x + self.attn(self.norm1(x), mask) 72 | x = x + self.ffn(self.norm2(x)) 73 | return x 74 | 75 | 76 | class XLMRoberta(nn.Module): 77 | """ 78 | XLMRobertaModel with no pooler and no LM head. 79 | """ 80 | 81 | def __init__(self, 82 | vocab_size=250002, 83 | max_seq_len=514, 84 | type_size=1, 85 | pad_id=1, 86 | dim=1024, 87 | num_heads=16, 88 | num_layers=24, 89 | post_norm=True, 90 | dropout=0.1, 91 | eps=1e-5): 92 | super().__init__() 93 | self.vocab_size = vocab_size 94 | self.max_seq_len = max_seq_len 95 | self.type_size = type_size 96 | self.pad_id = pad_id 97 | self.dim = dim 98 | self.num_heads = num_heads 99 | self.num_layers = num_layers 100 | self.post_norm = post_norm 101 | self.eps = eps 102 | 103 | # embeddings 104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 105 | self.type_embedding = nn.Embedding(type_size, dim) 106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | # blocks 110 | self.blocks = nn.ModuleList([ 111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps) 112 | for _ in range(num_layers) 113 | ]) 114 | 115 | # norm layer 116 | self.norm = nn.LayerNorm(dim, eps=eps) 117 | 118 | def forward(self, ids): 119 | """ 120 | ids: [B, L] of torch.LongTensor. 121 | """ 122 | b, s = ids.shape 123 | mask = ids.ne(self.pad_id).long() 124 | 125 | # embeddings 126 | x = self.token_embedding(ids) + \ 127 | self.type_embedding(torch.zeros_like(ids)) + \ 128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 129 | if self.post_norm: 130 | x = self.norm(x) 131 | x = self.dropout(x) 132 | 133 | # blocks 134 | mask = torch.where( 135 | mask.view(b, 1, 1, s).gt(0), 0.0, 136 | torch.finfo(x.dtype).min) 137 | for block in self.blocks: 138 | x = block(x, mask) 139 | 140 | # output 141 | if not self.post_norm: 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | def xlm_roberta_large(pretrained=False, 147 | return_tokenizer=False, 148 | device='cpu', 149 | **kwargs): 150 | """ 151 | XLMRobertaLarge adapted from Huggingface. 152 | """ 153 | # params 154 | cfg = dict( 155 | vocab_size=250002, 156 | max_seq_len=514, 157 | type_size=1, 158 | pad_id=1, 159 | dim=1024, 160 | num_heads=16, 161 | num_layers=24, 162 | post_norm=True, 163 | dropout=0.1, 164 | eps=1e-5) 165 | cfg.update(**kwargs) 166 | 167 | # init a model on device 168 | with torch.device(device): 169 | model = XLMRoberta(**cfg) 170 | return model 171 | -------------------------------------------------------------------------------- /phantom_wan/subject2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved. 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | import logging 17 | import math 18 | import os 19 | import random 20 | import sys 21 | import types 22 | from contextlib import contextmanager 23 | from functools import partial 24 | 25 | import torch 26 | import torch.cuda.amp as amp 27 | import torch.distributed as dist 28 | from tqdm import tqdm 29 | import torchvision.transforms.functional as TF 30 | 31 | from .distributed.fsdp import shard_model 32 | from .modules.model import WanModel 33 | from .modules.t5 import T5EncoderModel 34 | from .modules.vae import WanVAE 35 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, 36 | get_sampling_sigmas, retrieve_timesteps) 37 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 38 | 39 | 40 | class Phantom_Wan_S2V: 41 | 42 | def __init__( 43 | self, 44 | config, 45 | checkpoint_dir, 46 | phantom_ckpt, 47 | device_id=0, 48 | rank=0, 49 | t5_fsdp=False, 50 | dit_fsdp=False, 51 | use_usp=False, 52 | t5_cpu=False, 53 | ): 54 | r""" 55 | Initializes the Wan text-to-video generation model components. 56 | 57 | Args: 58 | config (EasyDict): 59 | Object containing model parameters initialized from config.py 60 | checkpoint_dir (`str`): 61 | Path to directory containing model checkpoints 62 | phantom_ckpt (`str`): 63 | Path of Phantom-Wan dit checkpoint 64 | device_id (`int`, *optional*, defaults to 0): 65 | Id of target GPU device 66 | rank (`int`, *optional*, defaults to 0): 67 | Process rank for distributed training 68 | t5_fsdp (`bool`, *optional*, defaults to False): 69 | Enable FSDP sharding for T5 model 70 | dit_fsdp (`bool`, *optional*, defaults to False): 71 | Enable FSDP sharding for DiT model 72 | use_usp (`bool`, *optional*, defaults to False): 73 | Enable distribution strategy of USP. 74 | t5_cpu (`bool`, *optional*, defaults to False): 75 | Whether to place T5 model on CPU. Only works without t5_fsdp. 76 | """ 77 | self.device = torch.device(f"cuda:{device_id}") 78 | self.config = config 79 | self.rank = rank 80 | self.t5_cpu = t5_cpu 81 | 82 | self.num_train_timesteps = config.num_train_timesteps 83 | self.param_dtype = config.param_dtype 84 | 85 | shard_fn = partial(shard_model, device_id=device_id) 86 | self.text_encoder = T5EncoderModel( 87 | text_len=config.text_len, 88 | dtype=config.t5_dtype, 89 | device=torch.device('cpu'), 90 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 91 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 92 | shard_fn=shard_fn if t5_fsdp else None) 93 | 94 | self.vae_stride = config.vae_stride 95 | self.patch_size = config.patch_size 96 | self.vae = WanVAE( 97 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 98 | device=self.device) 99 | 100 | logging.info(f"Creating WanModel from {phantom_ckpt}") 101 | self.model = WanModel(dim=config.dim, 102 | ffn_dim=config.ffn_dim, 103 | freq_dim=config.freq_dim, 104 | num_heads=config.num_heads, 105 | num_layers=config.num_layers, 106 | window_size=config.window_size, 107 | qk_norm=config.qk_norm, 108 | cross_attn_norm=config.cross_attn_norm, 109 | eps=config.eps) 110 | logging.info(f"loading ckpt.") 111 | state = torch.load(phantom_ckpt, map_location=self.device) 112 | logging.info(f"loading state dict.") 113 | self.model.load_state_dict(state, strict=False) 114 | # self.model = WanModel.from_pretrained(checkpoint_dir) 115 | self.model.eval().requires_grad_(False) 116 | 117 | if use_usp: 118 | from xfuser.core.distributed import \ 119 | get_sequence_parallel_world_size 120 | 121 | from .distributed.xdit_context_parallel import (usp_attn_forward, 122 | usp_dit_forward) 123 | for block in self.model.blocks: 124 | block.self_attn.forward = types.MethodType( 125 | usp_attn_forward, block.self_attn) 126 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 127 | self.sp_size = get_sequence_parallel_world_size() 128 | else: 129 | self.sp_size = 1 130 | 131 | if dist.is_initialized(): 132 | dist.barrier() 133 | if dit_fsdp: 134 | self.model = shard_fn(self.model) 135 | else: 136 | self.model.to(self.device) 137 | 138 | self.sample_neg_prompt = config.sample_neg_prompt 139 | 140 | 141 | def get_vae_latents(self, ref_images, device): 142 | ref_vae_latents = [] 143 | for ref_image in ref_images: 144 | ref_image = TF.to_tensor(ref_image).sub_(0.5).div_(0.5).to(self.device) 145 | img_vae_latent = self.vae.encode([ref_image.unsqueeze(1)]) 146 | ref_vae_latents.append(img_vae_latent[0]) 147 | 148 | return torch.cat(ref_vae_latents, dim=1) 149 | 150 | 151 | def generate(self, 152 | input_prompt, 153 | ref_images, 154 | size=(1280, 720), 155 | frame_num=81, 156 | shift=5.0, 157 | sample_solver='unipc', 158 | sampling_steps=50, 159 | guide_scale_img=5.0, 160 | guide_scale_text=7.5, 161 | n_prompt="", 162 | seed=-1, 163 | offload_model=True): 164 | r""" 165 | Generates video frames from text prompt using diffusion process. 166 | 167 | Args: 168 | input_prompt (`str`): 169 | Text prompt for content generation 170 | ref_images ([`PIL.Image`])`: 171 | Reference images for subject generation. 172 | size (tupele[`int`], *optional*, defaults to (1280,720)): 173 | Controls video resolution, (width,height). 174 | frame_num (`int`, *optional*, defaults to 81): 175 | How many frames to sample from a video. The number should be 4n+1 176 | shift (`float`, *optional*, defaults to 5.0): 177 | Noise schedule shift parameter. Affects temporal dynamics 178 | sample_solver (`str`, *optional*, defaults to 'unipc'): 179 | Solver used to sample the video. 180 | sampling_steps (`int`, *optional*, defaults to 40): 181 | Number of diffusion sampling steps. Higher values improve quality but slow generation 182 | guide_scale (`float`, *optional*, defaults 5.0): 183 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 184 | n_prompt (`str`, *optional*, defaults to ""): 185 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 186 | seed (`int`, *optional*, defaults to -1): 187 | Random seed for noise generation. If -1, use random seed. 188 | offload_model (`bool`, *optional*, defaults to True): 189 | If True, offloads models to CPU during generation to save VRAM 190 | 191 | Returns: 192 | torch.Tensor: 193 | Generated video frames tensor. Dimensions: (C, N H, W) where: 194 | - C: Color channels (3 for RGB) 195 | - N: Number of frames (81) 196 | - H: Frame height (from size) 197 | - W: Frame width from size) 198 | """ 199 | # preprocess 200 | 201 | ref_latents = [self.get_vae_latents(ref_images, self.device)] 202 | ref_latents_neg = [torch.zeros_like(ref_latents[0])] 203 | 204 | F = frame_num 205 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + ref_latents[0].shape[1], 206 | size[1] // self.vae_stride[1], 207 | size[0] // self.vae_stride[2]) 208 | 209 | seq_len = math.ceil((target_shape[2] * target_shape[3]) / 210 | (self.patch_size[1] * self.patch_size[2]) * 211 | target_shape[1] / self.sp_size) * self.sp_size 212 | 213 | if n_prompt == "": 214 | n_prompt = self.sample_neg_prompt 215 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 216 | seed_g = torch.Generator(device=self.device) 217 | seed_g.manual_seed(seed) 218 | 219 | if not self.t5_cpu: 220 | self.text_encoder.model.to(self.device) 221 | context = self.text_encoder([input_prompt], self.device) 222 | context_null = self.text_encoder([n_prompt], self.device) 223 | if offload_model: 224 | self.text_encoder.model.cpu() 225 | else: 226 | context = self.text_encoder([input_prompt], torch.device('cpu')) 227 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 228 | context = [t.to(self.device) for t in context] 229 | context_null = [t.to(self.device) for t in context_null] 230 | 231 | noise = [ 232 | torch.randn( 233 | target_shape[0], 234 | target_shape[1], 235 | target_shape[2], 236 | target_shape[3], 237 | dtype=torch.float32, 238 | device=self.device, 239 | generator=seed_g) 240 | ] 241 | 242 | @contextmanager 243 | def noop_no_sync(): 244 | yield 245 | 246 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 247 | 248 | # evaluation mode 249 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 250 | 251 | if sample_solver == 'unipc': 252 | sample_scheduler = FlowUniPCMultistepScheduler( 253 | num_train_timesteps=self.num_train_timesteps, 254 | shift=1, 255 | use_dynamic_shifting=False) 256 | sample_scheduler.set_timesteps( 257 | sampling_steps, device=self.device, shift=shift) 258 | timesteps = sample_scheduler.timesteps 259 | else: 260 | raise NotImplementedError("Unsupported solver.") 261 | 262 | # sample videos 263 | latents = noise 264 | 265 | arg_c = {'context': context, 'seq_len': seq_len} 266 | arg_null = {'context': context_null, 'seq_len': seq_len} 267 | 268 | for _, t in enumerate(tqdm(timesteps)): 269 | 270 | timestep = [t] 271 | timestep = torch.stack(timestep) 272 | 273 | self.model.to(self.device) 274 | pos_it = self.model( 275 | [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latents, ref_latents)], t=timestep, **arg_c 276 | )[0] 277 | pos_i = self.model( 278 | [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latents, ref_latents)], t=timestep, **arg_null 279 | )[0] 280 | neg = self.model( 281 | [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latents, ref_latents_neg)], t=timestep, **arg_null 282 | )[0] 283 | 284 | noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i) 285 | 286 | temp_x0 = sample_scheduler.step( 287 | noise_pred.unsqueeze(0), 288 | t, 289 | latents[0].unsqueeze(0), 290 | return_dict=False, 291 | generator=seed_g)[0] 292 | latents = [temp_x0.squeeze(0)] 293 | 294 | x0 = latents 295 | x0 = [x0_[:,:-ref_latents[0].shape[1]] for x0_ in x0] 296 | 297 | if offload_model: 298 | self.model.cpu() 299 | if self.rank == 0: 300 | videos = self.vae.decode(x0) 301 | 302 | del noise, latents 303 | del sample_scheduler 304 | if offload_model: 305 | gc.collect() 306 | torch.cuda.synchronize() 307 | if dist.is_initialized(): 308 | dist.barrier() 309 | 310 | return videos[0] if self.rank == 0 else None 311 | -------------------------------------------------------------------------------- /phantom_wan/text2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import torch 13 | import torch.cuda.amp as amp 14 | import torch.distributed as dist 15 | from tqdm import tqdm 16 | 17 | from .distributed.fsdp import shard_model 18 | from .modules.model import WanModel 19 | from .modules.t5 import T5EncoderModel 20 | from .modules.vae import WanVAE 21 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, 22 | get_sampling_sigmas, retrieve_timesteps) 23 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 24 | 25 | 26 | class WanT2V: 27 | 28 | def __init__( 29 | self, 30 | config, 31 | checkpoint_dir, 32 | device_id=0, 33 | rank=0, 34 | t5_fsdp=False, 35 | dit_fsdp=False, 36 | use_usp=False, 37 | t5_cpu=False, 38 | ): 39 | r""" 40 | Initializes the Wan text-to-video generation model components. 41 | 42 | Args: 43 | config (EasyDict): 44 | Object containing model parameters initialized from config.py 45 | checkpoint_dir (`str`): 46 | Path to directory containing model checkpoints 47 | device_id (`int`, *optional*, defaults to 0): 48 | Id of target GPU device 49 | rank (`int`, *optional*, defaults to 0): 50 | Process rank for distributed training 51 | t5_fsdp (`bool`, *optional*, defaults to False): 52 | Enable FSDP sharding for T5 model 53 | dit_fsdp (`bool`, *optional*, defaults to False): 54 | Enable FSDP sharding for DiT model 55 | use_usp (`bool`, *optional*, defaults to False): 56 | Enable distribution strategy of USP. 57 | t5_cpu (`bool`, *optional*, defaults to False): 58 | Whether to place T5 model on CPU. Only works without t5_fsdp. 59 | """ 60 | self.device = torch.device(f"cuda:{device_id}") 61 | self.config = config 62 | self.rank = rank 63 | self.t5_cpu = t5_cpu 64 | 65 | self.num_train_timesteps = config.num_train_timesteps 66 | self.param_dtype = config.param_dtype 67 | 68 | shard_fn = partial(shard_model, device_id=device_id) 69 | self.text_encoder = T5EncoderModel( 70 | text_len=config.text_len, 71 | dtype=config.t5_dtype, 72 | device=torch.device('cpu'), 73 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 74 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 75 | shard_fn=shard_fn if t5_fsdp else None) 76 | 77 | self.vae_stride = config.vae_stride 78 | self.patch_size = config.patch_size 79 | self.vae = WanVAE( 80 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 81 | device=self.device) 82 | 83 | logging.info(f"Creating WanModel from {checkpoint_dir}") 84 | # self.model = WanModel() 85 | # logging.info(f"loading ckpt.") 86 | # state = torch.load("/mnt/bn/matianxiang-data-hl2/vfm-prod/downloads/Wan2.1-T2V-14B-dit.pth", map_location=self.device) 87 | # logging.info(f"loading state dict.") 88 | # self.model.load_state_dict(state, strict=False) 89 | self.model = WanModel.from_pretrained(checkpoint_dir) 90 | self.model.eval().requires_grad_(False) 91 | 92 | if use_usp: 93 | from xfuser.core.distributed import \ 94 | get_sequence_parallel_world_size 95 | 96 | from .distributed.xdit_context_parallel import (usp_attn_forward, 97 | usp_dit_forward) 98 | for block in self.model.blocks: 99 | block.self_attn.forward = types.MethodType( 100 | usp_attn_forward, block.self_attn) 101 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 102 | self.sp_size = get_sequence_parallel_world_size() 103 | else: 104 | self.sp_size = 1 105 | 106 | if dist.is_initialized(): 107 | dist.barrier() 108 | if dit_fsdp: 109 | self.model = shard_fn(self.model) 110 | else: 111 | self.model.to(self.device) 112 | 113 | self.sample_neg_prompt = config.sample_neg_prompt 114 | 115 | def generate(self, 116 | input_prompt, 117 | size=(1280, 720), 118 | frame_num=81, 119 | shift=5.0, 120 | sample_solver='unipc', 121 | sampling_steps=50, 122 | guide_scale=5.0, 123 | n_prompt="", 124 | seed=-1, 125 | offload_model=True): 126 | r""" 127 | Generates video frames from text prompt using diffusion process. 128 | 129 | Args: 130 | input_prompt (`str`): 131 | Text prompt for content generation 132 | size (tupele[`int`], *optional*, defaults to (1280,720)): 133 | Controls video resolution, (width,height). 134 | frame_num (`int`, *optional*, defaults to 81): 135 | How many frames to sample from a video. The number should be 4n+1 136 | shift (`float`, *optional*, defaults to 5.0): 137 | Noise schedule shift parameter. Affects temporal dynamics 138 | sample_solver (`str`, *optional*, defaults to 'unipc'): 139 | Solver used to sample the video. 140 | sampling_steps (`int`, *optional*, defaults to 40): 141 | Number of diffusion sampling steps. Higher values improve quality but slow generation 142 | guide_scale (`float`, *optional*, defaults 5.0): 143 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 144 | n_prompt (`str`, *optional*, defaults to ""): 145 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 146 | seed (`int`, *optional*, defaults to -1): 147 | Random seed for noise generation. If -1, use random seed. 148 | offload_model (`bool`, *optional*, defaults to True): 149 | If True, offloads models to CPU during generation to save VRAM 150 | 151 | Returns: 152 | torch.Tensor: 153 | Generated video frames tensor. Dimensions: (C, N H, W) where: 154 | - C: Color channels (3 for RGB) 155 | - N: Number of frames (81) 156 | - H: Frame height (from size) 157 | - W: Frame width from size) 158 | """ 159 | # preprocess 160 | F = frame_num 161 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, 162 | size[1] // self.vae_stride[1], 163 | size[0] // self.vae_stride[2]) 164 | 165 | seq_len = math.ceil((target_shape[2] * target_shape[3]) / 166 | (self.patch_size[1] * self.patch_size[2]) * 167 | target_shape[1] / self.sp_size) * self.sp_size 168 | 169 | if n_prompt == "": 170 | n_prompt = self.sample_neg_prompt 171 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 172 | seed_g = torch.Generator(device=self.device) 173 | seed_g.manual_seed(seed) 174 | 175 | if not self.t5_cpu: 176 | self.text_encoder.model.to(self.device) 177 | context = self.text_encoder([input_prompt], self.device) 178 | context_null = self.text_encoder([n_prompt], self.device) 179 | if offload_model: 180 | self.text_encoder.model.cpu() 181 | else: 182 | context = self.text_encoder([input_prompt], torch.device('cpu')) 183 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 184 | context = [t.to(self.device) for t in context] 185 | context_null = [t.to(self.device) for t in context_null] 186 | 187 | noise = [ 188 | torch.randn( 189 | target_shape[0], 190 | target_shape[1], 191 | target_shape[2], 192 | target_shape[3], 193 | dtype=torch.float32, 194 | device=self.device, 195 | generator=seed_g) 196 | ] 197 | 198 | @contextmanager 199 | def noop_no_sync(): 200 | yield 201 | 202 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 203 | 204 | # evaluation mode 205 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 206 | 207 | if sample_solver == 'unipc': 208 | sample_scheduler = FlowUniPCMultistepScheduler( 209 | num_train_timesteps=self.num_train_timesteps, 210 | shift=1, 211 | use_dynamic_shifting=False) 212 | sample_scheduler.set_timesteps( 213 | sampling_steps, device=self.device, shift=shift) 214 | timesteps = sample_scheduler.timesteps 215 | elif sample_solver == 'dpm++': 216 | sample_scheduler = FlowDPMSolverMultistepScheduler( 217 | num_train_timesteps=self.num_train_timesteps, 218 | shift=1, 219 | use_dynamic_shifting=False) 220 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 221 | timesteps, _ = retrieve_timesteps( 222 | sample_scheduler, 223 | device=self.device, 224 | sigmas=sampling_sigmas) 225 | else: 226 | raise NotImplementedError("Unsupported solver.") 227 | 228 | # sample videos 229 | latents = noise 230 | 231 | arg_c = {'context': context, 'seq_len': seq_len} 232 | arg_null = {'context': context_null, 'seq_len': seq_len} 233 | 234 | for _, t in enumerate(tqdm(timesteps)): 235 | latent_model_input = latents 236 | timestep = [t] 237 | 238 | timestep = torch.stack(timestep) 239 | 240 | self.model.to(self.device) 241 | noise_pred_cond = self.model( 242 | latent_model_input, t=timestep, **arg_c)[0] 243 | noise_pred_uncond = self.model( 244 | latent_model_input, t=timestep, **arg_null)[0] 245 | 246 | noise_pred = noise_pred_uncond + guide_scale * ( 247 | noise_pred_cond - noise_pred_uncond) 248 | 249 | temp_x0 = sample_scheduler.step( 250 | noise_pred.unsqueeze(0), 251 | t, 252 | latents[0].unsqueeze(0), 253 | return_dict=False, 254 | generator=seed_g)[0] 255 | latents = [temp_x0.squeeze(0)] 256 | 257 | x0 = latents 258 | if offload_model: 259 | self.model.cpu() 260 | if self.rank == 0: 261 | videos = self.vae.decode(x0) 262 | 263 | del noise, latents 264 | del sample_scheduler 265 | if offload_model: 266 | gc.collect() 267 | torch.cuda.synchronize() 268 | if dist.is_initialized(): 269 | dist.barrier() 270 | 271 | return videos[0] if self.rank == 0 else None 272 | -------------------------------------------------------------------------------- /phantom_wan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, 2 | retrieve_timesteps) 3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 4 | 5 | __all__ = [ 6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' 8 | ] 9 | -------------------------------------------------------------------------------- /phantom_wan/utils/qwen_vl_utils.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/kq-chen/qwen-vl-utils 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | from __future__ import annotations 4 | 5 | import base64 6 | import logging 7 | import math 8 | import os 9 | import sys 10 | import time 11 | import warnings 12 | from functools import lru_cache 13 | from io import BytesIO 14 | 15 | import requests 16 | import torch 17 | import torchvision 18 | from packaging import version 19 | from PIL import Image 20 | from torchvision import io, transforms 21 | from torchvision.transforms import InterpolationMode 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | IMAGE_FACTOR = 28 26 | MIN_PIXELS = 4 * 28 * 28 27 | MAX_PIXELS = 16384 * 28 * 28 28 | MAX_RATIO = 200 29 | 30 | VIDEO_MIN_PIXELS = 128 * 28 * 28 31 | VIDEO_MAX_PIXELS = 768 * 28 * 28 32 | VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 33 | FRAME_FACTOR = 2 34 | FPS = 2.0 35 | FPS_MIN_FRAMES = 4 36 | FPS_MAX_FRAMES = 768 37 | 38 | 39 | def round_by_factor(number: int, factor: int) -> int: 40 | """Returns the closest integer to 'number' that is divisible by 'factor'.""" 41 | return round(number / factor) * factor 42 | 43 | 44 | def ceil_by_factor(number: int, factor: int) -> int: 45 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" 46 | return math.ceil(number / factor) * factor 47 | 48 | 49 | def floor_by_factor(number: int, factor: int) -> int: 50 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" 51 | return math.floor(number / factor) * factor 52 | 53 | 54 | def smart_resize(height: int, 55 | width: int, 56 | factor: int = IMAGE_FACTOR, 57 | min_pixels: int = MIN_PIXELS, 58 | max_pixels: int = MAX_PIXELS) -> tuple[int, int]: 59 | """ 60 | Rescales the image so that the following conditions are met: 61 | 62 | 1. Both dimensions (height and width) are divisible by 'factor'. 63 | 64 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 65 | 66 | 3. The aspect ratio of the image is maintained as closely as possible. 67 | """ 68 | if max(height, width) / min(height, width) > MAX_RATIO: 69 | raise ValueError( 70 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" 71 | ) 72 | h_bar = max(factor, round_by_factor(height, factor)) 73 | w_bar = max(factor, round_by_factor(width, factor)) 74 | if h_bar * w_bar > max_pixels: 75 | beta = math.sqrt((height * width) / max_pixels) 76 | h_bar = floor_by_factor(height / beta, factor) 77 | w_bar = floor_by_factor(width / beta, factor) 78 | elif h_bar * w_bar < min_pixels: 79 | beta = math.sqrt(min_pixels / (height * width)) 80 | h_bar = ceil_by_factor(height * beta, factor) 81 | w_bar = ceil_by_factor(width * beta, factor) 82 | return h_bar, w_bar 83 | 84 | 85 | def fetch_image(ele: dict[str, str | Image.Image], 86 | size_factor: int = IMAGE_FACTOR) -> Image.Image: 87 | if "image" in ele: 88 | image = ele["image"] 89 | else: 90 | image = ele["image_url"] 91 | image_obj = None 92 | if isinstance(image, Image.Image): 93 | image_obj = image 94 | elif image.startswith("http://") or image.startswith("https://"): 95 | image_obj = Image.open(requests.get(image, stream=True).raw) 96 | elif image.startswith("file://"): 97 | image_obj = Image.open(image[7:]) 98 | elif image.startswith("data:image"): 99 | if "base64," in image: 100 | _, base64_data = image.split("base64,", 1) 101 | data = base64.b64decode(base64_data) 102 | image_obj = Image.open(BytesIO(data)) 103 | else: 104 | image_obj = Image.open(image) 105 | if image_obj is None: 106 | raise ValueError( 107 | f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" 108 | ) 109 | image = image_obj.convert("RGB") 110 | ## resize 111 | if "resized_height" in ele and "resized_width" in ele: 112 | resized_height, resized_width = smart_resize( 113 | ele["resized_height"], 114 | ele["resized_width"], 115 | factor=size_factor, 116 | ) 117 | else: 118 | width, height = image.size 119 | min_pixels = ele.get("min_pixels", MIN_PIXELS) 120 | max_pixels = ele.get("max_pixels", MAX_PIXELS) 121 | resized_height, resized_width = smart_resize( 122 | height, 123 | width, 124 | factor=size_factor, 125 | min_pixels=min_pixels, 126 | max_pixels=max_pixels, 127 | ) 128 | image = image.resize((resized_width, resized_height)) 129 | 130 | return image 131 | 132 | 133 | def smart_nframes( 134 | ele: dict, 135 | total_frames: int, 136 | video_fps: int | float, 137 | ) -> int: 138 | """calculate the number of frames for video used for model inputs. 139 | 140 | Args: 141 | ele (dict): a dict contains the configuration of video. 142 | support either `fps` or `nframes`: 143 | - nframes: the number of frames to extract for model inputs. 144 | - fps: the fps to extract frames for model inputs. 145 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 146 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 147 | total_frames (int): the original total number of frames of the video. 148 | video_fps (int | float): the original fps of the video. 149 | 150 | Raises: 151 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 152 | 153 | Returns: 154 | int: the number of frames for video used for model inputs. 155 | """ 156 | assert not ("fps" in ele and 157 | "nframes" in ele), "Only accept either `fps` or `nframes`" 158 | if "nframes" in ele: 159 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 160 | else: 161 | fps = ele.get("fps", FPS) 162 | min_frames = ceil_by_factor( 163 | ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 164 | max_frames = floor_by_factor( 165 | ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), 166 | FRAME_FACTOR) 167 | nframes = total_frames / video_fps * fps 168 | nframes = min(max(nframes, min_frames), max_frames) 169 | nframes = round_by_factor(nframes, FRAME_FACTOR) 170 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 171 | raise ValueError( 172 | f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." 173 | ) 174 | return nframes 175 | 176 | 177 | def _read_video_torchvision(ele: dict,) -> torch.Tensor: 178 | """read video using torchvision.io.read_video 179 | 180 | Args: 181 | ele (dict): a dict contains the configuration of video. 182 | support keys: 183 | - video: the path of video. support "file://", "http://", "https://" and local path. 184 | - video_start: the start time of video. 185 | - video_end: the end time of video. 186 | Returns: 187 | torch.Tensor: the video tensor with shape (T, C, H, W). 188 | """ 189 | video_path = ele["video"] 190 | if version.parse(torchvision.__version__) < version.parse("0.19.0"): 191 | if "http://" in video_path or "https://" in video_path: 192 | warnings.warn( 193 | "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." 194 | ) 195 | if "file://" in video_path: 196 | video_path = video_path[7:] 197 | st = time.time() 198 | video, audio, info = io.read_video( 199 | video_path, 200 | start_pts=ele.get("video_start", 0.0), 201 | end_pts=ele.get("video_end", None), 202 | pts_unit="sec", 203 | output_format="TCHW", 204 | ) 205 | total_frames, video_fps = video.size(0), info["video_fps"] 206 | logger.info( 207 | f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 208 | ) 209 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 210 | idx = torch.linspace(0, total_frames - 1, nframes).round().long() 211 | video = video[idx] 212 | return video 213 | 214 | 215 | def is_decord_available() -> bool: 216 | import importlib.util 217 | 218 | return importlib.util.find_spec("decord") is not None 219 | 220 | 221 | def _read_video_decord(ele: dict,) -> torch.Tensor: 222 | """read video using decord.VideoReader 223 | 224 | Args: 225 | ele (dict): a dict contains the configuration of video. 226 | support keys: 227 | - video: the path of video. support "file://", "http://", "https://" and local path. 228 | - video_start: the start time of video. 229 | - video_end: the end time of video. 230 | Returns: 231 | torch.Tensor: the video tensor with shape (T, C, H, W). 232 | """ 233 | import decord 234 | video_path = ele["video"] 235 | st = time.time() 236 | vr = decord.VideoReader(video_path) 237 | # TODO: support start_pts and end_pts 238 | if 'video_start' in ele or 'video_end' in ele: 239 | raise NotImplementedError( 240 | "not support start_pts and end_pts in decord for now.") 241 | total_frames, video_fps = len(vr), vr.get_avg_fps() 242 | logger.info( 243 | f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 244 | ) 245 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 246 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 247 | video = vr.get_batch(idx).asnumpy() 248 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format 249 | return video 250 | 251 | 252 | VIDEO_READER_BACKENDS = { 253 | "decord": _read_video_decord, 254 | "torchvision": _read_video_torchvision, 255 | } 256 | 257 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) 258 | 259 | 260 | @lru_cache(maxsize=1) 261 | def get_video_reader_backend() -> str: 262 | if FORCE_QWENVL_VIDEO_READER is not None: 263 | video_reader_backend = FORCE_QWENVL_VIDEO_READER 264 | elif is_decord_available(): 265 | video_reader_backend = "decord" 266 | else: 267 | video_reader_backend = "torchvision" 268 | print( 269 | f"qwen-vl-utils using {video_reader_backend} to read video.", 270 | file=sys.stderr) 271 | return video_reader_backend 272 | 273 | 274 | def fetch_video( 275 | ele: dict, 276 | image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: 277 | if isinstance(ele["video"], str): 278 | video_reader_backend = get_video_reader_backend() 279 | video = VIDEO_READER_BACKENDS[video_reader_backend](ele) 280 | nframes, _, height, width = video.shape 281 | 282 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 283 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 284 | max_pixels = max( 285 | min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), 286 | int(min_pixels * 1.05)) 287 | max_pixels = ele.get("max_pixels", max_pixels) 288 | if "resized_height" in ele and "resized_width" in ele: 289 | resized_height, resized_width = smart_resize( 290 | ele["resized_height"], 291 | ele["resized_width"], 292 | factor=image_factor, 293 | ) 294 | else: 295 | resized_height, resized_width = smart_resize( 296 | height, 297 | width, 298 | factor=image_factor, 299 | min_pixels=min_pixels, 300 | max_pixels=max_pixels, 301 | ) 302 | video = transforms.functional.resize( 303 | video, 304 | [resized_height, resized_width], 305 | interpolation=InterpolationMode.BICUBIC, 306 | antialias=True, 307 | ).float() 308 | return video 309 | else: 310 | assert isinstance(ele["video"], (list, tuple)) 311 | process_info = ele.copy() 312 | process_info.pop("type", None) 313 | process_info.pop("video", None) 314 | images = [ 315 | fetch_image({ 316 | "image": video_element, 317 | **process_info 318 | }, 319 | size_factor=image_factor) 320 | for video_element in ele["video"] 321 | ] 322 | nframes = ceil_by_factor(len(images), FRAME_FACTOR) 323 | if len(images) < nframes: 324 | images.extend([images[-1]] * (nframes - len(images))) 325 | return images 326 | 327 | 328 | def extract_vision_info( 329 | conversations: list[dict] | list[list[dict]]) -> list[dict]: 330 | vision_infos = [] 331 | if isinstance(conversations[0], dict): 332 | conversations = [conversations] 333 | for conversation in conversations: 334 | for message in conversation: 335 | if isinstance(message["content"], list): 336 | for ele in message["content"]: 337 | if ("image" in ele or "image_url" in ele or 338 | "video" in ele or 339 | ele["type"] in ("image", "image_url", "video")): 340 | vision_infos.append(ele) 341 | return vision_infos 342 | 343 | 344 | def process_vision_info( 345 | conversations: list[dict] | list[list[dict]], 346 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | 347 | None]: 348 | vision_infos = extract_vision_info(conversations) 349 | ## Read images or videos 350 | image_inputs = [] 351 | video_inputs = [] 352 | for vision_info in vision_infos: 353 | if "image" in vision_info or "image_url" in vision_info: 354 | image_inputs.append(fetch_image(vision_info)) 355 | elif "video" in vision_info: 356 | video_inputs.append(fetch_video(vision_info)) 357 | else: 358 | raise ValueError("image, image_url or video should in content.") 359 | if len(image_inputs) == 0: 360 | image_inputs = None 361 | if len(video_inputs) == 0: 362 | video_inputs = None 363 | return image_inputs, video_inputs 364 | -------------------------------------------------------------------------------- /phantom_wan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import binascii 4 | import os 5 | import os.path as osp 6 | 7 | import imageio 8 | import torch 9 | import torchvision 10 | 11 | __all__ = ['cache_video', 'cache_image', 'str2bool'] 12 | 13 | 14 | def rand_name(length=8, suffix=''): 15 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 16 | if suffix: 17 | if not suffix.startswith('.'): 18 | suffix = '.' + suffix 19 | name += suffix 20 | return name 21 | 22 | 23 | def cache_video(tensor, 24 | save_file=None, 25 | fps=30, 26 | suffix='.mp4', 27 | nrow=8, 28 | normalize=True, 29 | value_range=(-1, 1), 30 | retry=5): 31 | # cache file 32 | cache_file = osp.join('/tmp', rand_name( 33 | suffix=suffix)) if save_file is None else save_file 34 | 35 | # save to cache 36 | error = None 37 | for _ in range(retry): 38 | try: 39 | # preprocess 40 | tensor = tensor.clamp(min(value_range), max(value_range)) 41 | tensor = torch.stack([ 42 | torchvision.utils.make_grid( 43 | u, nrow=nrow, normalize=normalize, value_range=value_range) 44 | for u in tensor.unbind(2) 45 | ], 46 | dim=1).permute(1, 2, 3, 0) 47 | tensor = (tensor * 255).type(torch.uint8).cpu() 48 | 49 | # write video 50 | writer = imageio.get_writer( 51 | cache_file, fps=fps, codec='libx264', quality=8) 52 | for frame in tensor.numpy(): 53 | writer.append_data(frame) 54 | writer.close() 55 | return cache_file 56 | except Exception as e: 57 | error = e 58 | continue 59 | else: 60 | print(f'cache_video failed, error: {error}', flush=True) 61 | return None 62 | 63 | 64 | def cache_image(tensor, 65 | save_file, 66 | nrow=8, 67 | normalize=True, 68 | value_range=(-1, 1), 69 | retry=5): 70 | # cache file 71 | suffix = osp.splitext(save_file)[1] 72 | if suffix.lower() not in [ 73 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' 74 | ]: 75 | suffix = '.png' 76 | 77 | # save to cache 78 | error = None 79 | for _ in range(retry): 80 | try: 81 | tensor = tensor.clamp(min(value_range), max(value_range)) 82 | torchvision.utils.save_image( 83 | tensor, 84 | save_file, 85 | nrow=nrow, 86 | normalize=normalize, 87 | value_range=value_range) 88 | return save_file 89 | except Exception as e: 90 | error = e 91 | continue 92 | 93 | 94 | def str2bool(v): 95 | """ 96 | Convert a string to a boolean. 97 | 98 | Supported true values: 'yes', 'true', 't', 'y', '1' 99 | Supported false values: 'no', 'false', 'f', 'n', '0' 100 | 101 | Args: 102 | v (str): String to convert. 103 | 104 | Returns: 105 | bool: Converted boolean value. 106 | 107 | Raises: 108 | argparse.ArgumentTypeError: If the value cannot be converted to boolean. 109 | """ 110 | if isinstance(v, bool): 111 | return v 112 | v_lower = v.lower() 113 | if v_lower in ('yes', 'true', 't', 'y', '1'): 114 | return True 115 | elif v_lower in ('no', 'false', 'f', 'n', '0'): 116 | return False 117 | else: 118 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') 119 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.4.0 2 | torchvision>=0.19.0 3 | opencv-python>=4.9.0.80 4 | diffusers>=0.31.0 5 | transformers>=4.49.0 6 | tokenizers>=0.20.3 7 | accelerate>=1.1.1 8 | tqdm 9 | imageio 10 | easydict 11 | ftfy 12 | dashscope 13 | imageio-ffmpeg 14 | flash_attn 15 | gradio>=5.0.0 16 | numpy>=1.23.5,<2 17 | xfuser>=0.4.1 --------------------------------------------------------------------------------