├── .gitignore
├── LICENSE.txt
├── README.md
├── assets
├── Alegreya-Regular.ttf
├── Caveat-Medium.ttf
├── DancingScript-Bold.ttf
├── arrow-right.svg
├── background.png
├── bootstrap.min.css
├── bootstrap.min.js
├── icon-new.png
├── icon.png
├── images
│ ├── id_eval.png
│ ├── ip_eval_m_00.png
│ └── ip_eval_s.png
├── index.css
├── index.js
├── roboto.woff2
├── teaser.png
└── videos
│ ├── .DS_Store
│ ├── demo
│ ├── .DS_Store
│ ├── demo1.mp4
│ ├── demo2.mp4
│ └── demo3.mp4
│ ├── id
│ ├── .DS_Store
│ ├── 0_ref.mov
│ ├── 1_ref.mov
│ ├── 2_ref.mov
│ ├── 4_ref.mov
│ ├── 5_ref.mov
│ ├── 6_ref.mov
│ ├── 7_ref.mov
│ ├── 8_ref.mov
│ └── 9_ref.mov
│ ├── multi_ip
│ ├── .DS_Store
│ ├── 0.mov
│ ├── 1.mov
│ ├── 10.mov
│ ├── 2.mov
│ ├── 4.mov
│ ├── 5.mov
│ ├── 7.mov
│ ├── 8.mov
│ └── 9.mov
│ └── single_ip
│ ├── .DS_Store
│ ├── 0.mov
│ ├── 1.mov
│ ├── 2.mov
│ ├── 3.mov
│ ├── 4.mov
│ ├── 5.mov
│ ├── 6.mov
│ ├── 7.mov
│ └── 8.mov
├── examples
├── ref1.png
├── ref10.png
├── ref11.png
├── ref2.png
├── ref3.png
├── ref4.png
├── ref5.png
├── ref6.png
├── ref7.png
├── ref8.png
├── ref9.png
└── ref_results
│ ├── result1.gif
│ ├── result2.gif
│ ├── result3.gif
│ └── result4.gif
├── generate.py
├── infer.sh
├── phantom_wan
├── __init__.py
├── configs
│ ├── __init__.py
│ ├── shared_config.py
│ ├── wan_i2v_14B.py
│ ├── wan_s2v_1_3B.py
│ ├── wan_t2v_14B.py
│ └── wan_t2v_1_3B.py
├── distributed
│ ├── __init__.py
│ ├── fsdp.py
│ └── xdit_context_parallel.py
├── image2video.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── clip.py
│ ├── model.py
│ ├── t5.py
│ ├── tokenizers.py
│ ├── vae.py
│ └── xlm_roberta.py
├── subject2video.py
├── text2video.py
└── utils
│ ├── __init__.py
│ ├── fm_solvers.py
│ ├── fm_solvers_unipc.py
│ ├── prompt_extend.py
│ ├── qwen_vl_utils.py
│ └── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .*
2 | *.py[cod]
3 | *.bmp
4 | *.log
5 | *.zip
6 | *.tar
7 | *.pt
8 | *.pth
9 | *.ckpt
10 | *.safetensors
11 | *.json
12 | *.backup
13 | *.pkl
14 | *.html
15 | *.pdf
16 | *.whl
17 | cache
18 | __pycache__/
19 | storage/
20 | samples/
21 | !.gitignore
22 | !requirements.txt
23 | .DS_Store
24 | *DS_Store
25 | assets/.DS_Store
26 | google/
27 | Wan2.1-T2V-14B/
28 | Wan2.1-T2V-1.3B/
29 | Wan2.1-I2V-14B-480P/
30 | Wan2.1-I2V-14B-720P/
31 | Phantom-Wan-1.3B/
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 lab-cv
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment
2 |
3 |
4 |
5 | [](https://arxiv.org/abs/2502.11079)
6 | [](https://phantom-video.github.io/Phantom/)
7 |
8 |
9 |
10 |
11 | > [**Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment**](https://arxiv.org/abs/2502.11079)
12 | > [Lijie Liu](https://liulj13.github.io/) * , [Tianxiang Ma](https://tianxiangma.github.io/) * , [Bingchuan Li](https://scholar.google.com/citations?user=ac5Se6QAAAAJ) * †, [Zhuowei Chen](https://scholar.google.com/citations?user=ow1jGJkAAAAJ) * , [Jiawei Liu](https://scholar.google.com/citations?user=X21Fz-EAAAAJ), Gen Li, Siyu Zhou, [Qian He](https://scholar.google.com/citations?user=9rWWCgUAAAAJ), Xinglong Wu
13 | >
* Equal contribution, † Project lead
14 | >
Intelligent Creation Team, ByteDance
15 |
16 |
17 |
18 |
19 |
20 | ## 🔥 Latest News!
21 | * Apr 10, 2025: We have updated the full version of the Phantom paper, which now includes more detailed descriptions of the model architecture and dataset pipeline.
22 | * Apr 21, 2025: 👋 Phantom-Wan is coming! We adapted the Phantom framework into the [Wan2.1](https://github.com/Wan-Video/Wan2.1) video generation model. The inference codes and checkpoint have been released.
23 | * Apr 23, 2025: 😊 Thanks to [ComfyUI-WanVideoWrapper](https://github.com/kijai/ComfyUI-WanVideoWrapper/tree/dev) for adapting ComfyUI to Phantom-Wan-1.3B. Everyone is welcome to use it !
24 |
25 | ## 📑 Todo List
26 | - [x] Inference codes and Checkpoint of Phantom-Wan 1.3B
27 | - [ ] Checkpoint of Phantom-Wan 14B
28 | - [ ] Training codes of Phantom-Wan
29 |
30 | ## 📖 Overview
31 | Phantom is a unified video generation framework for single and multi-subject references, built on existing text-to-video and image-to-video architectures. It achieves cross-modal alignment using text-image-video triplet data by redesigning the joint text-image injection model. Additionally, it emphasizes subject consistency in human generation while enhancing ID-preserving video generation.
32 |
33 | ## ⚡️ Quickstart
34 |
35 | ### Installation
36 | Clone the repo:
37 | ```sh
38 | git clone https://github.com/Phantom-video/Phantom.git
39 | cd Phantom
40 | ```
41 |
42 | Install dependencies:
43 | ```sh
44 | # Ensure torch >= 2.4.0
45 | pip install -r requirements.txt
46 | ```
47 |
48 | ### Model Download
49 | First you need to download the 1.3B original model of Wan2.1. Download Wan2.1-1.3B using huggingface-cli:
50 | ``` sh
51 | pip install "huggingface_hub[cli]"
52 | huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir ./Wan2.1-T2V-1.3B
53 | ```
54 | Then download the Phantom-Wan-1.3B model:
55 | ``` sh
56 | huggingface-cli download bytedance-research/Phantom --local-dir ./Phantom-Wan-1.3B
57 | ```
58 |
59 | ### Run Subject-to-Video Generation
60 |
61 | - Single-GPU inference
62 |
63 | ``` sh
64 | python generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref1.png,examples/ref2.png" --prompt "暖阳漫过草地,扎着双马尾、头戴绿色蝴蝶结、身穿浅绿色连衣裙的小女孩蹲在盛开的雏菊旁。她身旁一只棕白相间的狗狗吐着舌头,毛茸茸尾巴欢快摇晃。小女孩笑着举起黄红配色、带有蓝色按钮的玩具相机,将和狗狗的欢乐瞬间定格。" --base_seed 42
65 | ```
66 |
67 | - Multi-GPU inference using FSDP + xDiT USP
68 |
69 | ``` sh
70 | pip install "xfuser>=0.4.1"
71 | torchrun --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref3.png,examples/ref4.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "夕阳下,一位有着小麦色肌肤、留着乌黑长发的女人穿上有着大朵立体花朵装饰、肩袖处带有飘逸纱带的红色纱裙,漫步在金色的海滩上,海风轻拂她的长发,画面唯美动人。" --base_seed 42
72 | ```
73 |
74 | > 💡Note:
75 | > * Changing `--ref_image` can achieve single reference Subject-to-Video generation or multi-reference Subject-to-Video generation. The number of reference images should be within 4.
76 | > * To achieve the best generation results, we recommend that you describe the visual content of the reference image as accurately as possible when writing `--prompt`. For example, "examples/ref1.png" can be described as "a toy camera in yellow and red with blue buttons".
77 | > * When the generated video is unsatisfactory, the most straightforward solution is to try changing the `--base_seed` and modifying the description in the `--prompt`.
78 |
79 | For more inference examples, please refer to "infer.sh". You will get the following generated results:
80 |
81 |
82 |
83 |
84 | Reference Images
85 | |
86 |
87 | Generated Videos
88 | |
89 |
90 |
91 |
92 |
93 |
94 |
95 | |
96 |
97 |
98 | |
99 |
100 |
101 |
102 |
103 |
104 |
105 | |
106 |
107 |
108 | |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | |
118 |
119 |
120 | |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 | |
130 |
131 |
132 | |
133 |
134 |
135 |
136 |
137 |
138 | ## 🆚 Comparative Results
139 | - **Identity Preserving Video Generation**.
140 | 
141 | - **Single Reference Subject-to-Video Generation**.
142 | 
143 | - **Multi-Reference Subject-to-Video Generation**.
144 | 
145 |
146 | ## Acknowledgements
147 | We would like to express our gratitude to the SEED team for their support. Special thanks to Lu Jiang, Haoyuan Guo, Zhibei Ma, and Sen Wang for their assistance with the model and data. In addition, we are also very grateful to Siying Chen, Qingyang Li, and Wei Han for their help with the evaluation.
148 |
149 | ## BibTeX
150 | ```bibtex
151 | @article{liu2025phantom,
152 | title={Phantom: Subject-Consistent Video Generation via Cross-Modal Alignment},
153 | author={Liu, Lijie and Ma, Tianxaing and Li, Bingchuan and Chen, Zhuowei and Liu, Jiawei and He, Qian and Wu, Xinglong},
154 | journal={arXiv preprint arXiv:2502.11079},
155 | year={2025}
156 | }
157 | ```
158 |
159 | ## Star History
160 | [](https://www.star-history.com/#Phantom-video/Phantom&Date)
--------------------------------------------------------------------------------
/assets/Alegreya-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/Alegreya-Regular.ttf
--------------------------------------------------------------------------------
/assets/Caveat-Medium.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/Caveat-Medium.ttf
--------------------------------------------------------------------------------
/assets/DancingScript-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/DancingScript-Bold.ttf
--------------------------------------------------------------------------------
/assets/background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/background.png
--------------------------------------------------------------------------------
/assets/icon-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/icon-new.png
--------------------------------------------------------------------------------
/assets/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/icon.png
--------------------------------------------------------------------------------
/assets/images/id_eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/id_eval.png
--------------------------------------------------------------------------------
/assets/images/ip_eval_m_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/ip_eval_m_00.png
--------------------------------------------------------------------------------
/assets/images/ip_eval_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/images/ip_eval_s.png
--------------------------------------------------------------------------------
/assets/index.css:
--------------------------------------------------------------------------------
1 | .nopadding {
2 | padding: 0 !important;
3 | margin: 0 !important;
4 | }
5 |
6 | .flex {
7 | display: flex;
8 | }
9 |
10 | .paper-btn {
11 | position: relative;
12 | text-align: center;
13 |
14 | display: inline-block;
15 | margin: 8px;
16 | padding: 8px 8px;
17 |
18 | border-width: 0;
19 | outline: none;
20 | border-radius: 10px;
21 |
22 | background-color: #3e3e40;
23 | color: white !important;
24 | font-size: 16px;
25 | width: 200px;
26 | font-weight: 350;
27 | }
28 | .paper-btn-parent {
29 | display: flex;
30 | justify-content: center;
31 | margin: 16px 0px;
32 | }
33 | .paper-btn:hover {
34 | opacity: 0.80;
35 | }
36 |
37 | .video {
38 | width: 100%;
39 | height: auto;
40 | }
41 |
42 | .hover-overlay {
43 | position: absolute;
44 | top: 50%;
45 | left: 50%;
46 | transform: translate(-50%, -50%);
47 | width: 80%;
48 | height: auto;
49 | opacity: 1;
50 | display: flex;
51 | flex-direction: column;
52 | justify-content: center;
53 | align-items: center;
54 | font-family: Chalkduster;
55 | font-size: 16px;
56 | text-align: center;
57 | transition: opacity 0.3s ease-in-out;
58 | }
59 |
60 | div.scroll-container {
61 | background-color: #3e3e40;
62 | overflow: auto;
63 | white-space: nowrap;
64 | margin-top: 25px;
65 | margin-bottom: 25px;
66 | padding: 7px 7.5px 2.5px 9px;
67 | }
68 |
69 | .youtube-container {
70 | position: relative;
71 | width: 100%;
72 | padding-bottom: 56.25%; /* 16:9 aspect ratio */
73 | height: 0;
74 | overflow: hidden;
75 | }
76 | .youtube-container iframe {
77 | position: absolute;
78 | top: 0;
79 | left: 0;
80 | width: 100%;
81 | height: 100%;
82 | }
83 |
84 | .image-set {
85 | width: 80%
86 | }
87 | .image-set img {
88 | cursor: pointer;
89 | border: 0px solid transparent;
90 | opacity: 0.5;
91 | height: min(90px, 6vw);
92 | transition: opacity 0.3s ease; /* Smooth transition */
93 | }
94 | .image-set img.selected {
95 | opacity: 1.0;
96 | }
97 | .image-set img:hover {
98 | opacity: 1.0;
99 | }
100 |
101 | .show-neighbors {
102 | overflow: hidden;
103 | }
104 |
105 | .show-neighbors .carousel-indicators {
106 | margin-right: 15%;
107 | margin-left: 15%;
108 | }
109 | .show-neighbors .carousel-control-prev,
110 | .show-neighbors .carousel-control-next {
111 | height: 80%;
112 | width: 15%;
113 | z-index: 11;
114 | /* .carousel-caption has z-index 10 */
115 | }
116 | .show-neighbors .carousel-inner {
117 | width: 200%;
118 | left: -50%;
119 | }
120 | .show-neighbors .carousel-item-next:not(.carousel-item-left),
121 | .show-neighbors .carousel-item-right.active {
122 | -webkit-transform: translate3d(33%, 0, 0);
123 | transform: translate3d(33%, 0, 0);
124 | }
125 | .show-neighbors .carousel-item-prev:not(.carousel-item-right),
126 | .show-neighbors .carousel-item-left.active {
127 | -webkit-transform: translate3d(-33%, 0, 0);
128 | transform: translate3d(-33%, 0, 0);
129 | }
130 | .show-neighbors .item__third {
131 | display: block !important;
132 | float: left;
133 | position: relative;
134 | width: 33.333333333333%;
135 | padding: 0.5%;
136 | text-align: center;
137 | }
138 |
139 | .custom-button {
140 | color: white !important;
141 | text-decoration: none !important;
142 | background-color: #2f2f30;
143 | background-size: cover;
144 | background-repeat: no-repeat;
145 | background-position: center;
146 | border: none;
147 | cursor: pointer;
148 | opacity: 0.7; /* Initial opacity */
149 | transition: opacity 0.3s ease; /* Smooth transition */
150 | }
151 | .custom-button:hover {
152 | opacity: 1.0; /* Increased opacity on hover */
153 | }
154 |
155 | .block-container {
156 | display: flex;
157 | flex-wrap: wrap;
158 | justify-content: center;
159 | gap: 0.5%;
160 | }
161 |
162 | @media screen and (max-width: 1024px) {
163 | .mobile-break { display: block; }
164 | }
--------------------------------------------------------------------------------
/assets/index.js:
--------------------------------------------------------------------------------
1 |
2 | $('.carousel-item', '.show-neighbors').each(function(){
3 | var next = $(this).next();
4 | if (! next.length) {
5 | next = $(this).siblings(':first');
6 | }
7 | next.children(':first-child').clone().appendTo($(this));
8 | }).each(function(){
9 | var prev = $(this).prev();
10 | if (! prev.length) {
11 | prev = $(this).siblings(':last');
12 | }
13 | prev.children(':nth-last-child(2)').clone().prependTo($(this));
14 | });
15 |
--------------------------------------------------------------------------------
/assets/roboto.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/roboto.woff2
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/teaser.png
--------------------------------------------------------------------------------
/assets/videos/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/.DS_Store
--------------------------------------------------------------------------------
/assets/videos/demo/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/.DS_Store
--------------------------------------------------------------------------------
/assets/videos/demo/demo1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo1.mp4
--------------------------------------------------------------------------------
/assets/videos/demo/demo2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo2.mp4
--------------------------------------------------------------------------------
/assets/videos/demo/demo3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/demo/demo3.mp4
--------------------------------------------------------------------------------
/assets/videos/id/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/.DS_Store
--------------------------------------------------------------------------------
/assets/videos/id/0_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/0_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/1_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/1_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/2_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/2_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/4_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/4_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/5_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/5_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/6_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/6_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/7_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/7_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/8_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/8_ref.mov
--------------------------------------------------------------------------------
/assets/videos/id/9_ref.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/id/9_ref.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/.DS_Store
--------------------------------------------------------------------------------
/assets/videos/multi_ip/0.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/0.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/1.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/1.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/10.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/10.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/2.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/2.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/4.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/4.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/5.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/5.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/7.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/7.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/8.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/8.mov
--------------------------------------------------------------------------------
/assets/videos/multi_ip/9.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/multi_ip/9.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/.DS_Store
--------------------------------------------------------------------------------
/assets/videos/single_ip/0.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/0.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/1.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/1.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/2.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/2.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/3.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/3.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/4.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/4.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/5.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/5.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/6.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/6.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/7.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/7.mov
--------------------------------------------------------------------------------
/assets/videos/single_ip/8.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/assets/videos/single_ip/8.mov
--------------------------------------------------------------------------------
/examples/ref1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref1.png
--------------------------------------------------------------------------------
/examples/ref10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref10.png
--------------------------------------------------------------------------------
/examples/ref11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref11.png
--------------------------------------------------------------------------------
/examples/ref2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref2.png
--------------------------------------------------------------------------------
/examples/ref3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref3.png
--------------------------------------------------------------------------------
/examples/ref4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref4.png
--------------------------------------------------------------------------------
/examples/ref5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref5.png
--------------------------------------------------------------------------------
/examples/ref6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref6.png
--------------------------------------------------------------------------------
/examples/ref7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref7.png
--------------------------------------------------------------------------------
/examples/ref8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref8.png
--------------------------------------------------------------------------------
/examples/ref9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref9.png
--------------------------------------------------------------------------------
/examples/ref_results/result1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result1.gif
--------------------------------------------------------------------------------
/examples/ref_results/result2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result2.gif
--------------------------------------------------------------------------------
/examples/ref_results/result3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result3.gif
--------------------------------------------------------------------------------
/examples/ref_results/result4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/examples/ref_results/result4.gif
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | from datetime import datetime
17 | import logging
18 | import os
19 | import sys
20 | import warnings
21 |
22 | warnings.filterwarnings('ignore')
23 |
24 | import torch, random
25 | import torch.distributed as dist
26 | from PIL import Image, ImageOps
27 |
28 | import phantom_wan
29 | from phantom_wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
30 | from phantom_wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
31 | from phantom_wan.utils.utils import cache_video, cache_image, str2bool
32 |
33 | EXAMPLE_PROMPT = {
34 | "t2v-1.3B": {
35 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
36 | },
37 | "t2v-14B": {
38 | "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
39 | },
40 | "t2i-14B": {
41 | "prompt": "一个朴素端庄的美人",
42 | },
43 | "i2v-14B": {
44 | "prompt":
45 | "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
46 | "image":
47 | "examples/i2v_input.JPG",
48 | },
49 | }
50 |
51 |
52 | def _validate_args(args):
53 | # Basic check
54 | assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
55 | assert args.phantom_ckpt is not None, "Please specify the Phantom-Wan checkpoint."
56 | assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
57 |
58 | # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
59 | if args.sample_steps is None:
60 | args.sample_steps = 40 if "i2v" in args.task else 50
61 |
62 | if args.sample_shift is None:
63 | args.sample_shift = 5.0
64 | if "i2v" in args.task and args.size in ["832*480", "480*832"]:
65 | args.sample_shift = 3.0
66 |
67 | # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
68 | if args.frame_num is None:
69 | args.frame_num = 1 if "t2i" in args.task else 81
70 |
71 | # T2I frame_num check
72 | if "t2i" in args.task:
73 | assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
74 |
75 | args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
76 | 0, sys.maxsize)
77 | # Size check
78 | assert args.size in SUPPORTED_SIZES[
79 | args.
80 | task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
81 |
82 |
83 | def _parse_args():
84 | parser = argparse.ArgumentParser(
85 | description="Generate a image or video from a text prompt or image using Wan"
86 | )
87 | parser.add_argument(
88 | "--task",
89 | type=str,
90 | default="s2v-1.3B",
91 | choices=list(WAN_CONFIGS.keys()),
92 | help="The task to run.")
93 | parser.add_argument(
94 | "--size",
95 | type=str,
96 | default="1280*720",
97 | choices=list(SIZE_CONFIGS.keys()),
98 | help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
99 | )
100 | parser.add_argument(
101 | "--frame_num",
102 | type=int,
103 | default=None,
104 | help="How many frames to sample from a image or video. The number should be 4n+1"
105 | )
106 | parser.add_argument(
107 | "--ckpt_dir",
108 | type=str,
109 | default=None,
110 | help="The path to the checkpoint directory.")
111 | parser.add_argument(
112 | "--phantom_ckpt",
113 | type=str,
114 | default=None,
115 | help="The path to the Phantom-Wan checkpoint.")
116 | parser.add_argument(
117 | "--offload_model",
118 | type=str2bool,
119 | default=None,
120 | help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
121 | )
122 | parser.add_argument(
123 | "--ulysses_size",
124 | type=int,
125 | default=1,
126 | help="The size of the ulysses parallelism in DiT.")
127 | parser.add_argument(
128 | "--ring_size",
129 | type=int,
130 | default=1,
131 | help="The size of the ring attention parallelism in DiT.")
132 | parser.add_argument(
133 | "--t5_fsdp",
134 | action="store_true",
135 | default=False,
136 | help="Whether to use FSDP for T5.")
137 | parser.add_argument(
138 | "--t5_cpu",
139 | action="store_true",
140 | default=False,
141 | help="Whether to place T5 model on CPU.")
142 | parser.add_argument(
143 | "--dit_fsdp",
144 | action="store_true",
145 | default=False,
146 | help="Whether to use FSDP for DiT.")
147 | parser.add_argument(
148 | "--save_file",
149 | type=str,
150 | default=None,
151 | help="The file to save the generated image or video to.")
152 | parser.add_argument(
153 | "--prompt",
154 | type=str,
155 | default=None,
156 | help="The prompt to generate the image or video from.")
157 | parser.add_argument(
158 | "--use_prompt_extend",
159 | action="store_true",
160 | default=False,
161 | help="Whether to use prompt extend.")
162 | parser.add_argument(
163 | "--prompt_extend_method",
164 | type=str,
165 | default="local_qwen",
166 | choices=["dashscope", "local_qwen"],
167 | help="The prompt extend method to use.")
168 | parser.add_argument(
169 | "--prompt_extend_model",
170 | type=str,
171 | default=None,
172 | help="The prompt extend model to use.")
173 | parser.add_argument(
174 | "--prompt_extend_target_lang",
175 | type=str,
176 | default="ch",
177 | choices=["ch", "en"],
178 | help="The target language of prompt extend.")
179 | parser.add_argument(
180 | "--base_seed",
181 | type=int,
182 | default=-1,
183 | help="The seed to use for generating the image or video.")
184 | parser.add_argument(
185 | "--image",
186 | type=str,
187 | default=None,
188 | help="The image to generate the video from.")
189 | parser.add_argument(
190 | "--ref_image",
191 | type=str,
192 | default=None,
193 | help="The reference images used by Phantom-Wan.")
194 | parser.add_argument(
195 | "--sample_solver",
196 | type=str,
197 | default='unipc',
198 | choices=['unipc', 'dpm++'],
199 | help="The solver used to sample.")
200 | parser.add_argument(
201 | "--sample_steps", type=int, default=None, help="The sampling steps.")
202 | parser.add_argument(
203 | "--sample_shift",
204 | type=float,
205 | default=None,
206 | help="Sampling shift factor for flow matching schedulers.")
207 | parser.add_argument(
208 | "--sample_guide_scale",
209 | type=float,
210 | default=5.0,
211 | help="Classifier free guidance scale.")
212 | parser.add_argument(
213 | "--sample_guide_scale_img",
214 | type=float,
215 | default=5.0,
216 | help="Classifier free guidance scale for reference images.")
217 | parser.add_argument(
218 | "--sample_guide_scale_text",
219 | type=float,
220 | default=7.5,
221 | help="Classifier free guidance scale for text.")
222 |
223 | args = parser.parse_args()
224 |
225 | _validate_args(args)
226 |
227 | return args
228 |
229 |
230 | def _init_logging(rank):
231 | # logging
232 | if rank == 0:
233 | # set format
234 | logging.basicConfig(
235 | level=logging.INFO,
236 | format="[%(asctime)s] %(levelname)s: %(message)s",
237 | handlers=[logging.StreamHandler(stream=sys.stdout)])
238 | else:
239 | logging.basicConfig(level=logging.ERROR)
240 |
241 |
242 | def load_ref_images(path, size):
243 | # Load size.
244 | h, w = size[1], size[0]
245 | # Load images.
246 | ref_paths = path.split(",")
247 | ref_images = []
248 | for image_path in ref_paths:
249 | with Image.open(image_path) as img:
250 | img = img.convert("RGB")
251 |
252 | # Calculate the required size to keep aspect ratio and fill the rest with padding.
253 | img_ratio = img.width / img.height
254 | target_ratio = w / h
255 |
256 | if img_ratio > target_ratio: # Image is wider than target
257 | new_width = w
258 | new_height = int(new_width / img_ratio)
259 | else: # Image is taller than target
260 | new_height = h
261 | new_width = int(new_height * img_ratio)
262 |
263 | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
264 |
265 | # Create a new image with the target size and place the resized image in the center
266 | delta_w = w - img.size[0]
267 | delta_h = h - img.size[1]
268 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
269 | new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
270 | ref_images.append(new_img)
271 |
272 | return ref_images
273 |
274 |
275 | def generate(args):
276 | rank = int(os.getenv("RANK", 0))
277 | world_size = int(os.getenv("WORLD_SIZE", 1))
278 | local_rank = int(os.getenv("LOCAL_RANK", 0))
279 | device = local_rank
280 | _init_logging(rank)
281 |
282 | if args.offload_model is None:
283 | args.offload_model = False if world_size > 1 else True
284 | logging.info(
285 | f"offload_model is not specified, set to {args.offload_model}.")
286 | if world_size > 1:
287 | torch.cuda.set_device(local_rank)
288 | dist.init_process_group(
289 | backend="nccl",
290 | init_method="env://",
291 | rank=rank,
292 | world_size=world_size)
293 | else:
294 | assert not (
295 | args.t5_fsdp or args.dit_fsdp
296 | ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
297 | assert not (
298 | args.ulysses_size > 1 or args.ring_size > 1
299 | ), f"context parallel are not supported in non-distributed environments."
300 |
301 | if args.ulysses_size > 1 or args.ring_size > 1:
302 | assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
303 | from xfuser.core.distributed import (initialize_model_parallel,
304 | init_distributed_environment)
305 | init_distributed_environment(
306 | rank=dist.get_rank(), world_size=dist.get_world_size())
307 |
308 | initialize_model_parallel(
309 | sequence_parallel_degree=dist.get_world_size(),
310 | ring_degree=args.ring_size,
311 | ulysses_degree=args.ulysses_size,
312 | )
313 |
314 | if args.use_prompt_extend:
315 | if args.prompt_extend_method == "dashscope":
316 | prompt_expander = DashScopePromptExpander(
317 | model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
318 | elif args.prompt_extend_method == "local_qwen":
319 | prompt_expander = QwenPromptExpander(
320 | model_name=args.prompt_extend_model,
321 | is_vl="i2v" in args.task,
322 | device=rank)
323 | else:
324 | raise NotImplementedError(
325 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
326 |
327 | cfg = WAN_CONFIGS[args.task]
328 | if args.ulysses_size > 1:
329 | assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
330 |
331 | logging.info(f"Generation job args: {args}")
332 | logging.info(f"Generation model config: {cfg}")
333 |
334 | if dist.is_initialized():
335 | base_seed = [args.base_seed] if rank == 0 else [None]
336 | dist.broadcast_object_list(base_seed, src=0)
337 | args.base_seed = base_seed[0]
338 |
339 | if "s2v" in args.task:
340 |
341 | ref_images = load_ref_images(args.ref_image, SIZE_CONFIGS[args.size])
342 |
343 | logging.info("Creating Phantom-Wan pipeline.")
344 | wan_s2v = phantom_wan.Phantom_Wan_S2V(
345 | config=cfg,
346 | checkpoint_dir=args.ckpt_dir,
347 | phantom_ckpt=args.phantom_ckpt,
348 | device_id=device,
349 | rank=rank,
350 | t5_fsdp=args.t5_fsdp,
351 | dit_fsdp=args.dit_fsdp,
352 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
353 | t5_cpu=args.t5_cpu,
354 | )
355 |
356 | logging.info(
357 | f"Generating {'image' if 't2i' in args.task else 'video'} ...")
358 | video = wan_s2v.generate(
359 | args.prompt,
360 | ref_images,
361 | size=SIZE_CONFIGS[args.size],
362 | frame_num=args.frame_num,
363 | shift=args.sample_shift,
364 | sample_solver=args.sample_solver,
365 | sampling_steps=args.sample_steps,
366 | guide_scale_img=args.sample_guide_scale_img,
367 | guide_scale_text=args.sample_guide_scale_text,
368 | seed=args.base_seed,
369 | offload_model=args.offload_model)
370 |
371 | elif "t2v" in args.task or "t2i" in args.task:
372 | if args.prompt is None:
373 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
374 | logging.info(f"Input prompt: {args.prompt}")
375 | if args.use_prompt_extend:
376 | logging.info("Extending prompt ...")
377 | if rank == 0:
378 | prompt_output = prompt_expander(
379 | args.prompt,
380 | tar_lang=args.prompt_extend_target_lang,
381 | seed=args.base_seed)
382 | if prompt_output.status == False:
383 | logging.info(
384 | f"Extending prompt failed: {prompt_output.message}")
385 | logging.info("Falling back to original prompt.")
386 | input_prompt = args.prompt
387 | else:
388 | input_prompt = prompt_output.prompt
389 | input_prompt = [input_prompt]
390 | else:
391 | input_prompt = [None]
392 | if dist.is_initialized():
393 | dist.broadcast_object_list(input_prompt, src=0)
394 | args.prompt = input_prompt[0]
395 | logging.info(f"Extended prompt: {args.prompt}")
396 |
397 | logging.info("Creating Phantom-Wan pipeline.")
398 | wan_t2v = phantom_wan.WanT2V(
399 | config=cfg,
400 | checkpoint_dir=args.ckpt_dir,
401 | device_id=device,
402 | rank=rank,
403 | t5_fsdp=args.t5_fsdp,
404 | dit_fsdp=args.dit_fsdp,
405 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
406 | t5_cpu=args.t5_cpu,
407 | )
408 |
409 | logging.info(
410 | f"Generating {'image' if 't2i' in args.task else 'video'} ...")
411 | video = wan_t2v.generate(
412 | args.prompt,
413 | size=SIZE_CONFIGS[args.size],
414 | frame_num=args.frame_num,
415 | shift=args.sample_shift,
416 | sample_solver=args.sample_solver,
417 | sampling_steps=args.sample_steps,
418 | guide_scale=args.sample_guide_scale,
419 | seed=args.base_seed,
420 | offload_model=args.offload_model)
421 |
422 | else:
423 | if args.prompt is None:
424 | args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
425 | if args.image is None:
426 | args.image = EXAMPLE_PROMPT[args.task]["image"]
427 | logging.info(f"Input prompt: {args.prompt}")
428 | logging.info(f"Input image: {args.image}")
429 |
430 | img = Image.open(args.image).convert("RGB")
431 | if args.use_prompt_extend:
432 | logging.info("Extending prompt ...")
433 | if rank == 0:
434 | prompt_output = prompt_expander(
435 | args.prompt,
436 | tar_lang=args.prompt_extend_target_lang,
437 | image=img,
438 | seed=args.base_seed)
439 | if prompt_output.status == False:
440 | logging.info(
441 | f"Extending prompt failed: {prompt_output.message}")
442 | logging.info("Falling back to original prompt.")
443 | input_prompt = args.prompt
444 | else:
445 | input_prompt = prompt_output.prompt
446 | input_prompt = [input_prompt]
447 | else:
448 | input_prompt = [None]
449 | if dist.is_initialized():
450 | dist.broadcast_object_list(input_prompt, src=0)
451 | args.prompt = input_prompt[0]
452 | logging.info(f"Extended prompt: {args.prompt}")
453 |
454 | logging.info("Creating WanI2V pipeline.")
455 | wan_i2v = phantom_wan.WanI2V(
456 | config=cfg,
457 | checkpoint_dir=args.ckpt_dir,
458 | device_id=device,
459 | rank=rank,
460 | t5_fsdp=args.t5_fsdp,
461 | dit_fsdp=args.dit_fsdp,
462 | use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
463 | t5_cpu=args.t5_cpu,
464 | )
465 |
466 | logging.info("Generating video ...")
467 | video = wan_i2v.generate(
468 | args.prompt,
469 | img,
470 | max_area=MAX_AREA_CONFIGS[args.size],
471 | frame_num=args.frame_num,
472 | shift=args.sample_shift,
473 | sample_solver=args.sample_solver,
474 | sampling_steps=args.sample_steps,
475 | guide_scale=args.sample_guide_scale,
476 | seed=args.base_seed,
477 | offload_model=args.offload_model)
478 |
479 | if rank == 0:
480 | if args.save_file is None:
481 | formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
482 | formatted_prompt = args.prompt.replace(" ", "_").replace("/",
483 | "_")[:50]
484 | suffix = '.png' if "t2i" in args.task else '.mp4'
485 | args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
486 |
487 | if "t2i" in args.task:
488 | logging.info(f"Saving generated image to {args.save_file}")
489 | cache_image(
490 | tensor=video.squeeze(1)[None],
491 | save_file=args.save_file,
492 | nrow=1,
493 | normalize=True,
494 | value_range=(-1, 1))
495 | else:
496 | logging.info(f"Saving generated video to {args.save_file}")
497 | cache_video(
498 | tensor=video[None],
499 | save_file=args.save_file,
500 | fps=cfg.sample_fps,
501 | nrow=1,
502 | normalize=True,
503 | value_range=(-1, 1))
504 | logging.info("Finished.")
505 |
506 |
507 | if __name__ == "__main__":
508 | args = _parse_args()
509 | generate(args)
510 |
--------------------------------------------------------------------------------
/infer.sh:
--------------------------------------------------------------------------------
1 | torchrun \
2 | --node_rank=0 \
3 | --nnodes=1 \
4 | --rdzv_endpoint=127.0.0.1:23468 \
5 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref1.png,examples/ref2.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "暖阳漫过草地,扎着双马尾、头戴绿色蝴蝶结、身穿浅绿色连衣裙的小女孩蹲在盛开的雏菊旁。她身旁一只棕白相间的狗狗吐着舌头,毛茸茸尾巴欢快摇晃。小女孩笑着举起黄红配色、带有蓝色按钮的玩具相机,将和狗狗的欢乐瞬间定格。" --base_seed 42
6 |
7 | torchrun \
8 | --node_rank=0 \
9 | --nnodes=1 \
10 | --rdzv_endpoint=127.0.0.1:23468 \
11 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref3.png,examples/ref4.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "夕阳下,一位有着小麦色肌肤、留着乌黑长发的女人穿上有着大朵立体花朵装饰、肩袖处带有飘逸纱带的红色纱裙,漫步在金色的海滩上,海风轻拂她的长发,画面唯美动人。" --base_seed 42
12 |
13 | torchrun \
14 | --node_rank=0 \
15 | --nnodes=1 \
16 | --rdzv_endpoint=127.0.0.1:23468 \
17 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref5.png,examples/ref6.png,examples/ref7.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "在被冰雪覆盖,周围盛开着粉色花朵,有蝴蝶飞舞,屋内透出暖黄色灯光的梦幻小屋场景下,一位头发灰白、穿着深绿色上衣的老人牵着梳着双丸子头、身着中式传统服饰、外披白色毛绒衣物的小女孩的手,缓缓前行,画面温馨宁静。" --base_seed 42
18 |
19 | torchrun \
20 | --node_rank=0 \
21 | --nnodes=1 \
22 | --rdzv_endpoint=127.0.0.1:23468 \
23 | --nproc_per_node=8 generate.py --task s2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --phantom_ckpt ./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth --ref_image "examples/ref8.png,examples/ref9.png,examples/ref10.png,examples/ref11.png" --dit_fsdp --t5_fsdp --ulysses_size 4 --ring_size 2 --prompt "一位金色长发的女人身穿棕色带波点网纱长袖、胸前系带设计的泳衣,手持一杯有橙色切片和草莓装饰、插着绿色吸管的分层鸡尾酒,坐在有着棕榈树、铺有蓝白条纹毯子和灰色垫子、摆放着躺椅的沙滩上晒日光浴的慢镜头,捕捉她享受阳光的微笑与海浪轻抚沙滩的美景。" --base_seed 42
--------------------------------------------------------------------------------
/phantom_wan/__init__.py:
--------------------------------------------------------------------------------
1 | from . import configs, distributed, modules
2 | from .image2video import WanI2V
3 | from .text2video import WanT2V
4 | from .subject2video import Phantom_Wan_S2V
5 |
--------------------------------------------------------------------------------
/phantom_wan/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import copy
3 | import os
4 |
5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6 |
7 | from .wan_s2v_1_3B import s2v_1_3B
8 | from .wan_i2v_14B import i2v_14B
9 | from .wan_t2v_1_3B import t2v_1_3B
10 | from .wan_t2v_14B import t2v_14B
11 |
12 | # the config of t2i_14B is the same as t2v_14B
13 | t2i_14B = copy.deepcopy(t2v_14B)
14 | t2i_14B.__name__ = 'Config: Wan T2I 14B'
15 |
16 | WAN_CONFIGS = {
17 | 's2v-1.3B': s2v_1_3B,
18 | 't2v-14B': t2v_14B,
19 | 't2v-1.3B': t2v_1_3B,
20 | 'i2v-14B': i2v_14B,
21 | 't2i-14B': t2i_14B,
22 | }
23 |
24 | SIZE_CONFIGS = {
25 | '720*1280': (720, 1280),
26 | '1280*720': (1280, 720),
27 | '480*832': (480, 832),
28 | '832*480': (832, 480),
29 | '1024*1024': (1024, 1024),
30 | }
31 |
32 | MAX_AREA_CONFIGS = {
33 | '720*1280': 720 * 1280,
34 | '1280*720': 1280 * 720,
35 | '480*832': 480 * 832,
36 | '832*480': 832 * 480,
37 | }
38 |
39 | SUPPORTED_SIZES = {
40 | 's2v-1.3B': ('832*480',),
41 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
42 | 't2v-1.3B': ('480*832', '832*480'),
43 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
44 | 't2i-14B': tuple(SIZE_CONFIGS.keys()),
45 | }
46 |
--------------------------------------------------------------------------------
/phantom_wan/configs/shared_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | from easydict import EasyDict
4 |
5 | #------------------------ Wan shared config ------------------------#
6 | wan_shared_cfg = EasyDict()
7 |
8 | # t5
9 | wan_shared_cfg.t5_model = 'umt5_xxl'
10 | wan_shared_cfg.t5_dtype = torch.bfloat16
11 | wan_shared_cfg.text_len = 512
12 |
13 | # transformer
14 | wan_shared_cfg.param_dtype = torch.bfloat16
15 |
16 | # inference
17 | wan_shared_cfg.num_train_timesteps = 1000
18 | wan_shared_cfg.sample_fps = 16
19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20 |
--------------------------------------------------------------------------------
/phantom_wan/configs/wan_i2v_14B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | from easydict import EasyDict
4 |
5 | from .shared_config import wan_shared_cfg
6 |
7 | #------------------------ Wan I2V 14B ------------------------#
8 |
9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10 | i2v_14B.update(wan_shared_cfg)
11 |
12 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | i2v_14B.t5_tokenizer = 'google/umt5-xxl'
14 |
15 | # clip
16 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
17 | i2v_14B.clip_dtype = torch.float16
18 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
19 | i2v_14B.clip_tokenizer = 'xlm-roberta-large'
20 |
21 | # vae
22 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
23 | i2v_14B.vae_stride = (4, 8, 8)
24 |
25 | # transformer
26 | i2v_14B.patch_size = (1, 2, 2)
27 | i2v_14B.dim = 5120
28 | i2v_14B.ffn_dim = 13824
29 | i2v_14B.freq_dim = 256
30 | i2v_14B.num_heads = 40
31 | i2v_14B.num_layers = 40
32 | i2v_14B.window_size = (-1, -1)
33 | i2v_14B.qk_norm = True
34 | i2v_14B.cross_attn_norm = True
35 | i2v_14B.eps = 1e-6
36 |
--------------------------------------------------------------------------------
/phantom_wan/configs/wan_s2v_1_3B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from easydict import EasyDict
16 |
17 | from .shared_config import wan_shared_cfg
18 |
19 | #------------------------ Wan T2V 1.3B ------------------------#
20 |
21 | s2v_1_3B = EasyDict(__name__='Config: Phantom-Wan S2V 1.3B')
22 | s2v_1_3B.update(wan_shared_cfg)
23 |
24 | # t5
25 | s2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
26 | s2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
27 |
28 | # vae
29 | s2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
30 | s2v_1_3B.vae_stride = (4, 8, 8)
31 |
32 | # transformer
33 | s2v_1_3B.patch_size = (1, 2, 2)
34 | s2v_1_3B.dim = 1536
35 | s2v_1_3B.ffn_dim = 8960
36 | s2v_1_3B.freq_dim = 256
37 | s2v_1_3B.num_heads = 12
38 | s2v_1_3B.num_layers = 30
39 | s2v_1_3B.window_size = (-1, -1)
40 | s2v_1_3B.qk_norm = True
41 | s2v_1_3B.cross_attn_norm = True
42 | s2v_1_3B.eps = 1e-6
43 |
--------------------------------------------------------------------------------
/phantom_wan/configs/wan_t2v_14B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | from easydict import EasyDict
3 |
4 | from .shared_config import wan_shared_cfg
5 |
6 | #------------------------ Wan T2V 14B ------------------------#
7 |
8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9 | t2v_14B.update(wan_shared_cfg)
10 |
11 | # t5
12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14 |
15 | # vae
16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_14B.vae_stride = (4, 8, 8)
18 |
19 | # transformer
20 | t2v_14B.patch_size = (1, 2, 2)
21 | t2v_14B.dim = 5120
22 | t2v_14B.ffn_dim = 13824
23 | t2v_14B.freq_dim = 256
24 | t2v_14B.num_heads = 40
25 | t2v_14B.num_layers = 40
26 | t2v_14B.window_size = (-1, -1)
27 | t2v_14B.qk_norm = True
28 | t2v_14B.cross_attn_norm = True
29 | t2v_14B.eps = 1e-6
30 |
--------------------------------------------------------------------------------
/phantom_wan/configs/wan_t2v_1_3B.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | from easydict import EasyDict
3 |
4 | from .shared_config import wan_shared_cfg
5 |
6 | #------------------------ Wan T2V 1.3B ------------------------#
7 |
8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
9 | t2v_1_3B.update(wan_shared_cfg)
10 |
11 | # t5
12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14 |
15 | # vae
16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_1_3B.vae_stride = (4, 8, 8)
18 |
19 | # transformer
20 | t2v_1_3B.patch_size = (1, 2, 2)
21 | t2v_1_3B.dim = 1536
22 | t2v_1_3B.ffn_dim = 8960
23 | t2v_1_3B.freq_dim = 256
24 | t2v_1_3B.num_heads = 12
25 | t2v_1_3B.num_layers = 30
26 | t2v_1_3B.window_size = (-1, -1)
27 | t2v_1_3B.qk_norm = True
28 | t2v_1_3B.cross_attn_norm = True
29 | t2v_1_3B.eps = 1e-6
30 |
--------------------------------------------------------------------------------
/phantom_wan/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phantom-video/Phantom/a82240a0f9a480f941718189f39242c23f23b472/phantom_wan/distributed/__init__.py
--------------------------------------------------------------------------------
/phantom_wan/distributed/fsdp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | from functools import partial
3 |
4 | import torch
5 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
7 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
8 |
9 |
10 | def shard_model(
11 | model,
12 | device_id,
13 | param_dtype=torch.bfloat16,
14 | reduce_dtype=torch.float32,
15 | buffer_dtype=torch.float32,
16 | process_group=None,
17 | sharding_strategy=ShardingStrategy.FULL_SHARD,
18 | sync_module_states=True,
19 | ):
20 | model = FSDP(
21 | module=model,
22 | process_group=process_group,
23 | sharding_strategy=sharding_strategy,
24 | auto_wrap_policy=partial(
25 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
26 | mixed_precision=MixedPrecision(
27 | param_dtype=param_dtype,
28 | reduce_dtype=reduce_dtype,
29 | buffer_dtype=buffer_dtype),
30 | device_id=device_id,
31 | sync_module_states=sync_module_states)
32 | return model
33 |
--------------------------------------------------------------------------------
/phantom_wan/distributed/xdit_context_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 | import torch.cuda.amp as amp
4 | from xfuser.core.distributed import (get_sequence_parallel_rank,
5 | get_sequence_parallel_world_size,
6 | get_sp_group)
7 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8 |
9 | from ..modules.model import sinusoidal_embedding_1d
10 |
11 |
12 | def pad_freqs(original_tensor, target_len):
13 | seq_len, s1, s2 = original_tensor.shape
14 | pad_size = target_len - seq_len
15 | padding_tensor = torch.ones(
16 | pad_size,
17 | s1,
18 | s2,
19 | dtype=original_tensor.dtype,
20 | device=original_tensor.device)
21 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22 | return padded_tensor
23 |
24 |
25 | @amp.autocast(enabled=False)
26 | def rope_apply(x, grid_sizes, freqs):
27 | """
28 | x: [B, L, N, C].
29 | grid_sizes: [B, 3].
30 | freqs: [M, C // 2].
31 | """
32 | s, n, c = x.size(1), x.size(2), x.size(3) // 2
33 | # split freqs
34 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35 |
36 | # loop over samples
37 | output = []
38 | for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39 | seq_len = f * h * w
40 |
41 | # precompute multipliers
42 | x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
43 | s, n, -1, 2))
44 | freqs_i = torch.cat([
45 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48 | ],
49 | dim=-1).reshape(seq_len, 1, -1)
50 |
51 | # apply rotary embedding
52 | sp_size = get_sequence_parallel_world_size()
53 | sp_rank = get_sequence_parallel_rank()
54 | freqs_i = pad_freqs(freqs_i, s * sp_size)
55 | s_per_rank = s
56 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57 | s_per_rank), :, :]
58 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59 | x_i = torch.cat([x_i, x[i, s:]])
60 |
61 | # append to collection
62 | output.append(x_i)
63 | return torch.stack(output).float()
64 |
65 |
66 | def usp_dit_forward(
67 | self,
68 | x,
69 | t,
70 | context,
71 | seq_len,
72 | clip_fea=None,
73 | y=None,
74 | ):
75 | """
76 | x: A list of videos each with shape [C, T, H, W].
77 | t: [B].
78 | context: A list of text embeddings each with shape [L, C].
79 | """
80 | if self.model_type == 'i2v':
81 | assert clip_fea is not None and y is not None
82 | # params
83 | device = self.patch_embedding.weight.device
84 | if self.freqs.device != device:
85 | self.freqs = self.freqs.to(device)
86 |
87 | if y is not None:
88 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89 |
90 | # embeddings
91 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92 | grid_sizes = torch.stack(
93 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94 | x = [u.flatten(2).transpose(1, 2) for u in x]
95 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96 | assert seq_lens.max() <= seq_len
97 | x = torch.cat([
98 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99 | for u in x
100 | ])
101 |
102 | # time embeddings
103 | with amp.autocast(dtype=torch.float32):
104 | e = self.time_embedding(
105 | sinusoidal_embedding_1d(self.freq_dim, t).float())
106 | e0 = self.time_projection(e).unflatten(1, (6, self.dim))
107 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
108 |
109 | # context
110 | context_lens = None
111 | context = self.text_embedding(
112 | torch.stack([
113 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
114 | for u in context
115 | ]))
116 |
117 | if clip_fea is not None:
118 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
119 | context = torch.concat([context_clip, context], dim=1)
120 |
121 | # arguments
122 | kwargs = dict(
123 | e=e0,
124 | seq_lens=seq_lens,
125 | grid_sizes=grid_sizes,
126 | freqs=self.freqs,
127 | context=context,
128 | context_lens=context_lens)
129 |
130 | # Context Parallel
131 | x = torch.chunk(
132 | x, get_sequence_parallel_world_size(),
133 | dim=1)[get_sequence_parallel_rank()]
134 |
135 | for block in self.blocks:
136 | x = block(x, **kwargs)
137 |
138 | # head
139 | x = self.head(x, e)
140 |
141 | # Context Parallel
142 | x = get_sp_group().all_gather(x, dim=1)
143 |
144 | # unpatchify
145 | x = self.unpatchify(x, grid_sizes)
146 | return [u.float() for u in x]
147 |
148 |
149 | def usp_attn_forward(self,
150 | x,
151 | seq_lens,
152 | grid_sizes,
153 | freqs,
154 | dtype=torch.bfloat16):
155 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
156 | half_dtypes = (torch.float16, torch.bfloat16)
157 |
158 | def half(x):
159 | return x if x.dtype in half_dtypes else x.to(dtype)
160 |
161 | # query, key, value function
162 | def qkv_fn(x):
163 | q = self.norm_q(self.q(x)).view(b, s, n, d)
164 | k = self.norm_k(self.k(x)).view(b, s, n, d)
165 | v = self.v(x).view(b, s, n, d)
166 | return q, k, v
167 |
168 | q, k, v = qkv_fn(x)
169 | q = rope_apply(q, grid_sizes, freqs)
170 | k = rope_apply(k, grid_sizes, freqs)
171 |
172 | # TODO: We should use unpaded q,k,v for attention.
173 | # k_lens = seq_lens // get_sequence_parallel_world_size()
174 | # if k_lens is not None:
175 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
176 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
177 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
178 |
179 | x = xFuserLongContextAttention()(
180 | None,
181 | query=half(q),
182 | key=half(k),
183 | value=half(v),
184 | window_size=self.window_size)
185 |
186 | # TODO: padding after attention.
187 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
188 |
189 | # output
190 | x = x.flatten(2)
191 | x = self.o(x)
192 | return x
--------------------------------------------------------------------------------
/phantom_wan/image2video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import gc
3 | import logging
4 | import math
5 | import os
6 | import random
7 | import sys
8 | import types
9 | from contextlib import contextmanager
10 | from functools import partial
11 |
12 | import numpy as np
13 | import torch
14 | import torch.cuda.amp as amp
15 | import torch.distributed as dist
16 | import torchvision.transforms.functional as TF
17 | from tqdm import tqdm
18 |
19 | from .distributed.fsdp import shard_model
20 | from .modules.clip import CLIPModel
21 | from .modules.model import WanModel
22 | from .modules.t5 import T5EncoderModel
23 | from .modules.vae import WanVAE
24 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25 | get_sampling_sigmas, retrieve_timesteps)
26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27 |
28 |
29 | class WanI2V:
30 |
31 | def __init__(
32 | self,
33 | config,
34 | checkpoint_dir,
35 | device_id=0,
36 | rank=0,
37 | t5_fsdp=False,
38 | dit_fsdp=False,
39 | use_usp=False,
40 | t5_cpu=False,
41 | # init_on_cpu=True,
42 | ):
43 | r"""
44 | Initializes the image-to-video generation model components.
45 |
46 | Args:
47 | config (EasyDict):
48 | Object containing model parameters initialized from config.py
49 | checkpoint_dir (`str`):
50 | Path to directory containing model checkpoints
51 | device_id (`int`, *optional*, defaults to 0):
52 | Id of target GPU device
53 | rank (`int`, *optional*, defaults to 0):
54 | Process rank for distributed training
55 | t5_fsdp (`bool`, *optional*, defaults to False):
56 | Enable FSDP sharding for T5 model
57 | dit_fsdp (`bool`, *optional*, defaults to False):
58 | Enable FSDP sharding for DiT model
59 | use_usp (`bool`, *optional*, defaults to False):
60 | Enable distribution strategy of USP.
61 | t5_cpu (`bool`, *optional*, defaults to False):
62 | Whether to place T5 model on CPU. Only works without t5_fsdp.
63 | init_on_cpu (`bool`, *optional*, defaults to True):
64 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
65 | """
66 | self.device = torch.device(f"cuda:{device_id}")
67 | self.config = config
68 | self.rank = rank
69 | self.use_usp = use_usp
70 | self.t5_cpu = t5_cpu
71 |
72 | self.num_train_timesteps = config.num_train_timesteps
73 | self.param_dtype = config.param_dtype
74 |
75 | shard_fn = partial(shard_model, device_id=device_id)
76 | self.text_encoder = T5EncoderModel(
77 | text_len=config.text_len,
78 | dtype=config.t5_dtype,
79 | device=torch.device('cpu'),
80 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
81 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
82 | shard_fn=shard_fn if t5_fsdp else None,
83 | )
84 |
85 | self.vae_stride = config.vae_stride
86 | self.patch_size = config.patch_size
87 | self.vae = WanVAE(
88 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
89 | device=self.device)
90 |
91 | self.clip = CLIPModel(
92 | dtype=config.clip_dtype,
93 | device=self.device,
94 | checkpoint_path=os.path.join(checkpoint_dir,
95 | config.clip_checkpoint),
96 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
97 |
98 | logging.info(f"Creating WanModel from {checkpoint_dir}")
99 | self.model = WanModel.from_pretrained(checkpoint_dir, )
100 | self.model.eval().requires_grad_(False)
101 |
102 | # if t5_fsdp or dit_fsdp or use_usp:
103 | # init_on_cpu = False
104 |
105 | if use_usp:
106 | from xfuser.core.distributed import \
107 | get_sequence_parallel_world_size
108 |
109 | from .distributed.xdit_context_parallel import (usp_attn_forward,
110 | usp_dit_forward)
111 | for block in self.model.blocks:
112 | block.self_attn.forward = types.MethodType(
113 | usp_attn_forward, block.self_attn)
114 | self.model.forward = types.MethodType(usp_dit_forward, self.model)
115 | self.sp_size = get_sequence_parallel_world_size()
116 | else:
117 | self.sp_size = 1
118 |
119 | if dist.is_initialized():
120 | dist.barrier()
121 | if dit_fsdp:
122 | self.model = shard_fn(self.model)
123 | else:
124 | # if not init_on_cpu:
125 | self.model.to(self.device)
126 |
127 | self.sample_neg_prompt = config.sample_neg_prompt
128 |
129 | def generate(self,
130 | input_prompt,
131 | img,
132 | max_area=720 * 1280,
133 | frame_num=81,
134 | shift=5.0,
135 | sample_solver='unipc',
136 | sampling_steps=40,
137 | guide_scale=5.0,
138 | n_prompt="",
139 | seed=-1,
140 | offload_model=True):
141 | r"""
142 | Generates video frames from input image and text prompt using diffusion process.
143 |
144 | Args:
145 | input_prompt (`str`):
146 | Text prompt for content generation.
147 | img (PIL.Image.Image):
148 | Input image tensor. Shape: [3, H, W]
149 | max_area (`int`, *optional*, defaults to 720*1280):
150 | Maximum pixel area for latent space calculation. Controls video resolution scaling
151 | frame_num (`int`, *optional*, defaults to 81):
152 | How many frames to sample from a video. The number should be 4n+1
153 | shift (`float`, *optional*, defaults to 5.0):
154 | Noise schedule shift parameter. Affects temporal dynamics
155 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
156 | sample_solver (`str`, *optional*, defaults to 'unipc'):
157 | Solver used to sample the video.
158 | sampling_steps (`int`, *optional*, defaults to 40):
159 | Number of diffusion sampling steps. Higher values improve quality but slow generation
160 | guide_scale (`float`, *optional*, defaults 5.0):
161 | Classifier-free guidance scale. Controls prompt adherence vs. creativity
162 | n_prompt (`str`, *optional*, defaults to ""):
163 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
164 | seed (`int`, *optional*, defaults to -1):
165 | Random seed for noise generation. If -1, use random seed
166 | offload_model (`bool`, *optional*, defaults to True):
167 | If True, offloads models to CPU during generation to save VRAM
168 |
169 | Returns:
170 | torch.Tensor:
171 | Generated video frames tensor. Dimensions: (C, N H, W) where:
172 | - C: Color channels (3 for RGB)
173 | - N: Number of frames (81)
174 | - H: Frame height (from max_area)
175 | - W: Frame width from max_area)
176 | """
177 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
178 |
179 | F = frame_num
180 | h, w = img.shape[1:]
181 | aspect_ratio = h / w
182 | lat_h = round(
183 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
184 | self.patch_size[1] * self.patch_size[1])
185 | lat_w = round(
186 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
187 | self.patch_size[2] * self.patch_size[2])
188 | h = lat_h * self.vae_stride[1]
189 | w = lat_w * self.vae_stride[2]
190 |
191 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
192 | self.patch_size[1] * self.patch_size[2])
193 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
194 |
195 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
196 | seed_g = torch.Generator(device=self.device)
197 | seed_g.manual_seed(seed)
198 | noise = torch.randn(
199 | 16,
200 | 21,
201 | lat_h,
202 | lat_w,
203 | dtype=torch.float32,
204 | generator=seed_g,
205 | device=self.device)
206 |
207 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
208 | msk[:, 1:] = 0
209 | msk = torch.concat([
210 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
211 | ],dim=1)
212 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
213 | msk = msk.transpose(1, 2)[0]
214 |
215 | if n_prompt == "":
216 | n_prompt = self.sample_neg_prompt
217 |
218 | # preprocess
219 | if not self.t5_cpu:
220 | self.text_encoder.model.to(self.device)
221 | context = self.text_encoder([input_prompt], self.device)
222 | context_null = self.text_encoder([n_prompt], self.device)
223 | if offload_model:
224 | self.text_encoder.model.cpu()
225 | else:
226 | context = self.text_encoder([input_prompt], torch.device('cpu'))
227 | context_null = self.text_encoder([n_prompt], torch.device('cpu'))
228 | context = [t.to(self.device) for t in context]
229 | context_null = [t.to(self.device) for t in context_null]
230 |
231 | self.clip.model.to(self.device)
232 | clip_context = self.clip.visual([img[:, None, :, :]])
233 | if offload_model:
234 | self.clip.model.cpu()
235 |
236 | y = self.vae.encode([
237 | torch.concat([
238 | torch.nn.functional.interpolate(
239 | img[None].cpu(), size=(h, w), mode='bicubic').transpose(
240 | 0, 1),
241 | torch.zeros(3, 80, h, w)
242 | ],
243 | dim=1).to(self.device)
244 | ])[0]
245 | y = torch.concat([msk, y])
246 |
247 | @contextmanager
248 | def noop_no_sync():
249 | yield
250 |
251 | no_sync = getattr(self.model, 'no_sync', noop_no_sync)
252 |
253 | # evaluation mode
254 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
255 |
256 | if sample_solver == 'unipc':
257 | sample_scheduler = FlowUniPCMultistepScheduler(
258 | num_train_timesteps=self.num_train_timesteps,
259 | shift=1,
260 | use_dynamic_shifting=False)
261 | sample_scheduler.set_timesteps(
262 | sampling_steps, device=self.device, shift=shift)
263 | timesteps = sample_scheduler.timesteps
264 | elif sample_solver == 'dpm++':
265 | sample_scheduler = FlowDPMSolverMultistepScheduler(
266 | num_train_timesteps=self.num_train_timesteps,
267 | shift=1,
268 | use_dynamic_shifting=False)
269 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
270 | timesteps, _ = retrieve_timesteps(
271 | sample_scheduler,
272 | device=self.device,
273 | sigmas=sampling_sigmas)
274 | else:
275 | raise NotImplementedError("Unsupported solver.")
276 |
277 | # sample videos
278 | latent = noise
279 |
280 | arg_c = {
281 | 'context': [context[0]],
282 | 'clip_fea': clip_context,
283 | 'seq_len': max_seq_len,
284 | 'y': [y],
285 | }
286 |
287 | arg_null = {
288 | 'context': context_null,
289 | 'clip_fea': clip_context,
290 | 'seq_len': max_seq_len,
291 | 'y': [y],
292 | }
293 |
294 | if offload_model:
295 | torch.cuda.empty_cache()
296 |
297 | self.model.to(self.device)
298 | for _, t in enumerate(tqdm(timesteps)):
299 | latent_model_input = [latent.to(self.device)]
300 | timestep = [t]
301 |
302 | timestep = torch.stack(timestep).to(self.device)
303 |
304 | noise_pred_cond = self.model(
305 | latent_model_input, t=timestep, **arg_c)[0].to(
306 | torch.device('cpu') if offload_model else self.device)
307 | if offload_model:
308 | torch.cuda.empty_cache()
309 | noise_pred_uncond = self.model(
310 | latent_model_input, t=timestep, **arg_null)[0].to(
311 | torch.device('cpu') if offload_model else self.device)
312 | if offload_model:
313 | torch.cuda.empty_cache()
314 | noise_pred = noise_pred_uncond + guide_scale * (
315 | noise_pred_cond - noise_pred_uncond)
316 |
317 | latent = latent.to(
318 | torch.device('cpu') if offload_model else self.device)
319 |
320 | temp_x0 = sample_scheduler.step(
321 | noise_pred.unsqueeze(0),
322 | t,
323 | latent.unsqueeze(0),
324 | return_dict=False,
325 | generator=seed_g)[0]
326 | latent = temp_x0.squeeze(0)
327 |
328 | x0 = [latent.to(self.device)]
329 | del latent_model_input, timestep
330 |
331 | if offload_model:
332 | self.model.cpu()
333 | torch.cuda.empty_cache()
334 |
335 | if self.rank == 0:
336 | videos = self.vae.decode(x0)
337 |
338 | del noise, latent
339 | del sample_scheduler
340 | if offload_model:
341 | gc.collect()
342 | torch.cuda.synchronize()
343 | if dist.is_initialized():
344 | dist.barrier()
345 |
346 | return videos[0] if self.rank == 0 else None
347 |
--------------------------------------------------------------------------------
/phantom_wan/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import flash_attention
2 | from .model import WanModel
3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4 | from .tokenizers import HuggingfaceTokenizer
5 | from .vae import WanVAE
6 |
7 | __all__ = [
8 | 'WanVAE',
9 | 'WanModel',
10 | 'T5Model',
11 | 'T5Encoder',
12 | 'T5Decoder',
13 | 'T5EncoderModel',
14 | 'HuggingfaceTokenizer',
15 | 'flash_attention',
16 | ]
17 |
--------------------------------------------------------------------------------
/phantom_wan/modules/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import torch
3 |
4 | try:
5 | import flash_attn_interface
6 | FLASH_ATTN_3_AVAILABLE = True
7 | except ModuleNotFoundError:
8 | FLASH_ATTN_3_AVAILABLE = False
9 |
10 | try:
11 | import flash_attn
12 | FLASH_ATTN_2_AVAILABLE = True
13 | except ModuleNotFoundError:
14 | FLASH_ATTN_2_AVAILABLE = False
15 |
16 | import warnings
17 |
18 | __all__ = [
19 | 'flash_attention',
20 | 'attention',
21 | ]
22 |
23 |
24 | def flash_attention(
25 | q,
26 | k,
27 | v,
28 | q_lens=None,
29 | k_lens=None,
30 | dropout_p=0.,
31 | softmax_scale=None,
32 | q_scale=None,
33 | causal=False,
34 | window_size=(-1, -1),
35 | deterministic=False,
36 | dtype=torch.bfloat16,
37 | version=None,
38 | ):
39 | """
40 | q: [B, Lq, Nq, C1].
41 | k: [B, Lk, Nk, C1].
42 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43 | q_lens: [B].
44 | k_lens: [B].
45 | dropout_p: float. Dropout probability.
46 | softmax_scale: float. The scaling of QK^T before applying softmax.
47 | causal: bool. Whether to apply causal attention mask.
48 | window_size: (left right). If not (-1, -1), apply sliding window local attention.
49 | deterministic: bool. If True, slightly slower and uses more memory.
50 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51 | """
52 | half_dtypes = (torch.float16, torch.bfloat16)
53 | assert dtype in half_dtypes
54 | assert q.device.type == 'cuda' and q.size(-1) <= 256
55 |
56 | # params
57 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58 |
59 | def half(x):
60 | return x if x.dtype in half_dtypes else x.to(dtype)
61 |
62 | # preprocess query
63 | if q_lens is None:
64 | q = half(q.flatten(0, 1))
65 | q_lens = torch.tensor(
66 | [lq] * b, dtype=torch.int32).to(
67 | device=q.device, non_blocking=True)
68 | else:
69 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70 |
71 | # preprocess key, value
72 | if k_lens is None:
73 | k = half(k.flatten(0, 1))
74 | v = half(v.flatten(0, 1))
75 | k_lens = torch.tensor(
76 | [lk] * b, dtype=torch.int32).to(
77 | device=k.device, non_blocking=True)
78 | else:
79 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81 |
82 | q = q.to(v.dtype)
83 | k = k.to(v.dtype)
84 |
85 | if q_scale is not None:
86 | q = q * q_scale
87 |
88 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89 | warnings.warn(
90 | 'Flash attention 3 is not available, use flash attention 2 instead.'
91 | )
92 |
93 | # apply attention
94 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95 | # Note: dropout_p, window_size are not supported in FA3 now.
96 | x = flash_attn_interface.flash_attn_varlen_func(
97 | q=q,
98 | k=k,
99 | v=v,
100 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101 | 0, dtype=torch.int32).to(q.device, non_blocking=True),
102 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103 | 0, dtype=torch.int32).to(q.device, non_blocking=True),
104 | seqused_q=None,
105 | seqused_k=None,
106 | max_seqlen_q=lq,
107 | max_seqlen_k=lk,
108 | softmax_scale=softmax_scale,
109 | causal=causal,
110 | deterministic=deterministic)[0].unflatten(0, (b, lq))
111 | else:
112 | assert FLASH_ATTN_2_AVAILABLE
113 | x = flash_attn.flash_attn_varlen_func(
114 | q=q,
115 | k=k,
116 | v=v,
117 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118 | 0, dtype=torch.int32).to(q.device, non_blocking=True),
119 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120 | 0, dtype=torch.int32).to(q.device, non_blocking=True),
121 | max_seqlen_q=lq,
122 | max_seqlen_k=lk,
123 | dropout_p=dropout_p,
124 | softmax_scale=softmax_scale,
125 | causal=causal,
126 | window_size=window_size,
127 | deterministic=deterministic).unflatten(0, (b, lq))
128 |
129 | # output
130 | return x.type(out_dtype)
131 |
132 |
133 | def attention(
134 | q,
135 | k,
136 | v,
137 | q_lens=None,
138 | k_lens=None,
139 | dropout_p=0.,
140 | softmax_scale=None,
141 | q_scale=None,
142 | causal=False,
143 | window_size=(-1, -1),
144 | deterministic=False,
145 | dtype=torch.bfloat16,
146 | fa_version=None,
147 | ):
148 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149 | return flash_attention(
150 | q=q,
151 | k=k,
152 | v=v,
153 | q_lens=q_lens,
154 | k_lens=k_lens,
155 | dropout_p=dropout_p,
156 | softmax_scale=softmax_scale,
157 | q_scale=q_scale,
158 | causal=causal,
159 | window_size=window_size,
160 | deterministic=deterministic,
161 | dtype=dtype,
162 | version=fa_version,
163 | )
164 | else:
165 | if q_lens is not None or k_lens is not None:
166 | warnings.warn(
167 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168 | )
169 | attn_mask = None
170 |
171 | q = q.transpose(1, 2).to(dtype)
172 | k = k.transpose(1, 2).to(dtype)
173 | v = v.transpose(1, 2).to(dtype)
174 |
175 | out = torch.nn.functional.scaled_dot_product_attention(
176 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177 |
178 | out = out.transpose(1, 2).contiguous()
179 | return out
180 |
--------------------------------------------------------------------------------
/phantom_wan/modules/clip.py:
--------------------------------------------------------------------------------
1 | # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import logging
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision.transforms as T
10 |
11 | from .attention import flash_attention
12 | from .tokenizers import HuggingfaceTokenizer
13 | from .xlm_roberta import XLMRoberta
14 |
15 | __all__ = [
16 | 'XLMRobertaCLIP',
17 | 'clip_xlm_roberta_vit_h_14',
18 | 'CLIPModel',
19 | ]
20 |
21 |
22 | def pos_interpolate(pos, seq_len):
23 | if pos.size(1) == seq_len:
24 | return pos
25 | else:
26 | src_grid = int(math.sqrt(pos.size(1)))
27 | tar_grid = int(math.sqrt(seq_len))
28 | n = pos.size(1) - src_grid * src_grid
29 | return torch.cat([
30 | pos[:, :n],
31 | F.interpolate(
32 | pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33 | 0, 3, 1, 2),
34 | size=(tar_grid, tar_grid),
35 | mode='bicubic',
36 | align_corners=False).flatten(2).transpose(1, 2)
37 | ],
38 | dim=1)
39 |
40 |
41 | class QuickGELU(nn.Module):
42 |
43 | def forward(self, x):
44 | return x * torch.sigmoid(1.702 * x)
45 |
46 |
47 | class LayerNorm(nn.LayerNorm):
48 |
49 | def forward(self, x):
50 | return super().forward(x.float()).type_as(x)
51 |
52 |
53 | class SelfAttention(nn.Module):
54 |
55 | def __init__(self,
56 | dim,
57 | num_heads,
58 | causal=False,
59 | attn_dropout=0.0,
60 | proj_dropout=0.0):
61 | assert dim % num_heads == 0
62 | super().__init__()
63 | self.dim = dim
64 | self.num_heads = num_heads
65 | self.head_dim = dim // num_heads
66 | self.causal = causal
67 | self.attn_dropout = attn_dropout
68 | self.proj_dropout = proj_dropout
69 |
70 | # layers
71 | self.to_qkv = nn.Linear(dim, dim * 3)
72 | self.proj = nn.Linear(dim, dim)
73 |
74 | def forward(self, x):
75 | """
76 | x: [B, L, C].
77 | """
78 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79 |
80 | # compute query, key, value
81 | q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82 |
83 | # compute attention
84 | p = self.attn_dropout if self.training else 0.0
85 | x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86 | x = x.reshape(b, s, c)
87 |
88 | # output
89 | x = self.proj(x)
90 | x = F.dropout(x, self.proj_dropout, self.training)
91 | return x
92 |
93 |
94 | class SwiGLU(nn.Module):
95 |
96 | def __init__(self, dim, mid_dim):
97 | super().__init__()
98 | self.dim = dim
99 | self.mid_dim = mid_dim
100 |
101 | # layers
102 | self.fc1 = nn.Linear(dim, mid_dim)
103 | self.fc2 = nn.Linear(dim, mid_dim)
104 | self.fc3 = nn.Linear(mid_dim, dim)
105 |
106 | def forward(self, x):
107 | x = F.silu(self.fc1(x)) * self.fc2(x)
108 | x = self.fc3(x)
109 | return x
110 |
111 |
112 | class AttentionBlock(nn.Module):
113 |
114 | def __init__(self,
115 | dim,
116 | mlp_ratio,
117 | num_heads,
118 | post_norm=False,
119 | causal=False,
120 | activation='quick_gelu',
121 | attn_dropout=0.0,
122 | proj_dropout=0.0,
123 | norm_eps=1e-5):
124 | assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125 | super().__init__()
126 | self.dim = dim
127 | self.mlp_ratio = mlp_ratio
128 | self.num_heads = num_heads
129 | self.post_norm = post_norm
130 | self.causal = causal
131 | self.norm_eps = norm_eps
132 |
133 | # layers
134 | self.norm1 = LayerNorm(dim, eps=norm_eps)
135 | self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136 | proj_dropout)
137 | self.norm2 = LayerNorm(dim, eps=norm_eps)
138 | if activation == 'swi_glu':
139 | self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140 | else:
141 | self.mlp = nn.Sequential(
142 | nn.Linear(dim, int(dim * mlp_ratio)),
143 | QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145 |
146 | def forward(self, x):
147 | if self.post_norm:
148 | x = x + self.norm1(self.attn(x))
149 | x = x + self.norm2(self.mlp(x))
150 | else:
151 | x = x + self.attn(self.norm1(x))
152 | x = x + self.mlp(self.norm2(x))
153 | return x
154 |
155 |
156 | class AttentionPool(nn.Module):
157 |
158 | def __init__(self,
159 | dim,
160 | mlp_ratio,
161 | num_heads,
162 | activation='gelu',
163 | proj_dropout=0.0,
164 | norm_eps=1e-5):
165 | assert dim % num_heads == 0
166 | super().__init__()
167 | self.dim = dim
168 | self.mlp_ratio = mlp_ratio
169 | self.num_heads = num_heads
170 | self.head_dim = dim // num_heads
171 | self.proj_dropout = proj_dropout
172 | self.norm_eps = norm_eps
173 |
174 | # layers
175 | gain = 1.0 / math.sqrt(dim)
176 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177 | self.to_q = nn.Linear(dim, dim)
178 | self.to_kv = nn.Linear(dim, dim * 2)
179 | self.proj = nn.Linear(dim, dim)
180 | self.norm = LayerNorm(dim, eps=norm_eps)
181 | self.mlp = nn.Sequential(
182 | nn.Linear(dim, int(dim * mlp_ratio)),
183 | QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184 | nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185 |
186 | def forward(self, x):
187 | """
188 | x: [B, L, C].
189 | """
190 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191 |
192 | # compute query, key, value
193 | q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194 | k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195 |
196 | # compute attention
197 | x = flash_attention(q, k, v, version=2)
198 | x = x.reshape(b, 1, c)
199 |
200 | # output
201 | x = self.proj(x)
202 | x = F.dropout(x, self.proj_dropout, self.training)
203 |
204 | # mlp
205 | x = x + self.mlp(self.norm(x))
206 | return x[:, 0]
207 |
208 |
209 | class VisionTransformer(nn.Module):
210 |
211 | def __init__(self,
212 | image_size=224,
213 | patch_size=16,
214 | dim=768,
215 | mlp_ratio=4,
216 | out_dim=512,
217 | num_heads=12,
218 | num_layers=12,
219 | pool_type='token',
220 | pre_norm=True,
221 | post_norm=False,
222 | activation='quick_gelu',
223 | attn_dropout=0.0,
224 | proj_dropout=0.0,
225 | embedding_dropout=0.0,
226 | norm_eps=1e-5):
227 | if image_size % patch_size != 0:
228 | print(
229 | '[WARNING] image_size is not divisible by patch_size',
230 | flush=True)
231 | assert pool_type in ('token', 'token_fc', 'attn_pool')
232 | out_dim = out_dim or dim
233 | super().__init__()
234 | self.image_size = image_size
235 | self.patch_size = patch_size
236 | self.num_patches = (image_size // patch_size)**2
237 | self.dim = dim
238 | self.mlp_ratio = mlp_ratio
239 | self.out_dim = out_dim
240 | self.num_heads = num_heads
241 | self.num_layers = num_layers
242 | self.pool_type = pool_type
243 | self.post_norm = post_norm
244 | self.norm_eps = norm_eps
245 |
246 | # embeddings
247 | gain = 1.0 / math.sqrt(dim)
248 | self.patch_embedding = nn.Conv2d(
249 | 3,
250 | dim,
251 | kernel_size=patch_size,
252 | stride=patch_size,
253 | bias=not pre_norm)
254 | if pool_type in ('token', 'token_fc'):
255 | self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256 | self.pos_embedding = nn.Parameter(gain * torch.randn(
257 | 1, self.num_patches +
258 | (1 if pool_type in ('token', 'token_fc') else 0), dim))
259 | self.dropout = nn.Dropout(embedding_dropout)
260 |
261 | # transformer
262 | self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263 | self.transformer = nn.Sequential(*[
264 | AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265 | activation, attn_dropout, proj_dropout, norm_eps)
266 | for _ in range(num_layers)
267 | ])
268 | self.post_norm = LayerNorm(dim, eps=norm_eps)
269 |
270 | # head
271 | if pool_type == 'token':
272 | self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273 | elif pool_type == 'token_fc':
274 | self.head = nn.Linear(dim, out_dim)
275 | elif pool_type == 'attn_pool':
276 | self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277 | proj_dropout, norm_eps)
278 |
279 | def forward(self, x, interpolation=False, use_31_block=False):
280 | b = x.size(0)
281 |
282 | # embeddings
283 | x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284 | if self.pool_type in ('token', 'token_fc'):
285 | x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286 | if interpolation:
287 | e = pos_interpolate(self.pos_embedding, x.size(1))
288 | else:
289 | e = self.pos_embedding
290 | x = self.dropout(x + e)
291 | if self.pre_norm is not None:
292 | x = self.pre_norm(x)
293 |
294 | # transformer
295 | if use_31_block:
296 | x = self.transformer[:-1](x)
297 | return x
298 | else:
299 | x = self.transformer(x)
300 | return x
301 |
302 |
303 | class XLMRobertaWithHead(XLMRoberta):
304 |
305 | def __init__(self, **kwargs):
306 | self.out_dim = kwargs.pop('out_dim')
307 | super().__init__(**kwargs)
308 |
309 | # head
310 | mid_dim = (self.dim + self.out_dim) // 2
311 | self.head = nn.Sequential(
312 | nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313 | nn.Linear(mid_dim, self.out_dim, bias=False))
314 |
315 | def forward(self, ids):
316 | # xlm-roberta
317 | x = super().forward(ids)
318 |
319 | # average pooling
320 | mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321 | x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322 |
323 | # head
324 | x = self.head(x)
325 | return x
326 |
327 |
328 | class XLMRobertaCLIP(nn.Module):
329 |
330 | def __init__(self,
331 | embed_dim=1024,
332 | image_size=224,
333 | patch_size=14,
334 | vision_dim=1280,
335 | vision_mlp_ratio=4,
336 | vision_heads=16,
337 | vision_layers=32,
338 | vision_pool='token',
339 | vision_pre_norm=True,
340 | vision_post_norm=False,
341 | activation='gelu',
342 | vocab_size=250002,
343 | max_text_len=514,
344 | type_size=1,
345 | pad_id=1,
346 | text_dim=1024,
347 | text_heads=16,
348 | text_layers=24,
349 | text_post_norm=True,
350 | text_dropout=0.1,
351 | attn_dropout=0.0,
352 | proj_dropout=0.0,
353 | embedding_dropout=0.0,
354 | norm_eps=1e-5):
355 | super().__init__()
356 | self.embed_dim = embed_dim
357 | self.image_size = image_size
358 | self.patch_size = patch_size
359 | self.vision_dim = vision_dim
360 | self.vision_mlp_ratio = vision_mlp_ratio
361 | self.vision_heads = vision_heads
362 | self.vision_layers = vision_layers
363 | self.vision_pre_norm = vision_pre_norm
364 | self.vision_post_norm = vision_post_norm
365 | self.activation = activation
366 | self.vocab_size = vocab_size
367 | self.max_text_len = max_text_len
368 | self.type_size = type_size
369 | self.pad_id = pad_id
370 | self.text_dim = text_dim
371 | self.text_heads = text_heads
372 | self.text_layers = text_layers
373 | self.text_post_norm = text_post_norm
374 | self.norm_eps = norm_eps
375 |
376 | # models
377 | self.visual = VisionTransformer(
378 | image_size=image_size,
379 | patch_size=patch_size,
380 | dim=vision_dim,
381 | mlp_ratio=vision_mlp_ratio,
382 | out_dim=embed_dim,
383 | num_heads=vision_heads,
384 | num_layers=vision_layers,
385 | pool_type=vision_pool,
386 | pre_norm=vision_pre_norm,
387 | post_norm=vision_post_norm,
388 | activation=activation,
389 | attn_dropout=attn_dropout,
390 | proj_dropout=proj_dropout,
391 | embedding_dropout=embedding_dropout,
392 | norm_eps=norm_eps)
393 | self.textual = XLMRobertaWithHead(
394 | vocab_size=vocab_size,
395 | max_seq_len=max_text_len,
396 | type_size=type_size,
397 | pad_id=pad_id,
398 | dim=text_dim,
399 | out_dim=embed_dim,
400 | num_heads=text_heads,
401 | num_layers=text_layers,
402 | post_norm=text_post_norm,
403 | dropout=text_dropout)
404 | self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405 |
406 | def forward(self, imgs, txt_ids):
407 | """
408 | imgs: [B, 3, H, W] of torch.float32.
409 | - mean: [0.48145466, 0.4578275, 0.40821073]
410 | - std: [0.26862954, 0.26130258, 0.27577711]
411 | txt_ids: [B, L] of torch.long.
412 | Encoded by data.CLIPTokenizer.
413 | """
414 | xi = self.visual(imgs)
415 | xt = self.textual(txt_ids)
416 | return xi, xt
417 |
418 | def param_groups(self):
419 | groups = [{
420 | 'params': [
421 | p for n, p in self.named_parameters()
422 | if 'norm' in n or n.endswith('bias')
423 | ],
424 | 'weight_decay': 0.0
425 | }, {
426 | 'params': [
427 | p for n, p in self.named_parameters()
428 | if not ('norm' in n or n.endswith('bias'))
429 | ]
430 | }]
431 | return groups
432 |
433 |
434 | def _clip(pretrained=False,
435 | pretrained_name=None,
436 | model_cls=XLMRobertaCLIP,
437 | return_transforms=False,
438 | return_tokenizer=False,
439 | tokenizer_padding='eos',
440 | dtype=torch.float32,
441 | device='cpu',
442 | **kwargs):
443 | # init a model on device
444 | with torch.device(device):
445 | model = model_cls(**kwargs)
446 |
447 | # set device
448 | model = model.to(dtype=dtype, device=device)
449 | output = (model,)
450 |
451 | # init transforms
452 | if return_transforms:
453 | # mean and std
454 | if 'siglip' in pretrained_name.lower():
455 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456 | else:
457 | mean = [0.48145466, 0.4578275, 0.40821073]
458 | std = [0.26862954, 0.26130258, 0.27577711]
459 |
460 | # transforms
461 | transforms = T.Compose([
462 | T.Resize((model.image_size, model.image_size),
463 | interpolation=T.InterpolationMode.BICUBIC),
464 | T.ToTensor(),
465 | T.Normalize(mean=mean, std=std)
466 | ])
467 | output += (transforms,)
468 | return output[0] if len(output) == 1 else output
469 |
470 |
471 | def clip_xlm_roberta_vit_h_14(
472 | pretrained=False,
473 | pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474 | **kwargs):
475 | cfg = dict(
476 | embed_dim=1024,
477 | image_size=224,
478 | patch_size=14,
479 | vision_dim=1280,
480 | vision_mlp_ratio=4,
481 | vision_heads=16,
482 | vision_layers=32,
483 | vision_pool='token',
484 | activation='gelu',
485 | vocab_size=250002,
486 | max_text_len=514,
487 | type_size=1,
488 | pad_id=1,
489 | text_dim=1024,
490 | text_heads=16,
491 | text_layers=24,
492 | text_post_norm=True,
493 | text_dropout=0.1,
494 | attn_dropout=0.0,
495 | proj_dropout=0.0,
496 | embedding_dropout=0.0)
497 | cfg.update(**kwargs)
498 | return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499 |
500 |
501 | class CLIPModel:
502 |
503 | def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504 | self.dtype = dtype
505 | self.device = device
506 | self.checkpoint_path = checkpoint_path
507 | self.tokenizer_path = tokenizer_path
508 |
509 | # init model
510 | self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511 | pretrained=False,
512 | return_transforms=True,
513 | return_tokenizer=False,
514 | dtype=dtype,
515 | device=device)
516 | self.model = self.model.eval().requires_grad_(False)
517 | logging.info(f'loading {checkpoint_path}')
518 | self.model.load_state_dict(
519 | torch.load(checkpoint_path, map_location='cpu'))
520 |
521 | # init tokenizer
522 | self.tokenizer = HuggingfaceTokenizer(
523 | name=tokenizer_path,
524 | seq_len=self.model.max_text_len - 2,
525 | clean='whitespace')
526 |
527 | def visual(self, videos):
528 | # preprocess
529 | size = (self.model.image_size,) * 2
530 | videos = torch.cat([
531 | F.interpolate(
532 | u.transpose(0, 1),
533 | size=size,
534 | mode='bicubic',
535 | align_corners=False) for u in videos
536 | ])
537 | videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538 |
539 | # forward
540 | with torch.cuda.amp.autocast(dtype=self.dtype):
541 | out = self.model.visual(videos, use_31_block=True)
542 | return out
543 |
--------------------------------------------------------------------------------
/phantom_wan/modules/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import math
3 |
4 | import torch
5 | import torch.cuda.amp as amp
6 | import torch.nn as nn
7 | from diffusers.configuration_utils import ConfigMixin, register_to_config
8 | from diffusers.models.modeling_utils import ModelMixin
9 |
10 | from .attention import flash_attention
11 |
12 | __all__ = ['WanModel']
13 |
14 |
15 | def sinusoidal_embedding_1d(dim, position):
16 | # preprocess
17 | assert dim % 2 == 0
18 | half = dim // 2
19 | position = position.type(torch.float64)
20 |
21 | # calculation
22 | sinusoid = torch.outer(
23 | position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
24 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
25 | return x
26 |
27 |
28 | @amp.autocast(enabled=False)
29 | def rope_params(max_seq_len, dim, theta=10000):
30 | assert dim % 2 == 0
31 | freqs = torch.outer(
32 | torch.arange(max_seq_len),
33 | 1.0 / torch.pow(theta,
34 | torch.arange(0, dim, 2).div(dim)))
35 | freqs = torch.polar(torch.ones_like(freqs), freqs)
36 | return freqs
37 |
38 |
39 | @amp.autocast(enabled=False)
40 | def rope_apply(x, grid_sizes, freqs):
41 | n, c = x.size(2), x.size(3) // 2
42 |
43 | # split freqs
44 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
45 |
46 | # loop over samples
47 | output = []
48 | for i, (f, h, w) in enumerate(grid_sizes.tolist()):
49 | seq_len = f * h * w
50 |
51 | # precompute multipliers
52 | x_i = torch.view_as_complex(x[i, :seq_len].reshape(
53 | seq_len, n, -1, 2))
54 | freqs_i = torch.cat([
55 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
57 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
58 | ],
59 | dim=-1).reshape(seq_len, 1, -1)
60 |
61 | # apply rotary embedding
62 | x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
63 | x_i = torch.cat([x_i, x[i, seq_len:]])
64 |
65 | # append to collection
66 | output.append(x_i)
67 | return torch.stack(output).float()
68 |
69 |
70 | class WanRMSNorm(nn.Module):
71 |
72 | def __init__(self, dim, eps=1e-5):
73 | super().__init__()
74 | self.dim = dim
75 | self.eps = eps
76 | self.weight = nn.Parameter(torch.ones(dim))
77 |
78 | def forward(self, x):
79 | r"""
80 | Args:
81 | x(Tensor): Shape [B, L, C]
82 | """
83 | return self._norm(x.float()).type_as(x) * self.weight
84 |
85 | def _norm(self, x):
86 | return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
87 |
88 |
89 | class WanLayerNorm(nn.LayerNorm):
90 |
91 | def __init__(self, dim, eps=1e-6, elementwise_affine=False):
92 | super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
93 |
94 | def forward(self, x):
95 | r"""
96 | Args:
97 | x(Tensor): Shape [B, L, C]
98 | """
99 | return super().forward(x.float()).type_as(x)
100 |
101 |
102 | class WanSelfAttention(nn.Module):
103 |
104 | def __init__(self,
105 | dim,
106 | num_heads,
107 | window_size=(-1, -1),
108 | qk_norm=True,
109 | eps=1e-6):
110 | assert dim % num_heads == 0
111 | super().__init__()
112 | self.dim = dim
113 | self.num_heads = num_heads
114 | self.head_dim = dim // num_heads
115 | self.window_size = window_size
116 | self.qk_norm = qk_norm
117 | self.eps = eps
118 |
119 | # layers
120 | self.q = nn.Linear(dim, dim)
121 | self.k = nn.Linear(dim, dim)
122 | self.v = nn.Linear(dim, dim)
123 | self.o = nn.Linear(dim, dim)
124 | self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125 | self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126 |
127 | def forward(self, x, seq_lens, grid_sizes, freqs):
128 | r"""
129 | Args:
130 | x(Tensor): Shape [B, L, num_heads, C / num_heads]
131 | seq_lens(Tensor): Shape [B]
132 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
133 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
134 | """
135 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
136 |
137 | # query, key, value function
138 | def qkv_fn(x):
139 | q = self.norm_q(self.q(x)).view(b, s, n, d)
140 | k = self.norm_k(self.k(x)).view(b, s, n, d)
141 | v = self.v(x).view(b, s, n, d)
142 | return q, k, v
143 |
144 | q, k, v = qkv_fn(x)
145 |
146 | x = flash_attention(
147 | q=rope_apply(q, grid_sizes, freqs),
148 | k=rope_apply(k, grid_sizes, freqs),
149 | v=v,
150 | k_lens=seq_lens,
151 | window_size=self.window_size)
152 |
153 | # output
154 | x = x.flatten(2)
155 | x = self.o(x)
156 | return x
157 |
158 |
159 | class WanT2VCrossAttention(WanSelfAttention):
160 |
161 | def forward(self, x, context, context_lens):
162 | r"""
163 | Args:
164 | x(Tensor): Shape [B, L1, C]
165 | context(Tensor): Shape [B, L2, C]
166 | context_lens(Tensor): Shape [B]
167 | """
168 | b, n, d = x.size(0), self.num_heads, self.head_dim
169 |
170 | # compute query, key, value
171 | q = self.norm_q(self.q(x)).view(b, -1, n, d)
172 | k = self.norm_k(self.k(context)).view(b, -1, n, d)
173 | v = self.v(context).view(b, -1, n, d)
174 |
175 | # compute attention
176 | x = flash_attention(q, k, v, k_lens=context_lens)
177 |
178 | # output
179 | x = x.flatten(2)
180 | x = self.o(x)
181 | return x
182 |
183 |
184 | class WanI2VCrossAttention(WanSelfAttention):
185 |
186 | def __init__(self,
187 | dim,
188 | num_heads,
189 | window_size=(-1, -1),
190 | qk_norm=True,
191 | eps=1e-6):
192 | super().__init__(dim, num_heads, window_size, qk_norm, eps)
193 |
194 | self.k_img = nn.Linear(dim, dim)
195 | self.v_img = nn.Linear(dim, dim)
196 | # self.alpha = nn.Parameter(torch.zeros((1, )))
197 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
198 |
199 | def forward(self, x, context, context_lens):
200 | r"""
201 | Args:
202 | x(Tensor): Shape [B, L1, C]
203 | context(Tensor): Shape [B, L2, C]
204 | context_lens(Tensor): Shape [B]
205 | """
206 | context_img = context[:, :257]
207 | context = context[:, 257:]
208 | b, n, d = x.size(0), self.num_heads, self.head_dim
209 |
210 | # compute query, key, value
211 | q = self.norm_q(self.q(x)).view(b, -1, n, d)
212 | k = self.norm_k(self.k(context)).view(b, -1, n, d)
213 | v = self.v(context).view(b, -1, n, d)
214 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
215 | v_img = self.v_img(context_img).view(b, -1, n, d)
216 | img_x = flash_attention(q, k_img, v_img, k_lens=None)
217 | # compute attention
218 | x = flash_attention(q, k, v, k_lens=context_lens)
219 |
220 | # output
221 | x = x.flatten(2)
222 | img_x = img_x.flatten(2)
223 | x = x + img_x
224 | x = self.o(x)
225 | return x
226 |
227 |
228 | WAN_CROSSATTENTION_CLASSES = {
229 | 't2v_cross_attn': WanT2VCrossAttention,
230 | 'i2v_cross_attn': WanI2VCrossAttention,
231 | }
232 |
233 |
234 | class WanAttentionBlock(nn.Module):
235 |
236 | def __init__(self,
237 | cross_attn_type,
238 | dim,
239 | ffn_dim,
240 | num_heads,
241 | window_size=(-1, -1),
242 | qk_norm=True,
243 | cross_attn_norm=False,
244 | eps=1e-6):
245 | super().__init__()
246 | self.dim = dim
247 | self.ffn_dim = ffn_dim
248 | self.num_heads = num_heads
249 | self.window_size = window_size
250 | self.qk_norm = qk_norm
251 | self.cross_attn_norm = cross_attn_norm
252 | self.eps = eps
253 |
254 | # layers
255 | self.norm1 = WanLayerNorm(dim, eps)
256 | self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
257 | eps)
258 | self.norm3 = WanLayerNorm(
259 | dim, eps,
260 | elementwise_affine=True) if cross_attn_norm else nn.Identity()
261 | self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
262 | num_heads,
263 | (-1, -1),
264 | qk_norm,
265 | eps)
266 | self.norm2 = WanLayerNorm(dim, eps)
267 | self.ffn = nn.Sequential(
268 | nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
269 | nn.Linear(ffn_dim, dim))
270 |
271 | # modulation
272 | self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
273 |
274 | def forward(
275 | self,
276 | x,
277 | e,
278 | seq_lens,
279 | grid_sizes,
280 | freqs,
281 | context,
282 | context_lens,
283 | ):
284 | r"""
285 | Args:
286 | x(Tensor): Shape [B, L, C]
287 | e(Tensor): Shape [B, 6, C]
288 | seq_lens(Tensor): Shape [B], length of each sequence in batch
289 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
290 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
291 | """
292 | assert e.dtype == torch.float32
293 | with amp.autocast(dtype=torch.float32):
294 | e = (self.modulation + e).chunk(6, dim=1)
295 | assert e[0].dtype == torch.float32
296 |
297 | # self-attention
298 | y = self.self_attn(
299 | self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
300 | freqs)
301 | with amp.autocast(dtype=torch.float32):
302 | x = x + y * e[2]
303 |
304 | # cross-attention & ffn function
305 | def cross_attn_ffn(x, context, context_lens, e):
306 | x = x + self.cross_attn(self.norm3(x), context, context_lens)
307 | y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
308 | with amp.autocast(dtype=torch.float32):
309 | x = x + y * e[5]
310 | return x
311 |
312 | x = cross_attn_ffn(x, context, context_lens, e)
313 | return x
314 |
315 |
316 | class Head(nn.Module):
317 |
318 | def __init__(self, dim, out_dim, patch_size, eps=1e-6):
319 | super().__init__()
320 | self.dim = dim
321 | self.out_dim = out_dim
322 | self.patch_size = patch_size
323 | self.eps = eps
324 |
325 | # layers
326 | out_dim = math.prod(patch_size) * out_dim
327 | self.norm = WanLayerNorm(dim, eps)
328 | self.head = nn.Linear(dim, out_dim)
329 |
330 | # modulation
331 | self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
332 |
333 | def forward(self, x, e):
334 | r"""
335 | Args:
336 | x(Tensor): Shape [B, L1, C]
337 | e(Tensor): Shape [B, C]
338 | """
339 | assert e.dtype == torch.float32
340 | with amp.autocast(dtype=torch.float32):
341 | e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
342 | x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
343 | return x
344 |
345 |
346 | class MLPProj(torch.nn.Module):
347 |
348 | def __init__(self, in_dim, out_dim):
349 | super().__init__()
350 |
351 | self.proj = torch.nn.Sequential(
352 | torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
353 | torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
354 | torch.nn.LayerNorm(out_dim))
355 |
356 | def forward(self, image_embeds):
357 | clip_extra_context_tokens = self.proj(image_embeds)
358 | return clip_extra_context_tokens
359 |
360 |
361 | class WanModel(ModelMixin, ConfigMixin):
362 | r"""
363 | Wan diffusion backbone supporting both text-to-video and image-to-video.
364 | """
365 |
366 | ignore_for_config = [
367 | 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
368 | ]
369 | _no_split_modules = ['WanAttentionBlock']
370 |
371 | @register_to_config
372 | def __init__(self,
373 | model_type='t2v',
374 | patch_size=(1, 2, 2),
375 | text_len=512,
376 | in_dim=16,
377 | dim=5120,
378 | ffn_dim=13824,
379 | freq_dim=256,
380 | text_dim=4096,
381 | out_dim=16,
382 | num_heads=40,
383 | num_layers=40,
384 | window_size=(-1, -1),
385 | qk_norm=True,
386 | cross_attn_norm=True,
387 | eps=1e-6):
388 | r"""
389 | Initialize the diffusion model backbone.
390 |
391 | Args:
392 | model_type (`str`, *optional*, defaults to 't2v'):
393 | Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
394 | patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
395 | 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
396 | text_len (`int`, *optional*, defaults to 512):
397 | Fixed length for text embeddings
398 | in_dim (`int`, *optional*, defaults to 16):
399 | Input video channels (C_in)
400 | dim (`int`, *optional*, defaults to 2048):
401 | Hidden dimension of the transformer
402 | ffn_dim (`int`, *optional*, defaults to 8192):
403 | Intermediate dimension in feed-forward network
404 | freq_dim (`int`, *optional*, defaults to 256):
405 | Dimension for sinusoidal time embeddings
406 | text_dim (`int`, *optional*, defaults to 4096):
407 | Input dimension for text embeddings
408 | out_dim (`int`, *optional*, defaults to 16):
409 | Output video channels (C_out)
410 | num_heads (`int`, *optional*, defaults to 16):
411 | Number of attention heads
412 | num_layers (`int`, *optional*, defaults to 32):
413 | Number of transformer blocks
414 | window_size (`tuple`, *optional*, defaults to (-1, -1)):
415 | Window size for local attention (-1 indicates global attention)
416 | qk_norm (`bool`, *optional*, defaults to True):
417 | Enable query/key normalization
418 | cross_attn_norm (`bool`, *optional*, defaults to False):
419 | Enable cross-attention normalization
420 | eps (`float`, *optional*, defaults to 1e-6):
421 | Epsilon value for normalization layers
422 | """
423 |
424 | super().__init__()
425 |
426 | assert model_type in ['t2v', 'i2v']
427 | self.model_type = model_type
428 |
429 | self.patch_size = patch_size
430 | self.text_len = text_len
431 | self.in_dim = in_dim
432 | self.dim = dim
433 | self.ffn_dim = ffn_dim
434 | self.freq_dim = freq_dim
435 | self.text_dim = text_dim
436 | self.out_dim = out_dim
437 | self.num_heads = num_heads
438 | self.num_layers = num_layers
439 | self.window_size = window_size
440 | self.qk_norm = qk_norm
441 | self.cross_attn_norm = cross_attn_norm
442 | self.eps = eps
443 |
444 | # embeddings
445 | self.patch_embedding = nn.Conv3d(
446 | in_dim, dim, kernel_size=patch_size, stride=patch_size)
447 | self.text_embedding = nn.Sequential(
448 | nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
449 | nn.Linear(dim, dim))
450 |
451 | self.time_embedding = nn.Sequential(
452 | nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
453 | self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
454 |
455 | # blocks
456 | cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
457 | self.blocks = nn.ModuleList([
458 | WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
459 | window_size, qk_norm, cross_attn_norm, eps)
460 | for _ in range(num_layers)
461 | ])
462 |
463 | # head
464 | self.head = Head(dim, out_dim, patch_size, eps)
465 |
466 | # buffers (don't use register_buffer otherwise dtype will be changed in to())
467 | assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
468 | d = dim // num_heads
469 | self.freqs = torch.cat([
470 | rope_params(1024, d - 4 * (d // 6)),
471 | rope_params(1024, 2 * (d // 6)),
472 | rope_params(1024, 2 * (d // 6))
473 | ],
474 | dim=1)
475 |
476 | if model_type == 'i2v':
477 | self.img_emb = MLPProj(1280, dim)
478 |
479 | # initialize weights
480 | self.init_weights()
481 |
482 |
483 | def forward(
484 | self,
485 | x,
486 | t,
487 | context,
488 | seq_len,
489 | clip_fea=None,
490 | y=None,
491 | ):
492 | r"""
493 | Forward pass through the diffusion model
494 |
495 | Args:
496 | x (List[Tensor]):
497 | List of input video tensors, each with shape [C_in, F, H, W]
498 | t (Tensor):
499 | Diffusion timesteps tensor of shape [B]
500 | context (List[Tensor]):
501 | List of text embeddings each with shape [L, C]
502 | seq_len (`int`):
503 | Maximum sequence length for positional encoding
504 | clip_fea (Tensor, *optional*):
505 | CLIP image features for image-to-video mode
506 | y (List[Tensor], *optional*):
507 | Conditional video inputs for image-to-video mode, same shape as x
508 |
509 | Returns:
510 | List[Tensor]:
511 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
512 | """
513 | if self.model_type == 'i2v':
514 | assert clip_fea is not None and y is not None
515 | # params
516 | device = self.patch_embedding.weight.device
517 | if self.freqs.device != device:
518 | self.freqs = self.freqs.to(device)
519 |
520 | if y is not None:
521 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
522 |
523 | # embeddings
524 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
525 | grid_sizes = torch.stack(
526 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
527 | x = [u.flatten(2).transpose(1, 2) for u in x]
528 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
529 | assert seq_lens.max() <= seq_len
530 | x = torch.cat([
531 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
532 | dim=1) for u in x
533 | ])
534 |
535 | # time embeddings
536 | with amp.autocast(dtype=torch.float32):
537 | e = self.time_embedding(
538 | sinusoidal_embedding_1d(self.freq_dim, t).float())
539 | e0 = self.time_projection(e).unflatten(1, (6, self.dim))
540 | assert e.dtype == torch.float32 and e0.dtype == torch.float32
541 |
542 | # context
543 | context_lens = None
544 | context = self.text_embedding(
545 | torch.stack([
546 | torch.cat(
547 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
548 | for u in context
549 | ]))
550 |
551 | if clip_fea is not None:
552 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim
553 | context = torch.concat([context_clip, context], dim=1)
554 |
555 | # arguments
556 | kwargs = dict(
557 | e=e0,
558 | seq_lens=seq_lens,
559 | grid_sizes=grid_sizes,
560 | freqs=self.freqs,
561 | context=context,
562 | context_lens=context_lens)
563 |
564 | for block in self.blocks:
565 | x = block(x, **kwargs)
566 |
567 | # head
568 | x = self.head(x, e)
569 |
570 | # unpatchify
571 | x = self.unpatchify(x, grid_sizes)
572 | return [u.float() for u in x]
573 |
574 | def unpatchify(self, x, grid_sizes):
575 | r"""
576 | Reconstruct video tensors from patch embeddings.
577 |
578 | Args:
579 | x (List[Tensor]):
580 | List of patchified features, each with shape [L, C_out * prod(patch_size)]
581 | grid_sizes (Tensor):
582 | Original spatial-temporal grid dimensions before patching,
583 | shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
584 |
585 | Returns:
586 | List[Tensor]:
587 | Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
588 | """
589 |
590 | c = self.out_dim
591 | out = []
592 | for u, v in zip(x, grid_sizes.tolist()):
593 | u = u[:math.prod(v)].view(*v, *self.patch_size, c)
594 | u = torch.einsum('fhwpqrc->cfphqwr', u)
595 | u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
596 | out.append(u)
597 | return out
598 |
599 | def init_weights(self):
600 | r"""
601 | Initialize model parameters using Xavier initialization.
602 | """
603 |
604 | # basic init
605 | for m in self.modules():
606 | if isinstance(m, nn.Linear):
607 | nn.init.xavier_uniform_(m.weight)
608 | if m.bias is not None:
609 | nn.init.zeros_(m.bias)
610 |
611 | # init embeddings
612 | nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
613 | for m in self.text_embedding.modules():
614 | if isinstance(m, nn.Linear):
615 | nn.init.normal_(m.weight, std=.02)
616 | for m in self.time_embedding.modules():
617 | if isinstance(m, nn.Linear):
618 | nn.init.normal_(m.weight, std=.02)
619 |
620 | # init output layer
621 | nn.init.zeros_(self.head.head.weight)
622 |
--------------------------------------------------------------------------------
/phantom_wan/modules/t5.py:
--------------------------------------------------------------------------------
1 | # Modified from transformers.models.t5.modeling_t5
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import logging
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .tokenizers import HuggingfaceTokenizer
11 |
12 | __all__ = [
13 | 'T5Model',
14 | 'T5Encoder',
15 | 'T5Decoder',
16 | 'T5EncoderModel',
17 | ]
18 |
19 |
20 | def fp16_clamp(x):
21 | if x.dtype == torch.float16 and torch.isinf(x).any():
22 | clamp = torch.finfo(x.dtype).max - 1000
23 | x = torch.clamp(x, min=-clamp, max=clamp)
24 | return x
25 |
26 |
27 | def init_weights(m):
28 | if isinstance(m, T5LayerNorm):
29 | nn.init.ones_(m.weight)
30 | elif isinstance(m, T5Model):
31 | nn.init.normal_(m.token_embedding.weight, std=1.0)
32 | elif isinstance(m, T5FeedForward):
33 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36 | elif isinstance(m, T5Attention):
37 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38 | nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39 | nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41 | elif isinstance(m, T5RelativeEmbedding):
42 | nn.init.normal_(
43 | m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44 |
45 |
46 | class GELU(nn.Module):
47 |
48 | def forward(self, x):
49 | return 0.5 * x * (1.0 + torch.tanh(
50 | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51 |
52 |
53 | class T5LayerNorm(nn.Module):
54 |
55 | def __init__(self, dim, eps=1e-6):
56 | super(T5LayerNorm, self).__init__()
57 | self.dim = dim
58 | self.eps = eps
59 | self.weight = nn.Parameter(torch.ones(dim))
60 |
61 | def forward(self, x):
62 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63 | self.eps)
64 | if self.weight.dtype in [torch.float16, torch.bfloat16]:
65 | x = x.type_as(self.weight)
66 | return self.weight * x
67 |
68 |
69 | class T5Attention(nn.Module):
70 |
71 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72 | assert dim_attn % num_heads == 0
73 | super(T5Attention, self).__init__()
74 | self.dim = dim
75 | self.dim_attn = dim_attn
76 | self.num_heads = num_heads
77 | self.head_dim = dim_attn // num_heads
78 |
79 | # layers
80 | self.q = nn.Linear(dim, dim_attn, bias=False)
81 | self.k = nn.Linear(dim, dim_attn, bias=False)
82 | self.v = nn.Linear(dim, dim_attn, bias=False)
83 | self.o = nn.Linear(dim_attn, dim, bias=False)
84 | self.dropout = nn.Dropout(dropout)
85 |
86 | def forward(self, x, context=None, mask=None, pos_bias=None):
87 | """
88 | x: [B, L1, C].
89 | context: [B, L2, C] or None.
90 | mask: [B, L2] or [B, L1, L2] or None.
91 | """
92 | # check inputs
93 | context = x if context is None else context
94 | b, n, c = x.size(0), self.num_heads, self.head_dim
95 |
96 | # compute query, key, value
97 | q = self.q(x).view(b, -1, n, c)
98 | k = self.k(context).view(b, -1, n, c)
99 | v = self.v(context).view(b, -1, n, c)
100 |
101 | # attention bias
102 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103 | if pos_bias is not None:
104 | attn_bias += pos_bias
105 | if mask is not None:
106 | assert mask.ndim in [2, 3]
107 | mask = mask.view(b, 1, 1,
108 | -1) if mask.ndim == 2 else mask.unsqueeze(1)
109 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110 |
111 | # compute attention (T5 does not use scaling)
112 | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113 | attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114 | x = torch.einsum('bnij,bjnc->binc', attn, v)
115 |
116 | # output
117 | x = x.reshape(b, -1, n * c)
118 | x = self.o(x)
119 | x = self.dropout(x)
120 | return x
121 |
122 |
123 | class T5FeedForward(nn.Module):
124 |
125 | def __init__(self, dim, dim_ffn, dropout=0.1):
126 | super(T5FeedForward, self).__init__()
127 | self.dim = dim
128 | self.dim_ffn = dim_ffn
129 |
130 | # layers
131 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134 | self.dropout = nn.Dropout(dropout)
135 |
136 | def forward(self, x):
137 | x = self.fc1(x) * self.gate(x)
138 | x = self.dropout(x)
139 | x = self.fc2(x)
140 | x = self.dropout(x)
141 | return x
142 |
143 |
144 | class T5SelfAttention(nn.Module):
145 |
146 | def __init__(self,
147 | dim,
148 | dim_attn,
149 | dim_ffn,
150 | num_heads,
151 | num_buckets,
152 | shared_pos=True,
153 | dropout=0.1):
154 | super(T5SelfAttention, self).__init__()
155 | self.dim = dim
156 | self.dim_attn = dim_attn
157 | self.dim_ffn = dim_ffn
158 | self.num_heads = num_heads
159 | self.num_buckets = num_buckets
160 | self.shared_pos = shared_pos
161 |
162 | # layers
163 | self.norm1 = T5LayerNorm(dim)
164 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165 | self.norm2 = T5LayerNorm(dim)
166 | self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168 | num_buckets, num_heads, bidirectional=True)
169 |
170 | def forward(self, x, mask=None, pos_bias=None):
171 | e = pos_bias if self.shared_pos else self.pos_embedding(
172 | x.size(1), x.size(1))
173 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174 | x = fp16_clamp(x + self.ffn(self.norm2(x)))
175 | return x
176 |
177 |
178 | class T5CrossAttention(nn.Module):
179 |
180 | def __init__(self,
181 | dim,
182 | dim_attn,
183 | dim_ffn,
184 | num_heads,
185 | num_buckets,
186 | shared_pos=True,
187 | dropout=0.1):
188 | super(T5CrossAttention, self).__init__()
189 | self.dim = dim
190 | self.dim_attn = dim_attn
191 | self.dim_ffn = dim_ffn
192 | self.num_heads = num_heads
193 | self.num_buckets = num_buckets
194 | self.shared_pos = shared_pos
195 |
196 | # layers
197 | self.norm1 = T5LayerNorm(dim)
198 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199 | self.norm2 = T5LayerNorm(dim)
200 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201 | self.norm3 = T5LayerNorm(dim)
202 | self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204 | num_buckets, num_heads, bidirectional=False)
205 |
206 | def forward(self,
207 | x,
208 | mask=None,
209 | encoder_states=None,
210 | encoder_mask=None,
211 | pos_bias=None):
212 | e = pos_bias if self.shared_pos else self.pos_embedding(
213 | x.size(1), x.size(1))
214 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215 | x = fp16_clamp(x + self.cross_attn(
216 | self.norm2(x), context=encoder_states, mask=encoder_mask))
217 | x = fp16_clamp(x + self.ffn(self.norm3(x)))
218 | return x
219 |
220 |
221 | class T5RelativeEmbedding(nn.Module):
222 |
223 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224 | super(T5RelativeEmbedding, self).__init__()
225 | self.num_buckets = num_buckets
226 | self.num_heads = num_heads
227 | self.bidirectional = bidirectional
228 | self.max_dist = max_dist
229 |
230 | # layers
231 | self.embedding = nn.Embedding(num_buckets, num_heads)
232 |
233 | def forward(self, lq, lk):
234 | device = self.embedding.weight.device
235 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236 | # torch.arange(lq).unsqueeze(1).to(device)
237 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238 | torch.arange(lq, device=device).unsqueeze(1)
239 | rel_pos = self._relative_position_bucket(rel_pos)
240 | rel_pos_embeds = self.embedding(rel_pos)
241 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242 | 0) # [1, N, Lq, Lk]
243 | return rel_pos_embeds.contiguous()
244 |
245 | def _relative_position_bucket(self, rel_pos):
246 | # preprocess
247 | if self.bidirectional:
248 | num_buckets = self.num_buckets // 2
249 | rel_buckets = (rel_pos > 0).long() * num_buckets
250 | rel_pos = torch.abs(rel_pos)
251 | else:
252 | num_buckets = self.num_buckets
253 | rel_buckets = 0
254 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255 |
256 | # embeddings for small and large positions
257 | max_exact = num_buckets // 2
258 | rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259 | math.log(self.max_dist / max_exact) *
260 | (num_buckets - max_exact)).long()
261 | rel_pos_large = torch.min(
262 | rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264 | return rel_buckets
265 |
266 |
267 | class T5Encoder(nn.Module):
268 |
269 | def __init__(self,
270 | vocab,
271 | dim,
272 | dim_attn,
273 | dim_ffn,
274 | num_heads,
275 | num_layers,
276 | num_buckets,
277 | shared_pos=True,
278 | dropout=0.1):
279 | super(T5Encoder, self).__init__()
280 | self.dim = dim
281 | self.dim_attn = dim_attn
282 | self.dim_ffn = dim_ffn
283 | self.num_heads = num_heads
284 | self.num_layers = num_layers
285 | self.num_buckets = num_buckets
286 | self.shared_pos = shared_pos
287 |
288 | # layers
289 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290 | else nn.Embedding(vocab, dim)
291 | self.pos_embedding = T5RelativeEmbedding(
292 | num_buckets, num_heads, bidirectional=True) if shared_pos else None
293 | self.dropout = nn.Dropout(dropout)
294 | self.blocks = nn.ModuleList([
295 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296 | shared_pos, dropout) for _ in range(num_layers)
297 | ])
298 | self.norm = T5LayerNorm(dim)
299 |
300 | # initialize weights
301 | self.apply(init_weights)
302 |
303 | def forward(self, ids, mask=None):
304 | x = self.token_embedding(ids)
305 | x = self.dropout(x)
306 | e = self.pos_embedding(x.size(1),
307 | x.size(1)) if self.shared_pos else None
308 | for block in self.blocks:
309 | x = block(x, mask, pos_bias=e)
310 | x = self.norm(x)
311 | x = self.dropout(x)
312 | return x
313 |
314 |
315 | class T5Decoder(nn.Module):
316 |
317 | def __init__(self,
318 | vocab,
319 | dim,
320 | dim_attn,
321 | dim_ffn,
322 | num_heads,
323 | num_layers,
324 | num_buckets,
325 | shared_pos=True,
326 | dropout=0.1):
327 | super(T5Decoder, self).__init__()
328 | self.dim = dim
329 | self.dim_attn = dim_attn
330 | self.dim_ffn = dim_ffn
331 | self.num_heads = num_heads
332 | self.num_layers = num_layers
333 | self.num_buckets = num_buckets
334 | self.shared_pos = shared_pos
335 |
336 | # layers
337 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338 | else nn.Embedding(vocab, dim)
339 | self.pos_embedding = T5RelativeEmbedding(
340 | num_buckets, num_heads, bidirectional=False) if shared_pos else None
341 | self.dropout = nn.Dropout(dropout)
342 | self.blocks = nn.ModuleList([
343 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344 | shared_pos, dropout) for _ in range(num_layers)
345 | ])
346 | self.norm = T5LayerNorm(dim)
347 |
348 | # initialize weights
349 | self.apply(init_weights)
350 |
351 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352 | b, s = ids.size()
353 |
354 | # causal mask
355 | if mask is None:
356 | mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357 | elif mask.ndim == 2:
358 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359 |
360 | # layers
361 | x = self.token_embedding(ids)
362 | x = self.dropout(x)
363 | e = self.pos_embedding(x.size(1),
364 | x.size(1)) if self.shared_pos else None
365 | for block in self.blocks:
366 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367 | x = self.norm(x)
368 | x = self.dropout(x)
369 | return x
370 |
371 |
372 | class T5Model(nn.Module):
373 |
374 | def __init__(self,
375 | vocab_size,
376 | dim,
377 | dim_attn,
378 | dim_ffn,
379 | num_heads,
380 | encoder_layers,
381 | decoder_layers,
382 | num_buckets,
383 | shared_pos=True,
384 | dropout=0.1):
385 | super(T5Model, self).__init__()
386 | self.vocab_size = vocab_size
387 | self.dim = dim
388 | self.dim_attn = dim_attn
389 | self.dim_ffn = dim_ffn
390 | self.num_heads = num_heads
391 | self.encoder_layers = encoder_layers
392 | self.decoder_layers = decoder_layers
393 | self.num_buckets = num_buckets
394 |
395 | # layers
396 | self.token_embedding = nn.Embedding(vocab_size, dim)
397 | self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398 | num_heads, encoder_layers, num_buckets,
399 | shared_pos, dropout)
400 | self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401 | num_heads, decoder_layers, num_buckets,
402 | shared_pos, dropout)
403 | self.head = nn.Linear(dim, vocab_size, bias=False)
404 |
405 | # initialize weights
406 | self.apply(init_weights)
407 |
408 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409 | x = self.encoder(encoder_ids, encoder_mask)
410 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411 | x = self.head(x)
412 | return x
413 |
414 |
415 | def _t5(name,
416 | encoder_only=False,
417 | decoder_only=False,
418 | return_tokenizer=False,
419 | tokenizer_kwargs={},
420 | dtype=torch.float32,
421 | device='cpu',
422 | **kwargs):
423 | # sanity check
424 | assert not (encoder_only and decoder_only)
425 |
426 | # params
427 | if encoder_only:
428 | model_cls = T5Encoder
429 | kwargs['vocab'] = kwargs.pop('vocab_size')
430 | kwargs['num_layers'] = kwargs.pop('encoder_layers')
431 | _ = kwargs.pop('decoder_layers')
432 | elif decoder_only:
433 | model_cls = T5Decoder
434 | kwargs['vocab'] = kwargs.pop('vocab_size')
435 | kwargs['num_layers'] = kwargs.pop('decoder_layers')
436 | _ = kwargs.pop('encoder_layers')
437 | else:
438 | model_cls = T5Model
439 |
440 | # init model
441 | with torch.device(device):
442 | model = model_cls(**kwargs)
443 |
444 | # set device
445 | model = model.to(dtype=dtype, device=device)
446 |
447 | # init tokenizer
448 | if return_tokenizer:
449 | from .tokenizers import HuggingfaceTokenizer
450 | tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451 | return model, tokenizer
452 | else:
453 | return model
454 |
455 |
456 | def umt5_xxl(**kwargs):
457 | cfg = dict(
458 | vocab_size=256384,
459 | dim=4096,
460 | dim_attn=4096,
461 | dim_ffn=10240,
462 | num_heads=64,
463 | encoder_layers=24,
464 | decoder_layers=24,
465 | num_buckets=32,
466 | shared_pos=False,
467 | dropout=0.1)
468 | cfg.update(**kwargs)
469 | return _t5('umt5-xxl', **cfg)
470 |
471 |
472 | class T5EncoderModel:
473 |
474 | def __init__(
475 | self,
476 | text_len,
477 | dtype=torch.bfloat16,
478 | device=torch.cuda.current_device(),
479 | checkpoint_path=None,
480 | tokenizer_path=None,
481 | shard_fn=None,
482 | ):
483 | self.text_len = text_len
484 | self.dtype = dtype
485 | self.device = device
486 | self.checkpoint_path = checkpoint_path
487 | self.tokenizer_path = tokenizer_path
488 |
489 | # init model
490 | model = umt5_xxl(
491 | encoder_only=True,
492 | return_tokenizer=False,
493 | dtype=dtype,
494 | device=device).eval().requires_grad_(False)
495 | logging.info(f'loading {checkpoint_path}')
496 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497 | self.model = model
498 | if shard_fn is not None:
499 | self.model = shard_fn(self.model, sync_module_states=False)
500 | else:
501 | self.model.to(self.device)
502 | # init tokenizer
503 | self.tokenizer = HuggingfaceTokenizer(
504 | name=tokenizer_path, seq_len=text_len, clean='whitespace')
505 |
506 | def __call__(self, texts, device):
507 | ids, mask = self.tokenizer(
508 | texts, return_mask=True, add_special_tokens=True)
509 | ids = ids.to(device)
510 | mask = mask.to(device)
511 | seq_lens = mask.gt(0).sum(dim=1).long()
512 | context = self.model(ids, mask)
513 | return [u[:v] for u, v in zip(context, seq_lens)]
514 |
--------------------------------------------------------------------------------
/phantom_wan/modules/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import html
3 | import string
4 |
5 | import ftfy
6 | import regex as re
7 | from transformers import AutoTokenizer
8 |
9 | __all__ = ['HuggingfaceTokenizer']
10 |
11 |
12 | def basic_clean(text):
13 | text = ftfy.fix_text(text)
14 | text = html.unescape(html.unescape(text))
15 | return text.strip()
16 |
17 |
18 | def whitespace_clean(text):
19 | text = re.sub(r'\s+', ' ', text)
20 | text = text.strip()
21 | return text
22 |
23 |
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 | text = text.replace('_', ' ')
26 | if keep_punctuation_exact_string:
27 | text = keep_punctuation_exact_string.join(
28 | part.translate(str.maketrans('', '', string.punctuation))
29 | for part in text.split(keep_punctuation_exact_string))
30 | else:
31 | text = text.translate(str.maketrans('', '', string.punctuation))
32 | text = text.lower()
33 | text = re.sub(r'\s+', ' ', text)
34 | return text.strip()
35 |
36 |
37 | class HuggingfaceTokenizer:
38 |
39 | def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41 | self.name = name
42 | self.seq_len = seq_len
43 | self.clean = clean
44 |
45 | # init tokenizer
46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 | self.vocab_size = self.tokenizer.vocab_size
48 |
49 | def __call__(self, sequence, **kwargs):
50 | return_mask = kwargs.pop('return_mask', False)
51 |
52 | # arguments
53 | _kwargs = {'return_tensors': 'pt'}
54 | if self.seq_len is not None:
55 | _kwargs.update({
56 | 'padding': 'max_length',
57 | 'truncation': True,
58 | 'max_length': self.seq_len
59 | })
60 | _kwargs.update(**kwargs)
61 |
62 | # tokenization
63 | if isinstance(sequence, str):
64 | sequence = [sequence]
65 | if self.clean:
66 | sequence = [self._clean(u) for u in sequence]
67 | ids = self.tokenizer(sequence, **_kwargs)
68 |
69 | # output
70 | if return_mask:
71 | return ids.input_ids, ids.attention_mask
72 | else:
73 | return ids.input_ids
74 |
75 | def _clean(self, text):
76 | if self.clean == 'whitespace':
77 | text = whitespace_clean(basic_clean(text))
78 | elif self.clean == 'lower':
79 | text = whitespace_clean(basic_clean(text)).lower()
80 | elif self.clean == 'canonicalize':
81 | text = canonicalize(basic_clean(text))
82 | return text
83 |
--------------------------------------------------------------------------------
/phantom_wan/modules/xlm_roberta.py:
--------------------------------------------------------------------------------
1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | __all__ = ['XLMRoberta', 'xlm_roberta_large']
8 |
9 |
10 | class SelfAttention(nn.Module):
11 |
12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13 | assert dim % num_heads == 0
14 | super().__init__()
15 | self.dim = dim
16 | self.num_heads = num_heads
17 | self.head_dim = dim // num_heads
18 | self.eps = eps
19 |
20 | # layers
21 | self.q = nn.Linear(dim, dim)
22 | self.k = nn.Linear(dim, dim)
23 | self.v = nn.Linear(dim, dim)
24 | self.o = nn.Linear(dim, dim)
25 | self.dropout = nn.Dropout(dropout)
26 |
27 | def forward(self, x, mask):
28 | """
29 | x: [B, L, C].
30 | """
31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32 |
33 | # compute query, key, value
34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37 |
38 | # compute attention
39 | p = self.dropout.p if self.training else 0.0
40 | x = F.scaled_dot_product_attention(q, k, v, mask, p)
41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42 |
43 | # output
44 | x = self.o(x)
45 | x = self.dropout(x)
46 | return x
47 |
48 |
49 | class AttentionBlock(nn.Module):
50 |
51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52 | super().__init__()
53 | self.dim = dim
54 | self.num_heads = num_heads
55 | self.post_norm = post_norm
56 | self.eps = eps
57 |
58 | # layers
59 | self.attn = SelfAttention(dim, num_heads, dropout, eps)
60 | self.norm1 = nn.LayerNorm(dim, eps=eps)
61 | self.ffn = nn.Sequential(
62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63 | nn.Dropout(dropout))
64 | self.norm2 = nn.LayerNorm(dim, eps=eps)
65 |
66 | def forward(self, x, mask):
67 | if self.post_norm:
68 | x = self.norm1(x + self.attn(x, mask))
69 | x = self.norm2(x + self.ffn(x))
70 | else:
71 | x = x + self.attn(self.norm1(x), mask)
72 | x = x + self.ffn(self.norm2(x))
73 | return x
74 |
75 |
76 | class XLMRoberta(nn.Module):
77 | """
78 | XLMRobertaModel with no pooler and no LM head.
79 | """
80 |
81 | def __init__(self,
82 | vocab_size=250002,
83 | max_seq_len=514,
84 | type_size=1,
85 | pad_id=1,
86 | dim=1024,
87 | num_heads=16,
88 | num_layers=24,
89 | post_norm=True,
90 | dropout=0.1,
91 | eps=1e-5):
92 | super().__init__()
93 | self.vocab_size = vocab_size
94 | self.max_seq_len = max_seq_len
95 | self.type_size = type_size
96 | self.pad_id = pad_id
97 | self.dim = dim
98 | self.num_heads = num_heads
99 | self.num_layers = num_layers
100 | self.post_norm = post_norm
101 | self.eps = eps
102 |
103 | # embeddings
104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105 | self.type_embedding = nn.Embedding(type_size, dim)
106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107 | self.dropout = nn.Dropout(dropout)
108 |
109 | # blocks
110 | self.blocks = nn.ModuleList([
111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112 | for _ in range(num_layers)
113 | ])
114 |
115 | # norm layer
116 | self.norm = nn.LayerNorm(dim, eps=eps)
117 |
118 | def forward(self, ids):
119 | """
120 | ids: [B, L] of torch.LongTensor.
121 | """
122 | b, s = ids.shape
123 | mask = ids.ne(self.pad_id).long()
124 |
125 | # embeddings
126 | x = self.token_embedding(ids) + \
127 | self.type_embedding(torch.zeros_like(ids)) + \
128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129 | if self.post_norm:
130 | x = self.norm(x)
131 | x = self.dropout(x)
132 |
133 | # blocks
134 | mask = torch.where(
135 | mask.view(b, 1, 1, s).gt(0), 0.0,
136 | torch.finfo(x.dtype).min)
137 | for block in self.blocks:
138 | x = block(x, mask)
139 |
140 | # output
141 | if not self.post_norm:
142 | x = self.norm(x)
143 | return x
144 |
145 |
146 | def xlm_roberta_large(pretrained=False,
147 | return_tokenizer=False,
148 | device='cpu',
149 | **kwargs):
150 | """
151 | XLMRobertaLarge adapted from Huggingface.
152 | """
153 | # params
154 | cfg = dict(
155 | vocab_size=250002,
156 | max_seq_len=514,
157 | type_size=1,
158 | pad_id=1,
159 | dim=1024,
160 | num_heads=16,
161 | num_layers=24,
162 | post_norm=True,
163 | dropout=0.1,
164 | eps=1e-5)
165 | cfg.update(**kwargs)
166 |
167 | # init a model on device
168 | with torch.device(device):
169 | model = XLMRoberta(**cfg)
170 | return model
171 |
--------------------------------------------------------------------------------
/phantom_wan/subject2video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 |
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 | import logging
17 | import math
18 | import os
19 | import random
20 | import sys
21 | import types
22 | from contextlib import contextmanager
23 | from functools import partial
24 |
25 | import torch
26 | import torch.cuda.amp as amp
27 | import torch.distributed as dist
28 | from tqdm import tqdm
29 | import torchvision.transforms.functional as TF
30 |
31 | from .distributed.fsdp import shard_model
32 | from .modules.model import WanModel
33 | from .modules.t5 import T5EncoderModel
34 | from .modules.vae import WanVAE
35 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
36 | get_sampling_sigmas, retrieve_timesteps)
37 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
38 |
39 |
40 | class Phantom_Wan_S2V:
41 |
42 | def __init__(
43 | self,
44 | config,
45 | checkpoint_dir,
46 | phantom_ckpt,
47 | device_id=0,
48 | rank=0,
49 | t5_fsdp=False,
50 | dit_fsdp=False,
51 | use_usp=False,
52 | t5_cpu=False,
53 | ):
54 | r"""
55 | Initializes the Wan text-to-video generation model components.
56 |
57 | Args:
58 | config (EasyDict):
59 | Object containing model parameters initialized from config.py
60 | checkpoint_dir (`str`):
61 | Path to directory containing model checkpoints
62 | phantom_ckpt (`str`):
63 | Path of Phantom-Wan dit checkpoint
64 | device_id (`int`, *optional*, defaults to 0):
65 | Id of target GPU device
66 | rank (`int`, *optional*, defaults to 0):
67 | Process rank for distributed training
68 | t5_fsdp (`bool`, *optional*, defaults to False):
69 | Enable FSDP sharding for T5 model
70 | dit_fsdp (`bool`, *optional*, defaults to False):
71 | Enable FSDP sharding for DiT model
72 | use_usp (`bool`, *optional*, defaults to False):
73 | Enable distribution strategy of USP.
74 | t5_cpu (`bool`, *optional*, defaults to False):
75 | Whether to place T5 model on CPU. Only works without t5_fsdp.
76 | """
77 | self.device = torch.device(f"cuda:{device_id}")
78 | self.config = config
79 | self.rank = rank
80 | self.t5_cpu = t5_cpu
81 |
82 | self.num_train_timesteps = config.num_train_timesteps
83 | self.param_dtype = config.param_dtype
84 |
85 | shard_fn = partial(shard_model, device_id=device_id)
86 | self.text_encoder = T5EncoderModel(
87 | text_len=config.text_len,
88 | dtype=config.t5_dtype,
89 | device=torch.device('cpu'),
90 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
91 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
92 | shard_fn=shard_fn if t5_fsdp else None)
93 |
94 | self.vae_stride = config.vae_stride
95 | self.patch_size = config.patch_size
96 | self.vae = WanVAE(
97 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
98 | device=self.device)
99 |
100 | logging.info(f"Creating WanModel from {phantom_ckpt}")
101 | self.model = WanModel(dim=config.dim,
102 | ffn_dim=config.ffn_dim,
103 | freq_dim=config.freq_dim,
104 | num_heads=config.num_heads,
105 | num_layers=config.num_layers,
106 | window_size=config.window_size,
107 | qk_norm=config.qk_norm,
108 | cross_attn_norm=config.cross_attn_norm,
109 | eps=config.eps)
110 | logging.info(f"loading ckpt.")
111 | state = torch.load(phantom_ckpt, map_location=self.device)
112 | logging.info(f"loading state dict.")
113 | self.model.load_state_dict(state, strict=False)
114 | # self.model = WanModel.from_pretrained(checkpoint_dir)
115 | self.model.eval().requires_grad_(False)
116 |
117 | if use_usp:
118 | from xfuser.core.distributed import \
119 | get_sequence_parallel_world_size
120 |
121 | from .distributed.xdit_context_parallel import (usp_attn_forward,
122 | usp_dit_forward)
123 | for block in self.model.blocks:
124 | block.self_attn.forward = types.MethodType(
125 | usp_attn_forward, block.self_attn)
126 | self.model.forward = types.MethodType(usp_dit_forward, self.model)
127 | self.sp_size = get_sequence_parallel_world_size()
128 | else:
129 | self.sp_size = 1
130 |
131 | if dist.is_initialized():
132 | dist.barrier()
133 | if dit_fsdp:
134 | self.model = shard_fn(self.model)
135 | else:
136 | self.model.to(self.device)
137 |
138 | self.sample_neg_prompt = config.sample_neg_prompt
139 |
140 |
141 | def get_vae_latents(self, ref_images, device):
142 | ref_vae_latents = []
143 | for ref_image in ref_images:
144 | ref_image = TF.to_tensor(ref_image).sub_(0.5).div_(0.5).to(self.device)
145 | img_vae_latent = self.vae.encode([ref_image.unsqueeze(1)])
146 | ref_vae_latents.append(img_vae_latent[0])
147 |
148 | return torch.cat(ref_vae_latents, dim=1)
149 |
150 |
151 | def generate(self,
152 | input_prompt,
153 | ref_images,
154 | size=(1280, 720),
155 | frame_num=81,
156 | shift=5.0,
157 | sample_solver='unipc',
158 | sampling_steps=50,
159 | guide_scale_img=5.0,
160 | guide_scale_text=7.5,
161 | n_prompt="",
162 | seed=-1,
163 | offload_model=True):
164 | r"""
165 | Generates video frames from text prompt using diffusion process.
166 |
167 | Args:
168 | input_prompt (`str`):
169 | Text prompt for content generation
170 | ref_images ([`PIL.Image`])`:
171 | Reference images for subject generation.
172 | size (tupele[`int`], *optional*, defaults to (1280,720)):
173 | Controls video resolution, (width,height).
174 | frame_num (`int`, *optional*, defaults to 81):
175 | How many frames to sample from a video. The number should be 4n+1
176 | shift (`float`, *optional*, defaults to 5.0):
177 | Noise schedule shift parameter. Affects temporal dynamics
178 | sample_solver (`str`, *optional*, defaults to 'unipc'):
179 | Solver used to sample the video.
180 | sampling_steps (`int`, *optional*, defaults to 40):
181 | Number of diffusion sampling steps. Higher values improve quality but slow generation
182 | guide_scale (`float`, *optional*, defaults 5.0):
183 | Classifier-free guidance scale. Controls prompt adherence vs. creativity
184 | n_prompt (`str`, *optional*, defaults to ""):
185 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
186 | seed (`int`, *optional*, defaults to -1):
187 | Random seed for noise generation. If -1, use random seed.
188 | offload_model (`bool`, *optional*, defaults to True):
189 | If True, offloads models to CPU during generation to save VRAM
190 |
191 | Returns:
192 | torch.Tensor:
193 | Generated video frames tensor. Dimensions: (C, N H, W) where:
194 | - C: Color channels (3 for RGB)
195 | - N: Number of frames (81)
196 | - H: Frame height (from size)
197 | - W: Frame width from size)
198 | """
199 | # preprocess
200 |
201 | ref_latents = [self.get_vae_latents(ref_images, self.device)]
202 | ref_latents_neg = [torch.zeros_like(ref_latents[0])]
203 |
204 | F = frame_num
205 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + ref_latents[0].shape[1],
206 | size[1] // self.vae_stride[1],
207 | size[0] // self.vae_stride[2])
208 |
209 | seq_len = math.ceil((target_shape[2] * target_shape[3]) /
210 | (self.patch_size[1] * self.patch_size[2]) *
211 | target_shape[1] / self.sp_size) * self.sp_size
212 |
213 | if n_prompt == "":
214 | n_prompt = self.sample_neg_prompt
215 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
216 | seed_g = torch.Generator(device=self.device)
217 | seed_g.manual_seed(seed)
218 |
219 | if not self.t5_cpu:
220 | self.text_encoder.model.to(self.device)
221 | context = self.text_encoder([input_prompt], self.device)
222 | context_null = self.text_encoder([n_prompt], self.device)
223 | if offload_model:
224 | self.text_encoder.model.cpu()
225 | else:
226 | context = self.text_encoder([input_prompt], torch.device('cpu'))
227 | context_null = self.text_encoder([n_prompt], torch.device('cpu'))
228 | context = [t.to(self.device) for t in context]
229 | context_null = [t.to(self.device) for t in context_null]
230 |
231 | noise = [
232 | torch.randn(
233 | target_shape[0],
234 | target_shape[1],
235 | target_shape[2],
236 | target_shape[3],
237 | dtype=torch.float32,
238 | device=self.device,
239 | generator=seed_g)
240 | ]
241 |
242 | @contextmanager
243 | def noop_no_sync():
244 | yield
245 |
246 | no_sync = getattr(self.model, 'no_sync', noop_no_sync)
247 |
248 | # evaluation mode
249 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
250 |
251 | if sample_solver == 'unipc':
252 | sample_scheduler = FlowUniPCMultistepScheduler(
253 | num_train_timesteps=self.num_train_timesteps,
254 | shift=1,
255 | use_dynamic_shifting=False)
256 | sample_scheduler.set_timesteps(
257 | sampling_steps, device=self.device, shift=shift)
258 | timesteps = sample_scheduler.timesteps
259 | else:
260 | raise NotImplementedError("Unsupported solver.")
261 |
262 | # sample videos
263 | latents = noise
264 |
265 | arg_c = {'context': context, 'seq_len': seq_len}
266 | arg_null = {'context': context_null, 'seq_len': seq_len}
267 |
268 | for _, t in enumerate(tqdm(timesteps)):
269 |
270 | timestep = [t]
271 | timestep = torch.stack(timestep)
272 |
273 | self.model.to(self.device)
274 | pos_it = self.model(
275 | [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latents, ref_latents)], t=timestep, **arg_c
276 | )[0]
277 | pos_i = self.model(
278 | [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latents, ref_latents)], t=timestep, **arg_null
279 | )[0]
280 | neg = self.model(
281 | [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latents, ref_latents_neg)], t=timestep, **arg_null
282 | )[0]
283 |
284 | noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
285 |
286 | temp_x0 = sample_scheduler.step(
287 | noise_pred.unsqueeze(0),
288 | t,
289 | latents[0].unsqueeze(0),
290 | return_dict=False,
291 | generator=seed_g)[0]
292 | latents = [temp_x0.squeeze(0)]
293 |
294 | x0 = latents
295 | x0 = [x0_[:,:-ref_latents[0].shape[1]] for x0_ in x0]
296 |
297 | if offload_model:
298 | self.model.cpu()
299 | if self.rank == 0:
300 | videos = self.vae.decode(x0)
301 |
302 | del noise, latents
303 | del sample_scheduler
304 | if offload_model:
305 | gc.collect()
306 | torch.cuda.synchronize()
307 | if dist.is_initialized():
308 | dist.barrier()
309 |
310 | return videos[0] if self.rank == 0 else None
311 |
--------------------------------------------------------------------------------
/phantom_wan/text2video.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import gc
3 | import logging
4 | import math
5 | import os
6 | import random
7 | import sys
8 | import types
9 | from contextlib import contextmanager
10 | from functools import partial
11 |
12 | import torch
13 | import torch.cuda.amp as amp
14 | import torch.distributed as dist
15 | from tqdm import tqdm
16 |
17 | from .distributed.fsdp import shard_model
18 | from .modules.model import WanModel
19 | from .modules.t5 import T5EncoderModel
20 | from .modules.vae import WanVAE
21 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
22 | get_sampling_sigmas, retrieve_timesteps)
23 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24 |
25 |
26 | class WanT2V:
27 |
28 | def __init__(
29 | self,
30 | config,
31 | checkpoint_dir,
32 | device_id=0,
33 | rank=0,
34 | t5_fsdp=False,
35 | dit_fsdp=False,
36 | use_usp=False,
37 | t5_cpu=False,
38 | ):
39 | r"""
40 | Initializes the Wan text-to-video generation model components.
41 |
42 | Args:
43 | config (EasyDict):
44 | Object containing model parameters initialized from config.py
45 | checkpoint_dir (`str`):
46 | Path to directory containing model checkpoints
47 | device_id (`int`, *optional*, defaults to 0):
48 | Id of target GPU device
49 | rank (`int`, *optional*, defaults to 0):
50 | Process rank for distributed training
51 | t5_fsdp (`bool`, *optional*, defaults to False):
52 | Enable FSDP sharding for T5 model
53 | dit_fsdp (`bool`, *optional*, defaults to False):
54 | Enable FSDP sharding for DiT model
55 | use_usp (`bool`, *optional*, defaults to False):
56 | Enable distribution strategy of USP.
57 | t5_cpu (`bool`, *optional*, defaults to False):
58 | Whether to place T5 model on CPU. Only works without t5_fsdp.
59 | """
60 | self.device = torch.device(f"cuda:{device_id}")
61 | self.config = config
62 | self.rank = rank
63 | self.t5_cpu = t5_cpu
64 |
65 | self.num_train_timesteps = config.num_train_timesteps
66 | self.param_dtype = config.param_dtype
67 |
68 | shard_fn = partial(shard_model, device_id=device_id)
69 | self.text_encoder = T5EncoderModel(
70 | text_len=config.text_len,
71 | dtype=config.t5_dtype,
72 | device=torch.device('cpu'),
73 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
74 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
75 | shard_fn=shard_fn if t5_fsdp else None)
76 |
77 | self.vae_stride = config.vae_stride
78 | self.patch_size = config.patch_size
79 | self.vae = WanVAE(
80 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
81 | device=self.device)
82 |
83 | logging.info(f"Creating WanModel from {checkpoint_dir}")
84 | # self.model = WanModel()
85 | # logging.info(f"loading ckpt.")
86 | # state = torch.load("/mnt/bn/matianxiang-data-hl2/vfm-prod/downloads/Wan2.1-T2V-14B-dit.pth", map_location=self.device)
87 | # logging.info(f"loading state dict.")
88 | # self.model.load_state_dict(state, strict=False)
89 | self.model = WanModel.from_pretrained(checkpoint_dir)
90 | self.model.eval().requires_grad_(False)
91 |
92 | if use_usp:
93 | from xfuser.core.distributed import \
94 | get_sequence_parallel_world_size
95 |
96 | from .distributed.xdit_context_parallel import (usp_attn_forward,
97 | usp_dit_forward)
98 | for block in self.model.blocks:
99 | block.self_attn.forward = types.MethodType(
100 | usp_attn_forward, block.self_attn)
101 | self.model.forward = types.MethodType(usp_dit_forward, self.model)
102 | self.sp_size = get_sequence_parallel_world_size()
103 | else:
104 | self.sp_size = 1
105 |
106 | if dist.is_initialized():
107 | dist.barrier()
108 | if dit_fsdp:
109 | self.model = shard_fn(self.model)
110 | else:
111 | self.model.to(self.device)
112 |
113 | self.sample_neg_prompt = config.sample_neg_prompt
114 |
115 | def generate(self,
116 | input_prompt,
117 | size=(1280, 720),
118 | frame_num=81,
119 | shift=5.0,
120 | sample_solver='unipc',
121 | sampling_steps=50,
122 | guide_scale=5.0,
123 | n_prompt="",
124 | seed=-1,
125 | offload_model=True):
126 | r"""
127 | Generates video frames from text prompt using diffusion process.
128 |
129 | Args:
130 | input_prompt (`str`):
131 | Text prompt for content generation
132 | size (tupele[`int`], *optional*, defaults to (1280,720)):
133 | Controls video resolution, (width,height).
134 | frame_num (`int`, *optional*, defaults to 81):
135 | How many frames to sample from a video. The number should be 4n+1
136 | shift (`float`, *optional*, defaults to 5.0):
137 | Noise schedule shift parameter. Affects temporal dynamics
138 | sample_solver (`str`, *optional*, defaults to 'unipc'):
139 | Solver used to sample the video.
140 | sampling_steps (`int`, *optional*, defaults to 40):
141 | Number of diffusion sampling steps. Higher values improve quality but slow generation
142 | guide_scale (`float`, *optional*, defaults 5.0):
143 | Classifier-free guidance scale. Controls prompt adherence vs. creativity
144 | n_prompt (`str`, *optional*, defaults to ""):
145 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
146 | seed (`int`, *optional*, defaults to -1):
147 | Random seed for noise generation. If -1, use random seed.
148 | offload_model (`bool`, *optional*, defaults to True):
149 | If True, offloads models to CPU during generation to save VRAM
150 |
151 | Returns:
152 | torch.Tensor:
153 | Generated video frames tensor. Dimensions: (C, N H, W) where:
154 | - C: Color channels (3 for RGB)
155 | - N: Number of frames (81)
156 | - H: Frame height (from size)
157 | - W: Frame width from size)
158 | """
159 | # preprocess
160 | F = frame_num
161 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
162 | size[1] // self.vae_stride[1],
163 | size[0] // self.vae_stride[2])
164 |
165 | seq_len = math.ceil((target_shape[2] * target_shape[3]) /
166 | (self.patch_size[1] * self.patch_size[2]) *
167 | target_shape[1] / self.sp_size) * self.sp_size
168 |
169 | if n_prompt == "":
170 | n_prompt = self.sample_neg_prompt
171 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
172 | seed_g = torch.Generator(device=self.device)
173 | seed_g.manual_seed(seed)
174 |
175 | if not self.t5_cpu:
176 | self.text_encoder.model.to(self.device)
177 | context = self.text_encoder([input_prompt], self.device)
178 | context_null = self.text_encoder([n_prompt], self.device)
179 | if offload_model:
180 | self.text_encoder.model.cpu()
181 | else:
182 | context = self.text_encoder([input_prompt], torch.device('cpu'))
183 | context_null = self.text_encoder([n_prompt], torch.device('cpu'))
184 | context = [t.to(self.device) for t in context]
185 | context_null = [t.to(self.device) for t in context_null]
186 |
187 | noise = [
188 | torch.randn(
189 | target_shape[0],
190 | target_shape[1],
191 | target_shape[2],
192 | target_shape[3],
193 | dtype=torch.float32,
194 | device=self.device,
195 | generator=seed_g)
196 | ]
197 |
198 | @contextmanager
199 | def noop_no_sync():
200 | yield
201 |
202 | no_sync = getattr(self.model, 'no_sync', noop_no_sync)
203 |
204 | # evaluation mode
205 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
206 |
207 | if sample_solver == 'unipc':
208 | sample_scheduler = FlowUniPCMultistepScheduler(
209 | num_train_timesteps=self.num_train_timesteps,
210 | shift=1,
211 | use_dynamic_shifting=False)
212 | sample_scheduler.set_timesteps(
213 | sampling_steps, device=self.device, shift=shift)
214 | timesteps = sample_scheduler.timesteps
215 | elif sample_solver == 'dpm++':
216 | sample_scheduler = FlowDPMSolverMultistepScheduler(
217 | num_train_timesteps=self.num_train_timesteps,
218 | shift=1,
219 | use_dynamic_shifting=False)
220 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
221 | timesteps, _ = retrieve_timesteps(
222 | sample_scheduler,
223 | device=self.device,
224 | sigmas=sampling_sigmas)
225 | else:
226 | raise NotImplementedError("Unsupported solver.")
227 |
228 | # sample videos
229 | latents = noise
230 |
231 | arg_c = {'context': context, 'seq_len': seq_len}
232 | arg_null = {'context': context_null, 'seq_len': seq_len}
233 |
234 | for _, t in enumerate(tqdm(timesteps)):
235 | latent_model_input = latents
236 | timestep = [t]
237 |
238 | timestep = torch.stack(timestep)
239 |
240 | self.model.to(self.device)
241 | noise_pred_cond = self.model(
242 | latent_model_input, t=timestep, **arg_c)[0]
243 | noise_pred_uncond = self.model(
244 | latent_model_input, t=timestep, **arg_null)[0]
245 |
246 | noise_pred = noise_pred_uncond + guide_scale * (
247 | noise_pred_cond - noise_pred_uncond)
248 |
249 | temp_x0 = sample_scheduler.step(
250 | noise_pred.unsqueeze(0),
251 | t,
252 | latents[0].unsqueeze(0),
253 | return_dict=False,
254 | generator=seed_g)[0]
255 | latents = [temp_x0.squeeze(0)]
256 |
257 | x0 = latents
258 | if offload_model:
259 | self.model.cpu()
260 | if self.rank == 0:
261 | videos = self.vae.decode(x0)
262 |
263 | del noise, latents
264 | del sample_scheduler
265 | if offload_model:
266 | gc.collect()
267 | torch.cuda.synchronize()
268 | if dist.is_initialized():
269 | dist.barrier()
270 |
271 | return videos[0] if self.rank == 0 else None
272 |
--------------------------------------------------------------------------------
/phantom_wan/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
2 | retrieve_timesteps)
3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler
4 |
5 | __all__ = [
6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
8 | ]
9 |
--------------------------------------------------------------------------------
/phantom_wan/utils/qwen_vl_utils.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/kq-chen/qwen-vl-utils
2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3 | from __future__ import annotations
4 |
5 | import base64
6 | import logging
7 | import math
8 | import os
9 | import sys
10 | import time
11 | import warnings
12 | from functools import lru_cache
13 | from io import BytesIO
14 |
15 | import requests
16 | import torch
17 | import torchvision
18 | from packaging import version
19 | from PIL import Image
20 | from torchvision import io, transforms
21 | from torchvision.transforms import InterpolationMode
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | IMAGE_FACTOR = 28
26 | MIN_PIXELS = 4 * 28 * 28
27 | MAX_PIXELS = 16384 * 28 * 28
28 | MAX_RATIO = 200
29 |
30 | VIDEO_MIN_PIXELS = 128 * 28 * 28
31 | VIDEO_MAX_PIXELS = 768 * 28 * 28
32 | VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
33 | FRAME_FACTOR = 2
34 | FPS = 2.0
35 | FPS_MIN_FRAMES = 4
36 | FPS_MAX_FRAMES = 768
37 |
38 |
39 | def round_by_factor(number: int, factor: int) -> int:
40 | """Returns the closest integer to 'number' that is divisible by 'factor'."""
41 | return round(number / factor) * factor
42 |
43 |
44 | def ceil_by_factor(number: int, factor: int) -> int:
45 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
46 | return math.ceil(number / factor) * factor
47 |
48 |
49 | def floor_by_factor(number: int, factor: int) -> int:
50 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
51 | return math.floor(number / factor) * factor
52 |
53 |
54 | def smart_resize(height: int,
55 | width: int,
56 | factor: int = IMAGE_FACTOR,
57 | min_pixels: int = MIN_PIXELS,
58 | max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
59 | """
60 | Rescales the image so that the following conditions are met:
61 |
62 | 1. Both dimensions (height and width) are divisible by 'factor'.
63 |
64 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
65 |
66 | 3. The aspect ratio of the image is maintained as closely as possible.
67 | """
68 | if max(height, width) / min(height, width) > MAX_RATIO:
69 | raise ValueError(
70 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
71 | )
72 | h_bar = max(factor, round_by_factor(height, factor))
73 | w_bar = max(factor, round_by_factor(width, factor))
74 | if h_bar * w_bar > max_pixels:
75 | beta = math.sqrt((height * width) / max_pixels)
76 | h_bar = floor_by_factor(height / beta, factor)
77 | w_bar = floor_by_factor(width / beta, factor)
78 | elif h_bar * w_bar < min_pixels:
79 | beta = math.sqrt(min_pixels / (height * width))
80 | h_bar = ceil_by_factor(height * beta, factor)
81 | w_bar = ceil_by_factor(width * beta, factor)
82 | return h_bar, w_bar
83 |
84 |
85 | def fetch_image(ele: dict[str, str | Image.Image],
86 | size_factor: int = IMAGE_FACTOR) -> Image.Image:
87 | if "image" in ele:
88 | image = ele["image"]
89 | else:
90 | image = ele["image_url"]
91 | image_obj = None
92 | if isinstance(image, Image.Image):
93 | image_obj = image
94 | elif image.startswith("http://") or image.startswith("https://"):
95 | image_obj = Image.open(requests.get(image, stream=True).raw)
96 | elif image.startswith("file://"):
97 | image_obj = Image.open(image[7:])
98 | elif image.startswith("data:image"):
99 | if "base64," in image:
100 | _, base64_data = image.split("base64,", 1)
101 | data = base64.b64decode(base64_data)
102 | image_obj = Image.open(BytesIO(data))
103 | else:
104 | image_obj = Image.open(image)
105 | if image_obj is None:
106 | raise ValueError(
107 | f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
108 | )
109 | image = image_obj.convert("RGB")
110 | ## resize
111 | if "resized_height" in ele and "resized_width" in ele:
112 | resized_height, resized_width = smart_resize(
113 | ele["resized_height"],
114 | ele["resized_width"],
115 | factor=size_factor,
116 | )
117 | else:
118 | width, height = image.size
119 | min_pixels = ele.get("min_pixels", MIN_PIXELS)
120 | max_pixels = ele.get("max_pixels", MAX_PIXELS)
121 | resized_height, resized_width = smart_resize(
122 | height,
123 | width,
124 | factor=size_factor,
125 | min_pixels=min_pixels,
126 | max_pixels=max_pixels,
127 | )
128 | image = image.resize((resized_width, resized_height))
129 |
130 | return image
131 |
132 |
133 | def smart_nframes(
134 | ele: dict,
135 | total_frames: int,
136 | video_fps: int | float,
137 | ) -> int:
138 | """calculate the number of frames for video used for model inputs.
139 |
140 | Args:
141 | ele (dict): a dict contains the configuration of video.
142 | support either `fps` or `nframes`:
143 | - nframes: the number of frames to extract for model inputs.
144 | - fps: the fps to extract frames for model inputs.
145 | - min_frames: the minimum number of frames of the video, only used when fps is provided.
146 | - max_frames: the maximum number of frames of the video, only used when fps is provided.
147 | total_frames (int): the original total number of frames of the video.
148 | video_fps (int | float): the original fps of the video.
149 |
150 | Raises:
151 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
152 |
153 | Returns:
154 | int: the number of frames for video used for model inputs.
155 | """
156 | assert not ("fps" in ele and
157 | "nframes" in ele), "Only accept either `fps` or `nframes`"
158 | if "nframes" in ele:
159 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
160 | else:
161 | fps = ele.get("fps", FPS)
162 | min_frames = ceil_by_factor(
163 | ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
164 | max_frames = floor_by_factor(
165 | ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
166 | FRAME_FACTOR)
167 | nframes = total_frames / video_fps * fps
168 | nframes = min(max(nframes, min_frames), max_frames)
169 | nframes = round_by_factor(nframes, FRAME_FACTOR)
170 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
171 | raise ValueError(
172 | f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
173 | )
174 | return nframes
175 |
176 |
177 | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
178 | """read video using torchvision.io.read_video
179 |
180 | Args:
181 | ele (dict): a dict contains the configuration of video.
182 | support keys:
183 | - video: the path of video. support "file://", "http://", "https://" and local path.
184 | - video_start: the start time of video.
185 | - video_end: the end time of video.
186 | Returns:
187 | torch.Tensor: the video tensor with shape (T, C, H, W).
188 | """
189 | video_path = ele["video"]
190 | if version.parse(torchvision.__version__) < version.parse("0.19.0"):
191 | if "http://" in video_path or "https://" in video_path:
192 | warnings.warn(
193 | "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
194 | )
195 | if "file://" in video_path:
196 | video_path = video_path[7:]
197 | st = time.time()
198 | video, audio, info = io.read_video(
199 | video_path,
200 | start_pts=ele.get("video_start", 0.0),
201 | end_pts=ele.get("video_end", None),
202 | pts_unit="sec",
203 | output_format="TCHW",
204 | )
205 | total_frames, video_fps = video.size(0), info["video_fps"]
206 | logger.info(
207 | f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
208 | )
209 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
210 | idx = torch.linspace(0, total_frames - 1, nframes).round().long()
211 | video = video[idx]
212 | return video
213 |
214 |
215 | def is_decord_available() -> bool:
216 | import importlib.util
217 |
218 | return importlib.util.find_spec("decord") is not None
219 |
220 |
221 | def _read_video_decord(ele: dict,) -> torch.Tensor:
222 | """read video using decord.VideoReader
223 |
224 | Args:
225 | ele (dict): a dict contains the configuration of video.
226 | support keys:
227 | - video: the path of video. support "file://", "http://", "https://" and local path.
228 | - video_start: the start time of video.
229 | - video_end: the end time of video.
230 | Returns:
231 | torch.Tensor: the video tensor with shape (T, C, H, W).
232 | """
233 | import decord
234 | video_path = ele["video"]
235 | st = time.time()
236 | vr = decord.VideoReader(video_path)
237 | # TODO: support start_pts and end_pts
238 | if 'video_start' in ele or 'video_end' in ele:
239 | raise NotImplementedError(
240 | "not support start_pts and end_pts in decord for now.")
241 | total_frames, video_fps = len(vr), vr.get_avg_fps()
242 | logger.info(
243 | f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
244 | )
245 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
246 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
247 | video = vr.get_batch(idx).asnumpy()
248 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
249 | return video
250 |
251 |
252 | VIDEO_READER_BACKENDS = {
253 | "decord": _read_video_decord,
254 | "torchvision": _read_video_torchvision,
255 | }
256 |
257 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
258 |
259 |
260 | @lru_cache(maxsize=1)
261 | def get_video_reader_backend() -> str:
262 | if FORCE_QWENVL_VIDEO_READER is not None:
263 | video_reader_backend = FORCE_QWENVL_VIDEO_READER
264 | elif is_decord_available():
265 | video_reader_backend = "decord"
266 | else:
267 | video_reader_backend = "torchvision"
268 | print(
269 | f"qwen-vl-utils using {video_reader_backend} to read video.",
270 | file=sys.stderr)
271 | return video_reader_backend
272 |
273 |
274 | def fetch_video(
275 | ele: dict,
276 | image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
277 | if isinstance(ele["video"], str):
278 | video_reader_backend = get_video_reader_backend()
279 | video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
280 | nframes, _, height, width = video.shape
281 |
282 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
283 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
284 | max_pixels = max(
285 | min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
286 | int(min_pixels * 1.05))
287 | max_pixels = ele.get("max_pixels", max_pixels)
288 | if "resized_height" in ele and "resized_width" in ele:
289 | resized_height, resized_width = smart_resize(
290 | ele["resized_height"],
291 | ele["resized_width"],
292 | factor=image_factor,
293 | )
294 | else:
295 | resized_height, resized_width = smart_resize(
296 | height,
297 | width,
298 | factor=image_factor,
299 | min_pixels=min_pixels,
300 | max_pixels=max_pixels,
301 | )
302 | video = transforms.functional.resize(
303 | video,
304 | [resized_height, resized_width],
305 | interpolation=InterpolationMode.BICUBIC,
306 | antialias=True,
307 | ).float()
308 | return video
309 | else:
310 | assert isinstance(ele["video"], (list, tuple))
311 | process_info = ele.copy()
312 | process_info.pop("type", None)
313 | process_info.pop("video", None)
314 | images = [
315 | fetch_image({
316 | "image": video_element,
317 | **process_info
318 | },
319 | size_factor=image_factor)
320 | for video_element in ele["video"]
321 | ]
322 | nframes = ceil_by_factor(len(images), FRAME_FACTOR)
323 | if len(images) < nframes:
324 | images.extend([images[-1]] * (nframes - len(images)))
325 | return images
326 |
327 |
328 | def extract_vision_info(
329 | conversations: list[dict] | list[list[dict]]) -> list[dict]:
330 | vision_infos = []
331 | if isinstance(conversations[0], dict):
332 | conversations = [conversations]
333 | for conversation in conversations:
334 | for message in conversation:
335 | if isinstance(message["content"], list):
336 | for ele in message["content"]:
337 | if ("image" in ele or "image_url" in ele or
338 | "video" in ele or
339 | ele["type"] in ("image", "image_url", "video")):
340 | vision_infos.append(ele)
341 | return vision_infos
342 |
343 |
344 | def process_vision_info(
345 | conversations: list[dict] | list[list[dict]],
346 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
347 | None]:
348 | vision_infos = extract_vision_info(conversations)
349 | ## Read images or videos
350 | image_inputs = []
351 | video_inputs = []
352 | for vision_info in vision_infos:
353 | if "image" in vision_info or "image_url" in vision_info:
354 | image_inputs.append(fetch_image(vision_info))
355 | elif "video" in vision_info:
356 | video_inputs.append(fetch_video(vision_info))
357 | else:
358 | raise ValueError("image, image_url or video should in content.")
359 | if len(image_inputs) == 0:
360 | image_inputs = None
361 | if len(video_inputs) == 0:
362 | video_inputs = None
363 | return image_inputs, video_inputs
364 |
--------------------------------------------------------------------------------
/phantom_wan/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2 | import argparse
3 | import binascii
4 | import os
5 | import os.path as osp
6 |
7 | import imageio
8 | import torch
9 | import torchvision
10 |
11 | __all__ = ['cache_video', 'cache_image', 'str2bool']
12 |
13 |
14 | def rand_name(length=8, suffix=''):
15 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
16 | if suffix:
17 | if not suffix.startswith('.'):
18 | suffix = '.' + suffix
19 | name += suffix
20 | return name
21 |
22 |
23 | def cache_video(tensor,
24 | save_file=None,
25 | fps=30,
26 | suffix='.mp4',
27 | nrow=8,
28 | normalize=True,
29 | value_range=(-1, 1),
30 | retry=5):
31 | # cache file
32 | cache_file = osp.join('/tmp', rand_name(
33 | suffix=suffix)) if save_file is None else save_file
34 |
35 | # save to cache
36 | error = None
37 | for _ in range(retry):
38 | try:
39 | # preprocess
40 | tensor = tensor.clamp(min(value_range), max(value_range))
41 | tensor = torch.stack([
42 | torchvision.utils.make_grid(
43 | u, nrow=nrow, normalize=normalize, value_range=value_range)
44 | for u in tensor.unbind(2)
45 | ],
46 | dim=1).permute(1, 2, 3, 0)
47 | tensor = (tensor * 255).type(torch.uint8).cpu()
48 |
49 | # write video
50 | writer = imageio.get_writer(
51 | cache_file, fps=fps, codec='libx264', quality=8)
52 | for frame in tensor.numpy():
53 | writer.append_data(frame)
54 | writer.close()
55 | return cache_file
56 | except Exception as e:
57 | error = e
58 | continue
59 | else:
60 | print(f'cache_video failed, error: {error}', flush=True)
61 | return None
62 |
63 |
64 | def cache_image(tensor,
65 | save_file,
66 | nrow=8,
67 | normalize=True,
68 | value_range=(-1, 1),
69 | retry=5):
70 | # cache file
71 | suffix = osp.splitext(save_file)[1]
72 | if suffix.lower() not in [
73 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
74 | ]:
75 | suffix = '.png'
76 |
77 | # save to cache
78 | error = None
79 | for _ in range(retry):
80 | try:
81 | tensor = tensor.clamp(min(value_range), max(value_range))
82 | torchvision.utils.save_image(
83 | tensor,
84 | save_file,
85 | nrow=nrow,
86 | normalize=normalize,
87 | value_range=value_range)
88 | return save_file
89 | except Exception as e:
90 | error = e
91 | continue
92 |
93 |
94 | def str2bool(v):
95 | """
96 | Convert a string to a boolean.
97 |
98 | Supported true values: 'yes', 'true', 't', 'y', '1'
99 | Supported false values: 'no', 'false', 'f', 'n', '0'
100 |
101 | Args:
102 | v (str): String to convert.
103 |
104 | Returns:
105 | bool: Converted boolean value.
106 |
107 | Raises:
108 | argparse.ArgumentTypeError: If the value cannot be converted to boolean.
109 | """
110 | if isinstance(v, bool):
111 | return v
112 | v_lower = v.lower()
113 | if v_lower in ('yes', 'true', 't', 'y', '1'):
114 | return True
115 | elif v_lower in ('no', 'false', 'f', 'n', '0'):
116 | return False
117 | else:
118 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
119 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.4.0
2 | torchvision>=0.19.0
3 | opencv-python>=4.9.0.80
4 | diffusers>=0.31.0
5 | transformers>=4.49.0
6 | tokenizers>=0.20.3
7 | accelerate>=1.1.1
8 | tqdm
9 | imageio
10 | easydict
11 | ftfy
12 | dashscope
13 | imageio-ffmpeg
14 | flash_attn
15 | gradio>=5.0.0
16 | numpy>=1.23.5,<2
17 | xfuser>=0.4.1
--------------------------------------------------------------------------------