├── .gitattributes ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── examples ├── lumina_next_i2i_example_01.json ├── lumina_next_t2i_composition_example_01.json └── lumina_next_t2i_example_01.json ├── lumina_models ├── __init__.py ├── components.py └── nextdit.py ├── nodes.py ├── pyproject.toml ├── requirements.txt └── transport.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | - name: Publish Custom Node 19 | uses: Comfy-Org/publish-node-action@main 20 | with: 21 | ## Add your own personal access token to your Github Repository secrets and reference it here. 22 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | *.bak 7 | checkpoints 8 | results 9 | backup -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jukka Seppänen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WORK IN PROGRESS 2 | 3 | # Installation 4 | - Clone this repo into `custom_nodes` folder. 5 | - Install dependencies: `pip install -r requirements.txt` 6 | or if you use the portable install, run this in ComfyUI_windows_portable -folder: 7 | 8 | `python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-LuminaWrapper\requirements.txt` 9 | 10 | ## Note: Sampling is slow without `flash_attn` ! 11 | 12 | For Linux users this doesn't mean anything but `pip install flash_attn`. 13 | 14 | However doing same on Windows currently will most likely fail if you do not have a build environment setup, and even if you do it can take an hour to build. 15 | Alternative for Windows can be pre-built wheels from here, has to match your python environment: 16 | https://github.com/bdashore3/flash-attention/releases 17 | 18 | If flash_attn is not installed, attention code will fallback to torch SDP attention, which is at least twice as slow and memory hungry. 19 | 20 | ## Text encoder setup 21 | 22 | Lumina-next uses Google's Gemma-2b -LLM: https://huggingface.co/google/gemma-2b 23 | To download it you need to consent to their terms. This means having Hugginface account and requesting access (it's instant once you do it). 24 | 25 | Either download it yourself to `ComfyUI/models/LLM/gemma-2b` (don't need the gguf -file) or let the node autodownload it. 26 | 27 | ## Lumina models 28 | 29 | The nodes support the Lumina-next text to image models: 30 | 31 | https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT 32 | 33 | https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I 34 | 35 | They are automatically downloaded to `ComfyUI/models/lumina` 36 | 37 | # Examples 38 | The workflows are including in the examples -folder 39 | ![image](https://github.com/kijai/ComfyUI-LuminaWrapper/assets/40791699/d1efae46-590a-441e-ad42-9590062b3837) 40 | 41 | ![lumina_composition_example](https://github.com/kijai/ComfyUI-LuminaWrapper/assets/40791699/99603330-903a-444f-a23f-3ac0f332e73e) 42 | 43 | ![lumina_i2i_example](https://github.com/kijai/ComfyUI-LuminaWrapper/assets/40791699/680c032e-b700-4ec4-9484-977710228043) 44 | 45 | 46 | Original repo: 47 | 48 | https://github.com/Alpha-VLLM/Lumina-T2X 49 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /examples/lumina_next_i2i_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 33, 3 | "last_link_id": 58, 4 | "nodes": [ 5 | { 6 | "id": 15, 7 | "type": "DownloadAndLoadGemmaModel", 8 | "pos": [ 9 | -317, 10 | 271 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 82 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "gemma_model", 22 | "type": "GEMMAODEL", 23 | "links": [ 24 | 25 25 | ], 26 | "shape": 3 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "DownloadAndLoadGemmaModel" 31 | }, 32 | "widgets_values": [ 33 | "bf16" 34 | ] 35 | }, 36 | { 37 | "id": 2, 38 | "type": "DownloadAndLoadLuminaModel", 39 | "pos": [ 40 | -317, 41 | 140 42 | ], 43 | "size": { 44 | "0": 303.01300048828125, 45 | "1": 82 46 | }, 47 | "flags": {}, 48 | "order": 1, 49 | "mode": 0, 50 | "outputs": [ 51 | { 52 | "name": "lumina_model", 53 | "type": "LUMINAMODEL", 54 | "links": [ 55 | 28 56 | ], 57 | "shape": 3 58 | } 59 | ], 60 | "properties": { 61 | "Node name for S&R": "DownloadAndLoadLuminaModel" 62 | }, 63 | "widgets_values": [ 64 | "Alpha-VLLM/Lumina-Next-SFT", 65 | "bf16" 66 | ] 67 | }, 68 | { 69 | "id": 10, 70 | "type": "VAELoader", 71 | "pos": [ 72 | -319, 73 | 36 74 | ], 75 | "size": { 76 | "0": 315, 77 | "1": 58 78 | }, 79 | "flags": {}, 80 | "order": 2, 81 | "mode": 0, 82 | "outputs": [ 83 | { 84 | "name": "VAE", 85 | "type": "VAE", 86 | "links": [ 87 | 53 88 | ], 89 | "shape": 3, 90 | "slot_index": 0 91 | } 92 | ], 93 | "properties": { 94 | "Node name for S&R": "VAELoader" 95 | }, 96 | "widgets_values": [ 97 | "sdxl_vae.safetensors" 98 | ] 99 | }, 100 | { 101 | "id": 29, 102 | "type": "GetNode", 103 | "pos": [ 104 | 945, 105 | 188 106 | ], 107 | "size": { 108 | "0": 210, 109 | "1": 58 110 | }, 111 | "flags": { 112 | "collapsed": true 113 | }, 114 | "order": 3, 115 | "mode": 0, 116 | "outputs": [ 117 | { 118 | "name": "VAE", 119 | "type": "VAE", 120 | "links": [ 121 | 54 122 | ] 123 | } 124 | ], 125 | "title": "Get_VAE", 126 | "properties": {}, 127 | "widgets_values": [ 128 | "VAE" 129 | ], 130 | "color": "#322", 131 | "bgcolor": "#533" 132 | }, 133 | { 134 | "id": 21, 135 | "type": "VAEEncode", 136 | "pos": [ 137 | -322, 138 | 434 139 | ], 140 | "size": { 141 | "0": 210, 142 | "1": 46 143 | }, 144 | "flags": { 145 | "collapsed": true 146 | }, 147 | "order": 8, 148 | "mode": 0, 149 | "inputs": [ 150 | { 151 | "name": "pixels", 152 | "type": "IMAGE", 153 | "link": 48 154 | }, 155 | { 156 | "name": "vae", 157 | "type": "VAE", 158 | "link": 55, 159 | "slot_index": 1 160 | } 161 | ], 162 | "outputs": [ 163 | { 164 | "name": "LATENT", 165 | "type": "LATENT", 166 | "links": [ 167 | 50 168 | ], 169 | "shape": 3, 170 | "slot_index": 0 171 | } 172 | ], 173 | "properties": { 174 | "Node name for S&R": "VAEEncode" 175 | } 176 | }, 177 | { 178 | "id": 32, 179 | "type": "GetNode", 180 | "pos": [ 181 | -469, 182 | 432 183 | ], 184 | "size": { 185 | "0": 210, 186 | "1": 58 187 | }, 188 | "flags": { 189 | "collapsed": true 190 | }, 191 | "order": 4, 192 | "mode": 0, 193 | "outputs": [ 194 | { 195 | "name": "VAE", 196 | "type": "VAE", 197 | "links": [ 198 | 55 199 | ] 200 | } 201 | ], 202 | "title": "Get_VAE", 203 | "properties": {}, 204 | "widgets_values": [ 205 | "VAE" 206 | ], 207 | "color": "#322", 208 | "bgcolor": "#533" 209 | }, 210 | { 211 | "id": 20, 212 | "type": "LoadImage", 213 | "pos": [ 214 | -501, 215 | 548 216 | ], 217 | "size": [ 218 | 310.7085632324219, 219 | 378.0042419433594 220 | ], 221 | "flags": {}, 222 | "order": 5, 223 | "mode": 0, 224 | "outputs": [ 225 | { 226 | "name": "IMAGE", 227 | "type": "IMAGE", 228 | "links": [ 229 | 47 230 | ], 231 | "shape": 3, 232 | "slot_index": 0 233 | }, 234 | { 235 | "name": "MASK", 236 | "type": "MASK", 237 | "links": null, 238 | "shape": 3 239 | } 240 | ], 241 | "properties": { 242 | "Node name for S&R": "LoadImage" 243 | }, 244 | "widgets_values": [ 245 | "ComfyUI_temp_goygz_00446_ (10).png", 246 | "image" 247 | ] 248 | }, 249 | { 250 | "id": 13, 251 | "type": "LuminaGemmaTextEncode", 252 | "pos": [ 253 | 160, 254 | 197 255 | ], 256 | "size": { 257 | "0": 400, 258 | "1": 200 259 | }, 260 | "flags": {}, 261 | "order": 11, 262 | "mode": 0, 263 | "inputs": [ 264 | { 265 | "name": "gemma_model", 266 | "type": "GEMMAODEL", 267 | "link": 25, 268 | "slot_index": 0 269 | }, 270 | { 271 | "name": "latent", 272 | "type": "LATENT", 273 | "link": 52 274 | } 275 | ], 276 | "outputs": [ 277 | { 278 | "name": "lumina_embeds", 279 | "type": "LUMINATEMBED", 280 | "links": [ 281 | 29 282 | ], 283 | "shape": 3, 284 | "slot_index": 0 285 | } 286 | ], 287 | "properties": { 288 | "Node name for S&R": "LuminaGemmaTextEncode" 289 | }, 290 | "widgets_values": [ 291 | "high quality photograph of a mechanical robot toad", 292 | "bad quality, nsfw", 293 | true 294 | ] 295 | }, 296 | { 297 | "id": 27, 298 | "type": "VHS_DuplicateLatents", 299 | "pos": [ 300 | -163, 301 | 407 302 | ], 303 | "size": [ 304 | 299.2461547851558, 305 | 78 306 | ], 307 | "flags": {}, 308 | "order": 10, 309 | "mode": 0, 310 | "inputs": [ 311 | { 312 | "name": "latents", 313 | "type": "LATENT", 314 | "link": 50 315 | } 316 | ], 317 | "outputs": [ 318 | { 319 | "name": "LATENT", 320 | "type": "LATENT", 321 | "links": [ 322 | 52, 323 | 56 324 | ], 325 | "shape": 3, 326 | "slot_index": 0 327 | }, 328 | { 329 | "name": "count", 330 | "type": "INT", 331 | "links": null, 332 | "shape": 3 333 | } 334 | ], 335 | "properties": { 336 | "Node name for S&R": "VHS_DuplicateLatents" 337 | }, 338 | "widgets_values": { 339 | "multiply_by": 4 340 | } 341 | }, 342 | { 343 | "id": 17, 344 | "type": "LuminaT2ISampler", 345 | "pos": [ 346 | 600, 347 | 121 348 | ], 349 | "size": { 350 | "0": 315, 351 | "1": 338 352 | }, 353 | "flags": {}, 354 | "order": 13, 355 | "mode": 0, 356 | "inputs": [ 357 | { 358 | "name": "lumina_model", 359 | "type": "LUMINAMODEL", 360 | "link": 28 361 | }, 362 | { 363 | "name": "lumina_embeds", 364 | "type": "LUMINATEMBED", 365 | "link": 29 366 | }, 367 | { 368 | "name": "latent", 369 | "type": "LATENT", 370 | "link": 58 371 | } 372 | ], 373 | "outputs": [ 374 | { 375 | "name": "samples", 376 | "type": "LATENT", 377 | "links": [ 378 | 33 379 | ], 380 | "shape": 3, 381 | "slot_index": 0 382 | } 383 | ], 384 | "properties": { 385 | "Node name for S&R": "LuminaT2ISampler" 386 | }, 387 | "widgets_values": [ 388 | 143, 389 | "fixed", 390 | 25, 391 | 4, 392 | false, 393 | false, 394 | 0.3, 395 | 4, 396 | "midpoint", 397 | true, 398 | 0.5 399 | ] 400 | }, 401 | { 402 | "id": 33, 403 | "type": "Reroute", 404 | "pos": [ 405 | 460, 406 | 410 407 | ], 408 | "size": [ 409 | 90.4, 410 | 26 411 | ], 412 | "flags": {}, 413 | "order": 12, 414 | "mode": 0, 415 | "inputs": [ 416 | { 417 | "name": "", 418 | "type": "*", 419 | "link": 56 420 | } 421 | ], 422 | "outputs": [ 423 | { 424 | "name": "LATENT", 425 | "type": "LATENT", 426 | "links": [ 427 | 58 428 | ], 429 | "slot_index": 0 430 | } 431 | ], 432 | "properties": { 433 | "showOutputText": true, 434 | "horizontal": false 435 | } 436 | }, 437 | { 438 | "id": 28, 439 | "type": "SetNode", 440 | "pos": [ 441 | 17, 442 | 67 443 | ], 444 | "size": { 445 | "0": 210, 446 | "1": 58 447 | }, 448 | "flags": { 449 | "collapsed": true 450 | }, 451 | "order": 6, 452 | "mode": 0, 453 | "inputs": [ 454 | { 455 | "name": "VAE", 456 | "type": "VAE", 457 | "link": 53 458 | } 459 | ], 460 | "outputs": [ 461 | { 462 | "name": "*", 463 | "type": "*", 464 | "links": null 465 | } 466 | ], 467 | "title": "Set_VAE", 468 | "properties": { 469 | "previousName": "VAE" 470 | }, 471 | "widgets_values": [ 472 | "VAE" 473 | ], 474 | "color": "#322", 475 | "bgcolor": "#533" 476 | }, 477 | { 478 | "id": 26, 479 | "type": "PreviewImage", 480 | "pos": [ 481 | 452, 482 | 515 483 | ], 484 | "size": [ 485 | 447.5752929687501, 486 | 467.20423812866215 487 | ], 488 | "flags": {}, 489 | "order": 9, 490 | "mode": 0, 491 | "inputs": [ 492 | { 493 | "name": "images", 494 | "type": "IMAGE", 495 | "link": 49 496 | } 497 | ], 498 | "title": "InputImage", 499 | "properties": { 500 | "Node name for S&R": "PreviewImage" 501 | } 502 | }, 503 | { 504 | "id": 9, 505 | "type": "VAEDecode", 506 | "pos": [ 507 | 1080, 508 | 149 509 | ], 510 | "size": { 511 | "0": 210, 512 | "1": 46 513 | }, 514 | "flags": { 515 | "collapsed": true 516 | }, 517 | "order": 14, 518 | "mode": 0, 519 | "inputs": [ 520 | { 521 | "name": "samples", 522 | "type": "LATENT", 523 | "link": 33 524 | }, 525 | { 526 | "name": "vae", 527 | "type": "VAE", 528 | "link": 54, 529 | "slot_index": 1 530 | } 531 | ], 532 | "outputs": [ 533 | { 534 | "name": "IMAGE", 535 | "type": "IMAGE", 536 | "links": [ 537 | 14 538 | ], 539 | "shape": 3, 540 | "slot_index": 0 541 | } 542 | ], 543 | "properties": { 544 | "Node name for S&R": "VAEDecode" 545 | } 546 | }, 547 | { 548 | "id": 11, 549 | "type": "PreviewImage", 550 | "pos": [ 551 | 938, 552 | 242 553 | ], 554 | "size": [ 555 | 700.5826416015625, 556 | 729.5712127685547 557 | ], 558 | "flags": {}, 559 | "order": 15, 560 | "mode": 0, 561 | "inputs": [ 562 | { 563 | "name": "images", 564 | "type": "IMAGE", 565 | "link": 14 566 | } 567 | ], 568 | "properties": { 569 | "Node name for S&R": "PreviewImage" 570 | } 571 | }, 572 | { 573 | "id": 25, 574 | "type": "ImageResizeKJ", 575 | "pos": [ 576 | -177, 577 | 560 578 | ], 579 | "size": [ 580 | 315, 581 | 242 582 | ], 583 | "flags": {}, 584 | "order": 7, 585 | "mode": 0, 586 | "inputs": [ 587 | { 588 | "name": "image", 589 | "type": "IMAGE", 590 | "link": 47 591 | }, 592 | { 593 | "name": "get_image_size", 594 | "type": "IMAGE", 595 | "link": null 596 | }, 597 | { 598 | "name": "width_input", 599 | "type": "INT", 600 | "link": null, 601 | "widget": { 602 | "name": "width_input" 603 | } 604 | }, 605 | { 606 | "name": "height_input", 607 | "type": "INT", 608 | "link": null, 609 | "widget": { 610 | "name": "height_input" 611 | } 612 | } 613 | ], 614 | "outputs": [ 615 | { 616 | "name": "IMAGE", 617 | "type": "IMAGE", 618 | "links": [ 619 | 48, 620 | 49 621 | ], 622 | "shape": 3, 623 | "slot_index": 0 624 | }, 625 | { 626 | "name": "width", 627 | "type": "INT", 628 | "links": null, 629 | "shape": 3 630 | }, 631 | { 632 | "name": "height", 633 | "type": "INT", 634 | "links": null, 635 | "shape": 3, 636 | "slot_index": 2 637 | } 638 | ], 639 | "properties": { 640 | "Node name for S&R": "ImageResizeKJ" 641 | }, 642 | "widgets_values": [ 643 | 1024, 644 | 1024, 645 | "lanczos", 646 | true, 647 | 64, 648 | 0, 649 | 0 650 | ] 651 | } 652 | ], 653 | "links": [ 654 | [ 655 | 14, 656 | 9, 657 | 0, 658 | 11, 659 | 0, 660 | "IMAGE" 661 | ], 662 | [ 663 | 25, 664 | 15, 665 | 0, 666 | 13, 667 | 0, 668 | "GEMMAODEL" 669 | ], 670 | [ 671 | 28, 672 | 2, 673 | 0, 674 | 17, 675 | 0, 676 | "LUMINAMODEL" 677 | ], 678 | [ 679 | 29, 680 | 13, 681 | 0, 682 | 17, 683 | 1, 684 | "LUMINATEMBED" 685 | ], 686 | [ 687 | 33, 688 | 17, 689 | 0, 690 | 9, 691 | 0, 692 | "LATENT" 693 | ], 694 | [ 695 | 47, 696 | 20, 697 | 0, 698 | 25, 699 | 0, 700 | "IMAGE" 701 | ], 702 | [ 703 | 48, 704 | 25, 705 | 0, 706 | 21, 707 | 0, 708 | "IMAGE" 709 | ], 710 | [ 711 | 49, 712 | 25, 713 | 0, 714 | 26, 715 | 0, 716 | "IMAGE" 717 | ], 718 | [ 719 | 50, 720 | 21, 721 | 0, 722 | 27, 723 | 0, 724 | "LATENT" 725 | ], 726 | [ 727 | 52, 728 | 27, 729 | 0, 730 | 13, 731 | 1, 732 | "LATENT" 733 | ], 734 | [ 735 | 53, 736 | 10, 737 | 0, 738 | 28, 739 | 0, 740 | "*" 741 | ], 742 | [ 743 | 54, 744 | 29, 745 | 0, 746 | 9, 747 | 1, 748 | "VAE" 749 | ], 750 | [ 751 | 55, 752 | 32, 753 | 0, 754 | 21, 755 | 1, 756 | "VAE" 757 | ], 758 | [ 759 | 56, 760 | 27, 761 | 0, 762 | 33, 763 | 0, 764 | "*" 765 | ], 766 | [ 767 | 58, 768 | 33, 769 | 0, 770 | 17, 771 | 2, 772 | "LATENT" 773 | ] 774 | ], 775 | "groups": [], 776 | "config": {}, 777 | "extra": { 778 | "ds": { 779 | "scale": 1, 780 | "offset": { 781 | "0": 603.9016723632812, 782 | "1": 37.74279022216797 783 | } 784 | } 785 | }, 786 | "version": 0.4 787 | } -------------------------------------------------------------------------------- /examples/lumina_next_t2i_composition_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 24, 3 | "last_link_id": 41, 4 | "nodes": [ 5 | { 6 | "id": 2, 7 | "type": "DownloadAndLoadLuminaModel", 8 | "pos": [ 9 | -311, 10 | 140 11 | ], 12 | "size": { 13 | "0": 303.01300048828125, 14 | "1": 82 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "lumina_model", 22 | "type": "LUMINAMODEL", 23 | "links": [ 24 | 28 25 | ], 26 | "shape": 3 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "DownloadAndLoadLuminaModel" 31 | }, 32 | "widgets_values": [ 33 | "Alpha-VLLM/Lumina-Next-SFT", 34 | "bf16" 35 | ] 36 | }, 37 | { 38 | "id": 15, 39 | "type": "DownloadAndLoadGemmaModel", 40 | "pos": [ 41 | -317, 42 | 271 43 | ], 44 | "size": { 45 | "0": 315, 46 | "1": 82 47 | }, 48 | "flags": {}, 49 | "order": 1, 50 | "mode": 0, 51 | "outputs": [ 52 | { 53 | "name": "gemma_model", 54 | "type": "GEMMAODEL", 55 | "links": [ 56 | 35 57 | ], 58 | "shape": 3 59 | } 60 | ], 61 | "properties": { 62 | "Node name for S&R": "DownloadAndLoadGemmaModel" 63 | }, 64 | "widgets_values": [ 65 | "bf16" 66 | ] 67 | }, 68 | { 69 | "id": 9, 70 | "type": "VAEDecode", 71 | "pos": [ 72 | 830, 73 | 149 74 | ], 75 | "size": { 76 | "0": 210, 77 | "1": 46 78 | }, 79 | "flags": {}, 80 | "order": 11, 81 | "mode": 0, 82 | "inputs": [ 83 | { 84 | "name": "samples", 85 | "type": "LATENT", 86 | "link": 33 87 | }, 88 | { 89 | "name": "vae", 90 | "type": "VAE", 91 | "link": 13, 92 | "slot_index": 1 93 | } 94 | ], 95 | "outputs": [ 96 | { 97 | "name": "IMAGE", 98 | "type": "IMAGE", 99 | "links": [ 100 | 14 101 | ], 102 | "shape": 3, 103 | "slot_index": 0 104 | } 105 | ], 106 | "properties": { 107 | "Node name for S&R": "VAEDecode" 108 | } 109 | }, 110 | { 111 | "id": 10, 112 | "type": "VAELoader", 113 | "pos": [ 114 | 817, 115 | 241 116 | ], 117 | "size": { 118 | "0": 315, 119 | "1": 58 120 | }, 121 | "flags": { 122 | "collapsed": false 123 | }, 124 | "order": 2, 125 | "mode": 0, 126 | "outputs": [ 127 | { 128 | "name": "VAE", 129 | "type": "VAE", 130 | "links": [ 131 | 13 132 | ], 133 | "shape": 3 134 | } 135 | ], 136 | "properties": { 137 | "Node name for S&R": "VAELoader" 138 | }, 139 | "widgets_values": [ 140 | "sdxl_vae.safetensors" 141 | ] 142 | }, 143 | { 144 | "id": 6, 145 | "type": "EmptyLatentImage", 146 | "pos": [ 147 | -331, 148 | 436 149 | ], 150 | "size": { 151 | "0": 315, 152 | "1": 106 153 | }, 154 | "flags": {}, 155 | "order": 3, 156 | "mode": 0, 157 | "outputs": [ 158 | { 159 | "name": "LATENT", 160 | "type": "LATENT", 161 | "links": [ 162 | 31 163 | ], 164 | "shape": 3, 165 | "slot_index": 0 166 | } 167 | ], 168 | "properties": { 169 | "Node name for S&R": "EmptyLatentImage" 170 | }, 171 | "widgets_values": [ 172 | 2048, 173 | 512, 174 | 1 175 | ] 176 | }, 177 | { 178 | "id": 17, 179 | "type": "LuminaT2ISampler", 180 | "pos": [ 181 | 488, 182 | 143 183 | ], 184 | "size": { 185 | "0": 315, 186 | "1": 314 187 | }, 188 | "flags": {}, 189 | "order": 10, 190 | "mode": 0, 191 | "inputs": [ 192 | { 193 | "name": "lumina_model", 194 | "type": "LUMINAMODEL", 195 | "link": 28 196 | }, 197 | { 198 | "name": "lumina_embeds", 199 | "type": "LUMINATEMBED", 200 | "link": 34 201 | }, 202 | { 203 | "name": "latent", 204 | "type": "LATENT", 205 | "link": 32 206 | } 207 | ], 208 | "outputs": [ 209 | { 210 | "name": "samples", 211 | "type": "LATENT", 212 | "links": [ 213 | 33 214 | ], 215 | "shape": 3, 216 | "slot_index": 0 217 | } 218 | ], 219 | "properties": { 220 | "Node name for S&R": "LuminaT2ISampler" 221 | }, 222 | "widgets_values": [ 223 | 127, 224 | "fixed", 225 | 30, 226 | 5, 227 | false, 228 | false, 229 | 0.3, 230 | 6, 231 | "midpoint", 232 | true 233 | ] 234 | }, 235 | { 236 | "id": 11, 237 | "type": "PreviewImage", 238 | "pos": [ 239 | -300, 240 | 784 241 | ], 242 | "size": [ 243 | 1392.8533270941648, 244 | 376.1636668147139 245 | ], 246 | "flags": {}, 247 | "order": 12, 248 | "mode": 0, 249 | "inputs": [ 250 | { 251 | "name": "images", 252 | "type": "IMAGE", 253 | "link": 14 254 | } 255 | ], 256 | "properties": { 257 | "Node name for S&R": "PreviewImage" 258 | } 259 | }, 260 | { 261 | "id": 21, 262 | "type": "LuminaTextAreaAppend", 263 | "pos": [ 264 | -247, 265 | 597 266 | ], 267 | "size": [ 268 | 274.01493615024333, 269 | 144.12341662973722 270 | ], 271 | "flags": {}, 272 | "order": 4, 273 | "mode": 0, 274 | "inputs": [ 275 | { 276 | "name": "prev_prompt", 277 | "type": "LUMINAAREAPROMPT", 278 | "link": null 279 | } 280 | ], 281 | "outputs": [ 282 | { 283 | "name": "lumina_area_prompt", 284 | "type": "LUMINAAREAPROMPT", 285 | "links": [ 286 | 37 287 | ], 288 | "shape": 3, 289 | "slot_index": 0 290 | } 291 | ], 292 | "properties": { 293 | "Node name for S&R": "LuminaTextAreaAppend" 294 | }, 295 | "widgets_values": [ 296 | "waterfall", 297 | 1, 298 | 1 299 | ] 300 | }, 301 | { 302 | "id": 22, 303 | "type": "LuminaTextAreaAppend", 304 | "pos": [ 305 | 98, 306 | 598 307 | ], 308 | "size": [ 309 | 253.60000610351562, 310 | 143.32892002240646 311 | ], 312 | "flags": {}, 313 | "order": 6, 314 | "mode": 0, 315 | "inputs": [ 316 | { 317 | "name": "prev_prompt", 318 | "type": "LUMINAAREAPROMPT", 319 | "link": 37 320 | } 321 | ], 322 | "outputs": [ 323 | { 324 | "name": "lumina_area_prompt", 325 | "type": "LUMINAAREAPROMPT", 326 | "links": [ 327 | 38 328 | ], 329 | "shape": 3, 330 | "slot_index": 0 331 | } 332 | ], 333 | "properties": { 334 | "Node name for S&R": "LuminaTextAreaAppend" 335 | }, 336 | "widgets_values": [ 337 | "river", 338 | 1, 339 | 2 340 | ] 341 | }, 342 | { 343 | "id": 24, 344 | "type": "LuminaTextAreaAppend", 345 | "pos": [ 346 | 739, 347 | 590 348 | ], 349 | "size": { 350 | "0": 253.60000610351562, 351 | "1": 145.5343475341797 352 | }, 353 | "flags": {}, 354 | "order": 8, 355 | "mode": 0, 356 | "inputs": [ 357 | { 358 | "name": "prev_prompt", 359 | "type": "LUMINAAREAPROMPT", 360 | "link": 40 361 | } 362 | ], 363 | "outputs": [ 364 | { 365 | "name": "lumina_area_prompt", 366 | "type": "LUMINAAREAPROMPT", 367 | "links": [ 368 | 41 369 | ], 370 | "shape": 3, 371 | "slot_index": 0 372 | } 373 | ], 374 | "properties": { 375 | "Node name for S&R": "LuminaTextAreaAppend" 376 | }, 377 | "widgets_values": [ 378 | "red maple trees", 379 | 1, 380 | 4 381 | ] 382 | }, 383 | { 384 | "id": 23, 385 | "type": "LuminaTextAreaAppend", 386 | "pos": [ 387 | 411, 388 | 594 389 | ], 390 | "size": [ 391 | 253.60000610351562, 392 | 145.53435067229907 393 | ], 394 | "flags": {}, 395 | "order": 7, 396 | "mode": 0, 397 | "inputs": [ 398 | { 399 | "name": "prev_prompt", 400 | "type": "LUMINAAREAPROMPT", 401 | "link": 38 402 | } 403 | ], 404 | "outputs": [ 405 | { 406 | "name": "lumina_area_prompt", 407 | "type": "LUMINAAREAPROMPT", 408 | "links": [ 409 | 40 410 | ], 411 | "shape": 3, 412 | "slot_index": 0 413 | } 414 | ], 415 | "properties": { 416 | "Node name for S&R": "LuminaTextAreaAppend" 417 | }, 418 | "widgets_values": [ 419 | "majestic stag", 420 | 1, 421 | 3 422 | ] 423 | }, 424 | { 425 | "id": 20, 426 | "type": "LuminaGemmaTextEncodeArea", 427 | "pos": [ 428 | 43, 429 | 222 430 | ], 431 | "size": { 432 | "0": 405.5999755859375, 433 | "1": 200 434 | }, 435 | "flags": {}, 436 | "order": 9, 437 | "mode": 0, 438 | "inputs": [ 439 | { 440 | "name": "gemma_model", 441 | "type": "GEMMAODEL", 442 | "link": 35, 443 | "slot_index": 0 444 | }, 445 | { 446 | "name": "lumina_area_prompt", 447 | "type": "LUMINAAREAPROMPT", 448 | "link": 41, 449 | "slot_index": 1 450 | } 451 | ], 452 | "outputs": [ 453 | { 454 | "name": "lumina_embeds", 455 | "type": "LUMINATEMBED", 456 | "links": [ 457 | 34 458 | ], 459 | "shape": 3, 460 | "slot_index": 0 461 | } 462 | ], 463 | "properties": { 464 | "Node name for S&R": "LuminaGemmaTextEncodeArea" 465 | }, 466 | "widgets_values": [ 467 | "highly detailed high quality digital fantasy art illustration with cinematic lighting", 468 | "bad quality, unaesthetic, nsfw", 469 | true 470 | ] 471 | }, 472 | { 473 | "id": 19, 474 | "type": "Reroute", 475 | "pos": [ 476 | 342, 477 | 437 478 | ], 479 | "size": [ 480 | 90.4, 481 | 26 482 | ], 483 | "flags": {}, 484 | "order": 5, 485 | "mode": 0, 486 | "inputs": [ 487 | { 488 | "name": "", 489 | "type": "*", 490 | "link": 31 491 | } 492 | ], 493 | "outputs": [ 494 | { 495 | "name": "LATENT", 496 | "type": "LATENT", 497 | "links": [ 498 | 32 499 | ], 500 | "slot_index": 0 501 | } 502 | ], 503 | "properties": { 504 | "showOutputText": true, 505 | "horizontal": false 506 | } 507 | } 508 | ], 509 | "links": [ 510 | [ 511 | 13, 512 | 10, 513 | 0, 514 | 9, 515 | 1, 516 | "VAE" 517 | ], 518 | [ 519 | 14, 520 | 9, 521 | 0, 522 | 11, 523 | 0, 524 | "IMAGE" 525 | ], 526 | [ 527 | 28, 528 | 2, 529 | 0, 530 | 17, 531 | 0, 532 | "LUMINAMODEL" 533 | ], 534 | [ 535 | 31, 536 | 6, 537 | 0, 538 | 19, 539 | 0, 540 | "*" 541 | ], 542 | [ 543 | 32, 544 | 19, 545 | 0, 546 | 17, 547 | 2, 548 | "LATENT" 549 | ], 550 | [ 551 | 33, 552 | 17, 553 | 0, 554 | 9, 555 | 0, 556 | "LATENT" 557 | ], 558 | [ 559 | 34, 560 | 20, 561 | 0, 562 | 17, 563 | 1, 564 | "LUMINATEMBED" 565 | ], 566 | [ 567 | 35, 568 | 15, 569 | 0, 570 | 20, 571 | 0, 572 | "GEMMAODEL" 573 | ], 574 | [ 575 | 37, 576 | 21, 577 | 0, 578 | 22, 579 | 0, 580 | "LUMINAAREAPROMPT" 581 | ], 582 | [ 583 | 38, 584 | 22, 585 | 0, 586 | 23, 587 | 0, 588 | "LUMINAAREAPROMPT" 589 | ], 590 | [ 591 | 40, 592 | 23, 593 | 0, 594 | 24, 595 | 0, 596 | "LUMINAAREAPROMPT" 597 | ], 598 | [ 599 | 41, 600 | 24, 601 | 0, 602 | 20, 603 | 1, 604 | "LUMINAAREAPROMPT" 605 | ] 606 | ], 607 | "groups": [], 608 | "config": {}, 609 | "extra": { 610 | "ds": { 611 | "scale": 1.015255979947714, 612 | "offset": [ 613 | 630.0291517726646, 614 | -79.89796117514369 615 | ] 616 | } 617 | }, 618 | "version": 0.4 619 | } -------------------------------------------------------------------------------- /examples/lumina_next_t2i_example_01.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 19, 3 | "last_link_id": 33, 4 | "nodes": [ 5 | { 6 | "id": 2, 7 | "type": "DownloadAndLoadLuminaModel", 8 | "pos": [ 9 | -311, 10 | 140 11 | ], 12 | "size": { 13 | "0": 303.01300048828125, 14 | "1": 82 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "lumina_model", 22 | "type": "LUMINAMODEL", 23 | "links": [ 24 | 28 25 | ], 26 | "shape": 3 27 | } 28 | ], 29 | "properties": { 30 | "Node name for S&R": "DownloadAndLoadLuminaModel" 31 | }, 32 | "widgets_values": [ 33 | "Alpha-VLLM/Lumina-Next-SFT", 34 | "bf16" 35 | ] 36 | }, 37 | { 38 | "id": 13, 39 | "type": "LuminaGemmaTextEncode", 40 | "pos": [ 41 | 48, 42 | 227 43 | ], 44 | "size": { 45 | "0": 400, 46 | "1": 200 47 | }, 48 | "flags": {}, 49 | "order": 5, 50 | "mode": 0, 51 | "inputs": [ 52 | { 53 | "name": "gemma_model", 54 | "type": "GEMMAODEL", 55 | "link": 25, 56 | "slot_index": 0 57 | }, 58 | { 59 | "name": "latent", 60 | "type": "LATENT", 61 | "link": 20 62 | } 63 | ], 64 | "outputs": [ 65 | { 66 | "name": "lumina_embeds", 67 | "type": "LUMINATEMBED", 68 | "links": [ 69 | 29 70 | ], 71 | "shape": 3, 72 | "slot_index": 0 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "LuminaGemmaTextEncode" 77 | }, 78 | "widgets_values": [ 79 | "high quality photograph of a woman laying on grass, she's wearing a blue dress, top down view, her hair is on fire", 80 | "bad quality, drawing, illustration, nsfw", 81 | false 82 | ] 83 | }, 84 | { 85 | "id": 19, 86 | "type": "Reroute", 87 | "pos": [ 88 | 361, 89 | 444 90 | ], 91 | "size": [ 92 | 90.4, 93 | 26 94 | ], 95 | "flags": {}, 96 | "order": 4, 97 | "mode": 0, 98 | "inputs": [ 99 | { 100 | "name": "", 101 | "type": "*", 102 | "link": 31 103 | } 104 | ], 105 | "outputs": [ 106 | { 107 | "name": "LATENT", 108 | "type": "LATENT", 109 | "links": [ 110 | 32 111 | ], 112 | "slot_index": 0 113 | } 114 | ], 115 | "properties": { 116 | "showOutputText": true, 117 | "horizontal": false 118 | } 119 | }, 120 | { 121 | "id": 6, 122 | "type": "EmptyLatentImage", 123 | "pos": [ 124 | -315, 125 | 440 126 | ], 127 | "size": { 128 | "0": 315, 129 | "1": 106 130 | }, 131 | "flags": {}, 132 | "order": 1, 133 | "mode": 0, 134 | "outputs": [ 135 | { 136 | "name": "LATENT", 137 | "type": "LATENT", 138 | "links": [ 139 | 20, 140 | 31 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "EmptyLatentImage" 148 | }, 149 | "widgets_values": [ 150 | 1024, 151 | 1536, 152 | 1 153 | ] 154 | }, 155 | { 156 | "id": 17, 157 | "type": "LuminaT2ISampler", 158 | "pos": [ 159 | 488, 160 | 143 161 | ], 162 | "size": { 163 | "0": 315, 164 | "1": 314 165 | }, 166 | "flags": {}, 167 | "order": 6, 168 | "mode": 0, 169 | "inputs": [ 170 | { 171 | "name": "lumina_model", 172 | "type": "LUMINAMODEL", 173 | "link": 28 174 | }, 175 | { 176 | "name": "lumina_embeds", 177 | "type": "LUMINATEMBED", 178 | "link": 29 179 | }, 180 | { 181 | "name": "latent", 182 | "type": "LATENT", 183 | "link": 32 184 | } 185 | ], 186 | "outputs": [ 187 | { 188 | "name": "samples", 189 | "type": "LATENT", 190 | "links": [ 191 | 33 192 | ], 193 | "shape": 3, 194 | "slot_index": 0 195 | } 196 | ], 197 | "properties": { 198 | "Node name for S&R": "LuminaT2ISampler" 199 | }, 200 | "widgets_values": [ 201 | 123, 202 | "fixed", 203 | 25, 204 | 4, 205 | false, 206 | false, 207 | 0.3, 208 | 4, 209 | "midpoint", 210 | false 211 | ] 212 | }, 213 | { 214 | "id": 9, 215 | "type": "VAEDecode", 216 | "pos": [ 217 | 834, 218 | 139 219 | ], 220 | "size": { 221 | "0": 210, 222 | "1": 46 223 | }, 224 | "flags": {}, 225 | "order": 7, 226 | "mode": 0, 227 | "inputs": [ 228 | { 229 | "name": "samples", 230 | "type": "LATENT", 231 | "link": 33 232 | }, 233 | { 234 | "name": "vae", 235 | "type": "VAE", 236 | "link": 13, 237 | "slot_index": 1 238 | } 239 | ], 240 | "outputs": [ 241 | { 242 | "name": "IMAGE", 243 | "type": "IMAGE", 244 | "links": [ 245 | 14 246 | ], 247 | "shape": 3, 248 | "slot_index": 0 249 | } 250 | ], 251 | "properties": { 252 | "Node name for S&R": "VAEDecode" 253 | } 254 | }, 255 | { 256 | "id": 10, 257 | "type": "VAELoader", 258 | "pos": [ 259 | 488, 260 | 485 261 | ], 262 | "size": { 263 | "0": 315, 264 | "1": 58 265 | }, 266 | "flags": {}, 267 | "order": 2, 268 | "mode": 0, 269 | "outputs": [ 270 | { 271 | "name": "VAE", 272 | "type": "VAE", 273 | "links": [ 274 | 13 275 | ], 276 | "shape": 3 277 | } 278 | ], 279 | "properties": { 280 | "Node name for S&R": "VAELoader" 281 | }, 282 | "widgets_values": [ 283 | "sdxl_vae.safetensors" 284 | ] 285 | }, 286 | { 287 | "id": 11, 288 | "type": "PreviewImage", 289 | "pos": [ 290 | 839, 291 | 233 292 | ], 293 | "size": { 294 | "0": 459.1072998046875, 295 | "1": 714.3753662109375 296 | }, 297 | "flags": {}, 298 | "order": 8, 299 | "mode": 0, 300 | "inputs": [ 301 | { 302 | "name": "images", 303 | "type": "IMAGE", 304 | "link": 14 305 | } 306 | ], 307 | "properties": { 308 | "Node name for S&R": "PreviewImage" 309 | } 310 | }, 311 | { 312 | "id": 15, 313 | "type": "DownloadAndLoadGemmaModel", 314 | "pos": [ 315 | -317, 316 | 271 317 | ], 318 | "size": { 319 | "0": 315, 320 | "1": 82 321 | }, 322 | "flags": {}, 323 | "order": 3, 324 | "mode": 0, 325 | "outputs": [ 326 | { 327 | "name": "gemma_model", 328 | "type": "GEMMAODEL", 329 | "links": [ 330 | 25 331 | ], 332 | "shape": 3 333 | } 334 | ], 335 | "properties": { 336 | "Node name for S&R": "DownloadAndLoadGemmaModel" 337 | }, 338 | "widgets_values": [ 339 | "bf16", 340 | "text_encode" 341 | ] 342 | } 343 | ], 344 | "links": [ 345 | [ 346 | 13, 347 | 10, 348 | 0, 349 | 9, 350 | 1, 351 | "VAE" 352 | ], 353 | [ 354 | 14, 355 | 9, 356 | 0, 357 | 11, 358 | 0, 359 | "IMAGE" 360 | ], 361 | [ 362 | 20, 363 | 6, 364 | 0, 365 | 13, 366 | 1, 367 | "LATENT" 368 | ], 369 | [ 370 | 25, 371 | 15, 372 | 0, 373 | 13, 374 | 0, 375 | "GEMMAODEL" 376 | ], 377 | [ 378 | 28, 379 | 2, 380 | 0, 381 | 17, 382 | 0, 383 | "LUMINAMODEL" 384 | ], 385 | [ 386 | 29, 387 | 13, 388 | 0, 389 | 17, 390 | 1, 391 | "LUMINATEMBED" 392 | ], 393 | [ 394 | 31, 395 | 6, 396 | 0, 397 | 19, 398 | 0, 399 | "*" 400 | ], 401 | [ 402 | 32, 403 | 19, 404 | 0, 405 | 17, 406 | 2, 407 | "LATENT" 408 | ], 409 | [ 410 | 33, 411 | 17, 412 | 0, 413 | 9, 414 | 0, 415 | "LATENT" 416 | ] 417 | ], 418 | "groups": [], 419 | "config": {}, 420 | "extra": { 421 | "ds": { 422 | "scale": 1.1167815779424815, 423 | "offset": [ 424 | 519.1369343547321, 425 | 30.539049178911156 426 | ] 427 | } 428 | }, 429 | "version": 0.4 430 | } -------------------------------------------------------------------------------- /lumina_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nextdit import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2 2 | -------------------------------------------------------------------------------- /lumina_models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | try: 8 | from apex.normalization import FusedRMSNorm as RMSNorm 9 | except ImportError: 10 | warnings.warn("Cannot import apex RMSNorm, trying flash_attn RMSNorm") 11 | try: 12 | from flash_attn.ops.triton.layer_norm import RMSNorm 13 | except: 14 | try: 15 | from flash_attn.ops.rms_norm import RMSNorm 16 | except ImportError: 17 | warnings.warn("Cannot import flash_attn RMSNorm, falling back to PyTorch RMSNorm") 18 | 19 | class RMSNorm(torch.nn.Module): 20 | def __init__(self, dim: int, eps: float = 1e-6): 21 | """ 22 | Initialize the RMSNorm normalization layer. 23 | 24 | Args: 25 | dim (int): The dimension of the input tensor. 26 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 27 | 28 | Attributes: 29 | eps (float): A small value added to the denominator for numerical stability. 30 | weight (nn.Parameter): Learnable scaling parameter. 31 | 32 | """ 33 | super().__init__() 34 | self.eps = eps 35 | self.weight = nn.Parameter(torch.ones(dim)) 36 | 37 | def _norm(self, x): 38 | """ 39 | Apply the RMSNorm normalization to the input tensor. 40 | 41 | Args: 42 | x (torch.Tensor): The input tensor. 43 | 44 | Returns: 45 | torch.Tensor: The normalized tensor. 46 | 47 | """ 48 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 49 | 50 | def forward(self, x): 51 | """ 52 | Forward pass through the RMSNorm layer. 53 | 54 | Args: 55 | x (torch.Tensor): The input tensor. 56 | 57 | Returns: 58 | torch.Tensor: The output tensor after applying RMSNorm. 59 | 60 | """ 61 | output = self._norm(x.float()).type_as(x) 62 | return output * self.weight 63 | -------------------------------------------------------------------------------- /lumina_models/nextdit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import functools 13 | import math 14 | from typing import List, Optional, Tuple 15 | 16 | try: 17 | from flash_attn import flash_attn_varlen_func 18 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 19 | FLASH_ATTN_AVAILABLE = True 20 | except: 21 | FLASH_ATTN_AVAILABLE = False 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | from .components import RMSNorm 27 | 28 | import comfy.model_management 29 | import comfy.ops 30 | ops = comfy.ops.manual_cast 31 | 32 | device = comfy.model_management.get_torch_device() 33 | cast_device = comfy.model_management.get_autocast_device(device) 34 | 35 | def modulate(x, scale): 36 | return x * (1 + scale.unsqueeze(1)) 37 | 38 | 39 | ############################################################################# 40 | # Embedding Layers for Timesteps and Class Labels # 41 | ############################################################################# 42 | 43 | 44 | class TimestepEmbedder(nn.Module): 45 | """ 46 | Embeds scalar timesteps into vector representations. 47 | """ 48 | 49 | def __init__(self, hidden_size, frequency_embedding_size=256): 50 | super().__init__() 51 | self.mlp = nn.Sequential( 52 | ops.Linear( 53 | frequency_embedding_size, 54 | hidden_size, 55 | bias=True, 56 | ), 57 | nn.SiLU(), 58 | ops.Linear( 59 | hidden_size, 60 | hidden_size, 61 | bias=True, 62 | ), 63 | ) 64 | nn.init.normal_(self.mlp[0].weight, std=0.02) 65 | nn.init.zeros_(self.mlp[0].bias) 66 | nn.init.normal_(self.mlp[2].weight, std=0.02) 67 | nn.init.zeros_(self.mlp[2].bias) 68 | 69 | self.frequency_embedding_size = frequency_embedding_size 70 | 71 | @staticmethod 72 | def timestep_embedding(t, dim, max_period=10000): 73 | """ 74 | Create sinusoidal timestep embeddings. 75 | :param t: a 1-D Tensor of N indices, one per batch element. 76 | These may be fractional. 77 | :param dim: the dimension of the output. 78 | :param max_period: controls the minimum frequency of the embeddings. 79 | :return: an (N, D) Tensor of positional embeddings. 80 | """ 81 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 82 | half = dim // 2 83 | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( 84 | device=t.device 85 | ) 86 | args = t[:, None].float() * freqs[None] 87 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 88 | if dim % 2: 89 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 90 | return embedding 91 | 92 | def forward(self, t): 93 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 94 | t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) 95 | return t_emb 96 | 97 | 98 | ############################################################################# 99 | # Core NextDiT Model # 100 | ############################################################################# 101 | 102 | 103 | class Attention(nn.Module): 104 | """Multi-head attention module.""" 105 | 106 | def __init__( 107 | self, 108 | dim: int, 109 | n_heads: int, 110 | n_kv_heads: Optional[int], 111 | qk_norm: bool, 112 | y_dim: int, 113 | ): 114 | """ 115 | Initialize the Attention module. 116 | 117 | Args: 118 | dim (int): Number of input dimensions. 119 | n_heads (int): Number of heads. 120 | n_kv_heads (Optional[int]): Number of kv heads, if using GQA. 121 | 122 | """ 123 | super().__init__() 124 | self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads 125 | self.n_local_heads = n_heads 126 | self.n_local_kv_heads = self.n_kv_heads 127 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 128 | self.head_dim = dim // n_heads 129 | 130 | self.wq = ops.Linear( 131 | dim, 132 | n_heads * self.head_dim, 133 | bias=False, 134 | ) 135 | nn.init.xavier_uniform_(self.wq.weight) 136 | self.wk = ops.Linear( 137 | dim, 138 | self.n_kv_heads * self.head_dim, 139 | bias=False, 140 | ) 141 | nn.init.xavier_uniform_(self.wk.weight) 142 | self.wv = ops.Linear( 143 | dim, 144 | self.n_kv_heads * self.head_dim, 145 | bias=False, 146 | ) 147 | nn.init.xavier_uniform_(self.wv.weight) 148 | if y_dim > 0: 149 | self.wk_y = ops.Linear( 150 | y_dim, 151 | self.n_kv_heads * self.head_dim, 152 | bias=False, 153 | ) 154 | nn.init.xavier_uniform_(self.wk_y.weight) 155 | self.wv_y = ops.Linear( 156 | y_dim, 157 | self.n_kv_heads * self.head_dim, 158 | bias=False, 159 | ) 160 | nn.init.xavier_uniform_(self.wv_y.weight) 161 | self.gate = nn.Parameter(torch.zeros([self.n_local_heads])) 162 | 163 | self.wo = ops.Linear( 164 | n_heads * self.head_dim, 165 | dim, 166 | bias=False, 167 | ) 168 | nn.init.xavier_uniform_(self.wo.weight) 169 | 170 | if qk_norm: 171 | self.q_norm = ops.LayerNorm(self.n_local_heads * self.head_dim) 172 | self.k_norm = ops.LayerNorm(self.n_local_kv_heads * self.head_dim) 173 | if y_dim > 0: 174 | self.ky_norm = ops.LayerNorm(self.n_local_kv_heads * self.head_dim) 175 | else: 176 | self.ky_norm = nn.Identity() 177 | else: 178 | self.q_norm = self.k_norm = nn.Identity() 179 | self.ky_norm = nn.Identity() 180 | 181 | # for proportional attention computation 182 | self.base_seqlen = None 183 | self.proportional_attn = False 184 | 185 | @staticmethod 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | """ 188 | Reshape frequency tensor for broadcasting it with another tensor. 189 | 190 | This function reshapes the frequency tensor to have the same shape as 191 | the target tensor 'x' for the purpose of broadcasting the frequency 192 | tensor during element-wise operations. 193 | 194 | Args: 195 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 196 | x (torch.Tensor): Target tensor for broadcasting compatibility. 197 | 198 | Returns: 199 | torch.Tensor: Reshaped frequency tensor. 200 | 201 | Raises: 202 | AssertionError: If the frequency tensor doesn't match the expected 203 | shape. 204 | AssertionError: If the target tensor 'x' doesn't have the expected 205 | number of dimensions. 206 | """ 207 | ndim = x.ndim 208 | assert 0 <= 1 < ndim 209 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 210 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 211 | return freqs_cis.view(*shape) 212 | 213 | @staticmethod 214 | def apply_rotary_emb( 215 | x_in: torch.Tensor, 216 | freqs_cis: torch.Tensor, 217 | ) -> torch.Tensor: 218 | """ 219 | Apply rotary embeddings to input tensors using the given frequency 220 | tensor. 221 | 222 | This function applies rotary embeddings to the given query 'xq' and 223 | key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The 224 | input tensors are reshaped as complex numbers, and the frequency tensor 225 | is reshaped for broadcasting compatibility. The resulting tensors 226 | contain rotary embeddings and are returned as real tensors. 227 | 228 | Args: 229 | x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. 230 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex 231 | exponentials. 232 | 233 | Returns: 234 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor 235 | and key tensor with rotary embeddings. 236 | """ 237 | with torch.amp.autocast(cast_device, enabled=False): 238 | x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) 239 | freqs_cis = freqs_cis.unsqueeze(2) 240 | x_out = torch.view_as_real(x * freqs_cis).flatten(3) 241 | return x_out.type_as(x_in) 242 | 243 | # copied from huggingface modeling_llama.py 244 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 245 | def _get_unpad_data(attention_mask): 246 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 247 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 248 | max_seqlen_in_batch = seqlens_in_batch.max().item() 249 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 250 | return ( 251 | indices, 252 | cu_seqlens, 253 | max_seqlen_in_batch, 254 | ) 255 | 256 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 257 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 258 | 259 | key_layer = index_first_axis( 260 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 261 | indices_k, 262 | ) 263 | value_layer = index_first_axis( 264 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 265 | indices_k, 266 | ) 267 | if query_length == kv_seq_len: 268 | query_layer = index_first_axis( 269 | query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), 270 | indices_k, 271 | ) 272 | cu_seqlens_q = cu_seqlens_k 273 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 274 | indices_q = indices_k 275 | elif query_length == 1: 276 | max_seqlen_in_batch_q = 1 277 | cu_seqlens_q = torch.arange( 278 | batch_size + 1, dtype=torch.int32, device=query_layer.device 279 | ) # There is a memcpy here, that is very bad. 280 | indices_q = cu_seqlens_q[:-1] 281 | query_layer = query_layer.squeeze(1) 282 | else: 283 | # The -q_len: slice assumes left padding. 284 | attention_mask = attention_mask[:, -query_length:] 285 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 286 | 287 | return ( 288 | query_layer, 289 | key_layer, 290 | value_layer, 291 | indices_q, 292 | (cu_seqlens_q, cu_seqlens_k), 293 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 294 | ) 295 | 296 | def forward( 297 | self, 298 | x: torch.Tensor, 299 | x_mask: torch.Tensor, 300 | freqs_cis: torch.Tensor, 301 | y: torch.Tensor, 302 | y_mask: torch.Tensor, 303 | region_mask: Optional[torch.Tensor] = None, 304 | ) -> torch.Tensor: 305 | """ 306 | 307 | Args: 308 | x: 309 | x_mask: 310 | freqs_cis: 311 | y: 312 | y_mask: 313 | 314 | Returns: 315 | 316 | """ 317 | bsz, seqlen, _ = x.shape 318 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 319 | dtype = xq.dtype 320 | 321 | xq = self.q_norm(xq) 322 | xk = self.k_norm(xk) 323 | 324 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 325 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 326 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 327 | 328 | xq = Attention.apply_rotary_emb(xq, freqs_cis=freqs_cis) 329 | xk = Attention.apply_rotary_emb(xk, freqs_cis=freqs_cis) 330 | 331 | xq, xk = xq.to(dtype), xk.to(dtype) 332 | 333 | if self.proportional_attn: 334 | softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim) 335 | else: 336 | softmax_scale = math.sqrt(1 / self.head_dim) 337 | 338 | if dtype in [torch.float16, torch.bfloat16] and FLASH_ATTN_AVAILABLE: 339 | # begin var_len flash attn 340 | ( 341 | query_states, 342 | key_states, 343 | value_states, 344 | indices_q, 345 | cu_seq_lens, 346 | max_seq_lens, 347 | ) = self._upad_input(xq, xk, xv, x_mask, seqlen) 348 | 349 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 350 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 351 | 352 | attn_output_unpad = flash_attn_varlen_func( 353 | query_states, 354 | key_states, 355 | value_states, 356 | cu_seqlens_q=cu_seqlens_q, 357 | cu_seqlens_k=cu_seqlens_k, 358 | max_seqlen_q=max_seqlen_in_batch_q, 359 | max_seqlen_k=max_seqlen_in_batch_k, 360 | dropout_p=0.0, 361 | causal=False, 362 | softmax_scale=softmax_scale, 363 | ) 364 | output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) 365 | # end var_len_flash_attn 366 | 367 | else: 368 | n_rep = self.n_local_heads // self.n_local_kv_heads 369 | if n_rep >= 1: 370 | xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 371 | xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 372 | output = ( 373 | F.scaled_dot_product_attention( 374 | xq.permute(0, 2, 1, 3), 375 | xk.permute(0, 2, 1, 3), 376 | xv.permute(0, 2, 1, 3), 377 | attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1), 378 | scale=softmax_scale, 379 | ) 380 | .permute(0, 2, 1, 3) 381 | .to(dtype) 382 | ) 383 | 384 | if hasattr(self, "wk_y"): 385 | if x.shape[0] < 3: 386 | num_y = y.shape[0] 387 | xq = torch.cat([xq[0].unsqueeze(0).repeat(num_y - 1, 1, 1, 1), xq[-1].unsqueeze(0)], dim=0) 388 | yk = self.ky_norm(self.wk_y(y)).view(num_y, -1, self.n_local_kv_heads, self.head_dim) 389 | yv = self.wv_y(y).view(num_y, -1, self.n_local_kv_heads, self.head_dim) 390 | y_mask_in = y_mask.view(num_y, 1, 1, -1).repeat(1, self.n_local_heads, seqlen, 1) 391 | if region_mask is not None: 392 | region_mask_in = region_mask.view(num_y, 1, seqlen, 1).repeat( 393 | 1, self.n_local_heads, 1, y_mask_in.shape[-1] 394 | ) 395 | y_mask_in = y_mask_in & region_mask_in 396 | n_rep = self.n_local_heads // self.n_local_kv_heads 397 | if n_rep >= 1: 398 | yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 399 | yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 400 | output_y = F.scaled_dot_product_attention( 401 | xq.permute(0, 2, 1, 3), 402 | yk.permute(0, 2, 1, 3), 403 | yv.permute(0, 2, 1, 3), 404 | y_mask_in, 405 | ).permute(0, 2, 1, 3) 406 | output_y = torch.nan_to_num(output_y) 407 | output_y = output_y * self.gate.tanh().view(1, 1, -1, 1) 408 | output_y_cond = torch.sum(output_y[:-1], dim=0, keepdim=True) 409 | output_y_uncond = torch.sum(output_y[-1:], dim=0, keepdim=True) 410 | output_y = torch.cat([output_y_cond, output_y_uncond], dim=0) 411 | output = output + output_y 412 | else: 413 | # Original behavior 414 | yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim) 415 | yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim) 416 | n_rep = self.n_local_heads // self.n_local_kv_heads 417 | if n_rep >= 1: 418 | yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 419 | yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 420 | output_y = F.scaled_dot_product_attention( 421 | xq.permute(0, 2, 1, 3), 422 | yk.permute(0, 2, 1, 3), 423 | yv.permute(0, 2, 1, 3), 424 | y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1), 425 | ).permute(0, 2, 1, 3) 426 | output_y = output_y * self.gate.tanh().view(1, 1, -1, 1) 427 | output = output + output_y 428 | 429 | output = output.flatten(-2) 430 | 431 | return self.wo(output) 432 | 433 | 434 | class FeedForward(nn.Module): 435 | def __init__( 436 | self, 437 | dim: int, 438 | hidden_dim: int, 439 | multiple_of: int, 440 | ffn_dim_multiplier: Optional[float], 441 | ): 442 | """ 443 | Initialize the FeedForward module. 444 | 445 | Args: 446 | dim (int): Input dimension. 447 | hidden_dim (int): Hidden dimension of the feedforward layer. 448 | multiple_of (int): Value to ensure hidden dimension is a multiple 449 | of this value. 450 | ffn_dim_multiplier (float, optional): Custom multiplier for hidden 451 | dimension. Defaults to None. 452 | 453 | Attributes: 454 | w1 (ops.Linear): Linear transformation for the first 455 | layer. 456 | w2 (ops.Linear): Linear transformation for the second layer. 457 | w3 (ops.Linear): Linear transformation for the third 458 | layer. 459 | 460 | """ 461 | super().__init__() 462 | hidden_dim = int(2 * hidden_dim / 3) 463 | # custom dim factor multiplier 464 | if ffn_dim_multiplier is not None: 465 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 466 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 467 | 468 | self.w1 = ops.Linear( 469 | dim, 470 | hidden_dim, 471 | bias=False, 472 | ) 473 | nn.init.xavier_uniform_(self.w1.weight) 474 | self.w2 = ops.Linear( 475 | hidden_dim, 476 | dim, 477 | bias=False, 478 | ) 479 | nn.init.xavier_uniform_(self.w2.weight) 480 | self.w3 = ops.Linear( 481 | dim, 482 | hidden_dim, 483 | bias=False, 484 | ) 485 | nn.init.xavier_uniform_(self.w3.weight) 486 | 487 | # @torch.compile 488 | def _forward_silu_gating(self, x1, x3): 489 | return F.silu(x1) * x3 490 | 491 | def forward(self, x): 492 | return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) 493 | 494 | 495 | class TransformerBlock(nn.Module): 496 | def __init__( 497 | self, 498 | layer_id: int, 499 | dim: int, 500 | n_heads: int, 501 | n_kv_heads: int, 502 | multiple_of: int, 503 | ffn_dim_multiplier: float, 504 | norm_eps: float, 505 | qk_norm: bool, 506 | y_dim: int, 507 | ) -> None: 508 | """ 509 | Initialize a TransformerBlock. 510 | 511 | Args: 512 | layer_id (int): Identifier for the layer. 513 | dim (int): Embedding dimension of the input features. 514 | n_heads (int): Number of attention heads. 515 | n_kv_heads (Optional[int]): Number of attention heads in key and 516 | value features (if using GQA), or set to None for the same as 517 | query. 518 | multiple_of (int): 519 | ffn_dim_multiplier (float): 520 | norm_eps (float): 521 | 522 | Attributes: 523 | n_heads (int): Number of attention heads. 524 | dim (int): Dimension size of the model. 525 | head_dim (int): Dimension size of each attention head. 526 | attention (Attention): Attention module. 527 | feed_forward (FeedForward): FeedForward module. 528 | layer_id (int): Identifier for the layer. 529 | attention_norm (RMSNorm): Layer normalization for attention output. 530 | ffn_norm (RMSNorm): Layer normalization for feedforward output. 531 | 532 | """ 533 | super().__init__() 534 | self.dim = dim 535 | self.head_dim = dim // n_heads 536 | self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim) 537 | self.feed_forward = FeedForward( 538 | dim=dim, 539 | hidden_dim=4 * dim, 540 | multiple_of=multiple_of, 541 | ffn_dim_multiplier=ffn_dim_multiplier, 542 | ) 543 | self.layer_id = layer_id 544 | self.attention_norm1 = RMSNorm(dim, eps=norm_eps) 545 | self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) 546 | 547 | self.attention_norm2 = RMSNorm(dim, eps=norm_eps) 548 | self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) 549 | 550 | self.adaLN_modulation = nn.Sequential( 551 | nn.SiLU(), 552 | ops.Linear( 553 | min(dim, 1024), 554 | 4 * dim, 555 | bias=True, 556 | ), 557 | ) 558 | nn.init.zeros_(self.adaLN_modulation[1].weight) 559 | nn.init.zeros_(self.adaLN_modulation[1].bias) 560 | 561 | self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps) 562 | 563 | def forward( 564 | self, 565 | x: torch.Tensor, 566 | x_mask: torch.Tensor, 567 | freqs_cis: torch.Tensor, 568 | y: torch.Tensor, 569 | y_mask: torch.Tensor, 570 | adaln_input: Optional[torch.Tensor] = None, 571 | region_mask: Optional[torch.Tensor] = None, 572 | ): 573 | """ 574 | Perform a forward pass through the TransformerBlock. 575 | 576 | Args: 577 | x (torch.Tensor): Input tensor. 578 | freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. 579 | 580 | Returns: 581 | torch.Tensor: Output tensor after applying attention and 582 | feedforward layers. 583 | 584 | """ 585 | if adaln_input is not None: 586 | scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) 587 | 588 | x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( 589 | self.attention( 590 | modulate(self.attention_norm1(x), scale_msa), 591 | x_mask, 592 | freqs_cis, 593 | self.attention_y_norm(y), 594 | y_mask, 595 | region_mask, 596 | ) 597 | ) 598 | x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( 599 | self.feed_forward( 600 | modulate(self.ffn_norm1(x), scale_mlp), 601 | ) 602 | ) 603 | 604 | else: 605 | x = x + self.attention_norm2( 606 | self.attention( 607 | self.attention_norm1(x), 608 | x_mask, 609 | freqs_cis, 610 | self.attention_y_norm(y), 611 | y_mask, 612 | region_mask, 613 | ) 614 | ) 615 | x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) 616 | 617 | return x 618 | 619 | 620 | class FinalLayer(nn.Module): 621 | """ 622 | The final layer of NextDiT. 623 | """ 624 | 625 | def __init__(self, hidden_size, patch_size, out_channels): 626 | super().__init__() 627 | self.norm_final = ops.LayerNorm( 628 | hidden_size, 629 | elementwise_affine=False, 630 | eps=1e-6, 631 | ) 632 | self.linear = ops.Linear( 633 | hidden_size, 634 | patch_size * patch_size * out_channels, 635 | bias=True, 636 | ) 637 | nn.init.zeros_(self.linear.weight) 638 | nn.init.zeros_(self.linear.bias) 639 | 640 | self.adaLN_modulation = nn.Sequential( 641 | nn.SiLU(), 642 | ops.Linear( 643 | min(hidden_size, 1024), 644 | hidden_size, 645 | bias=True, 646 | ), 647 | ) 648 | nn.init.zeros_(self.adaLN_modulation[1].weight) 649 | nn.init.zeros_(self.adaLN_modulation[1].bias) 650 | 651 | def forward(self, x, c): 652 | scale = self.adaLN_modulation(c) 653 | # shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 654 | x = modulate(self.norm_final(x), scale) 655 | x = self.linear(x) 656 | return x 657 | 658 | 659 | class NextDiT(nn.Module): 660 | """ 661 | Diffusion model with a Transformer backbone. 662 | """ 663 | 664 | def __init__( 665 | self, 666 | patch_size: int = 2, 667 | in_channels: int = 4, 668 | dim: int = 4096, 669 | n_layers: int = 32, 670 | n_heads: int = 32, 671 | n_kv_heads: Optional[int] = None, 672 | multiple_of: int = 256, 673 | ffn_dim_multiplier: Optional[float] = None, 674 | norm_eps: float = 1e-5, 675 | learn_sigma: bool = True, 676 | qk_norm: bool = False, 677 | cap_feat_dim: int = 5120, 678 | scale_factor: float = 1.0, 679 | ) -> None: 680 | super().__init__() 681 | self.learn_sigma = learn_sigma 682 | self.in_channels = in_channels 683 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 684 | self.patch_size = patch_size 685 | 686 | self.x_embedder = ops.Linear( 687 | in_features=patch_size * patch_size * in_channels, 688 | out_features=dim, 689 | bias=True, 690 | ) 691 | nn.init.xavier_uniform_(self.x_embedder.weight) 692 | nn.init.constant_(self.x_embedder.bias, 0.0) 693 | 694 | self.t_embedder = TimestepEmbedder(min(dim, 1024)) 695 | self.cap_embedder = nn.Sequential( 696 | ops.LayerNorm(cap_feat_dim), 697 | ops.Linear( 698 | cap_feat_dim, 699 | min(dim, 1024), 700 | bias=True, 701 | ), 702 | ) 703 | nn.init.zeros_(self.cap_embedder[1].weight) 704 | nn.init.zeros_(self.cap_embedder[1].bias) 705 | 706 | self.layers = nn.ModuleList( 707 | [ 708 | TransformerBlock( 709 | layer_id, 710 | dim, 711 | n_heads, 712 | n_kv_heads, 713 | multiple_of, 714 | ffn_dim_multiplier, 715 | norm_eps, 716 | qk_norm, 717 | cap_feat_dim, 718 | ) 719 | for layer_id in range(n_layers) 720 | ] 721 | ) 722 | self.final_layer = FinalLayer(dim, patch_size, self.out_channels) 723 | 724 | assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" 725 | self.freqs_cis = NextDiT.precompute_freqs_cis( 726 | dim // n_heads, 727 | 384, 728 | scale_factor=scale_factor, 729 | ) 730 | self.dim = dim 731 | self.n_heads = n_heads 732 | self.scale_factor = scale_factor 733 | self.pad_token = nn.Parameter(torch.empty(dim)) 734 | nn.init.normal_(self.pad_token, std=0.02) 735 | 736 | def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]: 737 | """ 738 | x: (N, T, patch_size**2 * C) 739 | imgs: (N, H, W, C) 740 | """ 741 | pH = pW = self.patch_size 742 | if return_tensor: 743 | H, W = img_size[0] 744 | B = x.size(0) 745 | L = (H // pH) * (W // pW) 746 | x = x[:, :L].view(B, H // pH, W // pW, pH, pW, self.out_channels) 747 | x = x.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) 748 | return x 749 | else: 750 | imgs = [] 751 | for i in range(x.size(0)): 752 | H, W = img_size[i] 753 | L = (H // pH) * (W // pW) 754 | imgs.append( 755 | x[i][:L] 756 | .view(H // pH, W // pW, pH, pW, self.out_channels) 757 | .permute(4, 0, 2, 1, 3) 758 | .flatten(3, 4) 759 | .flatten(1, 2) 760 | ) 761 | return imgs 762 | 763 | def patchify_and_embed( 764 | self, x: List[torch.Tensor] | torch.Tensor 765 | ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: 766 | self.freqs_cis = self.freqs_cis.to(x[0].device) 767 | if isinstance(x, torch.Tensor): 768 | pH = pW = self.patch_size 769 | B, C, H, W = x.size() 770 | x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3) 771 | x = self.x_embedder(x) 772 | x = x.flatten(1, 2) 773 | 774 | mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) 775 | 776 | return ( 777 | x, 778 | mask, 779 | [(H, W)] * B, 780 | self.freqs_cis[: H // pH, : W // pW].flatten(0, 1).unsqueeze(0), 781 | ) 782 | else: 783 | pH = pW = self.patch_size 784 | x_embed = [] 785 | freqs_cis = [] 786 | img_size = [] 787 | l_effective_seq_len = [] 788 | 789 | for img in x: 790 | C, H, W = img.size() 791 | item_freqs_cis = self.freqs_cis[: H // pH, : W // pW] 792 | freqs_cis.append(item_freqs_cis.flatten(0, 1)) 793 | img_size.append((H, W)) 794 | img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2) 795 | img = self.x_embedder(img) 796 | img = img.flatten(0, 1) 797 | l_effective_seq_len.append(len(img)) 798 | x_embed.append(img) 799 | 800 | max_seq_len = max(l_effective_seq_len) 801 | mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device) 802 | padded_x_embed = [] 803 | padded_freqs_cis = [] 804 | for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate( 805 | zip(x_embed, freqs_cis, l_effective_seq_len) 806 | ): 807 | item_embed = torch.cat( 808 | [ 809 | item_embed, 810 | self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1), 811 | ], 812 | dim=0, 813 | ) 814 | item_freqs_cis = torch.cat( 815 | [ 816 | item_freqs_cis, 817 | item_freqs_cis[-1:].expand(max_seq_len - item_seq_len, -1), 818 | ], 819 | dim=0, 820 | ) 821 | padded_x_embed.append(item_embed) 822 | padded_freqs_cis.append(item_freqs_cis) 823 | mask[i][:item_seq_len] = 1 824 | 825 | x_embed = torch.stack(padded_x_embed, dim=0) 826 | freqs_cis = torch.stack(padded_freqs_cis, dim=0) 827 | return x_embed, mask, img_size, freqs_cis 828 | 829 | def forward( 830 | self, x, t, cap_feats, cap_mask, global_cap_feats=None, global_cap_mask=None, h_split_num=1, w_split_num=1 831 | ): 832 | """ 833 | Forward pass of NextDiT. 834 | t: (N,) tensor of diffusion timesteps 835 | y: (N,) tensor of class labels 836 | """ 837 | B, C, H, W = x.size() 838 | x_is_tensor = isinstance(x, torch.Tensor) 839 | x, mask, img_size, freqs_cis = self.patchify_and_embed(x) 840 | freqs_cis = freqs_cis.to(x.device) 841 | 842 | t = self.t_embedder(t) # (N, D) 843 | cap_mask_float = global_cap_mask.float().unsqueeze(-1) 844 | cap_feats_pool = (global_cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1) 845 | cap_feats_pool = cap_feats_pool.to(cap_feats) 846 | cap_emb = self.cap_embedder(cap_feats_pool) 847 | adaln_input = t + cap_emb 848 | 849 | region_mask = torch.zeros( 850 | cap_feats.shape[0], H // self.patch_size, W // self.patch_size, dtype=torch.float, device=x.device 851 | ) 852 | h_patch_size, w_patch_size = H // h_split_num // self.patch_size, W // w_split_num // self.patch_size 853 | for h_split in range(h_split_num): 854 | for w_split in range(w_split_num): 855 | region_id = (h_split + 1) * (w_split + 1) - 1 856 | region_mask[ 857 | region_id, 858 | h_patch_size * h_split : h_patch_size * (h_split + 1), 859 | w_patch_size * w_split : w_patch_size * (w_split + 1), 860 | ] = 1 861 | region_mask[-1, :, :] = 1 862 | 863 | region_mask = region_mask.flatten(1, 2) 864 | region_mask = region_mask > 0.5 865 | 866 | cap_mask = cap_mask.bool() 867 | for layer in self.layers: 868 | x = layer(x, mask, freqs_cis, cap_feats, cap_mask, adaln_input=adaln_input, region_mask=region_mask) 869 | 870 | x = self.final_layer(x, adaln_input) 871 | x = self.unpatchify(x, img_size, return_tensor=x_is_tensor) 872 | if self.learn_sigma: 873 | if x_is_tensor: 874 | x, _ = x.chunk(2, dim=1) 875 | else: 876 | x = [_.chunk(2, dim=0)[0] for _ in x] 877 | return x 878 | 879 | def forward_with_cfg( 880 | self, 881 | x, 882 | t, 883 | cap_feats, 884 | cap_mask, 885 | cfg_scale, 886 | scale_factor=1.0, 887 | scale_watershed=1.0, 888 | base_seqlen: Optional[int] = None, 889 | proportional_attn: bool = False, 890 | global_cap_feats=None, 891 | global_cap_mask=None, 892 | h_split_num=1, 893 | w_split_num=1 894 | ): 895 | """ 896 | Forward pass of NextDiT, but also batches the unconditional forward pass 897 | for classifier-free guidance. 898 | """ 899 | # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 900 | self.freqs_cis = NextDiT.precompute_freqs_cis( 901 | self.dim // self.n_heads, 902 | 384, 903 | scale_factor=scale_factor, 904 | scale_watershed=scale_watershed, 905 | timestep=t[0].item(), 906 | ) 907 | 908 | if proportional_attn: 909 | assert base_seqlen is not None 910 | for layer in self.layers: 911 | layer.attention.base_seqlen = base_seqlen 912 | layer.attention.proportional_attn = proportional_attn 913 | else: 914 | for layer in self.layers: 915 | layer.attention.base_seqlen = None 916 | layer.attention.proportional_attn = proportional_attn 917 | 918 | half = x[: len(x) // 2] 919 | combined = torch.cat([half, half], dim=0) 920 | model_out = self(combined, t, cap_feats, cap_mask, global_cap_feats, global_cap_mask, h_split_num, w_split_num) 921 | # For exact reproducibility reasons, we apply classifier-free guidance on only 922 | # three channels by default. The standard approach to cfg applies it to all channels. 923 | # This can be done by uncommenting the following line and commenting-out the line following that. 924 | # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 925 | eps, rest = model_out[:, :3], model_out[:, 3:] 926 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 927 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 928 | eps = torch.cat([half_eps, half_eps], dim=0) 929 | 930 | return torch.cat([eps, rest], dim=1) 931 | 932 | @staticmethod 933 | def precompute_freqs_cis( 934 | dim: int, 935 | end: int, 936 | theta: float = 10000.0, 937 | scale_factor: float = 1.0, 938 | scale_watershed: float = 1.0, 939 | timestep: float = 1.0, 940 | ): 941 | """ 942 | Precompute the frequency tensor for complex exponentials (cis) with 943 | given dimensions. 944 | 945 | This function calculates a frequency tensor with complex exponentials 946 | using the given dimension 'dim' and the end index 'end'. The 'theta' 947 | parameter scales the frequencies. The returned tensor contains complex 948 | values in complex64 data type. 949 | 950 | Args: 951 | dim (int): Dimension of the frequency tensor. 952 | end (int): End index for precomputing frequencies. 953 | theta (float, optional): Scaling factor for frequency computation. 954 | Defaults to 10000.0. 955 | 956 | Returns: 957 | torch.Tensor: Precomputed frequency tensor with complex 958 | exponentials. 959 | """ 960 | 961 | if timestep < scale_watershed: 962 | linear_factor = scale_factor 963 | ntk_factor = 1.0 964 | else: 965 | linear_factor = 1.0 966 | ntk_factor = scale_factor 967 | 968 | theta = theta * ntk_factor 969 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / linear_factor 970 | 971 | timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore 972 | 973 | freqs = torch.outer(timestep, freqs).float() # type: ignore 974 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 975 | 976 | freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1) 977 | freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1) 978 | freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2) 979 | 980 | return freqs_cis 981 | 982 | def parameter_count(self) -> int: 983 | total_params = 0 984 | 985 | def _recursive_count_params(module): 986 | nonlocal total_params 987 | for param in module.parameters(recurse=False): 988 | total_params += param.numel() 989 | for submodule in module.children(): 990 | _recursive_count_params(submodule) 991 | 992 | _recursive_count_params(self) 993 | return total_params 994 | 995 | def get_fsdp_wrap_module_list(self) -> List[nn.Module]: 996 | return list(self.layers) 997 | 998 | 999 | ############################################################################# 1000 | # NextDiT Configs # 1001 | ############################################################################# 1002 | def NextDiT_2B_patch2(**kwargs): 1003 | return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs) 1004 | 1005 | 1006 | def NextDiT_2B_GQA_patch2(**kwargs): 1007 | return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, n_kv_heads=8, **kwargs) 1008 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | import math 5 | import gc 6 | 7 | import comfy.model_management as mm 8 | from comfy.utils import ProgressBar, load_torch_file 9 | 10 | import folder_paths 11 | 12 | script_directory = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(script_directory) 14 | 15 | import lumina_models 16 | from transport import ODE 17 | from transformers import AutoModel, AutoTokenizer, GemmaForCausalLM 18 | from argparse import Namespace 19 | 20 | from contextlib import nullcontext 21 | try: 22 | from accelerate import init_empty_weights 23 | from accelerate.utils import set_module_tensor_to_device 24 | is_accelerate_available = True 25 | except: 26 | pass 27 | 28 | try: 29 | from flash_attn import flash_attn_varlen_func 30 | FLASH_ATTN_AVAILABLE = True 31 | print("Flash Attention is available") 32 | except: 33 | FLASH_ATTN_AVAILABLE = False 34 | print("LuminaWrapper: WARNING! Flash Attention is not available, using much slower torch SDP attention") 35 | 36 | class DownloadAndLoadLuminaModel: 37 | @classmethod 38 | def INPUT_TYPES(s): 39 | return {"required": { 40 | "model": ( 41 | [ 42 | 'Alpha-VLLM/Lumina-Next-SFT', 43 | 'Alpha-VLLM/Lumina-Next-T2I' 44 | ], 45 | { 46 | "default": 'Alpha-VLLM/Lumina-Next-SFT' 47 | }), 48 | "precision": ([ 'bf16','fp32'], 49 | { 50 | "default": 'bf16' 51 | }), 52 | }, 53 | } 54 | 55 | RETURN_TYPES = ("LUMINAMODEL",) 56 | RETURN_NAMES = ("lumina_model",) 57 | FUNCTION = "loadmodel" 58 | CATEGORY = "LuminaWrapper" 59 | 60 | def loadmodel(self, model, precision): 61 | device = mm.get_torch_device() 62 | offload_device = mm.unet_offload_device() 63 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 64 | 65 | model_name = model.rsplit('/', 1)[-1] 66 | model_path = os.path.join(folder_paths.models_dir, "lumina", model_name) 67 | safetensors_path = os.path.join(model_path, "consolidated.00-of-01.safetensors") 68 | 69 | if not os.path.exists(safetensors_path): 70 | print(f"Downloading Lumina model to: {model_path}") 71 | from huggingface_hub import snapshot_download 72 | snapshot_download(repo_id=model, 73 | ignore_patterns=['*ema*', '*.pth'], 74 | local_dir=model_path, 75 | local_dir_use_symlinks=False) 76 | 77 | #train_args = torch.load(os.path.join(model_path, "model_args.pth")) 78 | 79 | train_args = Namespace( 80 | model='NextDiT_2B_GQA_patch2', 81 | image_size=1024, 82 | vae='sdxl', 83 | precision='bf16', 84 | grad_precision='fp32', 85 | grad_clip=2.0, 86 | wd=0.0, 87 | qk_norm=True, 88 | model_parallel_size=1 89 | ) 90 | 91 | with (init_empty_weights() if is_accelerate_available else nullcontext()): 92 | model = lumina_models.__dict__[train_args.model](qk_norm=train_args.qk_norm, cap_feat_dim=2048) 93 | model.eval().to(dtype) 94 | 95 | sd = load_torch_file(safetensors_path) 96 | if is_accelerate_available: 97 | for key in sd: 98 | set_module_tensor_to_device(model, key, dtype=dtype, device=offload_device, value=sd[key]) 99 | else: 100 | model.load_state_dict(sd, strict=True) 101 | del sd 102 | mm.soft_empty_cache() 103 | 104 | lumina_model = { 105 | 'model': model, 106 | 'train_args': train_args, 107 | 'dtype': dtype 108 | } 109 | 110 | return (lumina_model,) 111 | 112 | class DownloadAndLoadGemmaModel: 113 | @classmethod 114 | def INPUT_TYPES(s): 115 | return {"required": { 116 | "precision": ([ 'bf16','fp32'], 117 | { 118 | "default": 'bf16' 119 | }), 120 | }, 121 | } 122 | 123 | RETURN_TYPES = ("GEMMAODEL",) 124 | RETURN_NAMES = ("gemma_model",) 125 | FUNCTION = "loadmodel" 126 | CATEGORY = "LuminaWrapper" 127 | 128 | def loadmodel(self, precision): 129 | device = mm.get_torch_device() 130 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 131 | 132 | gemma_path = os.path.join(folder_paths.models_dir, "LLM", "gemma-2b") 133 | 134 | if not os.path.exists(gemma_path): 135 | print(f"Downloading Gemma model to: {gemma_path}") 136 | from huggingface_hub import snapshot_download 137 | snapshot_download(repo_id="alpindale/gemma-2b", 138 | local_dir=gemma_path, 139 | ignore_patterns=['*gguf*'], 140 | local_dir_use_symlinks=False) 141 | 142 | tokenizer = AutoTokenizer.from_pretrained(gemma_path) 143 | tokenizer.padding_side = "right" 144 | 145 | attn_implementation = "flash_attention_2" if FLASH_ATTN_AVAILABLE and precision != "fp32" else "sdpa" 146 | print(f"Gemma attention mode: {attn_implementation}") 147 | 148 | #model_class = AutoModel if mode == 'text_encode' else GemmaForCausalLM 149 | model_class = GemmaForCausalLM 150 | text_encoder = model_class.from_pretrained( 151 | gemma_path, 152 | torch_dtype=dtype, 153 | device_map=device, 154 | attn_implementation=attn_implementation, 155 | ).eval() 156 | 157 | gemma_model = { 158 | 'tokenizer': tokenizer, 159 | 'text_encoder': text_encoder, 160 | } 161 | 162 | return (gemma_model,) 163 | 164 | class LuminaGemmaTextEncode: 165 | @classmethod 166 | def INPUT_TYPES(s): 167 | return { 168 | "required": { 169 | "gemma_model": ("GEMMAODEL", ), 170 | "latent": ("LATENT", ), 171 | "prompt": ("STRING", {"multiline": True, "default": "",}), 172 | "n_prompt": ("STRING", {"multiline": True, "default": "",}), 173 | }, 174 | "optional": { 175 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 176 | } 177 | } 178 | 179 | RETURN_TYPES = ("LUMINATEMBED",) 180 | RETURN_NAMES =("lumina_embeds",) 181 | FUNCTION = "encode" 182 | CATEGORY = "LuminaWrapper" 183 | 184 | def encode(self, gemma_model, latent, prompt, n_prompt, keep_model_loaded=False): 185 | device = mm.get_torch_device() 186 | offload_device = mm.unet_offload_device() 187 | mm.unload_all_models() 188 | mm.soft_empty_cache() 189 | 190 | tokenizer = gemma_model['tokenizer'] 191 | text_encoder = gemma_model['text_encoder'] 192 | text_encoder.to(device) 193 | 194 | B = latent["samples"].shape[0] 195 | prompts = [prompt] * B + [n_prompt] * B 196 | 197 | text_inputs = tokenizer( 198 | prompts, 199 | padding=True, 200 | pad_to_multiple_of=8, 201 | max_length=256, 202 | truncation=True, 203 | return_tensors="pt", 204 | ) 205 | 206 | text_input_ids = text_inputs.input_ids 207 | prompt_masks = text_inputs.attention_mask.to(device) 208 | 209 | prompt_embeds = text_encoder( 210 | input_ids=text_input_ids.to(device), 211 | attention_mask=prompt_masks.to(device), 212 | output_hidden_states=True, 213 | ).hidden_states[-2] 214 | 215 | if not keep_model_loaded: 216 | print("Offloading text encoder...") 217 | text_encoder.to(offload_device) 218 | mm.soft_empty_cache() 219 | gc.collect() 220 | lumina_embeds = { 221 | 'prompt_embeds': prompt_embeds, 222 | 'prompt_masks': prompt_masks, 223 | } 224 | 225 | return (lumina_embeds,) 226 | 227 | class LuminaTextAreaAppend: 228 | @classmethod 229 | def INPUT_TYPES(s): 230 | return { 231 | "required": { 232 | 233 | "prompt": ("STRING", {"multiline": True, "default": "",}), 234 | "row": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}), 235 | "column": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}), 236 | }, 237 | "optional": { 238 | "prev_prompt": ("LUMINAAREAPROMPT", ), 239 | } 240 | } 241 | 242 | RETURN_TYPES = ("LUMINAAREAPROMPT",) 243 | RETURN_NAMES =("lumina_area_prompt",) 244 | FUNCTION = "process" 245 | CATEGORY = "LuminaWrapper" 246 | 247 | def process(self, prompt, row, column, prev_prompt=None): 248 | prompt_entry = { 249 | 'prompt': prompt, 250 | 'row': row, 251 | 'column': column 252 | } 253 | 254 | if prev_prompt is not None: 255 | prompt_list = prev_prompt + [prompt_entry] 256 | else: 257 | prompt_list = [prompt_entry] 258 | 259 | return (prompt_list,) 260 | 261 | class LuminaGemmaTextEncodeArea: 262 | @classmethod 263 | def INPUT_TYPES(s): 264 | return { 265 | "required": { 266 | "gemma_model": ("GEMMAODEL", ), 267 | "lumina_area_prompt": ("LUMINAAREAPROMPT",), 268 | "append_prompt": ("STRING", {"multiline": True, "default": "",}), 269 | "n_prompt": ("STRING", {"multiline": True, "default": "",}), 270 | 271 | }, 272 | "optional": { 273 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 274 | } 275 | } 276 | 277 | RETURN_TYPES = ("LUMINATEMBED",) 278 | RETURN_NAMES =("lumina_embeds",) 279 | FUNCTION = "encode" 280 | CATEGORY = "LuminaWrapper" 281 | 282 | def encode(self, gemma_model, lumina_area_prompt, append_prompt, n_prompt, keep_model_loaded=False): 283 | device = mm.get_torch_device() 284 | offload_device = mm.unet_offload_device() 285 | 286 | tokenizer = gemma_model['tokenizer'] 287 | text_encoder = gemma_model['text_encoder'] 288 | text_encoder.to(device) 289 | 290 | prompt_list = [entry['prompt'] + "," + append_prompt for entry in lumina_area_prompt] 291 | global_prompt = " ".join(prompt_list) 292 | prompts = prompt_list + [n_prompt] + [global_prompt] 293 | print("prompts: ", prompts) 294 | 295 | text_inputs = tokenizer( 296 | prompts, 297 | padding=True, 298 | pad_to_multiple_of=8, 299 | max_length=256, 300 | truncation=True, 301 | return_tensors="pt", 302 | ) 303 | 304 | text_input_ids = text_inputs.input_ids 305 | prompt_masks = text_inputs.attention_mask.to(device) 306 | 307 | prompt_embeds = text_encoder( 308 | input_ids=text_input_ids.to(device), 309 | attention_mask=prompt_masks.to(device), 310 | output_hidden_states=True, 311 | ).hidden_states[-2] 312 | if not keep_model_loaded: 313 | print("Offloading text encoder...") 314 | text_encoder.to(offload_device) 315 | mm.soft_empty_cache() 316 | gc.collect() 317 | lumina_embeds = { 318 | 'prompt_embeds': prompt_embeds, 319 | 'prompt_masks': prompt_masks, 320 | 'lumina_area_prompt': lumina_area_prompt 321 | } 322 | 323 | return (lumina_embeds,) 324 | 325 | class GemmaSampler: 326 | @classmethod 327 | def INPUT_TYPES(s): 328 | return { 329 | "required": { 330 | "gemma_model": ("GEMMAODEL", ), 331 | "prompt": ("STRING", {"multiline": True, "default": "",}), 332 | "max_length": ("INT", {"default": 128, "min": 1, "max": 512, "step": 1}), 333 | "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), 334 | "do_sample": ("BOOLEAN", {"default": True}), 335 | "early_stopping": ("BOOLEAN", {"default": False}), 336 | "top_k": ("INT", {"default": 50, "min": 0, "max": 100, "step": 1}), 337 | "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}), 338 | "repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 339 | "length_penalty": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 340 | }, 341 | "optional": { 342 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 343 | } 344 | } 345 | 346 | RETURN_TYPES = ("STRING",) 347 | RETURN_NAMES =("string",) 348 | FUNCTION = "process" 349 | CATEGORY = "LuminaWrapper" 350 | 351 | def process(self, gemma_model, prompt, max_length, temperature, do_sample, top_k, top_p, repetition_penalty, 352 | length_penalty, early_stopping, keep_model_loaded=False): 353 | device = mm.get_torch_device() 354 | offload_device = mm.unet_offload_device() 355 | 356 | mm.unload_all_models() 357 | mm.soft_empty_cache() 358 | 359 | tokenizer = gemma_model['tokenizer'] 360 | model = gemma_model['text_encoder'] 361 | model.to(device) 362 | 363 | text_inputs = tokenizer( 364 | prompt, 365 | return_tensors="pt", 366 | ) 367 | 368 | text_input_ids = text_inputs.input_ids.to(device) 369 | 370 | result = model.generate( 371 | text_input_ids, 372 | max_length=max_length, 373 | temperature=temperature, 374 | do_sample=do_sample, 375 | early_stopping=early_stopping, 376 | top_k=top_k, 377 | top_p=top_p, 378 | repetition_penalty=repetition_penalty, 379 | length_penalty=length_penalty, 380 | ) 381 | decoded = tokenizer.batch_decode(result, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 382 | 383 | print(decoded) 384 | 385 | if not keep_model_loaded: 386 | print("Offloading text encoder...") 387 | model.to(offload_device) 388 | mm.soft_empty_cache() 389 | gc.collect() 390 | 391 | return (decoded,) 392 | 393 | class LuminaT2ISampler: 394 | @classmethod 395 | def INPUT_TYPES(s): 396 | return { 397 | "required": { 398 | "lumina_model": ("LUMINAMODEL", ), 399 | "lumina_embeds": ("LUMINATEMBED", ), 400 | "latent": ("LATENT", ), 401 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 402 | "steps": ("INT", {"default": 25, "min": 1, "max": 200, "step": 1}), 403 | "cfg": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 20.0, "step": 0.01}), 404 | "proportional_attn": ("BOOLEAN", {"default": False}), 405 | "do_extrapolation": ("BOOLEAN", {"default": False}), 406 | "scaling_watershed": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), 407 | "t_shift": ("INT", {"default": 4, "min": 1, "max": 20, "step": 1}), 408 | "solver": ( 409 | [ 410 | 'euler', 411 | 'midpoint', 412 | 'rk4', 413 | ], 414 | { 415 | "default": 'midpoint' 416 | }), 417 | }, 418 | "optional": { 419 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 420 | "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 421 | } 422 | } 423 | 424 | RETURN_TYPES = ("LATENT",) 425 | RETURN_NAMES =("samples",) 426 | FUNCTION = "process" 427 | CATEGORY = "LuminaWrapper" 428 | 429 | def process(self, lumina_model, lumina_embeds, latent, seed, steps, cfg, proportional_attn, solver, t_shift, 430 | do_extrapolation, scaling_watershed, strength=1.0, keep_model_loaded=False): 431 | device = mm.get_torch_device() 432 | offload_device = mm.unet_offload_device() 433 | 434 | model = lumina_model['model'] 435 | dtype = lumina_model['dtype'] 436 | 437 | vae_scaling_factor = 0.13025 #SDXL scaling factor 438 | 439 | x1 = latent["samples"].clone() * vae_scaling_factor 440 | 441 | ode = ODE(steps, solver, t_shift, strength) 442 | 443 | B = x1.shape[0] 444 | W = x1.shape[3] * 8 445 | H = x1.shape[2] * 8 446 | 447 | z = torch.zeros_like(x1) 448 | 449 | for i in range(B): 450 | torch.manual_seed(seed + i) 451 | z[i] = torch.randn_like(x1[i]) 452 | z[i] = z[i] * (1 - ode.t[0]) + x1[i] * ode.t[0] 453 | 454 | #torch.random.manual_seed(int(seed)) 455 | #z = torch.randn([1, 4, z.shape[2], z.shape[3]], device=device) 456 | 457 | z = z.repeat(2, 1, 1, 1) 458 | z = z.to(dtype).to(device) 459 | 460 | train_args = lumina_model['train_args'] 461 | 462 | cap_feats=lumina_embeds['prompt_embeds'] 463 | cap_mask=lumina_embeds['prompt_masks'] 464 | 465 | #calculate splits from prompt dict 466 | if 'lumina_area_prompt' in lumina_embeds: 467 | unique_rows = {entry['row'] for entry in lumina_embeds['lumina_area_prompt']} 468 | unique_columns = {entry['column'] for entry in lumina_embeds['lumina_area_prompt']} 469 | 470 | horizontal_splits = len(unique_columns) 471 | vertical_splits = len(unique_rows) 472 | print(f"Horizontal splits: {horizontal_splits} Vertical splits: {vertical_splits}") 473 | is_split=True 474 | else: 475 | horizontal_splits = 1 476 | vertical_splits = 1 477 | is_split=False 478 | 479 | model_kwargs = dict( 480 | cap_feats=cap_feats[:-1] if is_split else cap_feats, 481 | cap_mask=cap_mask[:-1] if is_split else cap_mask, 482 | global_cap_feats=cap_feats[-1:] if is_split else cap_feats, 483 | global_cap_mask=cap_mask[-1:] if is_split else cap_mask, 484 | cfg_scale=cfg, 485 | h_split_num=int(vertical_splits), 486 | w_split_num=int(horizontal_splits), 487 | ) 488 | if proportional_attn: 489 | model_kwargs["proportional_attn"] = True 490 | model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2 491 | else: 492 | model_kwargs["proportional_attn"] = False 493 | model_kwargs["base_seqlen"] = None 494 | 495 | if do_extrapolation: 496 | model_kwargs["scale_factor"] = math.sqrt(W * H / train_args.image_size**2) 497 | model_kwargs["scale_watershed"] = scaling_watershed 498 | else: 499 | model_kwargs["scale_factor"] = 1.0 500 | model_kwargs["scale_watershed"] = 1.0 501 | 502 | def offload_model(): 503 | print("Offloading Lumina model...") 504 | model.to(offload_device) 505 | mm.soft_empty_cache() 506 | gc.collect() 507 | 508 | #inference 509 | model.to(device) 510 | try: 511 | samples = ode.sample(z, model.forward_with_cfg, **model_kwargs)[-1] 512 | except Exception as e: 513 | if not keep_model_loaded: 514 | offload_model() 515 | print(e) 516 | raise mm.InterruptProcessingException() 517 | 518 | if not keep_model_loaded: 519 | offload_model() 520 | 521 | samples = samples[:len(samples) // 2] 522 | samples = samples / vae_scaling_factor 523 | 524 | return ({'samples': samples},) 525 | 526 | NODE_CLASS_MAPPINGS = { 527 | "LuminaT2ISampler": LuminaT2ISampler, 528 | "DownloadAndLoadLuminaModel": DownloadAndLoadLuminaModel, 529 | "DownloadAndLoadGemmaModel": DownloadAndLoadGemmaModel, 530 | "LuminaGemmaTextEncode": LuminaGemmaTextEncode, 531 | "LuminaGemmaTextEncodeArea": LuminaGemmaTextEncodeArea, 532 | "LuminaTextAreaAppend": LuminaTextAreaAppend, 533 | "GemmaSampler": GemmaSampler 534 | } 535 | NODE_DISPLAY_NAME_MAPPINGS = { 536 | "LuminaT2ISampler": "Lumina T2I Sampler", 537 | "DownloadAndLoadLuminaModel": "DownloadAndLoadLuminaModel", 538 | "DownloadAndLoadGemmaModel": "DownloadAndLoadGemmaModel", 539 | "LuminaGemmaTextEncode": "Lumina Gemma Text Encode", 540 | "LuminaGemmaTextEncodeArea": "Lumina Gemma Text Encode Area", 541 | "LuminaTextAreaAppend": "Lumina Text Area Append", 542 | "GemmaSampler": "Gemma Sampler" 543 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-luminawrapper" 3 | description = "ComfyUI wrapper nodes for Lumina models" 4 | version = "1.0.1" 5 | license = "MIT" 6 | dependencies = ["torchdiffeq", "accelerate", "tqdm", "transformers>=4.38.0"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-LuminaWrapper" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-LuminaWrapper" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchdiffeq 2 | accelerate 3 | tqdm 4 | transformers>=4.38.0 -------------------------------------------------------------------------------- /transport.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torchdiffeq import odeint 3 | from comfy.utils import ProgressBar 4 | from tqdm import tqdm 5 | 6 | def sample(x1): 7 | """Sampling x0 & t based on shape of x1 (if needed) 8 | Args: 9 | x1 - data point; [batch, *dim] 10 | """ 11 | if isinstance(x1, (list, tuple)): 12 | x0 = [th.randn_like(img_start) for img_start in x1] 13 | else: 14 | x0 = th.randn_like(x1) 15 | 16 | t = th.rand((len(x1),)) 17 | t = t.to(x1[0]) 18 | return t, x0, x1 19 | 20 | 21 | def training_losses(model, x1, model_kwargs=None): 22 | """Loss for training the score model 23 | Args: 24 | - model: backbone model; could be score, noise, or velocity 25 | - x1: datapoint 26 | - model_kwargs: additional arguments for the model 27 | """ 28 | if model_kwargs == None: 29 | model_kwargs = {} 30 | 31 | B = len(x1) 32 | 33 | t, x0, x1 = sample(x1) 34 | if isinstance(x1, (list, tuple)): 35 | xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)] 36 | ut = [x1[i] - x0[i] for i in range(B)] 37 | else: 38 | dims = [1] * (len(x1.size()) - 1) 39 | t_ = t.view(t.size(0), *dims) 40 | xt = t_ * x1 + (1 - t_) * x0 41 | ut = x1 - x0 42 | 43 | model_output = model(xt, t, **model_kwargs) 44 | 45 | terms = {} 46 | 47 | if isinstance(x1, (list, tuple)): 48 | terms["loss"] = th.stack( 49 | [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], 50 | dim=0, 51 | ) 52 | else: 53 | terms["loss"] = ((model_output - ut) ** 2).mean(dim=list(range(1, ut.ndim))) 54 | 55 | return terms 56 | 57 | 58 | class ODE: 59 | """ODE solver class""" 60 | 61 | def __init__( 62 | self, 63 | num_steps, 64 | sampler_type="euler", 65 | time_shifting_factor=None, 66 | strength=1.0, 67 | t0=0.0, 68 | t1=1.0, 69 | use_sd3=False, 70 | 71 | ): 72 | if use_sd3: 73 | self.t = th.linspace(t1, t0, num_steps) 74 | if time_shifting_factor: 75 | self.t = (time_shifting_factor * self.t) / (1 + (time_shifting_factor - 1) * self.t) 76 | else: 77 | self.t = th.linspace(t0, t1, num_steps) 78 | if time_shifting_factor: 79 | self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t) 80 | 81 | if strength != 1.0: 82 | self.t = self.t[int(num_steps * (1 - strength)):] 83 | 84 | self.use_sd3 = use_sd3 85 | self.sampler_type = sampler_type 86 | if self.sampler_type == "euler": 87 | total_steps = len(self.t) 88 | else: 89 | total_steps = (len(self.t) * 2) - 2 90 | self.comfy_pbar = ProgressBar(total_steps) 91 | self.pbar = tqdm(total = total_steps, desc='ODE Sampling') 92 | 93 | def sample(self, x, model, **model_kwargs): 94 | device = x[0].device if isinstance(x, tuple) else x.device 95 | 96 | if not self.use_sd3: 97 | def _fn(t, x): 98 | t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t 99 | model_output = model(x, t, **model_kwargs) 100 | self.pbar.update(1) 101 | self.comfy_pbar.update(1) 102 | return model_output 103 | else: 104 | cfg_scale = model_kwargs["cfg_scale"] 105 | model_kwargs.pop("cfg_scale") 106 | def _fn(t, x): 107 | t = th.ones(x.size(0)).to(device) * t * 1000 108 | half_x = x[:len(x) // 2] 109 | x = th.cat([half_x, half_x], dim=0) 110 | model_output = model(hidden_states=x, timestep=t, **model_kwargs)[0] 111 | uncond, cond = model_output.chunk(2, dim=0) 112 | model_output = uncond + cfg_scale * (cond - uncond) 113 | model_output = th.cat([model_output, model_output], dim=0) 114 | return model_output 115 | 116 | t = self.t.to(device) 117 | samples = odeint(_fn, x, t, method=self.sampler_type) 118 | self.pbar.close() 119 | return samples 120 | --------------------------------------------------------------------------------