├── .gitattributes ├── .gitignore ├── README.md ├── assets ├── 16_comparison.gif ├── 1_comparison.gif ├── 3_comparison.gif ├── 4_comparison.gif ├── 7_comparison.gif ├── 8_comparison.gif ├── Cogview4 │ ├── cfg.png │ └── ours.png ├── HiDream │ ├── cat_cfg.png │ └── cat_ours.png ├── Qwen2.5 │ ├── output-origin.mp3 │ └── output-ours.mp3 ├── easycontrol │ ├── image.webp │ ├── image_CFG.webp │ └── image_CFG_zero_star.webp ├── flux │ ├── image_cfg.png │ ├── image_ours.png │ └── lora │ │ ├── image_cfg_ds.png │ │ └── image_ours_ds.png ├── hunyuan │ ├── 376559893_output_cfg.gif │ └── 376559893_output_ours.gif ├── repo_teaser.jpg ├── sd3 │ ├── output_cfg.png │ └── output_ours.png └── wan2.1 │ ├── 1270611998_base.gif │ ├── 1270611998_ours.gif │ ├── 1306980124_base.gif │ ├── 1306980124_ours.gif │ ├── 1322140014_base.gif │ ├── 1322140014_ours.gif │ ├── 158241056_base.gif │ ├── 158241056_ours.gif │ ├── I2V_CFG.gif │ ├── I2V_Ours.gif │ ├── i2v-14B_832_480_cfg_3549111921.gif │ ├── i2v-14B_832_480_ours_3549111921.gif │ └── i2v_input.JPG ├── demo.py ├── models ├── Cogview4 │ ├── infer.py │ └── pipeline.py ├── HiDream │ └── pipeline.py ├── Qwen2.5 │ ├── infer.py │ └── qw_model.py ├── easycontrol │ ├── infer.py │ └── src │ │ ├── __init__.py │ │ ├── layers_cache.py │ │ ├── lora_helper.py │ │ ├── pipeline.py │ │ └── transformer_flux.py ├── flux │ ├── Guidance_distilled.py │ ├── infer_lora.py │ └── pipeline.py ├── hunyuan │ ├── pipeline.py │ └── t2v.py ├── sd │ ├── infer.py │ └── sd3_pipeline.py └── wan │ ├── T2V_infer.py │ ├── image2video_cfg_zero_star.py │ └── wan_pipeline.py ├── requirements.txt └── tools └── convert_to_gif.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gif filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # .* 2 | *.py[cod] 3 | # *.jpg 4 | *.jpeg 5 | # *.png 6 | # *.gif 7 | *.bmp 8 | *.mp4 9 | *.mov 10 | *.mkv 11 | *.log 12 | *.zip 13 | *.pt 14 | *.pth 15 | *.ckpt 16 | *.safetensors 17 | *.json 18 | # *.txt 19 | *.backup 20 | *.pkl 21 | *.html 22 | *.pdf 23 | *.whl 24 | cache 25 | __pycache__/ 26 | storage/ 27 | samples/ 28 | !.gitignore 29 | !requirements.txt 30 | .DS_Store 31 | *DS_Store 32 | 33 | generated_videos 34 | output 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models 2 | 3 |
4 | 5 | 6 | Weichen Fan1, 7 | 8 | Amber Yijia Zheng2, 9 | 10 | Raymond A. Yeh2, 11 | 12 | Ziwei Liu1✉ 13 | 14 |
15 |
16 | S-Lab, Nanyang Technological University1      Department of Computer Science, Purdue University 2 17 |
Corresponding Author.
18 |
19 | 20 |

21 | 22 |
23 | Paper | 24 | Project Page | 25 | Demo | 26 | Demo for Ghibli style 27 |
28 | 29 | --- 30 | 31 | 36 | 37 | 🔥 [Huggingface demo for Ghibli style generation](https://huggingface.co/spaces/jamesliu1217/EasyControl_Ghibli) supported by [EasyControl](https://github.com/Xiaojiu-z/EasyControl). 38 | 39 | ⚡️ [Huggingface demo](https://huggingface.co/spaces/weepiess2383/CFG-Zero-Star) now supports text-to-image generation with SD3 and SD3.5. 40 | 41 | 💰 Bonus tip: You can even use pure zero-init (zeroing out the prediction of the first step) as a quick test—if it improves your flow-matching model a lot, it may indicate that the model has not converged yet. 42 | 43 | **🧪 Usage Tip: Use both optimized-scale and zero-init together. Adjust the zero-init steps based on total inference steps — 4% is generally a good starting point.** 44 | 45 | ## 🔥 Update and News 46 | - [2025.4.14] [HiDream](https://github.com/HiDream-ai/HiDream-I1) is suppported now! 47 | - [2025.4.14] 🔥 Supported by [sdnext](https://github.com/vladmandic/sdnext/blob/dev/CHANGELOG.md#update-for-2025-04-13) now! 48 | - [2025.4.6] 📙 Supported by [EasyControl](https://github.com/Xiaojiu-z/EasyControl) now! 49 | - [2025.4.4] 🤗 Supported by [Diffusers](https://github.com/huggingface/diffusers) now! 50 | - [2025.4.2] 🙌 Mentioned by [Wan2.1](https://github.com/Wan-Video/Wan2.1)! 51 | - [2025.4.1] Qwen2.5-Omni is suppported now! 52 | - [2025.3.30] Hunyuan is officially supported now! 53 | - [2025.3.29] Flux is officially supported now! 54 | - [2025.3.29] Both Wan2.1-14B I2V & T2V are now supported! 55 | - [2025.3.28] Wan2.1-14B T2V is now supported! (Note: The default setting has been updated to zero out 4% of total steps for this scenario.) 56 | - [2025.3.27] 📙 Supported by [ComfyUI-KJNodes](https://github.com/kijai/ComfyUI-KJNodes) now! 57 | - [2025.03.26] 📙 Supported by [Wan2.1GP](https://github.com/deepbeepmeep/Wan2GP) now! 58 | - [2025.03.25] Paper|Demo|Code have been officially released. 59 | 60 | ## Community Works 61 | If you find that CFG-Zero* helps improve your model, we'd love to hear about it! 62 | 63 | Thanks to the following models for supporting our method! 64 | - [blissful-tuner](https://github.com/Sarania/blissful-tuner/tree/main) 65 | - [SD.Next](https://github.com/vladmandic/sdnext) 66 | - [EasyControl](https://huggingface.co/spaces/jamesliu1217/EasyControl_Ghibli) 67 | - [ComfyUI-KJNodes](https://github.com/kijai/ComfyUI-KJNodes) 68 | - [Wan2.1GP](https://github.com/deepbeepmeep/Wan2GP) 69 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) **Noted that ComfyUI's implementation is different from ours.** 70 | 71 | ## 📑 Todo List 72 | - Wan2.1 73 | - [x] 14B Text-to-Video 74 | - [x] 14B Image-to-Video 75 | - Hunyuan 76 | - [x] Text-to-Video 77 | - SD3/SD3.5 78 | - [x] Text-to-Image 79 | - Flux 80 | - [x] Text-to-Image (Guidance-distilled version) 81 | - [x] Lora 82 | - CogView4 83 | - [x] Text-to-Image 84 | - Qwen2.5-Omni 85 | - [x] Audio generation 86 | - EasyControl 87 | - [x] Ghibli-Style Portrait Generation 88 | - HiDream 89 | - [x] text2image pipeline 90 | 91 | ## :astonished: Gallery 92 | 93 |
94 | ▶ Click to expand comparison Images 95 | 96 | 97 | 98 | 99 | 100 |
101 |
102 | 103 | 104 |
105 | ▶ Click to expand comparison GIFs 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 |
121 | 122 |
123 | 124 | 125 | 126 | ## Installation 127 | 128 | ### 1. Create a conda environment and install PyTorch 129 | 130 | Note: You may want to adjust the CUDA version [according to your driver version](https://docs.nvidia.com/deploy/cuda-compatibility/#default-to-minor-version). 131 | 132 | ```bash 133 | conda create -n CFG_Zero_Star python=3.10 134 | conda activate CFG_Zero_Star 135 | 136 | #Install pytorch according to your cuda version 137 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia 138 | 139 | ``` 140 | 141 | ### 2. Install dependencies 142 | 143 | ```bash 144 | pip install -r requirements.txt 145 | apt install -y ffmpeg 146 | ``` 147 | 148 | ## Local demo 149 | Host a demo on your local machine. 150 | ~~~bash 151 | python demo.py 152 | ~~~ 153 | 154 | ## Inference 155 | ### 1. Wan2.1 156 | 157 | #### a. Text-to-Video Generation 158 | Simply run the following command to generate videos in the output folder. Noted that the current version is using Wan-AI/Wan2.1-T2V-14B-Diffusers with the default setting. 159 | 160 | **Noted that zero-steps for wan2.1 T2V is set to 1 (first 2 steps, 4% of the total steps).** 161 | 162 | ~~~bash 163 | python models/wan/T2V_infer.py 164 | ~~~ 165 | 166 | All results shown below were generated using this script on an H100 80G GPU. 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 231 | 232 | 233 |
CFGCFG-Zero*
180 | Prompt: "A cat walks on the grass, realistic"
181 | Seed: 1322140014 182 |
CFGCFG-Zero*
196 | Prompt: "A dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water."
197 | Seed: 1306980124 198 |
CFGCFG-Zero*
212 | Prompt: "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."
213 | Seed: 1270611998 214 |
CFGCFG-Zero*
228 | Prompt: "The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds."
229 | Seed: 158241056 230 |
234 | 235 | #### b. Image-to-Video Generation 236 | Follow Wan2.1 to clone the repo and finish the installation, then copy 'models/wan/image2video_cfg_zero_star.py' in this repo to the Wan2.1 repo (Wan2.1/wan). Modify 'Wan2.1/wan/__init__.py': replace 'from .image2video import WanI2V' with 'from .image2video_cfg_zero_star import WanI2V'. 237 | 238 | **Note: For I2V, zero_init_steps is set to 0 [2.5% zero out] by default to ensure stable generation. If you prefer more creative results, you can set it to 1 [5% zero out], though this may lead to instability in certain cases.** 239 | 240 | ~~~bash 241 | python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." --base_seed 0 242 | ~~~ 243 | 244 | All results shown below were generated using this script on an H100 80G GPU. 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 280 | 281 |
Input ImageCFGCFG-Zero*
259 | Prompt: "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
260 | Seed: 0 261 |
Input ImageCFGCFG-Zero*
276 | Prompt: "Summer beach vacation style. A white cat wearing sunglasses lounges confidently on a surfboard, gently bobbing with the ocean waves under the bright sun. The cat exudes a cool, laid-back attitude. After a moment, it casually reaches into a small bag, pulls out a cigarette, and lights it. A thin stream of smoke drifts into the salty breeze as the cat takes a slow drag, maintaining its nonchalant pose beneath the clear blue sky." 277 |
278 | Seed: 3549111921 279 |
282 | 283 | ### 2. Flux 284 | #### a. Text-to-Image Generation 285 | We used **black-forest-labs/FLUX.1-dev** for the following experiment. 286 | Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. 287 | All images below were generated with the same seed on an H100 80G GPU. 288 | 289 | ~~~bash 290 | python models/flux/Guidance_distilled.py 291 | ~~~ 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 307 | 308 |
CFGCFG-Zero*
304 | Prompt: "a tiny astronaut hatching from an egg on the moon."
305 | Seed: 105297965 306 |
309 | 310 | #### b. Lora 311 | We used **black-forest-labs/FLUX.1-dev** with different loras for the following experiment. 312 | Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. 313 | All images below were generated with the same seed on an H100 80G GPU. 314 | 315 | ~~~bash 316 | python models/flux/infer_lora.py 317 | ~~~ 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 334 | 335 |
CFGCFG-Zero*
330 | Prompt: "Death Stranding Style. A solitary figure in a futuristic suit with a large, intricate backpack stands on a grassy cliff, gazing at a vast, mist-covered landscape composed of rugged mountains and low valleys beneath a rainy, overcast sky. Raindrops streak softly through the air, and puddles glisten on the uneven ground. Above the horizon, an ethereal, upside-down rainbow arcs downward through the gray clouds — its surreal, inverted shape adding an otherworldly touch to the haunting scene. A soft glow from distant structures illuminates the depth of the valley, enhancing the mysterious atmosphere. The contrast between the rain-soaked greenery and jagged rocky terrain adds texture and detail, amplifying the sense of solitude, exploration, and the anticipation of unknown adventures beyond the horizon."
331 | Seed: 875187112
332 | Lora: https://civitai.com/models/46080/death-stranding 333 |
336 | 337 | ### 3. Hunyuan 338 | We used **hunyuanvideo-community/HunyuanVideo** for the following experiment. 339 | Since this model is guidance-distilled, we applied only zero-init from our CFG-Zero* method. 340 | All images below were generated with the same seed on an H100 80G GPU. 341 | 342 | ~~~bash 343 | python models/hunyuan/t2v.py 344 | ~~~ 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 360 | 361 |
CFGCFG-Zero*
357 | Prompt: "In an ornate, historical hall, a massive tidal wave peaks and begins to crash. A man is surfing, cinematic film shot in 35mm. High quality, high defination."
358 | Seed: 376559893 359 |
362 | 363 | ### 4. SD3 364 | We used **stabilityai/stable-diffusion-3.5-large** for the following experiment. 365 | All images below were generated with the same seed on an H100 80G GPU. 366 | 367 | ~~~bash 368 | python models/sd/infer.py 369 | ~~~ 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 385 | 386 |
CFGCFG-Zero*
382 | Prompt: "A capybara holding a sign that reads Hello World"
383 | Seed: 811677707 384 |
387 | 388 | ### 5. Qwen2.5-Omni 389 | Install dependencies for Qwen2.5-Omni 390 | ~~~bash 391 | pip install git+https://github.com/huggingface/transformers@f742a644ca32e65758c3adb36225aef1731bd2a8 392 | pip install qwen-omni-utils[decord] 393 | pip install flash-attn --no-build-isolation 394 | ~~~ 395 | 396 | Easy inference with CFG-Zero* 397 | ~~~bash 398 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 python models/Qwen2.5/infer.py 399 | ~~~ 400 | 401 | The following audios are generated by the script: 402 | 403 | | CFG | CFG-Zero⋆ | 404 | |-----|-----------| 405 | | 🔊 [Click to download](assets/Qwen2.5/output-origin.mp3) | 🔊 [Click to download](assets/Qwen2.5/output-ours.mp3) | 406 | 407 | 408 | ### 6. EasyControl 409 | Ghibli-Style Portrait Generation, the zero-init steps is set to 1 for default, feel free to try other values. 410 | 411 | ~~~bash 412 | python models/easycontrol/infer.py 413 | ~~~ 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 |
Source ImageCFGCFG-Zero*
427 | 428 | ### 7. Cogview4 429 | 430 | ~~~bash 431 | python models/Cogview4/infer.py 432 | ~~~ 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 448 | 449 |
CFGCFG-Zero*
445 | Prompt: "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
446 | Seed: 42 447 |
450 | 451 | ### 8. HiDream 452 | Git clone [HiDream](https://github.com/HiDream-ai/HiDream-I1), and replace [hidream_pipeline](https://github.com/HiDream-ai/HiDream-I1/blob/main/hi_diffusers/pipelines/hidream_image/pipeline_hidream_image.py) with ours 'models/HiDream/pipeline.py' 453 | 454 | Then modify 'zero_steps' according the total inference steps. 455 | 456 | ~~~bash 457 | cd HiDream-I1 458 | python ./inference.py --model_type full 459 | ~~~ 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 475 | 476 |
CFGCFG-Zero*
472 | Prompt: "A cat holding a sign that says \"Hi-Dreams.ai\"."
473 | Seed: 0 474 |
477 | 478 | ## Easy Implementation 479 | You can use this script to easily apply our method to any flow-matching-based model. 480 | ~~~python 481 | def optimized_scale(positive_flat, negative_flat): 482 | # Calculate dot production 483 | dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 484 | 485 | # Squared norm of uncondition 486 | squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 487 | 488 | # st_star = v_condˆT * v_uncond / ||v_uncond||ˆ2 489 | st_star = dot_product / squared_norm 490 | return st_star 491 | 492 | # Get the velocity prediction 493 | noise_pred_uncond, noise_pred_text = model(...) 494 | positive = noise_pred_text.view(Batchsize,-1) 495 | negative = noise_pred_uncond.view(Batchsize,-1) 496 | 497 | # Calculate the optimized scale 498 | st_star = optimized_scale(positive,negative) 499 | 500 | # Reshape for broadcasting 501 | st_star = st_star.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) 502 | 503 | # Perform CFG-Zero* sampling 504 | if sample_step == 0: 505 | # Perform zero init 506 | noise_pred = noise_pred_uncond * 0. 507 | else: 508 | # Perform optimized scale 509 | noise_pred = noise_pred_uncond * st_star + \ 510 | guidance_scale * (noise_pred_text - noise_pred_uncond * st_star) 511 | ~~~ 512 | 513 | ## BibTex 514 | ``` 515 | @misc{fan2025cfgzerostar, 516 | title={CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models}, 517 | author={Weichen Fan and Amber Yijia Zheng and Raymond A. Yeh and Ziwei Liu}, 518 | year={2025}, 519 | eprint={2503.18886}, 520 | archivePrefix={arXiv}, 521 | primaryClass={cs.CV}, 522 | url={https://arxiv.org/abs/2503.18886}, 523 | } 524 | ``` 525 | 526 | ## 🔑 License 527 | 528 | This code is licensed under Apache-2.0. The framework is fully open for academic research and also allows any commercial usage. 529 | 530 | 531 | ## Disclaimer 532 | 533 | We disclaim responsibility for user-generated content. The model was not trained to realistically represent people or events, so using it to generate such content is beyond the model's capabilities. It is prohibited for pornographic, violent and bloody content generation, and to generate content that is demeaning or harmful to people or their environment, culture, religion, etc. Users are solely liable for their actions. The project contributors are not legally affiliated with, nor accountable for users' behaviors. Use the generative model responsibly, adhering to ethical and legal standards. 534 | -------------------------------------------------------------------------------- /assets/16_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/16_comparison.gif -------------------------------------------------------------------------------- /assets/1_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/1_comparison.gif -------------------------------------------------------------------------------- /assets/3_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/3_comparison.gif -------------------------------------------------------------------------------- /assets/4_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/4_comparison.gif -------------------------------------------------------------------------------- /assets/7_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/7_comparison.gif -------------------------------------------------------------------------------- /assets/8_comparison.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/8_comparison.gif -------------------------------------------------------------------------------- /assets/Cogview4/cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/Cogview4/cfg.png -------------------------------------------------------------------------------- /assets/Cogview4/ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/Cogview4/ours.png -------------------------------------------------------------------------------- /assets/HiDream/cat_cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/HiDream/cat_cfg.png -------------------------------------------------------------------------------- /assets/HiDream/cat_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/HiDream/cat_ours.png -------------------------------------------------------------------------------- /assets/Qwen2.5/output-origin.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/Qwen2.5/output-origin.mp3 -------------------------------------------------------------------------------- /assets/Qwen2.5/output-ours.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/Qwen2.5/output-ours.mp3 -------------------------------------------------------------------------------- /assets/easycontrol/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/easycontrol/image.webp -------------------------------------------------------------------------------- /assets/easycontrol/image_CFG.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/easycontrol/image_CFG.webp -------------------------------------------------------------------------------- /assets/easycontrol/image_CFG_zero_star.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/easycontrol/image_CFG_zero_star.webp -------------------------------------------------------------------------------- /assets/flux/image_cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/flux/image_cfg.png -------------------------------------------------------------------------------- /assets/flux/image_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/flux/image_ours.png -------------------------------------------------------------------------------- /assets/flux/lora/image_cfg_ds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/flux/lora/image_cfg_ds.png -------------------------------------------------------------------------------- /assets/flux/lora/image_ours_ds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/flux/lora/image_ours_ds.png -------------------------------------------------------------------------------- /assets/hunyuan/376559893_output_cfg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/hunyuan/376559893_output_cfg.gif -------------------------------------------------------------------------------- /assets/hunyuan/376559893_output_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/hunyuan/376559893_output_ours.gif -------------------------------------------------------------------------------- /assets/repo_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/repo_teaser.jpg -------------------------------------------------------------------------------- /assets/sd3/output_cfg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/sd3/output_cfg.png -------------------------------------------------------------------------------- /assets/sd3/output_ours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/sd3/output_ours.png -------------------------------------------------------------------------------- /assets/wan2.1/1270611998_base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1270611998_base.gif -------------------------------------------------------------------------------- /assets/wan2.1/1270611998_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1270611998_ours.gif -------------------------------------------------------------------------------- /assets/wan2.1/1306980124_base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1306980124_base.gif -------------------------------------------------------------------------------- /assets/wan2.1/1306980124_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1306980124_ours.gif -------------------------------------------------------------------------------- /assets/wan2.1/1322140014_base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1322140014_base.gif -------------------------------------------------------------------------------- /assets/wan2.1/1322140014_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/1322140014_ours.gif -------------------------------------------------------------------------------- /assets/wan2.1/158241056_base.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/158241056_base.gif -------------------------------------------------------------------------------- /assets/wan2.1/158241056_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/158241056_ours.gif -------------------------------------------------------------------------------- /assets/wan2.1/I2V_CFG.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/I2V_CFG.gif -------------------------------------------------------------------------------- /assets/wan2.1/I2V_Ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/I2V_Ours.gif -------------------------------------------------------------------------------- /assets/wan2.1/i2v-14B_832_480_cfg_3549111921.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/i2v-14B_832_480_cfg_3549111921.gif -------------------------------------------------------------------------------- /assets/wan2.1/i2v-14B_832_480_ours_3549111921.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/i2v-14B_832_480_ours_3549111921.gif -------------------------------------------------------------------------------- /assets/wan2.1/i2v_input.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/assets/wan2.1/i2v_input.JPG -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from models.sd.sd3_pipeline import StableDiffusion3Pipeline 3 | import torch 4 | import random 5 | import numpy as np 6 | import os 7 | import gc 8 | import tempfile 9 | import imageio 10 | from diffusers import AutoencoderKLWan 11 | from models.wan.wan_pipeline import WanPipeline 12 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 13 | from PIL import Image 14 | from diffusers.utils import export_to_video 15 | import spaces 16 | 17 | os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "3000" 18 | 19 | def set_seed(seed): 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | 26 | # Model paths 27 | model_paths = { 28 | "sd3.5": "stabilityai/stable-diffusion-3.5-large", 29 | "sd3": "stabilityai/stable-diffusion-3-medium-diffusers", 30 | "wan-t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 31 | } 32 | 33 | # Global variable for current model 34 | current_model = None 35 | 36 | # Folder to save video outputs 37 | OUTPUT_DIR = "generated_videos" 38 | os.makedirs(OUTPUT_DIR, exist_ok=True) 39 | 40 | def load_model(model_name): 41 | global current_model 42 | if current_model is not None: 43 | del current_model # Delete the old model 44 | torch.cuda.empty_cache() # Free GPU memory 45 | gc.collect() # Force garbage collection 46 | 47 | if "wan-t2v" in model_name: 48 | vae = AutoencoderKLWan.from_pretrained(model_paths[model_name], subfolder="vae", torch_dtype=torch.float32) 49 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8.0) 50 | current_model = WanPipeline.from_pretrained(model_paths[model_name], vae=vae, torch_dtype=torch.bfloat16).to("cuda") 51 | current_model.scheduler = scheduler 52 | else: 53 | current_model = StableDiffusion3Pipeline.from_pretrained(model_paths[model_name], torch_dtype=torch.bfloat16).to("cuda") 54 | 55 | return current_model 56 | 57 | @spaces.GPU(duration=2000) 58 | def generate_content(prompt, model_name, guidance_scale=7.5, num_inference_steps=50, use_cfg_zero_star=True, use_zero_init=True, zero_steps=0, seed=None, compare_mode=False): 59 | model = load_model(model_name) 60 | if seed is None: 61 | seed = random.randint(0, 2**32 - 1) 62 | set_seed(seed) 63 | 64 | is_video_model = "wan-t2v" in model_name 65 | 66 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 67 | 68 | if is_video_model: 69 | if True: 70 | set_seed(seed) 71 | video1_frames = model( 72 | prompt=prompt, 73 | negative_prompt=negative_prompt, 74 | height=480, 75 | width=832, 76 | num_frames=81, 77 | guidance_scale=guidance_scale, 78 | use_cfg_zero_star=True, 79 | use_zero_init=True, 80 | zero_steps=zero_steps 81 | ).frames[0] 82 | video1_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG-Zero-Star.mp4") 83 | export_to_video(video1_frames, video1_path, fps=16) 84 | 85 | set_seed(seed) 86 | video2_frames = model( 87 | prompt=prompt, 88 | negative_prompt=negative_prompt, 89 | height=480, 90 | width=832, 91 | num_frames=81, 92 | guidance_scale=guidance_scale, 93 | use_cfg_zero_star=False, 94 | use_zero_init=False, 95 | zero_steps=0 96 | ).frames[0] 97 | video2_path = os.path.join(OUTPUT_DIR, f"{seed}_CFG.mp4") 98 | export_to_video(video2_frames, video2_path, fps=16) 99 | 100 | return None, None, video1_path, video2_path, seed 101 | 102 | if compare_mode: 103 | set_seed(seed) 104 | image1 = model( 105 | prompt, 106 | guidance_scale=guidance_scale, 107 | num_inference_steps=num_inference_steps, 108 | use_cfg_zero_star=True, 109 | use_zero_init=use_zero_init, 110 | zero_steps=zero_steps 111 | ).images[0] 112 | 113 | set_seed(seed) 114 | image2 = model( 115 | prompt, 116 | guidance_scale=guidance_scale, 117 | num_inference_steps=num_inference_steps, 118 | use_cfg_zero_star=False, 119 | use_zero_init=use_zero_init, 120 | zero_steps=zero_steps 121 | ).images[0] 122 | 123 | return image1, image2, None, None, seed 124 | else: 125 | image = model( 126 | prompt, 127 | guidance_scale=guidance_scale, 128 | num_inference_steps=num_inference_steps, 129 | use_cfg_zero_star=use_cfg_zero_star, 130 | use_zero_init=use_zero_init, 131 | zero_steps=zero_steps 132 | ).images[0] 133 | if use_cfg_zero_star: 134 | return image, None, None, None, seed 135 | else: 136 | return None, image, None, None, seed 137 | 138 | # Gradio UI 139 | demo = gr.Interface( 140 | fn=generate_content, 141 | inputs=[ 142 | gr.Textbox(value="A spooky haunted mansion on a hill silhouetted by a full moon.", label="Enter your prompt"), 143 | gr.Dropdown(choices=list(model_paths.keys()), label="Choose Model"), 144 | gr.Slider(1, 20, value=4.0, step=0.5, label="Guidance Scale"), 145 | gr.Slider(10, 100, value=28, step=5, label="Inference Steps"), 146 | gr.Checkbox(value=True, label="Use CFG Zero Star"), 147 | gr.Checkbox(value=True, label="Use Zero Init"), 148 | gr.Slider(0, 20, value=0, step=1, label="Zero out steps"), 149 | gr.Number(value=42, label="Seed (Leave blank for random)"), 150 | gr.Checkbox(value=True, label="Compare Mode") 151 | ], 152 | outputs=[ 153 | gr.Image(type="pil", label="CFG-Zero* Image"), 154 | gr.Image(type="pil", label="CFG Image"), 155 | gr.Video(label="CFG-Zero* Video"), 156 | gr.Video(label="CFG Video"), 157 | gr.Textbox(label="Used Seed") 158 | ], 159 | title="CFG-Zero*: Improved Classifier-Free Guidance for Flow Matching Models", 160 | ) 161 | 162 | demo.launch(server_name="127.0.0.1", server_port=7860) 163 | 164 | -------------------------------------------------------------------------------- /models/Cogview4/infer.py: -------------------------------------------------------------------------------- 1 | # from diffusers import CogView4Pipeline 2 | from pipeline import CogView4Pipeline 3 | import torch 4 | 5 | import numpy as np 6 | import random 7 | import os 8 | 9 | def set_seed(seed): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | 17 | 18 | pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) 19 | 20 | # Open it for reduce GPU memory usage 21 | pipe.enable_model_cpu_offload() 22 | pipe.vae.enable_slicing() 23 | pipe.vae.enable_tiling() 24 | 25 | prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background." 26 | 27 | set_seed(42) 28 | image = pipe( 29 | prompt=prompt, 30 | guidance_scale=3.5, 31 | num_images_per_prompt=1, 32 | num_inference_steps=50, 33 | width=1024, 34 | height=1024, 35 | use_cfg_zero_star=True, 36 | use_zero_init=True, 37 | zero_steps=1 38 | ).images[0] 39 | 40 | image.save("cogview4_ours.png") 41 | 42 | set_seed(42) 43 | image = pipe( 44 | prompt=prompt, 45 | guidance_scale=3.5, 46 | num_images_per_prompt=1, 47 | num_inference_steps=50, 48 | width=1024, 49 | height=1024, 50 | use_cfg_zero_star=False, 51 | use_zero_init=False, 52 | ).images[0] 53 | 54 | image.save("cogview4_cfg.png") 55 | 56 | -------------------------------------------------------------------------------- /models/HiDream/pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, Dict, List, Optional, Union 3 | import math 4 | import einops 5 | import torch 6 | from transformers import ( 7 | CLIPTextModelWithProjection, 8 | CLIPTokenizer, 9 | T5EncoderModel, 10 | T5Tokenizer, 11 | LlamaForCausalLM, 12 | PreTrainedTokenizerFast 13 | ) 14 | 15 | from diffusers.image_processor import VaeImageProcessor 16 | from diffusers.loaders import FromSingleFileMixin 17 | from diffusers.models.autoencoders import AutoencoderKL 18 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 19 | from diffusers.utils import ( 20 | USE_PEFT_BACKEND, 21 | is_torch_xla_available, 22 | logging, 23 | ) 24 | from diffusers.utils.torch_utils import randn_tensor 25 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 26 | from .pipeline_output import HiDreamImagePipelineOutput 27 | from ...models.transformers.transformer_hidream_image import HiDreamImageTransformer2DModel 28 | from ...schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler 29 | 30 | @torch.cuda.amp.autocast(dtype=torch.float32) 31 | def optimized_scale(positive_flat, negative_flat): 32 | 33 | # Calculate dot production 34 | dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 35 | 36 | # Squared norm of uncondition 37 | squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 38 | 39 | # st_star = v_cond^T * v_uncond / ||v_uncond||^2 40 | st_star = dot_product / squared_norm 41 | 42 | return st_star 43 | 44 | if is_torch_xla_available(): 45 | import torch_xla.core.xla_model as xm 46 | 47 | XLA_AVAILABLE = True 48 | else: 49 | XLA_AVAILABLE = False 50 | 51 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 52 | 53 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift 54 | def calculate_shift( 55 | image_seq_len, 56 | base_seq_len: int = 256, 57 | max_seq_len: int = 4096, 58 | base_shift: float = 0.5, 59 | max_shift: float = 1.15, 60 | ): 61 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) 62 | b = base_shift - m * base_seq_len 63 | mu = image_seq_len * m + b 64 | return mu 65 | 66 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 67 | def retrieve_timesteps( 68 | scheduler, 69 | num_inference_steps: Optional[int] = None, 70 | device: Optional[Union[str, torch.device]] = None, 71 | timesteps: Optional[List[int]] = None, 72 | sigmas: Optional[List[float]] = None, 73 | **kwargs, 74 | ): 75 | r""" 76 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 77 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 78 | 79 | Args: 80 | scheduler (`SchedulerMixin`): 81 | The scheduler to get timesteps from. 82 | num_inference_steps (`int`): 83 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` 84 | must be `None`. 85 | device (`str` or `torch.device`, *optional*): 86 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 87 | timesteps (`List[int]`, *optional*): 88 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, 89 | `num_inference_steps` and `sigmas` must be `None`. 90 | sigmas (`List[float]`, *optional*): 91 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, 92 | `num_inference_steps` and `timesteps` must be `None`. 93 | 94 | Returns: 95 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the 96 | second element is the number of inference steps. 97 | """ 98 | if timesteps is not None and sigmas is not None: 99 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") 100 | if timesteps is not None: 101 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 102 | if not accepts_timesteps: 103 | raise ValueError( 104 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 105 | f" timestep schedules. Please check whether you are using the correct scheduler." 106 | ) 107 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) 108 | timesteps = scheduler.timesteps 109 | num_inference_steps = len(timesteps) 110 | elif sigmas is not None: 111 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) 112 | if not accept_sigmas: 113 | raise ValueError( 114 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" 115 | f" sigmas schedules. Please check whether you are using the correct scheduler." 116 | ) 117 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) 118 | timesteps = scheduler.timesteps 119 | num_inference_steps = len(timesteps) 120 | else: 121 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) 122 | timesteps = scheduler.timesteps 123 | return timesteps, num_inference_steps 124 | 125 | class HiDreamImagePipeline(DiffusionPipeline, FromSingleFileMixin): 126 | model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->image_encoder->transformer->vae" 127 | _optional_components = ["image_encoder", "feature_extractor"] 128 | _callback_tensor_inputs = ["latents", "prompt_embeds"] 129 | 130 | def __init__( 131 | self, 132 | scheduler: FlowMatchEulerDiscreteScheduler, 133 | vae: AutoencoderKL, 134 | text_encoder: CLIPTextModelWithProjection, 135 | tokenizer: CLIPTokenizer, 136 | text_encoder_2: CLIPTextModelWithProjection, 137 | tokenizer_2: CLIPTokenizer, 138 | text_encoder_3: T5EncoderModel, 139 | tokenizer_3: T5Tokenizer, 140 | text_encoder_4: LlamaForCausalLM, 141 | tokenizer_4: PreTrainedTokenizerFast, 142 | ): 143 | super().__init__() 144 | 145 | self.register_modules( 146 | vae=vae, 147 | text_encoder=text_encoder, 148 | text_encoder_2=text_encoder_2, 149 | text_encoder_3=text_encoder_3, 150 | text_encoder_4=text_encoder_4, 151 | tokenizer=tokenizer, 152 | tokenizer_2=tokenizer_2, 153 | tokenizer_3=tokenizer_3, 154 | tokenizer_4=tokenizer_4, 155 | scheduler=scheduler, 156 | ) 157 | self.vae_scale_factor = ( 158 | 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 159 | ) 160 | # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible 161 | # by the patch size. So the vae scale factor is multiplied by the patch size to account for this 162 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) 163 | self.default_sample_size = 128 164 | self.tokenizer_4.pad_token = self.tokenizer_4.eos_token 165 | 166 | def _get_t5_prompt_embeds( 167 | self, 168 | prompt: Union[str, List[str]] = None, 169 | num_images_per_prompt: int = 1, 170 | max_sequence_length: int = 128, 171 | device: Optional[torch.device] = None, 172 | dtype: Optional[torch.dtype] = None, 173 | ): 174 | device = device or self._execution_device 175 | dtype = dtype or self.text_encoder_3.dtype 176 | 177 | prompt = [prompt] if isinstance(prompt, str) else prompt 178 | batch_size = len(prompt) 179 | 180 | text_inputs = self.tokenizer_3( 181 | prompt, 182 | padding="max_length", 183 | max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), 184 | truncation=True, 185 | add_special_tokens=True, 186 | return_tensors="pt", 187 | ) 188 | text_input_ids = text_inputs.input_ids 189 | attention_mask = text_inputs.attention_mask 190 | untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids 191 | 192 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 193 | removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]) 194 | logger.warning( 195 | "The following part of your input was truncated because `max_sequence_length` is set to " 196 | f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" 197 | ) 198 | 199 | prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] 200 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 201 | _, seq_len, _ = prompt_embeds.shape 202 | 203 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 204 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 205 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 206 | return prompt_embeds 207 | 208 | def _get_clip_prompt_embeds( 209 | self, 210 | tokenizer, 211 | text_encoder, 212 | prompt: Union[str, List[str]], 213 | num_images_per_prompt: int = 1, 214 | max_sequence_length: int = 128, 215 | device: Optional[torch.device] = None, 216 | dtype: Optional[torch.dtype] = None, 217 | ): 218 | device = device or self._execution_device 219 | dtype = dtype or text_encoder.dtype 220 | 221 | prompt = [prompt] if isinstance(prompt, str) else prompt 222 | batch_size = len(prompt) 223 | 224 | text_inputs = tokenizer( 225 | prompt, 226 | padding="max_length", 227 | max_length=min(max_sequence_length, 218), 228 | truncation=True, 229 | return_tensors="pt", 230 | ) 231 | 232 | text_input_ids = text_inputs.input_ids 233 | untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 234 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 235 | removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) 236 | logger.warning( 237 | "The following part of your input was truncated because CLIP can only handle sequences up to" 238 | f" {218} tokens: {removed_text}" 239 | ) 240 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 241 | 242 | # Use pooled output of CLIPTextModel 243 | prompt_embeds = prompt_embeds[0] 244 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 245 | 246 | # duplicate text embeddings for each generation per prompt, using mps friendly method 247 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) 248 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) 249 | 250 | return prompt_embeds 251 | 252 | def _get_llama3_prompt_embeds( 253 | self, 254 | prompt: Union[str, List[str]] = None, 255 | num_images_per_prompt: int = 1, 256 | max_sequence_length: int = 128, 257 | device: Optional[torch.device] = None, 258 | dtype: Optional[torch.dtype] = None, 259 | ): 260 | device = device or self._execution_device 261 | dtype = dtype or self.text_encoder_4.dtype 262 | 263 | prompt = [prompt] if isinstance(prompt, str) else prompt 264 | batch_size = len(prompt) 265 | 266 | text_inputs = self.tokenizer_4( 267 | prompt, 268 | padding="max_length", 269 | max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), 270 | truncation=True, 271 | add_special_tokens=True, 272 | return_tensors="pt", 273 | ) 274 | text_input_ids = text_inputs.input_ids 275 | attention_mask = text_inputs.attention_mask 276 | untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids 277 | 278 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 279 | removed_text = self.tokenizer_4.batch_decode(untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]) 280 | logger.warning( 281 | "The following part of your input was truncated because `max_sequence_length` is set to " 282 | f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" 283 | ) 284 | 285 | outputs = self.text_encoder_4( 286 | text_input_ids.to(device), 287 | attention_mask=attention_mask.to(device), 288 | output_hidden_states=True, 289 | output_attentions=True 290 | ) 291 | 292 | prompt_embeds = outputs.hidden_states[1:] 293 | prompt_embeds = torch.stack(prompt_embeds, dim=0) 294 | _, _, seq_len, dim = prompt_embeds.shape 295 | 296 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 297 | prompt_embeds = prompt_embeds.repeat(1, 1, num_images_per_prompt, 1) 298 | prompt_embeds = prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim) 299 | return prompt_embeds 300 | 301 | def encode_prompt( 302 | self, 303 | prompt: Union[str, List[str]], 304 | prompt_2: Union[str, List[str]], 305 | prompt_3: Union[str, List[str]], 306 | prompt_4: Union[str, List[str]], 307 | device: Optional[torch.device] = None, 308 | dtype: Optional[torch.dtype] = None, 309 | num_images_per_prompt: int = 1, 310 | do_classifier_free_guidance: bool = True, 311 | negative_prompt: Optional[Union[str, List[str]]] = None, 312 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 313 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 314 | negative_prompt_4: Optional[Union[str, List[str]]] = None, 315 | prompt_embeds: Optional[List[torch.FloatTensor]] = None, 316 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 317 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 318 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 319 | max_sequence_length: int = 128, 320 | lora_scale: Optional[float] = None, 321 | ): 322 | prompt = [prompt] if isinstance(prompt, str) else prompt 323 | if prompt is not None: 324 | batch_size = len(prompt) 325 | else: 326 | batch_size = prompt_embeds.shape[0] 327 | 328 | prompt_embeds, pooled_prompt_embeds = self._encode_prompt( 329 | prompt = prompt, 330 | prompt_2 = prompt_2, 331 | prompt_3 = prompt_3, 332 | prompt_4 = prompt_4, 333 | device = device, 334 | dtype = dtype, 335 | num_images_per_prompt = num_images_per_prompt, 336 | prompt_embeds = prompt_embeds, 337 | pooled_prompt_embeds = pooled_prompt_embeds, 338 | max_sequence_length = max_sequence_length, 339 | ) 340 | 341 | if do_classifier_free_guidance and negative_prompt_embeds is None: 342 | negative_prompt = negative_prompt or "" 343 | negative_prompt_2 = negative_prompt_2 or negative_prompt 344 | negative_prompt_3 = negative_prompt_3 or negative_prompt 345 | negative_prompt_4 = negative_prompt_4 or negative_prompt 346 | 347 | # normalize str to list 348 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 349 | negative_prompt_2 = ( 350 | batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 351 | ) 352 | negative_prompt_3 = ( 353 | batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 354 | ) 355 | negative_prompt_4 = ( 356 | batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 357 | ) 358 | 359 | if prompt is not None and type(prompt) is not type(negative_prompt): 360 | raise TypeError( 361 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 362 | f" {type(prompt)}." 363 | ) 364 | elif batch_size != len(negative_prompt): 365 | raise ValueError( 366 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 367 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 368 | " the batch size of `prompt`." 369 | ) 370 | 371 | negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt( 372 | prompt = negative_prompt, 373 | prompt_2 = negative_prompt_2, 374 | prompt_3 = negative_prompt_3, 375 | prompt_4 = negative_prompt_4, 376 | device = device, 377 | dtype = dtype, 378 | num_images_per_prompt = num_images_per_prompt, 379 | prompt_embeds = negative_prompt_embeds, 380 | pooled_prompt_embeds = negative_pooled_prompt_embeds, 381 | max_sequence_length = max_sequence_length, 382 | ) 383 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 384 | 385 | def _encode_prompt( 386 | self, 387 | prompt: Union[str, List[str]], 388 | prompt_2: Union[str, List[str]], 389 | prompt_3: Union[str, List[str]], 390 | prompt_4: Union[str, List[str]], 391 | device: Optional[torch.device] = None, 392 | dtype: Optional[torch.dtype] = None, 393 | num_images_per_prompt: int = 1, 394 | prompt_embeds: Optional[List[torch.FloatTensor]] = None, 395 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 396 | max_sequence_length: int = 128, 397 | ): 398 | device = device or self._execution_device 399 | 400 | if prompt_embeds is None: 401 | prompt_2 = prompt_2 or prompt 402 | prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 403 | 404 | prompt_3 = prompt_3 or prompt 405 | prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 406 | 407 | prompt_4 = prompt_4 or prompt 408 | prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 409 | 410 | pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( 411 | self.tokenizer, 412 | self.text_encoder, 413 | prompt = prompt, 414 | num_images_per_prompt = num_images_per_prompt, 415 | max_sequence_length = max_sequence_length, 416 | device = device, 417 | dtype = dtype, 418 | ) 419 | 420 | pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( 421 | self.tokenizer_2, 422 | self.text_encoder_2, 423 | prompt = prompt_2, 424 | num_images_per_prompt = num_images_per_prompt, 425 | max_sequence_length = max_sequence_length, 426 | device = device, 427 | dtype = dtype, 428 | ) 429 | 430 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) 431 | 432 | t5_prompt_embeds = self._get_t5_prompt_embeds( 433 | prompt = prompt_3, 434 | num_images_per_prompt = num_images_per_prompt, 435 | max_sequence_length = max_sequence_length, 436 | device = device, 437 | dtype = dtype 438 | ) 439 | llama3_prompt_embeds = self._get_llama3_prompt_embeds( 440 | prompt = prompt_4, 441 | num_images_per_prompt = num_images_per_prompt, 442 | max_sequence_length = max_sequence_length, 443 | device = device, 444 | dtype = dtype 445 | ) 446 | prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds] 447 | 448 | return prompt_embeds, pooled_prompt_embeds 449 | 450 | def enable_vae_slicing(self): 451 | r""" 452 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 453 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 454 | """ 455 | self.vae.enable_slicing() 456 | 457 | def disable_vae_slicing(self): 458 | r""" 459 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 460 | computing decoding in one step. 461 | """ 462 | self.vae.disable_slicing() 463 | 464 | def enable_vae_tiling(self): 465 | r""" 466 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 467 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 468 | processing larger images. 469 | """ 470 | self.vae.enable_tiling() 471 | 472 | def disable_vae_tiling(self): 473 | r""" 474 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 475 | computing decoding in one step. 476 | """ 477 | self.vae.disable_tiling() 478 | 479 | def prepare_latents( 480 | self, 481 | batch_size, 482 | num_channels_latents, 483 | height, 484 | width, 485 | dtype, 486 | device, 487 | generator, 488 | latents=None, 489 | ): 490 | # VAE applies 8x compression on images but we must also account for packing which requires 491 | # latent height and width to be divisible by 2. 492 | height = 2 * (int(height) // (self.vae_scale_factor * 2)) 493 | width = 2 * (int(width) // (self.vae_scale_factor * 2)) 494 | 495 | shape = (batch_size, num_channels_latents, height, width) 496 | 497 | if latents is None: 498 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 499 | else: 500 | if latents.shape != shape: 501 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 502 | latents = latents.to(device) 503 | return latents 504 | 505 | @property 506 | def guidance_scale(self): 507 | return self._guidance_scale 508 | 509 | @property 510 | def do_classifier_free_guidance(self): 511 | return self._guidance_scale > 1 512 | 513 | @property 514 | def joint_attention_kwargs(self): 515 | return self._joint_attention_kwargs 516 | 517 | @property 518 | def num_timesteps(self): 519 | return self._num_timesteps 520 | 521 | @property 522 | def interrupt(self): 523 | return self._interrupt 524 | 525 | @torch.no_grad() 526 | def __call__( 527 | self, 528 | prompt: Union[str, List[str]] = None, 529 | prompt_2: Optional[Union[str, List[str]]] = None, 530 | prompt_3: Optional[Union[str, List[str]]] = None, 531 | prompt_4: Optional[Union[str, List[str]]] = None, 532 | height: Optional[int] = None, 533 | width: Optional[int] = None, 534 | num_inference_steps: int = 50, 535 | sigmas: Optional[List[float]] = None, 536 | guidance_scale: float = 5.0, 537 | negative_prompt: Optional[Union[str, List[str]]] = None, 538 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 539 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 540 | negative_prompt_4: Optional[Union[str, List[str]]] = None, 541 | num_images_per_prompt: Optional[int] = 1, 542 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 543 | latents: Optional[torch.FloatTensor] = None, 544 | prompt_embeds: Optional[torch.FloatTensor] = None, 545 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 546 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 547 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 548 | output_type: Optional[str] = "pil", 549 | return_dict: bool = True, 550 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 551 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 552 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 553 | max_sequence_length: int = 128, 554 | use_cfg_zero_star: Optional[bool] = True, 555 | use_zero_init: Optional[bool] = True, 556 | zero_steps: Optional[int] = 1, 557 | ): 558 | height = height or self.default_sample_size * self.vae_scale_factor 559 | width = width or self.default_sample_size * self.vae_scale_factor 560 | 561 | division = self.vae_scale_factor * 2 562 | S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 563 | scale = S_max / (width * height) 564 | scale = math.sqrt(scale) 565 | width, height = int(width * scale // division * division), int(height * scale // division * division) 566 | 567 | self._guidance_scale = guidance_scale 568 | self._joint_attention_kwargs = joint_attention_kwargs 569 | self._interrupt = False 570 | 571 | # 2. Define call parameters 572 | if prompt is not None and isinstance(prompt, str): 573 | batch_size = 1 574 | elif prompt is not None and isinstance(prompt, list): 575 | batch_size = len(prompt) 576 | else: 577 | batch_size = prompt_embeds.shape[0] 578 | 579 | device = self._execution_device 580 | 581 | lora_scale = ( 582 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 583 | ) 584 | ( 585 | prompt_embeds, 586 | negative_prompt_embeds, 587 | pooled_prompt_embeds, 588 | negative_pooled_prompt_embeds, 589 | ) = self.encode_prompt( 590 | prompt=prompt, 591 | prompt_2=prompt_2, 592 | prompt_3=prompt_3, 593 | prompt_4=prompt_4, 594 | negative_prompt=negative_prompt, 595 | negative_prompt_2=negative_prompt_2, 596 | negative_prompt_3=negative_prompt_3, 597 | negative_prompt_4=negative_prompt_4, 598 | do_classifier_free_guidance=self.do_classifier_free_guidance, 599 | prompt_embeds=prompt_embeds, 600 | negative_prompt_embeds=negative_prompt_embeds, 601 | pooled_prompt_embeds=pooled_prompt_embeds, 602 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 603 | device=device, 604 | num_images_per_prompt=num_images_per_prompt, 605 | max_sequence_length=max_sequence_length, 606 | lora_scale=lora_scale, 607 | ) 608 | 609 | if self.do_classifier_free_guidance: 610 | prompt_embeds_arr = [] 611 | for n, p in zip(negative_prompt_embeds, prompt_embeds): 612 | if len(n.shape) == 3: 613 | prompt_embeds_arr.append(torch.cat([n, p], dim=0)) 614 | else: 615 | prompt_embeds_arr.append(torch.cat([n, p], dim=1)) 616 | prompt_embeds = prompt_embeds_arr 617 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 618 | 619 | # 4. Prepare latent variables 620 | num_channels_latents = self.transformer.config.in_channels 621 | latents = self.prepare_latents( 622 | batch_size * num_images_per_prompt, 623 | num_channels_latents, 624 | height, 625 | width, 626 | pooled_prompt_embeds.dtype, 627 | device, 628 | generator, 629 | latents, 630 | ) 631 | 632 | if latents.shape[-2] != latents.shape[-1]: 633 | B, C, H, W = latents.shape 634 | pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size 635 | 636 | img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) 637 | img_ids = torch.zeros(pH, pW, 3) 638 | img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] 639 | img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] 640 | img_ids = img_ids.reshape(pH * pW, -1) 641 | img_ids_pad = torch.zeros(self.transformer.max_seq, 3) 642 | img_ids_pad[:pH*pW, :] = img_ids 643 | 644 | img_sizes = img_sizes.unsqueeze(0).to(latents.device) 645 | img_ids = img_ids_pad.unsqueeze(0).to(latents.device) 646 | if self.do_classifier_free_guidance: 647 | img_sizes = img_sizes.repeat(2 * B, 1) 648 | img_ids = img_ids.repeat(2 * B, 1, 1) 649 | else: 650 | img_sizes = img_ids = None 651 | 652 | # 5. Prepare timesteps 653 | mu = calculate_shift(self.transformer.max_seq) 654 | scheduler_kwargs = {"mu": mu} 655 | if isinstance(self.scheduler, FlowUniPCMultistepScheduler): 656 | self.scheduler.set_timesteps(num_inference_steps, device=device, shift=math.exp(mu)) 657 | timesteps = self.scheduler.timesteps 658 | else: 659 | timesteps, num_inference_steps = retrieve_timesteps( 660 | self.scheduler, 661 | num_inference_steps, 662 | device, 663 | sigmas=sigmas, 664 | **scheduler_kwargs, 665 | ) 666 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 667 | self._num_timesteps = len(timesteps) 668 | 669 | # 6. Denoising loop 670 | with self.progress_bar(total=num_inference_steps) as progress_bar: 671 | for i, t in enumerate(timesteps): 672 | if self.interrupt: 673 | continue 674 | 675 | # expand the latents if we are doing classifier free guidance 676 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 677 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 678 | timestep = t.expand(latent_model_input.shape[0]) 679 | 680 | if latent_model_input.shape[-2] != latent_model_input.shape[-1]: 681 | B, C, H, W = latent_model_input.shape 682 | patch_size = self.transformer.config.patch_size 683 | pH, pW = H // patch_size, W // patch_size 684 | out = torch.zeros( 685 | (B, C, self.transformer.max_seq, patch_size * patch_size), 686 | dtype=latent_model_input.dtype, 687 | device=latent_model_input.device 688 | ) 689 | latent_model_input = einops.rearrange(latent_model_input, 'B C (H p1) (W p2) -> B C (H W) (p1 p2)', p1=patch_size, p2=patch_size) 690 | out[:, :, 0:pH*pW] = latent_model_input 691 | latent_model_input = out 692 | 693 | noise_pred = self.transformer( 694 | hidden_states = latent_model_input, 695 | timesteps = timestep, 696 | encoder_hidden_states = prompt_embeds, 697 | pooled_embeds = pooled_prompt_embeds, 698 | img_sizes = img_sizes, 699 | img_ids = img_ids, 700 | return_dict = False, 701 | )[0] 702 | noise_pred = -noise_pred 703 | 704 | # perform guidance 705 | if self.do_classifier_free_guidance: 706 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 707 | 708 | 709 | if use_cfg_zero_star: 710 | positive_flat = noise_pred_text.view(batch_size, -1) 711 | negative_flat = noise_pred_uncond.view(batch_size, -1) 712 | 713 | alpha = optimized_scale(positive_flat,negative_flat) 714 | alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) 715 | alpha = alpha.to(positive_flat.dtype) 716 | 717 | if (i <= zero_steps) and use_zero_init: 718 | noise_pred = noise_pred_text*0. 719 | else: 720 | noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha) 721 | else: 722 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 723 | else: 724 | if (i <= zero_steps) and use_zero_init: 725 | noise_pred = noise_pred*0. 726 | 727 | # compute the previous noisy sample x_t -> x_t-1 728 | latents_dtype = latents.dtype 729 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 730 | 731 | if latents.dtype != latents_dtype: 732 | if torch.backends.mps.is_available(): 733 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 734 | latents = latents.to(latents_dtype) 735 | 736 | if callback_on_step_end is not None: 737 | callback_kwargs = {} 738 | for k in callback_on_step_end_tensor_inputs: 739 | callback_kwargs[k] = locals()[k] 740 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 741 | 742 | latents = callback_outputs.pop("latents", latents) 743 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 744 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 745 | 746 | # call the callback, if provided 747 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 748 | progress_bar.update() 749 | 750 | if XLA_AVAILABLE: 751 | xm.mark_step() 752 | 753 | if output_type == "latent": 754 | image = latents 755 | 756 | else: 757 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 758 | 759 | image = self.vae.decode(latents, return_dict=False)[0] 760 | image = self.image_processor.postprocess(image, output_type=output_type) 761 | 762 | # Offload all models 763 | self.maybe_free_model_hooks() 764 | 765 | if not return_dict: 766 | return (image,) 767 | 768 | return HiDreamImagePipelineOutput(images=image) 769 | -------------------------------------------------------------------------------- /models/Qwen2.5/infer.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | 3 | from transformers import Qwen2_5OmniProcessor 4 | from qw_model import Qwen2_5OmniModel 5 | from qwen_omni_utils import process_mm_info 6 | import torch 7 | import os 8 | 9 | # default: Load the model on the available device(s) 10 | # model = Qwen2_5OmniModel.from_pretrained("Qwen/Qwen2.5-Omni-7B", torch_dtype="auto", device_map="auto") 11 | 12 | # We recommend enabling flash_attention_2 for better acceleration and memory saving. 13 | model = Qwen2_5OmniModel.from_pretrained( 14 | "Qwen/Qwen2.5-Omni-7B", 15 | torch_dtype="auto", 16 | device_map="auto", 17 | attn_implementation="flash_attention_2", 18 | ) 19 | 20 | processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") 21 | 22 | conversation = [ 23 | { 24 | "role": "system", 25 | "content": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.", 26 | }, 27 | { 28 | "role": "user", 29 | "content": [ 30 | {"type": "video", "video": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-Omni/draw.mp4"}, 31 | ], 32 | }, 33 | ] 34 | 35 | # set use audio in video 36 | USE_AUDIO_IN_VIDEO = True 37 | 38 | # Preparation for inference 39 | text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) 40 | audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO) 41 | inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO) 42 | inputs = inputs.to(model.device).to(model.dtype) 43 | 44 | # Inference: Generation of the output text and audio 45 | text_ids, audio = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO) 46 | 47 | text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 48 | print(text) 49 | os.makedirs("output",exist_ok=True) 50 | sf.write( 51 | "output/output.wav", 52 | audio.reshape(-1).detach().cpu().numpy(), 53 | samplerate=24000, 54 | ) 55 | 56 | -------------------------------------------------------------------------------- /models/easycontrol/infer.py: -------------------------------------------------------------------------------- 1 | import spaces 2 | import os 3 | import json 4 | import time 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | import gradio as gr 9 | 10 | from safetensors.torch import save_file 11 | from src.pipeline import FluxPipeline 12 | from src.transformer_flux import FluxTransformer2DModel 13 | from src.lora_helper import set_single_lora, set_multi_lora, unset_lora 14 | 15 | from huggingface_hub import hf_hub_download 16 | hf_hub_download(repo_id="Xiaojiu-Z/EasyControl", filename="models/Ghibli.safetensors", local_dir="./checkpoints/models/") 17 | 18 | # Initialize the image processor 19 | base_path = "black-forest-labs/FLUX.1-dev" 20 | lora_base_path = "checkpoints/models/models" 21 | 22 | pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16) 23 | transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16) 24 | pipe.transformer = transformer 25 | pipe.to("cuda") 26 | 27 | def clear_cache(transformer): 28 | for name, attn_processor in transformer.attn_processors.items(): 29 | attn_processor.bank_kv.clear() 30 | 31 | # Define the Gradio interface 32 | @spaces.GPU() 33 | def dual_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, zero_steps): 34 | # Set the control type 35 | if control_type == "Ghibli": 36 | lora_path = os.path.join(lora_base_path, "Ghibli.safetensors") 37 | set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) 38 | 39 | # Process the image 40 | spatial_imgs = [spatial_img] if spatial_img else [] 41 | 42 | # Image with use_zero_init=True 43 | image_true = pipe( 44 | prompt, 45 | height=int(height), 46 | width=int(width), 47 | guidance_scale=3.5, 48 | num_inference_steps=25, 49 | max_sequence_length=512, 50 | generator=torch.Generator("cpu").manual_seed(seed), 51 | subject_images=[], 52 | spatial_images=spatial_imgs, 53 | cond_size=512, 54 | use_zero_init=True, 55 | zero_steps=int(zero_steps) 56 | ).images[0] 57 | clear_cache(pipe.transformer) 58 | 59 | # Image with use_zero_init=False 60 | image_false = pipe( 61 | prompt, 62 | height=int(height), 63 | width=int(width), 64 | guidance_scale=3.5, 65 | num_inference_steps=25, 66 | max_sequence_length=512, 67 | generator=torch.Generator("cpu").manual_seed(seed), 68 | subject_images=[], 69 | spatial_images=spatial_imgs, 70 | cond_size=512, 71 | use_zero_init=False 72 | ).images[0] 73 | clear_cache(pipe.transformer) 74 | 75 | return image_true, image_false 76 | 77 | # Define the Gradio interface components 78 | control_types = ["Ghibli"] 79 | 80 | # Create the Gradio Blocks interface 81 | with gr.Blocks() as demo: 82 | gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl") 83 | gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.") 84 | gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Due to hardware constraints, only low-resolution images can be generated. For high-resolution (1024+), please set up your own environment.)") 85 | 86 | gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: Ghibli Studio style, Charming hand-drawn anime-style illustration") 87 | gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))") 88 | 89 | with gr.Tab("Ghibli Condition Generation"): 90 | with gr.Row(): 91 | with gr.Column(): 92 | prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration") 93 | spatial_img = gr.Image(label="Ghibli Image", type="pil") 94 | height = gr.Slider(minimum=256, maximum=1024, step=64, label="Height", value=768) 95 | width = gr.Slider(minimum=256, maximum=1024, step=64, label="Width", value=768) 96 | seed = gr.Number(label="Seed", value=42) 97 | zero_steps = gr.Number(label="Zero Init Steps", value=1) 98 | control_type = gr.Dropdown(choices=control_types, label="Control Type") 99 | single_generate_btn = gr.Button("Generate Image") 100 | with gr.Column(): 101 | image_with_zero_init = gr.Image(label="Image CFG-Zero*") 102 | image_without_zero_init = gr.Image(label="Image CFG") 103 | 104 | # Link the buttons to the functions 105 | single_generate_btn.click( 106 | dual_condition_generate_image, 107 | inputs=[prompt, spatial_img, height, width, seed, control_type, zero_steps], 108 | outputs=[image_with_zero_init, image_without_zero_init] 109 | ) 110 | 111 | # Launch the Gradio app 112 | demo.queue().launch() 113 | 114 | -------------------------------------------------------------------------------- /models/easycontrol/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeichenFan/CFG-Zero-star/3162be1fba5dd0129ac8423ad6919d928f420a8d/models/easycontrol/src/__init__.py -------------------------------------------------------------------------------- /models/easycontrol/src/layers_cache.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import math 3 | from typing import Callable, List, Optional, Tuple, Union 4 | from einops import rearrange 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from diffusers.models.attention_processor import Attention 10 | 11 | class LoRALinearLayer(nn.Module): 12 | def __init__( 13 | self, 14 | in_features: int, 15 | out_features: int, 16 | rank: int = 4, 17 | network_alpha: Optional[float] = None, 18 | device: Optional[Union[torch.device, str]] = None, 19 | dtype: Optional[torch.dtype] = None, 20 | cond_width=512, 21 | cond_height=512, 22 | number=0, 23 | n_loras=1 24 | ): 25 | super().__init__() 26 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 27 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 28 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 29 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 30 | self.network_alpha = network_alpha 31 | self.rank = rank 32 | self.out_features = out_features 33 | self.in_features = in_features 34 | 35 | nn.init.normal_(self.down.weight, std=1 / rank) 36 | nn.init.zeros_(self.up.weight) 37 | 38 | self.cond_height = cond_height 39 | self.cond_width = cond_width 40 | self.number = number 41 | self.n_loras = n_loras 42 | 43 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 44 | orig_dtype = hidden_states.dtype 45 | dtype = self.down.weight.dtype 46 | 47 | #### 48 | batch_size = hidden_states.shape[0] 49 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 50 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 51 | shape = (batch_size, hidden_states.shape[1], 3072) 52 | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype) 53 | mask[:, :block_size+self.number*cond_size, :] = 0 54 | mask[:, block_size+(self.number+1)*cond_size:, :] = 0 55 | hidden_states = mask * hidden_states 56 | #### 57 | 58 | down_hidden_states = self.down(hidden_states.to(dtype)) 59 | up_hidden_states = self.up(down_hidden_states) 60 | 61 | if self.network_alpha is not None: 62 | up_hidden_states *= self.network_alpha / self.rank 63 | 64 | return up_hidden_states.to(orig_dtype) 65 | 66 | 67 | class MultiSingleStreamBlockLoraProcessor(nn.Module): 68 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 69 | super().__init__() 70 | # Initialize a list to store the LoRA layers 71 | self.n_loras = n_loras 72 | self.cond_width = cond_width 73 | self.cond_height = cond_height 74 | 75 | self.q_loras = nn.ModuleList([ 76 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 77 | for i in range(n_loras) 78 | ]) 79 | self.k_loras = nn.ModuleList([ 80 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 81 | for i in range(n_loras) 82 | ]) 83 | self.v_loras = nn.ModuleList([ 84 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 85 | for i in range(n_loras) 86 | ]) 87 | self.lora_weights = lora_weights 88 | self.bank_attn = None 89 | self.bank_kv = [] 90 | 91 | 92 | def __call__(self, 93 | attn: Attention, 94 | hidden_states: torch.FloatTensor, 95 | encoder_hidden_states: torch.FloatTensor = None, 96 | attention_mask: Optional[torch.FloatTensor] = None, 97 | image_rotary_emb: Optional[torch.Tensor] = None, 98 | use_cond = False 99 | ) -> torch.FloatTensor: 100 | 101 | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 102 | scaled_seq_len = hidden_states.shape[1] 103 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 104 | block_size = scaled_seq_len - cond_size * self.n_loras 105 | scaled_cond_size = cond_size 106 | scaled_block_size = block_size 107 | 108 | if len(self.bank_kv)== 0: 109 | cache = True 110 | else: 111 | cache = False 112 | 113 | if cache: 114 | query = attn.to_q(hidden_states) 115 | key = attn.to_k(hidden_states) 116 | value = attn.to_v(hidden_states) 117 | for i in range(self.n_loras): 118 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 119 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 120 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 121 | 122 | inner_dim = key.shape[-1] 123 | head_dim = inner_dim // attn.heads 124 | 125 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 126 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 127 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 128 | 129 | self.bank_kv.append(key[:, :, scaled_block_size:, :]) 130 | self.bank_kv.append(value[:, :, scaled_block_size:, :]) 131 | 132 | if attn.norm_q is not None: 133 | query = attn.norm_q(query) 134 | if attn.norm_k is not None: 135 | key = attn.norm_k(key) 136 | 137 | if image_rotary_emb is not None: 138 | from diffusers.models.embeddings import apply_rotary_emb 139 | query = apply_rotary_emb(query, image_rotary_emb) 140 | key = apply_rotary_emb(key, image_rotary_emb) 141 | 142 | num_cond_blocks = self.n_loras 143 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 144 | mask[ :scaled_block_size, :] = 0 # First block_size row 145 | for i in range(num_cond_blocks): 146 | start = i * scaled_cond_size + scaled_block_size 147 | end = (i + 1) * scaled_cond_size + scaled_block_size 148 | mask[start:end, start:end] = 0 # Diagonal blocks 149 | mask = mask * -1e20 150 | mask = mask.to(query.dtype) 151 | 152 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 153 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :] 154 | 155 | else: 156 | query = attn.to_q(hidden_states) 157 | key = attn.to_k(hidden_states) 158 | value = attn.to_v(hidden_states) 159 | 160 | inner_dim = query.shape[-1] 161 | head_dim = inner_dim // attn.heads 162 | 163 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 164 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 165 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 166 | 167 | key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2) 168 | value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2) 169 | 170 | if attn.norm_q is not None: 171 | query = attn.norm_q(query) 172 | if attn.norm_k is not None: 173 | key = attn.norm_k(key) 174 | 175 | if image_rotary_emb is not None: 176 | from diffusers.models.embeddings import apply_rotary_emb 177 | query = apply_rotary_emb(query, image_rotary_emb) 178 | key = apply_rotary_emb(key, image_rotary_emb) 179 | 180 | query = query[:, :, :scaled_block_size, :] 181 | 182 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None) 183 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2) 184 | 185 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 186 | hidden_states = hidden_states.to(query.dtype) 187 | 188 | cond_hidden_states = hidden_states[:, block_size:,:] 189 | hidden_states = hidden_states[:, : block_size,:] 190 | 191 | return hidden_states if not use_cond else (hidden_states, cond_hidden_states) 192 | 193 | 194 | class MultiDoubleStreamBlockLoraProcessor(nn.Module): 195 | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): 196 | super().__init__() 197 | 198 | # Initialize a list to store the LoRA layers 199 | self.n_loras = n_loras 200 | self.cond_width = cond_width 201 | self.cond_height = cond_height 202 | self.q_loras = nn.ModuleList([ 203 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 204 | for i in range(n_loras) 205 | ]) 206 | self.k_loras = nn.ModuleList([ 207 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 208 | for i in range(n_loras) 209 | ]) 210 | self.v_loras = nn.ModuleList([ 211 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 212 | for i in range(n_loras) 213 | ]) 214 | self.proj_loras = nn.ModuleList([ 215 | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) 216 | for i in range(n_loras) 217 | ]) 218 | self.lora_weights = lora_weights 219 | self.bank_attn = None 220 | self.bank_kv = [] 221 | 222 | 223 | def __call__(self, 224 | attn: Attention, 225 | hidden_states: torch.FloatTensor, 226 | encoder_hidden_states: torch.FloatTensor = None, 227 | attention_mask: Optional[torch.FloatTensor] = None, 228 | image_rotary_emb: Optional[torch.Tensor] = None, 229 | use_cond=False, 230 | ) -> torch.FloatTensor: 231 | 232 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 233 | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 234 | block_size = hidden_states.shape[1] - cond_size * self.n_loras 235 | scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1] 236 | scaled_cond_size = cond_size 237 | scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras 238 | 239 | # `context` projections. 240 | inner_dim = 3072 241 | head_dim = inner_dim // attn.heads 242 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 243 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 244 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 245 | 246 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 247 | batch_size, -1, attn.heads, head_dim 248 | ).transpose(1, 2) 249 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 250 | batch_size, -1, attn.heads, head_dim 251 | ).transpose(1, 2) 252 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 253 | batch_size, -1, attn.heads, head_dim 254 | ).transpose(1, 2) 255 | 256 | if attn.norm_added_q is not None: 257 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 258 | if attn.norm_added_k is not None: 259 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 260 | 261 | if len(self.bank_kv)== 0: 262 | cache = True 263 | else: 264 | cache = False 265 | 266 | if cache: 267 | 268 | query = attn.to_q(hidden_states) 269 | key = attn.to_k(hidden_states) 270 | value = attn.to_v(hidden_states) 271 | for i in range(self.n_loras): 272 | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) 273 | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) 274 | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) 275 | 276 | inner_dim = key.shape[-1] 277 | head_dim = inner_dim // attn.heads 278 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 279 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 280 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 281 | 282 | 283 | self.bank_kv.append(key[:, :, block_size:, :]) 284 | self.bank_kv.append(value[:, :, block_size:, :]) 285 | 286 | if attn.norm_q is not None: 287 | query = attn.norm_q(query) 288 | if attn.norm_k is not None: 289 | key = attn.norm_k(key) 290 | 291 | # attention 292 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 293 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 294 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 295 | 296 | if image_rotary_emb is not None: 297 | from diffusers.models.embeddings import apply_rotary_emb 298 | query = apply_rotary_emb(query, image_rotary_emb) 299 | key = apply_rotary_emb(key, image_rotary_emb) 300 | 301 | num_cond_blocks = self.n_loras 302 | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) 303 | mask[ :scaled_block_size, :] = 0 # First block_size row 304 | for i in range(num_cond_blocks): 305 | start = i * scaled_cond_size + scaled_block_size 306 | end = (i + 1) * scaled_cond_size + scaled_block_size 307 | mask[start:end, start:end] = 0 # Diagonal blocks 308 | mask = mask * -1e20 309 | mask = mask.to(query.dtype) 310 | 311 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) 312 | self.bank_attn = hidden_states[:, :, scaled_block_size:, :] 313 | 314 | else: 315 | query = attn.to_q(hidden_states) 316 | key = attn.to_k(hidden_states) 317 | value = attn.to_v(hidden_states) 318 | 319 | inner_dim = query.shape[-1] 320 | head_dim = inner_dim // attn.heads 321 | 322 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 323 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 324 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 325 | 326 | key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2) 327 | value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2) 328 | 329 | if attn.norm_q is not None: 330 | query = attn.norm_q(query) 331 | if attn.norm_k is not None: 332 | key = attn.norm_k(key) 333 | 334 | # attention 335 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 336 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 337 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 338 | 339 | if image_rotary_emb is not None: 340 | from diffusers.models.embeddings import apply_rotary_emb 341 | query = apply_rotary_emb(query, image_rotary_emb) 342 | key = apply_rotary_emb(key, image_rotary_emb) 343 | 344 | query = query[:, :, :scaled_block_size, :] 345 | 346 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None) 347 | hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2) 348 | 349 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 350 | hidden_states = hidden_states.to(query.dtype) 351 | 352 | encoder_hidden_states, hidden_states = ( 353 | hidden_states[:, : encoder_hidden_states.shape[1]], 354 | hidden_states[:, encoder_hidden_states.shape[1] :], 355 | ) 356 | 357 | # Linear projection (with LoRA weight applied to each proj layer) 358 | hidden_states = attn.to_out[0](hidden_states) 359 | for i in range(self.n_loras): 360 | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states) 361 | # dropout 362 | hidden_states = attn.to_out[1](hidden_states) 363 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 364 | 365 | cond_hidden_states = hidden_states[:, block_size:,:] 366 | hidden_states = hidden_states[:, :block_size,:] 367 | 368 | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states) -------------------------------------------------------------------------------- /models/easycontrol/src/lora_helper.py: -------------------------------------------------------------------------------- 1 | from diffusers.models.attention_processor import FluxAttnProcessor2_0 2 | from safetensors import safe_open 3 | import re 4 | import torch 5 | from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor 6 | 7 | device = "cuda" 8 | 9 | def load_safetensors(path): 10 | tensors = {} 11 | with safe_open(path, framework="pt", device="cpu") as f: 12 | for key in f.keys(): 13 | tensors[key] = f.get_tensor(key) 14 | return tensors 15 | 16 | def get_lora_rank(checkpoint): 17 | for k in checkpoint.keys(): 18 | if k.endswith(".down.weight"): 19 | return checkpoint[k].shape[0] 20 | 21 | def load_checkpoint(local_path): 22 | if local_path is not None: 23 | if '.safetensors' in local_path: 24 | print(f"Loading .safetensors checkpoint from {local_path}") 25 | checkpoint = load_safetensors(local_path) 26 | else: 27 | print(f"Loading checkpoint from {local_path}") 28 | checkpoint = torch.load(local_path, map_location='cpu') 29 | return checkpoint 30 | 31 | def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size): 32 | number = len(lora_weights) 33 | ranks = [get_lora_rank(checkpoint) for _ in range(number)] 34 | lora_attn_procs = {} 35 | double_blocks_idx = list(range(19)) 36 | single_blocks_idx = list(range(38)) 37 | for name, attn_processor in transformer.attn_processors.items(): 38 | match = re.search(r'\.(\d+)\.', name) 39 | if match: 40 | layer_index = int(match.group(1)) 41 | 42 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 43 | 44 | lora_state_dicts = {} 45 | for key, value in checkpoint.items(): 46 | # Match based on the layer index in the key (assuming the key contains layer index) 47 | if re.search(r'\.(\d+)\.', key): 48 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 49 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 50 | lora_state_dicts[key] = value 51 | 52 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 53 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 54 | ) 55 | 56 | # Load the weights from the checkpoint dictionary into the corresponding layers 57 | for n in range(number): 58 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 59 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 60 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 61 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 62 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 63 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 64 | lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None) 65 | lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None) 66 | lora_attn_procs[name].to(device) 67 | 68 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 69 | 70 | lora_state_dicts = {} 71 | for key, value in checkpoint.items(): 72 | # Match based on the layer index in the key (assuming the key contains layer index) 73 | if re.search(r'\.(\d+)\.', key): 74 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 75 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 76 | lora_state_dicts[key] = value 77 | 78 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 79 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number 80 | ) 81 | # Load the weights from the checkpoint dictionary into the corresponding layers 82 | for n in range(number): 83 | lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None) 84 | lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None) 85 | lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None) 86 | lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None) 87 | lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None) 88 | lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None) 89 | lora_attn_procs[name].to(device) 90 | else: 91 | lora_attn_procs[name] = FluxAttnProcessor2_0() 92 | 93 | transformer.set_attn_processor(lora_attn_procs) 94 | 95 | 96 | def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size): 97 | ck_number = len(checkpoints) 98 | cond_lora_number = [len(ls) for ls in lora_weights] 99 | cond_number = sum(cond_lora_number) 100 | ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints] 101 | multi_lora_weight = [] 102 | for ls in lora_weights: 103 | for n in ls: 104 | multi_lora_weight.append(n) 105 | 106 | lora_attn_procs = {} 107 | double_blocks_idx = list(range(19)) 108 | single_blocks_idx = list(range(38)) 109 | for name, attn_processor in transformer.attn_processors.items(): 110 | match = re.search(r'\.(\d+)\.', name) 111 | if match: 112 | layer_index = int(match.group(1)) 113 | 114 | if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: 115 | lora_state_dicts = [{} for _ in range(ck_number)] 116 | for idx, checkpoint in enumerate(checkpoints): 117 | for key, value in checkpoint.items(): 118 | # Match based on the layer index in the key (assuming the key contains layer index) 119 | if re.search(r'\.(\d+)\.', key): 120 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 121 | if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"): 122 | lora_state_dicts[idx][key] = value 123 | 124 | lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( 125 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 126 | ) 127 | 128 | # Load the weights from the checkpoint dictionary into the corresponding layers 129 | num = 0 130 | for idx in range(ck_number): 131 | for n in range(cond_lora_number[idx]): 132 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 133 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 134 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 135 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 136 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 137 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 138 | lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None) 139 | lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None) 140 | lora_attn_procs[name].to(device) 141 | num += 1 142 | 143 | elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: 144 | 145 | lora_state_dicts = [{} for _ in range(ck_number)] 146 | for idx, checkpoint in enumerate(checkpoints): 147 | for key, value in checkpoint.items(): 148 | # Match based on the layer index in the key (assuming the key contains layer index) 149 | if re.search(r'\.(\d+)\.', key): 150 | checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1)) 151 | if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"): 152 | lora_state_dicts[idx][key] = value 153 | 154 | lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( 155 | dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number 156 | ) 157 | # Load the weights from the checkpoint dictionary into the corresponding layers 158 | num = 0 159 | for idx in range(ck_number): 160 | for n in range(cond_lora_number[idx]): 161 | lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None) 162 | lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None) 163 | lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None) 164 | lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None) 165 | lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None) 166 | lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None) 167 | lora_attn_procs[name].to(device) 168 | num += 1 169 | 170 | else: 171 | lora_attn_procs[name] = FluxAttnProcessor2_0() 172 | 173 | transformer.set_attn_processor(lora_attn_procs) 174 | 175 | 176 | def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512): 177 | checkpoint = load_checkpoint(local_path) 178 | update_model_with_lora(checkpoint, lora_weights, transformer, cond_size) 179 | 180 | def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512): 181 | checkpoints = [load_checkpoint(local_path) for local_path in local_paths] 182 | update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size) 183 | 184 | def unset_lora(transformer): 185 | lora_attn_procs = {} 186 | for name, attn_processor in transformer.attn_processors.items(): 187 | lora_attn_procs[name] = FluxAttnProcessor2_0() 188 | transformer.set_attn_processor(lora_attn_procs) 189 | 190 | 191 | ''' 192 | unset_lora(pipe.transformer) 193 | lora_path = "./lora.safetensors" 194 | lora_weights = [1, 1] 195 | set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512) 196 | ''' -------------------------------------------------------------------------------- /models/easycontrol/src/transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import ( 12 | Attention, 13 | AttentionProcessor, 14 | FluxAttnProcessor2_0, 15 | FluxAttnProcessor2_0_NPU, 16 | FusedFluxAttnProcessor2_0, 17 | ) 18 | from diffusers.models.modeling_utils import ModelMixin 19 | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle 20 | from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers 21 | from diffusers.utils.import_utils import is_torch_npu_available 22 | from diffusers.utils.torch_utils import maybe_allow_in_graph 23 | from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed 24 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | @maybe_allow_in_graph 29 | class FluxSingleTransformerBlock(nn.Module): 30 | 31 | def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): 32 | super().__init__() 33 | self.mlp_hidden_dim = int(dim * mlp_ratio) 34 | 35 | self.norm = AdaLayerNormZeroSingle(dim) 36 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 37 | self.act_mlp = nn.GELU(approximate="tanh") 38 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 39 | 40 | if is_torch_npu_available(): 41 | processor = FluxAttnProcessor2_0_NPU() 42 | else: 43 | processor = FluxAttnProcessor2_0() 44 | self.attn = Attention( 45 | query_dim=dim, 46 | cross_attention_dim=None, 47 | dim_head=attention_head_dim, 48 | heads=num_attention_heads, 49 | out_dim=dim, 50 | bias=True, 51 | processor=processor, 52 | qk_norm="rms_norm", 53 | eps=1e-6, 54 | pre_only=True, 55 | ) 56 | 57 | def forward( 58 | self, 59 | hidden_states: torch.Tensor, 60 | cond_hidden_states: torch.Tensor, 61 | temb: torch.Tensor, 62 | cond_temb: torch.Tensor, 63 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 64 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 65 | ) -> torch.Tensor: 66 | use_cond = cond_hidden_states is not None 67 | 68 | residual = hidden_states 69 | norm_hidden_states, gate = self.norm(hidden_states, emb=temb) 70 | mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 71 | 72 | if use_cond: 73 | residual_cond = cond_hidden_states 74 | norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb) 75 | mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states)) 76 | 77 | norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 78 | 79 | joint_attention_kwargs = joint_attention_kwargs or {} 80 | attn_output = self.attn( 81 | hidden_states=norm_hidden_states_concat, 82 | image_rotary_emb=image_rotary_emb, 83 | use_cond=use_cond, 84 | **joint_attention_kwargs, 85 | ) 86 | if use_cond: 87 | attn_output, cond_attn_output = attn_output 88 | 89 | hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) 90 | gate = gate.unsqueeze(1) 91 | hidden_states = gate * self.proj_out(hidden_states) 92 | hidden_states = residual + hidden_states 93 | 94 | if use_cond: 95 | condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2) 96 | cond_gate = cond_gate.unsqueeze(1) 97 | condition_latents = cond_gate * self.proj_out(condition_latents) 98 | condition_latents = residual_cond + condition_latents 99 | 100 | if hidden_states.dtype == torch.float16: 101 | hidden_states = hidden_states.clip(-65504, 65504) 102 | 103 | return hidden_states, condition_latents if use_cond else None 104 | 105 | 106 | @maybe_allow_in_graph 107 | class FluxTransformerBlock(nn.Module): 108 | def __init__( 109 | self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 110 | ): 111 | super().__init__() 112 | 113 | self.norm1 = AdaLayerNormZero(dim) 114 | 115 | self.norm1_context = AdaLayerNormZero(dim) 116 | 117 | if hasattr(F, "scaled_dot_product_attention"): 118 | processor = FluxAttnProcessor2_0() 119 | else: 120 | raise ValueError( 121 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 122 | ) 123 | self.attn = Attention( 124 | query_dim=dim, 125 | cross_attention_dim=None, 126 | added_kv_proj_dim=dim, 127 | dim_head=attention_head_dim, 128 | heads=num_attention_heads, 129 | out_dim=dim, 130 | context_pre_only=False, 131 | bias=True, 132 | processor=processor, 133 | qk_norm=qk_norm, 134 | eps=eps, 135 | ) 136 | 137 | self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 138 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 139 | 140 | self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) 141 | self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 142 | 143 | # let chunk size default to None 144 | self._chunk_size = None 145 | self._chunk_dim = 0 146 | 147 | def forward( 148 | self, 149 | hidden_states: torch.Tensor, 150 | cond_hidden_states: torch.Tensor, 151 | encoder_hidden_states: torch.Tensor, 152 | temb: torch.Tensor, 153 | cond_temb: torch.Tensor, 154 | image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 155 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 156 | ) -> Tuple[torch.Tensor, torch.Tensor]: 157 | use_cond = cond_hidden_states is not None 158 | 159 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 160 | if use_cond: 161 | ( 162 | norm_cond_hidden_states, 163 | cond_gate_msa, 164 | cond_shift_mlp, 165 | cond_scale_mlp, 166 | cond_gate_mlp, 167 | ) = self.norm1(cond_hidden_states, emb=cond_temb) 168 | 169 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 170 | encoder_hidden_states, emb=temb 171 | ) 172 | 173 | norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2) 174 | 175 | joint_attention_kwargs = joint_attention_kwargs or {} 176 | # Attention. 177 | attention_outputs = self.attn( 178 | hidden_states=norm_hidden_states, 179 | encoder_hidden_states=norm_encoder_hidden_states, 180 | image_rotary_emb=image_rotary_emb, 181 | use_cond=use_cond, 182 | **joint_attention_kwargs, 183 | ) 184 | 185 | attn_output, context_attn_output = attention_outputs[:2] 186 | cond_attn_output = attention_outputs[2] if use_cond else None 187 | 188 | # Process attention outputs for the `hidden_states`. 189 | attn_output = gate_msa.unsqueeze(1) * attn_output 190 | hidden_states = hidden_states + attn_output 191 | 192 | if use_cond: 193 | cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output 194 | cond_hidden_states = cond_hidden_states + cond_attn_output 195 | 196 | norm_hidden_states = self.norm2(hidden_states) 197 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 198 | 199 | if use_cond: 200 | norm_cond_hidden_states = self.norm2(cond_hidden_states) 201 | norm_cond_hidden_states = ( 202 | norm_cond_hidden_states * (1 + cond_scale_mlp[:, None]) 203 | + cond_shift_mlp[:, None] 204 | ) 205 | 206 | ff_output = self.ff(norm_hidden_states) 207 | ff_output = gate_mlp.unsqueeze(1) * ff_output 208 | hidden_states = hidden_states + ff_output 209 | 210 | if use_cond: 211 | cond_ff_output = self.ff(norm_cond_hidden_states) 212 | cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output 213 | cond_hidden_states = cond_hidden_states + cond_ff_output 214 | 215 | # Process attention outputs for the `encoder_hidden_states`. 216 | 217 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 218 | encoder_hidden_states = encoder_hidden_states + context_attn_output 219 | 220 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 221 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 222 | 223 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 224 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 225 | if encoder_hidden_states.dtype == torch.float16: 226 | encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) 227 | 228 | return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None 229 | 230 | 231 | class FluxTransformer2DModel( 232 | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin 233 | ): 234 | _supports_gradient_checkpointing = True 235 | _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] 236 | 237 | @register_to_config 238 | def __init__( 239 | self, 240 | patch_size: int = 1, 241 | in_channels: int = 64, 242 | out_channels: Optional[int] = None, 243 | num_layers: int = 19, 244 | num_single_layers: int = 38, 245 | attention_head_dim: int = 128, 246 | num_attention_heads: int = 24, 247 | joint_attention_dim: int = 4096, 248 | pooled_projection_dim: int = 768, 249 | guidance_embeds: bool = False, 250 | axes_dims_rope: Tuple[int] = (16, 56, 56), 251 | ): 252 | super().__init__() 253 | self.out_channels = out_channels or in_channels 254 | self.inner_dim = num_attention_heads * attention_head_dim 255 | 256 | self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) 257 | 258 | text_time_guidance_cls = ( 259 | CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings 260 | ) 261 | self.time_text_embed = text_time_guidance_cls( 262 | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim 263 | ) 264 | 265 | self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) 266 | self.x_embedder = nn.Linear(in_channels, self.inner_dim) 267 | 268 | self.transformer_blocks = nn.ModuleList( 269 | [ 270 | FluxTransformerBlock( 271 | dim=self.inner_dim, 272 | num_attention_heads=num_attention_heads, 273 | attention_head_dim=attention_head_dim, 274 | ) 275 | for _ in range(num_layers) 276 | ] 277 | ) 278 | 279 | self.single_transformer_blocks = nn.ModuleList( 280 | [ 281 | FluxSingleTransformerBlock( 282 | dim=self.inner_dim, 283 | num_attention_heads=num_attention_heads, 284 | attention_head_dim=attention_head_dim, 285 | ) 286 | for _ in range(num_single_layers) 287 | ] 288 | ) 289 | 290 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 291 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 292 | 293 | self.gradient_checkpointing = False 294 | 295 | @property 296 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 297 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 298 | r""" 299 | Returns: 300 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 301 | indexed by its weight name. 302 | """ 303 | # set recursively 304 | processors = {} 305 | 306 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 307 | if hasattr(module, "get_processor"): 308 | processors[f"{name}.processor"] = module.get_processor() 309 | 310 | for sub_name, child in module.named_children(): 311 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 312 | 313 | return processors 314 | 315 | for name, module in self.named_children(): 316 | fn_recursive_add_processors(name, module, processors) 317 | 318 | return processors 319 | 320 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 321 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 322 | r""" 323 | Sets the attention processor to use to compute attention. 324 | 325 | Parameters: 326 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 327 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 328 | for **all** `Attention` layers. 329 | 330 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 331 | processor. This is strongly recommended when setting trainable attention processors. 332 | 333 | """ 334 | count = len(self.attn_processors.keys()) 335 | 336 | if isinstance(processor, dict) and len(processor) != count: 337 | raise ValueError( 338 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 339 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 340 | ) 341 | 342 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 343 | if hasattr(module, "set_processor"): 344 | if not isinstance(processor, dict): 345 | module.set_processor(processor) 346 | else: 347 | module.set_processor(processor.pop(f"{name}.processor")) 348 | 349 | for sub_name, child in module.named_children(): 350 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 351 | 352 | for name, module in self.named_children(): 353 | fn_recursive_attn_processor(name, module, processor) 354 | 355 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 356 | def fuse_qkv_projections(self): 357 | """ 358 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 359 | are fused. For cross-attention modules, key and value projection matrices are fused. 360 | 361 | 362 | 363 | This API is 🧪 experimental. 364 | 365 | 366 | """ 367 | self.original_attn_processors = None 368 | 369 | for _, attn_processor in self.attn_processors.items(): 370 | if "Added" in str(attn_processor.__class__.__name__): 371 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 372 | 373 | self.original_attn_processors = self.attn_processors 374 | 375 | for module in self.modules(): 376 | if isinstance(module, Attention): 377 | module.fuse_projections(fuse=True) 378 | 379 | self.set_attn_processor(FusedFluxAttnProcessor2_0()) 380 | 381 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 382 | def unfuse_qkv_projections(self): 383 | """Disables the fused QKV projection if enabled. 384 | 385 | 386 | 387 | This API is 🧪 experimental. 388 | 389 | 390 | 391 | """ 392 | if self.original_attn_processors is not None: 393 | self.set_attn_processor(self.original_attn_processors) 394 | 395 | def _set_gradient_checkpointing(self, module, value=False): 396 | if hasattr(module, "gradient_checkpointing"): 397 | module.gradient_checkpointing = value 398 | 399 | def forward( 400 | self, 401 | hidden_states: torch.Tensor, 402 | cond_hidden_states: torch.Tensor = None, 403 | encoder_hidden_states: torch.Tensor = None, 404 | pooled_projections: torch.Tensor = None, 405 | timestep: torch.LongTensor = None, 406 | img_ids: torch.Tensor = None, 407 | txt_ids: torch.Tensor = None, 408 | guidance: torch.Tensor = None, 409 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 410 | controlnet_block_samples=None, 411 | controlnet_single_block_samples=None, 412 | return_dict: bool = True, 413 | controlnet_blocks_repeat: bool = False, 414 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 415 | if cond_hidden_states is not None: 416 | use_condition = True 417 | else: 418 | use_condition = False 419 | 420 | if joint_attention_kwargs is not None: 421 | joint_attention_kwargs = joint_attention_kwargs.copy() 422 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 423 | else: 424 | lora_scale = 1.0 425 | 426 | if USE_PEFT_BACKEND: 427 | # weight the lora layers by setting `lora_scale` for each PEFT layer 428 | scale_lora_layers(self, lora_scale) 429 | else: 430 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 431 | logger.warning( 432 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 433 | ) 434 | 435 | hidden_states = self.x_embedder(hidden_states) 436 | cond_hidden_states = self.x_embedder(cond_hidden_states) 437 | 438 | timestep = timestep.to(hidden_states.dtype) * 1000 439 | if guidance is not None: 440 | guidance = guidance.to(hidden_states.dtype) * 1000 441 | else: 442 | guidance = None 443 | 444 | temb = ( 445 | self.time_text_embed(timestep, pooled_projections) 446 | if guidance is None 447 | else self.time_text_embed(timestep, guidance, pooled_projections) 448 | ) 449 | 450 | cond_temb = ( 451 | self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections) 452 | if guidance is None 453 | else self.time_text_embed( 454 | torch.ones_like(timestep) * 0, guidance, pooled_projections 455 | ) 456 | ) 457 | 458 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 459 | 460 | if txt_ids.ndim == 3: 461 | logger.warning( 462 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 463 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 464 | ) 465 | txt_ids = txt_ids[0] 466 | if img_ids.ndim == 3: 467 | logger.warning( 468 | "Passing `img_ids` 3d torch.Tensor is deprecated." 469 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 470 | ) 471 | img_ids = img_ids[0] 472 | 473 | ids = torch.cat((txt_ids, img_ids), dim=0) 474 | image_rotary_emb = self.pos_embed(ids) 475 | 476 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: 477 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") 478 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) 479 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) 480 | 481 | for index_block, block in enumerate(self.transformer_blocks): 482 | if torch.is_grad_enabled() and self.gradient_checkpointing: 483 | 484 | def create_custom_forward(module, return_dict=None): 485 | def custom_forward(*inputs): 486 | if return_dict is not None: 487 | return module(*inputs, return_dict=return_dict) 488 | else: 489 | return module(*inputs) 490 | 491 | return custom_forward 492 | 493 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 494 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 495 | create_custom_forward(block), 496 | hidden_states, 497 | encoder_hidden_states, 498 | temb, 499 | image_rotary_emb, 500 | cond_temb=cond_temb if use_condition else None, 501 | cond_hidden_states=cond_hidden_states if use_condition else None, 502 | **ckpt_kwargs, 503 | ) 504 | 505 | else: 506 | encoder_hidden_states, hidden_states, cond_hidden_states = block( 507 | hidden_states=hidden_states, 508 | encoder_hidden_states=encoder_hidden_states, 509 | cond_hidden_states=cond_hidden_states if use_condition else None, 510 | temb=temb, 511 | cond_temb=cond_temb if use_condition else None, 512 | image_rotary_emb=image_rotary_emb, 513 | joint_attention_kwargs=joint_attention_kwargs, 514 | ) 515 | 516 | # controlnet residual 517 | if controlnet_block_samples is not None: 518 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 519 | interval_control = int(np.ceil(interval_control)) 520 | # For Xlabs ControlNet. 521 | if controlnet_blocks_repeat: 522 | hidden_states = ( 523 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 524 | ) 525 | else: 526 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 527 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 528 | 529 | for index_block, block in enumerate(self.single_transformer_blocks): 530 | if torch.is_grad_enabled() and self.gradient_checkpointing: 531 | 532 | def create_custom_forward(module, return_dict=None): 533 | def custom_forward(*inputs): 534 | if return_dict is not None: 535 | return module(*inputs, return_dict=return_dict) 536 | else: 537 | return module(*inputs) 538 | 539 | return custom_forward 540 | 541 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 542 | hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint( 543 | create_custom_forward(block), 544 | hidden_states, 545 | temb, 546 | image_rotary_emb, 547 | cond_temb=cond_temb if use_condition else None, 548 | cond_hidden_states=cond_hidden_states if use_condition else None, 549 | **ckpt_kwargs, 550 | ) 551 | 552 | else: 553 | hidden_states, cond_hidden_states = block( 554 | hidden_states=hidden_states, 555 | cond_hidden_states=cond_hidden_states if use_condition else None, 556 | temb=temb, 557 | cond_temb=cond_temb if use_condition else None, 558 | image_rotary_emb=image_rotary_emb, 559 | joint_attention_kwargs=joint_attention_kwargs, 560 | ) 561 | 562 | # controlnet residual 563 | if controlnet_single_block_samples is not None: 564 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 565 | interval_control = int(np.ceil(interval_control)) 566 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 567 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 568 | + controlnet_single_block_samples[index_block // interval_control] 569 | ) 570 | 571 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 572 | 573 | hidden_states = self.norm_out(hidden_states, temb) 574 | output = self.proj_out(hidden_states) 575 | 576 | if USE_PEFT_BACKEND: 577 | # remove `lora_scale` from each PEFT layer 578 | unscale_lora_layers(self, lora_scale) 579 | 580 | if not return_dict: 581 | return (output,) 582 | 583 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /models/flux/Guidance_distilled.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pipeline import FluxPipeline 3 | import random 4 | import os 5 | import numpy as np 6 | 7 | os.makedirs("output",exist_ok=True) 8 | 9 | def set_seed(seed): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) 17 | pipe.enable_model_cpu_offload() 18 | 19 | prompt = "a tiny astronaut hatching from an egg on the moon" 20 | 21 | seed = random.randint(0, 2**32 - 1) 22 | print('seed: ',seed) 23 | 24 | set_seed(seed) 25 | out = pipe( 26 | prompt=prompt, 27 | guidance_scale=3.5, 28 | height=768, 29 | width=1360, 30 | num_inference_steps=50, 31 | use_zero_init=True, 32 | zero_steps=0, 33 | ).images[0] 34 | out.save("output/image_ours.png") 35 | 36 | set_seed(seed) 37 | out = pipe( 38 | prompt=prompt, 39 | guidance_scale=3.5, 40 | height=768, 41 | width=1360, 42 | num_inference_steps=50, 43 | use_zero_init=False, 44 | zero_steps=0, 45 | ).images[0] 46 | out.save("output/image_cfg.png") 47 | 48 | -------------------------------------------------------------------------------- /models/flux/infer_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pipeline import FluxPipeline 3 | import random 4 | import os 5 | import numpy as np 6 | 7 | os.makedirs("output",exist_ok=True) 8 | 9 | def set_seed(seed): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) 17 | 18 | # Black Myth: Wukong 19 | pipe.load_lora_weights('Shakker-Labs/FLUX.1-dev-LoRA-collections', weight_name='FLUX-dev-lora-Black_Myth_Wukong_hyperrealism_v1.safetensors') 20 | pipe.fuse_lora(lora_scale=1.2) 21 | pipe.to("cuda") 22 | 23 | prompt = "aiyouxiketang, a man in armor with a beard and a beard" 24 | 25 | seed = random.randint(0, 2**32 - 1) 26 | print('seed: ',seed) 27 | 28 | set_seed(seed) 29 | 30 | image = pipe( 31 | prompt, 32 | num_inference_steps=25, 33 | guidance_scale=5.0, 34 | use_zero_init=False, 35 | zero_steps=0, 36 | ).images[0] 37 | image.save("output/image_cfg.png") 38 | 39 | set_seed(seed) 40 | 41 | image = pipe( 42 | prompt, 43 | num_inference_steps=25, 44 | guidance_scale=5.0, 45 | use_zero_init=True, 46 | zero_steps=0, 47 | ).images[0] 48 | image.save("output/image_ours.png") 49 | -------------------------------------------------------------------------------- /models/hunyuan/t2v.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import HunyuanVideoTransformer3DModel 3 | from pipeline import HunyuanVideoPipeline 4 | from diffusers.utils import export_to_video 5 | import random 6 | import os 7 | import numpy as np 8 | 9 | os.makedirs("output",exist_ok=True) 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | os.environ['PYTHONHASHSEED'] = str(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | 18 | model_id = "hunyuanvideo-community/HunyuanVideo" 19 | transformer = HunyuanVideoTransformer3DModel.from_pretrained( 20 | model_id, subfolder="transformer", torch_dtype=torch.bfloat16 21 | ) 22 | pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) 23 | pipe.vae.enable_tiling() 24 | pipe.to("cuda") 25 | 26 | seed = random.randint(0, 2**32 - 1) 27 | print('seed: ',seed) 28 | 29 | set_seed(seed) 30 | output = pipe( 31 | prompt="In an ornate, historical hall, a massive tidal wave peaks and begins to crash. A man is surfing, cinematic film shot in 35mm. High quality, high defination.", 32 | height=720, 33 | width=1280, 34 | num_frames=61, 35 | num_inference_steps=50, 36 | use_zero_init=False, 37 | zero_steps=0, 38 | ).frames[0] 39 | export_to_video(output, f"output/{seed}_output_cfg.mp4", fps=15) 40 | 41 | set_seed(seed) 42 | output = pipe( 43 | prompt="In an ornate, historical hall, a massive tidal wave peaks and begins to crash. A man is surfing, cinematic film shot in 35mm. High quality, high defination.", 44 | height=720, 45 | width=1280, 46 | num_frames=61, 47 | num_inference_steps=50, 48 | use_zero_init=True, 49 | zero_steps=1, 50 | ).frames[0] 51 | export_to_video(output, f"output/{seed}_output_ours.mp4", fps=15) 52 | -------------------------------------------------------------------------------- /models/sd/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sd3_pipeline import StableDiffusion3Pipeline 3 | 4 | 5 | import numpy as np 6 | import random 7 | import os 8 | 9 | def set_seed(seed): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | 17 | pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16) 18 | 19 | pipe = pipe.to("cuda") 20 | 21 | seed = random.randint(0, 2**32 - 1) 22 | print('seed: ',seed) 23 | 24 | set_seed(seed) 25 | 26 | image = pipe( 27 | "A capybara holding a sign that reads Hello World", 28 | num_inference_steps=28, 29 | guidance_scale=3.5, 30 | use_cfg_zero_star=True, 31 | use_zero_init=True, 32 | zero_steps=0 33 | ).images[0] 34 | image.save("output/output_ours.png") 35 | 36 | set_seed(seed) 37 | 38 | image = pipe( 39 | "A capybara holding a sign that reads Hello World", 40 | num_inference_steps=28, 41 | guidance_scale=3.5, 42 | use_cfg_zero_star=False, 43 | use_zero_init=False, 44 | zero_steps=0 45 | ).images[0] 46 | image.save("output/output_cfg.png") 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /models/wan/T2V_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils import export_to_video 3 | from diffusers import AutoencoderKLWan#, WanPipeline 4 | from wan_pipeline import WanPipeline 5 | from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 6 | 7 | import random 8 | import os 9 | import numpy as np 10 | 11 | 12 | def set_seed(seed): 13 | random.seed(seed) 14 | os.environ['PYTHONHASHSEED'] = str(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | 19 | model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 20 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 21 | flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P 22 | scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift) 23 | pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 24 | pipe.scheduler = scheduler 25 | pipe.to("cuda") 26 | 27 | prompt = "The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds." 28 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 29 | 30 | os.makedirs("output",exist_ok=True) 31 | 32 | seed = random.randint(0, 2**32 - 1) 33 | print('seed: ',seed) 34 | set_seed(seed) 35 | output = pipe( 36 | prompt=prompt, 37 | negative_prompt=negative_prompt, 38 | height=480, 39 | width=832, 40 | num_frames=81, 41 | guidance_scale=5.0, 42 | use_cfg_zero_star=False, 43 | use_zero_init=False, 44 | zero_steps=0 45 | ).frames[0] 46 | export_to_video(output, f"output/{seed}_base.mp4", fps=15) 47 | 48 | set_seed(seed) 49 | output = pipe( 50 | prompt=prompt, 51 | negative_prompt=negative_prompt, 52 | height=480, 53 | width=832, 54 | num_frames=81, 55 | guidance_scale=5.0, 56 | use_cfg_zero_star=True, 57 | use_zero_init=True, 58 | zero_steps=1 59 | ).frames[0] 60 | export_to_video(output, f"output/{seed}_ours.mp4", fps=15) 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /models/wan/image2video_cfg_zero_star.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, 25 | get_sampling_sigmas, retrieve_timesteps) 26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 27 | 28 | @torch.cuda.amp.autocast(dtype=torch.float32) 29 | def optimized_scale(positive_flat, negative_flat): 30 | 31 | # Calculate dot production 32 | dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 33 | 34 | # Squared norm of uncondition 35 | squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 36 | 37 | # st_star = v_cond^T * v_uncond / ||v_uncond||^2 38 | st_star = dot_product / squared_norm 39 | 40 | return st_star 41 | 42 | class WanI2V: 43 | 44 | def __init__( 45 | self, 46 | config, 47 | checkpoint_dir, 48 | device_id=0, 49 | rank=0, 50 | t5_fsdp=False, 51 | dit_fsdp=False, 52 | use_usp=False, 53 | t5_cpu=False, 54 | init_on_cpu=True, 55 | ): 56 | r""" 57 | Initializes the image-to-video generation model components. 58 | 59 | Args: 60 | config (EasyDict): 61 | Object containing model parameters initialized from config.py 62 | checkpoint_dir (`str`): 63 | Path to directory containing model checkpoints 64 | device_id (`int`, *optional*, defaults to 0): 65 | Id of target GPU device 66 | rank (`int`, *optional*, defaults to 0): 67 | Process rank for distributed training 68 | t5_fsdp (`bool`, *optional*, defaults to False): 69 | Enable FSDP sharding for T5 model 70 | dit_fsdp (`bool`, *optional*, defaults to False): 71 | Enable FSDP sharding for DiT model 72 | use_usp (`bool`, *optional*, defaults to False): 73 | Enable distribution strategy of USP. 74 | t5_cpu (`bool`, *optional*, defaults to False): 75 | Whether to place T5 model on CPU. Only works without t5_fsdp. 76 | init_on_cpu (`bool`, *optional*, defaults to True): 77 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 78 | """ 79 | self.device = torch.device(f"cuda:{device_id}") 80 | self.config = config 81 | self.rank = rank 82 | self.use_usp = use_usp 83 | self.t5_cpu = t5_cpu 84 | 85 | self.num_train_timesteps = config.num_train_timesteps 86 | self.param_dtype = config.param_dtype 87 | 88 | shard_fn = partial(shard_model, device_id=device_id) 89 | self.text_encoder = T5EncoderModel( 90 | text_len=config.text_len, 91 | dtype=config.t5_dtype, 92 | device=torch.device('cpu'), 93 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 94 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 95 | shard_fn=shard_fn if t5_fsdp else None, 96 | ) 97 | 98 | self.vae_stride = config.vae_stride 99 | self.patch_size = config.patch_size 100 | self.vae = WanVAE( 101 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 102 | device=self.device) 103 | 104 | self.clip = CLIPModel( 105 | dtype=config.clip_dtype, 106 | device=self.device, 107 | checkpoint_path=os.path.join(checkpoint_dir, 108 | config.clip_checkpoint), 109 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 110 | 111 | logging.info(f"Creating WanModel from {checkpoint_dir}") 112 | self.model = WanModel.from_pretrained(checkpoint_dir) 113 | self.model.eval().requires_grad_(False) 114 | 115 | if t5_fsdp or dit_fsdp or use_usp: 116 | init_on_cpu = False 117 | 118 | if use_usp: 119 | from xfuser.core.distributed import \ 120 | get_sequence_parallel_world_size 121 | 122 | from .distributed.xdit_context_parallel import (usp_attn_forward, 123 | usp_dit_forward) 124 | for block in self.model.blocks: 125 | block.self_attn.forward = types.MethodType( 126 | usp_attn_forward, block.self_attn) 127 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 128 | self.sp_size = get_sequence_parallel_world_size() 129 | else: 130 | self.sp_size = 1 131 | 132 | if dist.is_initialized(): 133 | dist.barrier() 134 | if dit_fsdp: 135 | self.model = shard_fn(self.model) 136 | else: 137 | if not init_on_cpu: 138 | self.model.to(self.device) 139 | 140 | self.sample_neg_prompt = config.sample_neg_prompt 141 | 142 | def generate(self, 143 | input_prompt, 144 | img, 145 | max_area=720 * 1280, 146 | frame_num=81, 147 | shift=5.0, 148 | sample_solver='unipc', 149 | sampling_steps=40, 150 | guide_scale=5.0, 151 | n_prompt="", 152 | seed=-1, 153 | offload_model=True, 154 | zero_init_steps=0): 155 | r""" 156 | Generates video frames from input image and text prompt using diffusion process. 157 | 158 | Args: 159 | input_prompt (`str`): 160 | Text prompt for content generation. 161 | img (PIL.Image.Image): 162 | Input image tensor. Shape: [3, H, W] 163 | max_area (`int`, *optional*, defaults to 720*1280): 164 | Maximum pixel area for latent space calculation. Controls video resolution scaling 165 | frame_num (`int`, *optional*, defaults to 81): 166 | How many frames to sample from a video. The number should be 4n+1 167 | shift (`float`, *optional*, defaults to 5.0): 168 | Noise schedule shift parameter. Affects temporal dynamics 169 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 170 | sample_solver (`str`, *optional*, defaults to 'unipc'): 171 | Solver used to sample the video. 172 | sampling_steps (`int`, *optional*, defaults to 40): 173 | Number of diffusion sampling steps. Higher values improve quality but slow generation 174 | guide_scale (`float`, *optional*, defaults 5.0): 175 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 176 | n_prompt (`str`, *optional*, defaults to ""): 177 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 178 | seed (`int`, *optional*, defaults to -1): 179 | Random seed for noise generation. If -1, use random seed 180 | offload_model (`bool`, *optional*, defaults to True): 181 | If True, offloads models to CPU during generation to save VRAM 182 | 183 | Returns: 184 | torch.Tensor: 185 | Generated video frames tensor. Dimensions: (C, N H, W) where: 186 | - C: Color channels (3 for RGB) 187 | - N: Number of frames (81) 188 | - H: Frame height (from max_area) 189 | - W: Frame width from max_area) 190 | """ 191 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) 192 | 193 | F = frame_num 194 | h, w = img.shape[1:] 195 | aspect_ratio = h / w 196 | lat_h = round( 197 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 198 | self.patch_size[1] * self.patch_size[1]) 199 | lat_w = round( 200 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 201 | self.patch_size[2] * self.patch_size[2]) 202 | h = lat_h * self.vae_stride[1] 203 | w = lat_w * self.vae_stride[2] 204 | 205 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 206 | self.patch_size[1] * self.patch_size[2]) 207 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 208 | 209 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 210 | seed_g = torch.Generator(device=self.device) 211 | seed_g.manual_seed(seed) 212 | noise = torch.randn( 213 | 16, 214 | 21, 215 | lat_h, 216 | lat_w, 217 | dtype=torch.float32, 218 | generator=seed_g, 219 | device=self.device) 220 | 221 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 222 | msk[:, 1:] = 0 223 | msk = torch.concat([ 224 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 225 | ], 226 | dim=1) 227 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 228 | msk = msk.transpose(1, 2)[0] 229 | 230 | if n_prompt == "": 231 | n_prompt = self.sample_neg_prompt 232 | 233 | # preprocess 234 | if not self.t5_cpu: 235 | self.text_encoder.model.to(self.device) 236 | context = self.text_encoder([input_prompt], self.device) 237 | context_null = self.text_encoder([n_prompt], self.device) 238 | if offload_model: 239 | self.text_encoder.model.cpu() 240 | else: 241 | context = self.text_encoder([input_prompt], torch.device('cpu')) 242 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 243 | context = [t.to(self.device) for t in context] 244 | context_null = [t.to(self.device) for t in context_null] 245 | 246 | self.clip.model.to(self.device) 247 | clip_context = self.clip.visual([img[:, None, :, :]]) 248 | if offload_model: 249 | self.clip.model.cpu() 250 | 251 | y = self.vae.encode([ 252 | torch.concat([ 253 | torch.nn.functional.interpolate( 254 | img[None].cpu(), size=(h, w), mode='bicubic').transpose( 255 | 0, 1), 256 | torch.zeros(3, 80, h, w) 257 | ], 258 | dim=1).to(self.device) 259 | ])[0] 260 | y = torch.concat([msk, y]) 261 | 262 | @contextmanager 263 | def noop_no_sync(): 264 | yield 265 | 266 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 267 | 268 | # evaluation mode 269 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 270 | 271 | if sample_solver == 'unipc': 272 | sample_scheduler = FlowUniPCMultistepScheduler( 273 | num_train_timesteps=self.num_train_timesteps, 274 | shift=1, 275 | use_dynamic_shifting=False) 276 | sample_scheduler.set_timesteps( 277 | sampling_steps, device=self.device, shift=shift) 278 | timesteps = sample_scheduler.timesteps 279 | elif sample_solver == 'dpm++': 280 | sample_scheduler = FlowDPMSolverMultistepScheduler( 281 | num_train_timesteps=self.num_train_timesteps, 282 | shift=1, 283 | use_dynamic_shifting=False) 284 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 285 | timesteps, _ = retrieve_timesteps( 286 | sample_scheduler, 287 | device=self.device, 288 | sigmas=sampling_sigmas) 289 | else: 290 | raise NotImplementedError("Unsupported solver.") 291 | 292 | # sample videos 293 | latent = noise 294 | 295 | arg_c = { 296 | 'context': [context[0]], 297 | 'clip_fea': clip_context, 298 | 'seq_len': max_seq_len, 299 | 'y': [y], 300 | } 301 | 302 | arg_null = { 303 | 'context': context_null, 304 | 'clip_fea': clip_context, 305 | 'seq_len': max_seq_len, 306 | 'y': [y], 307 | } 308 | 309 | if offload_model: 310 | torch.cuda.empty_cache() 311 | 312 | self.model.to(self.device) 313 | for i, t in enumerate(tqdm(timesteps)): 314 | latent_model_input = [latent.to(self.device)] 315 | timestep = [t] 316 | 317 | timestep = torch.stack(timestep).to(self.device) 318 | 319 | noise_pred_cond = self.model( 320 | latent_model_input, t=timestep, **arg_c)[0].to( 321 | torch.device('cpu') if offload_model else self.device) 322 | if offload_model: 323 | torch.cuda.empty_cache() 324 | noise_pred_uncond = self.model( 325 | latent_model_input, t=timestep, **arg_null)[0].to( 326 | torch.device('cpu') if offload_model else self.device) 327 | if offload_model: 328 | torch.cuda.empty_cache() 329 | 330 | batch_size = noise_pred_cond.shape[0] 331 | positive_flat = noise_pred_cond.view(batch_size, -1) 332 | negative_flat = noise_pred_uncond.view(batch_size, -1) 333 | 334 | alpha = optimized_scale(positive_flat,negative_flat) 335 | alpha = alpha.view(batch_size, *([1] * (len(noise_pred_cond.shape) - 1))) 336 | alpha = alpha.to(noise_pred_cond.dtype) 337 | 338 | if i <= zero_init_steps: 339 | noise_pred = noise_pred_cond*0. 340 | else: 341 | noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_cond - noise_pred_uncond * alpha) 342 | 343 | 344 | latent = latent.to( 345 | torch.device('cpu') if offload_model else self.device) 346 | 347 | temp_x0 = sample_scheduler.step( 348 | noise_pred.unsqueeze(0), 349 | t, 350 | latent.unsqueeze(0), 351 | return_dict=False, 352 | generator=seed_g)[0] 353 | latent = temp_x0.squeeze(0) 354 | 355 | x0 = [latent.to(self.device)] 356 | del latent_model_input, timestep 357 | 358 | if offload_model: 359 | self.model.cpu() 360 | torch.cuda.empty_cache() 361 | 362 | if self.rank == 0: 363 | videos = self.vae.decode(x0) 364 | 365 | del noise, latent 366 | del sample_scheduler 367 | if offload_model: 368 | gc.collect() 369 | torch.cuda.synchronize() 370 | if dist.is_initialized(): 371 | dist.barrier() 372 | 373 | return videos[0] if self.rank == 0 else None 374 | 375 | -------------------------------------------------------------------------------- /models/wan/wan_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import html 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import ftfy 19 | import regex as re 20 | import torch 21 | from transformers import AutoTokenizer, UMT5EncoderModel 22 | 23 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 24 | from diffusers.loaders import WanLoraLoaderMixin 25 | from diffusers.models import AutoencoderKLWan, WanTransformer3DModel 26 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 27 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 28 | from diffusers.utils.torch_utils import randn_tensor 29 | from diffusers.video_processor import VideoProcessor 30 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 31 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput 32 | 33 | 34 | if is_torch_xla_available(): 35 | import torch_xla.core.xla_model as xm 36 | 37 | XLA_AVAILABLE = True 38 | else: 39 | XLA_AVAILABLE = False 40 | 41 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 42 | 43 | 44 | EXAMPLE_DOC_STRING = """ 45 | Examples: 46 | ```python 47 | >>> import torch 48 | >>> from diffusers.utils import export_to_video 49 | >>> from diffusers import AutoencoderKLWan, WanPipeline 50 | >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler 51 | 52 | >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers 53 | >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 54 | >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 55 | >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) 56 | >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P 57 | >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) 58 | >>> pipe.to("cuda") 59 | 60 | >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." 61 | >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 62 | 63 | >>> output = pipe( 64 | ... prompt=prompt, 65 | ... negative_prompt=negative_prompt, 66 | ... height=720, 67 | ... width=1280, 68 | ... num_frames=81, 69 | ... guidance_scale=5.0, 70 | ... ).frames[0] 71 | >>> export_to_video(output, "output.mp4", fps=16) 72 | ``` 73 | """ 74 | @torch.cuda.amp.autocast(dtype=torch.float32) 75 | def optimized_scale(positive_flat, negative_flat): 76 | 77 | # Calculate dot production 78 | dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 79 | 80 | # Squared norm of uncondition 81 | squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 82 | 83 | # st_star = v_cond^T * v_uncond / ||v_uncond||^2 84 | st_star = dot_product / squared_norm 85 | 86 | return st_star 87 | 88 | def basic_clean(text): 89 | text = ftfy.fix_text(text) 90 | text = html.unescape(html.unescape(text)) 91 | return text.strip() 92 | 93 | 94 | def whitespace_clean(text): 95 | text = re.sub(r"\s+", " ", text) 96 | text = text.strip() 97 | return text 98 | 99 | 100 | def prompt_clean(text): 101 | text = whitespace_clean(basic_clean(text)) 102 | return text 103 | 104 | 105 | class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): 106 | r""" 107 | Pipeline for text-to-video generation using Wan. 108 | 109 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 110 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 111 | 112 | Args: 113 | tokenizer ([`T5Tokenizer`]): 114 | Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), 115 | specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 116 | text_encoder ([`T5EncoderModel`]): 117 | [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 118 | the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 119 | transformer ([`WanTransformer3DModel`]): 120 | Conditional Transformer to denoise the input latents. 121 | scheduler ([`UniPCMultistepScheduler`]): 122 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 123 | vae ([`AutoencoderKLWan`]): 124 | Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 125 | """ 126 | 127 | model_cpu_offload_seq = "text_encoder->transformer->vae" 128 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] 129 | 130 | def __init__( 131 | self, 132 | tokenizer: AutoTokenizer, 133 | text_encoder: UMT5EncoderModel, 134 | transformer: WanTransformer3DModel, 135 | vae: AutoencoderKLWan, 136 | scheduler: FlowMatchEulerDiscreteScheduler, 137 | ): 138 | super().__init__() 139 | 140 | self.register_modules( 141 | vae=vae, 142 | text_encoder=text_encoder, 143 | tokenizer=tokenizer, 144 | transformer=transformer, 145 | scheduler=scheduler, 146 | ) 147 | 148 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 149 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 150 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 151 | 152 | def _get_t5_prompt_embeds( 153 | self, 154 | prompt: Union[str, List[str]] = None, 155 | num_videos_per_prompt: int = 1, 156 | max_sequence_length: int = 226, 157 | device: Optional[torch.device] = None, 158 | dtype: Optional[torch.dtype] = None, 159 | ): 160 | device = device or self._execution_device 161 | dtype = dtype or self.text_encoder.dtype 162 | 163 | prompt = [prompt] if isinstance(prompt, str) else prompt 164 | prompt = [prompt_clean(u) for u in prompt] 165 | batch_size = len(prompt) 166 | 167 | text_inputs = self.tokenizer( 168 | prompt, 169 | padding="max_length", 170 | max_length=max_sequence_length, 171 | truncation=True, 172 | add_special_tokens=True, 173 | return_attention_mask=True, 174 | return_tensors="pt", 175 | ) 176 | text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask 177 | seq_lens = mask.gt(0).sum(dim=1).long() 178 | 179 | prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state 180 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 181 | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] 182 | prompt_embeds = torch.stack( 183 | [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 184 | ) 185 | 186 | # duplicate text embeddings for each generation per prompt, using mps friendly method 187 | _, seq_len, _ = prompt_embeds.shape 188 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 189 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 190 | 191 | return prompt_embeds 192 | 193 | def encode_prompt( 194 | self, 195 | prompt: Union[str, List[str]], 196 | negative_prompt: Optional[Union[str, List[str]]] = None, 197 | do_classifier_free_guidance: bool = True, 198 | num_videos_per_prompt: int = 1, 199 | prompt_embeds: Optional[torch.Tensor] = None, 200 | negative_prompt_embeds: Optional[torch.Tensor] = None, 201 | max_sequence_length: int = 226, 202 | device: Optional[torch.device] = None, 203 | dtype: Optional[torch.dtype] = None, 204 | ): 205 | r""" 206 | Encodes the prompt into text encoder hidden states. 207 | 208 | Args: 209 | prompt (`str` or `List[str]`, *optional*): 210 | prompt to be encoded 211 | negative_prompt (`str` or `List[str]`, *optional*): 212 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 213 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 214 | less than `1`). 215 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): 216 | Whether to use classifier free guidance or not. 217 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 218 | Number of videos that should be generated per prompt. torch device to place the resulting embeddings on 219 | prompt_embeds (`torch.Tensor`, *optional*): 220 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 221 | provided, text embeddings will be generated from `prompt` input argument. 222 | negative_prompt_embeds (`torch.Tensor`, *optional*): 223 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 224 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 225 | argument. 226 | device: (`torch.device`, *optional*): 227 | torch device 228 | dtype: (`torch.dtype`, *optional*): 229 | torch dtype 230 | """ 231 | device = device or self._execution_device 232 | 233 | prompt = [prompt] if isinstance(prompt, str) else prompt 234 | if prompt is not None: 235 | batch_size = len(prompt) 236 | else: 237 | batch_size = prompt_embeds.shape[0] 238 | 239 | if prompt_embeds is None: 240 | prompt_embeds = self._get_t5_prompt_embeds( 241 | prompt=prompt, 242 | num_videos_per_prompt=num_videos_per_prompt, 243 | max_sequence_length=max_sequence_length, 244 | device=device, 245 | dtype=dtype, 246 | ) 247 | 248 | if do_classifier_free_guidance and negative_prompt_embeds is None: 249 | negative_prompt = negative_prompt or "" 250 | negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt 251 | 252 | if prompt is not None and type(prompt) is not type(negative_prompt): 253 | raise TypeError( 254 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 255 | f" {type(prompt)}." 256 | ) 257 | elif batch_size != len(negative_prompt): 258 | raise ValueError( 259 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 260 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 261 | " the batch size of `prompt`." 262 | ) 263 | 264 | negative_prompt_embeds = self._get_t5_prompt_embeds( 265 | prompt=negative_prompt, 266 | num_videos_per_prompt=num_videos_per_prompt, 267 | max_sequence_length=max_sequence_length, 268 | device=device, 269 | dtype=dtype, 270 | ) 271 | 272 | return prompt_embeds, negative_prompt_embeds 273 | 274 | def check_inputs( 275 | self, 276 | prompt, 277 | negative_prompt, 278 | height, 279 | width, 280 | prompt_embeds=None, 281 | negative_prompt_embeds=None, 282 | callback_on_step_end_tensor_inputs=None, 283 | ): 284 | if height % 16 != 0 or width % 16 != 0: 285 | raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") 286 | 287 | if callback_on_step_end_tensor_inputs is not None and not all( 288 | k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs 289 | ): 290 | raise ValueError( 291 | f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" 292 | ) 293 | 294 | if prompt is not None and prompt_embeds is not None: 295 | raise ValueError( 296 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 297 | " only forward one of the two." 298 | ) 299 | elif negative_prompt is not None and negative_prompt_embeds is not None: 300 | raise ValueError( 301 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" 302 | " only forward one of the two." 303 | ) 304 | elif prompt is None and prompt_embeds is None: 305 | raise ValueError( 306 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 307 | ) 308 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 309 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 310 | elif negative_prompt is not None and ( 311 | not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) 312 | ): 313 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 314 | 315 | def prepare_latents( 316 | self, 317 | batch_size: int, 318 | num_channels_latents: int = 16, 319 | height: int = 480, 320 | width: int = 832, 321 | num_frames: int = 81, 322 | dtype: Optional[torch.dtype] = None, 323 | device: Optional[torch.device] = None, 324 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 325 | latents: Optional[torch.Tensor] = None, 326 | ) -> torch.Tensor: 327 | if latents is not None: 328 | return latents.to(device=device, dtype=dtype) 329 | 330 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 331 | shape = ( 332 | batch_size, 333 | num_channels_latents, 334 | num_latent_frames, 335 | int(height) // self.vae_scale_factor_spatial, 336 | int(width) // self.vae_scale_factor_spatial, 337 | ) 338 | if isinstance(generator, list) and len(generator) != batch_size: 339 | raise ValueError( 340 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 341 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 342 | ) 343 | 344 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 345 | return latents 346 | 347 | @property 348 | def guidance_scale(self): 349 | return self._guidance_scale 350 | 351 | @property 352 | def do_classifier_free_guidance(self): 353 | return self._guidance_scale > 1.0 354 | 355 | @property 356 | def num_timesteps(self): 357 | return self._num_timesteps 358 | 359 | @property 360 | def current_timestep(self): 361 | return self._current_timestep 362 | 363 | @property 364 | def interrupt(self): 365 | return self._interrupt 366 | 367 | @property 368 | def attention_kwargs(self): 369 | return self._attention_kwargs 370 | 371 | @torch.no_grad() 372 | @replace_example_docstring(EXAMPLE_DOC_STRING) 373 | def __call__( 374 | self, 375 | prompt: Union[str, List[str]] = None, 376 | negative_prompt: Union[str, List[str]] = None, 377 | height: int = 480, 378 | width: int = 832, 379 | num_frames: int = 81, 380 | num_inference_steps: int = 50, 381 | guidance_scale: float = 5.0, 382 | num_videos_per_prompt: Optional[int] = 1, 383 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 384 | latents: Optional[torch.Tensor] = None, 385 | prompt_embeds: Optional[torch.Tensor] = None, 386 | negative_prompt_embeds: Optional[torch.Tensor] = None, 387 | output_type: Optional[str] = "np", 388 | return_dict: bool = True, 389 | attention_kwargs: Optional[Dict[str, Any]] = None, 390 | callback_on_step_end: Optional[ 391 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 392 | ] = None, 393 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 394 | max_sequence_length: int = 512, 395 | use_cfg_zero_star: Optional[bool] = False, 396 | use_zero_init: Optional[bool] = True, 397 | zero_steps: Optional[int] = 0, 398 | ): 399 | r""" 400 | The call function to the pipeline for generation. 401 | 402 | Args: 403 | prompt (`str` or `List[str]`, *optional*): 404 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 405 | instead. 406 | height (`int`, defaults to `480`): 407 | The height in pixels of the generated image. 408 | width (`int`, defaults to `832`): 409 | The width in pixels of the generated image. 410 | num_frames (`int`, defaults to `81`): 411 | The number of frames in the generated video. 412 | num_inference_steps (`int`, defaults to `50`): 413 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 414 | expense of slower inference. 415 | guidance_scale (`float`, defaults to `5.0`): 416 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 417 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 418 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 419 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 420 | usually at the expense of lower image quality. 421 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 422 | The number of images to generate per prompt. 423 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 424 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 425 | generation deterministic. 426 | latents (`torch.Tensor`, *optional*): 427 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 428 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 429 | tensor is generated by sampling using the supplied random `generator`. 430 | prompt_embeds (`torch.Tensor`, *optional*): 431 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 432 | provided, text embeddings are generated from the `prompt` input argument. 433 | output_type (`str`, *optional*, defaults to `"pil"`): 434 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 435 | return_dict (`bool`, *optional*, defaults to `True`): 436 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. 437 | attention_kwargs (`dict`, *optional*): 438 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 439 | `self.processor` in 440 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 441 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 442 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 443 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 444 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 445 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 446 | callback_on_step_end_tensor_inputs (`List`, *optional*): 447 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 448 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 449 | `._callback_tensor_inputs` attribute of your pipeline class. 450 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): 451 | The dtype to use for the torch.amp.autocast. 452 | 453 | Examples: 454 | 455 | Returns: 456 | [`~WanPipelineOutput`] or `tuple`: 457 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where 458 | the first element is a list with the generated images and the second element is a list of `bool`s 459 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 460 | """ 461 | 462 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 463 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 464 | 465 | # 1. Check inputs. Raise error if not correct 466 | self.check_inputs( 467 | prompt, 468 | negative_prompt, 469 | height, 470 | width, 471 | prompt_embeds, 472 | negative_prompt_embeds, 473 | callback_on_step_end_tensor_inputs, 474 | ) 475 | 476 | self._guidance_scale = guidance_scale 477 | self._attention_kwargs = attention_kwargs 478 | self._current_timestep = None 479 | self._interrupt = False 480 | 481 | device = self._execution_device 482 | 483 | # 2. Define call parameters 484 | if prompt is not None and isinstance(prompt, str): 485 | batch_size = 1 486 | elif prompt is not None and isinstance(prompt, list): 487 | batch_size = len(prompt) 488 | else: 489 | batch_size = prompt_embeds.shape[0] 490 | 491 | # 3. Encode input prompt 492 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 493 | prompt=prompt, 494 | negative_prompt=negative_prompt, 495 | do_classifier_free_guidance=self.do_classifier_free_guidance, 496 | num_videos_per_prompt=num_videos_per_prompt, 497 | prompt_embeds=prompt_embeds, 498 | negative_prompt_embeds=negative_prompt_embeds, 499 | max_sequence_length=max_sequence_length, 500 | device=device, 501 | ) 502 | 503 | transformer_dtype = self.transformer.dtype 504 | prompt_embeds = prompt_embeds.to(transformer_dtype) 505 | if negative_prompt_embeds is not None: 506 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) 507 | 508 | # 4. Prepare timesteps 509 | self.scheduler.set_timesteps(num_inference_steps, device=device) 510 | timesteps = self.scheduler.timesteps 511 | 512 | # 5. Prepare latent variables 513 | num_channels_latents = self.transformer.config.in_channels 514 | latents = self.prepare_latents( 515 | batch_size * num_videos_per_prompt, 516 | num_channels_latents, 517 | height, 518 | width, 519 | num_frames, 520 | torch.float32, 521 | device, 522 | generator, 523 | latents, 524 | ) 525 | 526 | # 6. Denoising loop 527 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 528 | self._num_timesteps = len(timesteps) 529 | 530 | with self.progress_bar(total=num_inference_steps) as progress_bar: 531 | for i, t in enumerate(timesteps): 532 | if self.interrupt: 533 | continue 534 | 535 | self._current_timestep = t 536 | latent_model_input = latents.to(transformer_dtype) 537 | timestep = t.expand(latents.shape[0]) 538 | 539 | noise_pred = self.transformer( 540 | hidden_states=latent_model_input, 541 | timestep=timestep, 542 | encoder_hidden_states=prompt_embeds, 543 | attention_kwargs=attention_kwargs, 544 | return_dict=False, 545 | )[0] 546 | 547 | if self.do_classifier_free_guidance: 548 | noise_pred_uncond = self.transformer( 549 | hidden_states=latent_model_input, 550 | timestep=timestep, 551 | encoder_hidden_states=negative_prompt_embeds, 552 | attention_kwargs=attention_kwargs, 553 | return_dict=False, 554 | )[0] 555 | 556 | noise_pred_text = noise_pred 557 | if use_cfg_zero_star: 558 | positive_flat = noise_pred_text.view(batch_size, -1) 559 | negative_flat = noise_pred_uncond.view(batch_size, -1) 560 | 561 | alpha = optimized_scale(positive_flat,negative_flat) 562 | alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) 563 | alpha = alpha.to(noise_pred_text.dtype) 564 | 565 | if (i <= zero_steps) and use_zero_init: 566 | noise_pred = noise_pred_text*0. 567 | else: 568 | noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha) 569 | else: 570 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 571 | 572 | 573 | # compute the previous noisy sample x_t -> x_t-1 574 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 575 | 576 | if callback_on_step_end is not None: 577 | callback_kwargs = {} 578 | for k in callback_on_step_end_tensor_inputs: 579 | callback_kwargs[k] = locals()[k] 580 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 581 | 582 | latents = callback_outputs.pop("latents", latents) 583 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 584 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 585 | 586 | # call the callback, if provided 587 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 588 | progress_bar.update() 589 | 590 | if XLA_AVAILABLE: 591 | xm.mark_step() 592 | 593 | self._current_timestep = None 594 | 595 | if not output_type == "latent": 596 | latents = latents.to(self.vae.dtype) 597 | latents_mean = ( 598 | torch.tensor(self.vae.config.latents_mean) 599 | .view(1, self.vae.config.z_dim, 1, 1, 1) 600 | .to(latents.device, latents.dtype) 601 | ) 602 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 603 | latents.device, latents.dtype 604 | ) 605 | latents = latents / latents_std + latents_mean 606 | video = self.vae.decode(latents, return_dict=False)[0] 607 | video = self.video_processor.postprocess_video(video, output_type=output_type) 608 | else: 609 | video = latents 610 | 611 | # Offload all models 612 | self.maybe_free_model_hooks() 613 | 614 | if not return_dict: 615 | return (video,) 616 | 617 | return WanPipelineOutput(frames=video) 618 | 619 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | git+https://github.com/huggingface/diffusers.git@main 3 | transformers 4 | gradio 5 | click 6 | einops 7 | moviepy 8 | sentencepiece 9 | Pillow 10 | ftfy 11 | spaces 12 | protobuf 13 | peft 14 | -------------------------------------------------------------------------------- /tools/convert_to_gif.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | # Set your input folder path 5 | input_folder = "./output" # change this 6 | 7 | for filename in os.listdir(input_folder): 8 | if filename.lower().endswith(".mp4"): 9 | mp4_path = os.path.join(input_folder, filename) 10 | gif_path = os.path.splitext(mp4_path)[0] + ".gif" 11 | 12 | print(f"Converting {filename} to GIF...") 13 | 14 | subprocess.run([ 15 | "ffmpeg", "-i", mp4_path, 16 | "-vf", "fps=10", # only set frame rate, no resizing 17 | "-c:v", "gif", gif_path 18 | ]) 19 | 20 | print(f"Saved: {gif_path}") 21 | 22 | --------------------------------------------------------------------------------