├── .github └── workflows │ └── publish.yml ├── LICENSE ├── README.md ├── __init__.py ├── images ├── MBW_Layers_showcase.webp ├── Quant_Nodes_showcase.webp ├── SDNext_Merge_showcase.webp ├── VAE_Merge_showcase.webp └── VAE_Repeat_showcase.webp ├── merge.py ├── merge_PermSpec.py ├── merge_PermSpec_SDXL.py ├── merge_methods.py ├── merge_presets.py ├── merge_rebasin.py ├── merge_utils.py ├── pyproject.toml ├── quant_nodes.py ├── sdnextmerge_nodes.py └── vae_merge.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TechNodes 2 | ComfyUI nodes for merging, testing and more. 3 | 4 | 5 | ## Installation 6 | Inside the `ComfyUI/custom_nodes` directory, run: 7 | 8 | ``` 9 | git clone https://github.com/TechnoByteJS/ComfyUI-TechNodes --depth 1 10 | ``` 11 | 12 | ## SDNext Merge 13 | The merger from [SD.Next](https://github.com/vladmandic/automatic) (based on [meh](https://github.com/s1dlx/meh)) ported to ComfyUI, with [Re-Basin](https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion) built-in. 14 | 15 | ![SDNext Merge Showcase](images/SDNext_Merge_showcase.webp) 16 | 17 | ## VAE Merge 18 | A node that lets you merge VAEs using multiple methods (and support for individual blocks), and adjust the brightness or contrast. 19 | 20 | ![VAE Merge Showcase](images/VAE_Merge_showcase.webp) 21 | 22 | ## MBW Layers 23 | Allows for advanced merging by adjusting the alpha of each U-Net block individually, with binary versions that make it easy to extract specific layers. 24 | 25 | ![MBW Layers Showcase](images/MBW_Layers_showcase.webp) 26 | 27 | ## Repeat VAE 28 | A node that encodes and decodes an image with a VAE a specified amount of times, useful for testing and comparing the performance of different VAEs. 29 | 30 | ![Repeat VAE Showcase](images/VAE_Repeat_showcase.webp) 31 | 32 | ## Quantization 33 | Quantize the U-Net, CLIP, or VAE to the specified amount of bits 34 | > Note: This is purely experimental, there is no speed or storage benefits from this. 35 | 36 | ![Quant Nodes Showcase](images/Quant_Nodes_showcase.webp) 37 | 38 | ### Credits 39 | To create these nodes, I used code from: 40 | - [SD.Next](https://github.com/vladmandic/automatic) 41 | - [meh](https://github.com/s1dlx/meh) 42 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 43 | - [VAE-BlessUp](https://github.com/sALTaccount/VAE-BlessUp) 44 | 45 | Thank you [Kybalico](https://github.com/kybalico/) and [NovaZone](https://civitai.com/user/nova1337) for helping me test, and providing suggestions! ✨ -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .sdnextmerge_nodes import * 2 | from .vae_merge import * 3 | from .quant_nodes import * 4 | 5 | NODE_CLASS_MAPPINGS = { 6 | "SDNext Merge": SDNextMerge, 7 | "VAE Merge": VAEMerge, 8 | 9 | "SD1 MBW Layers": SD1_MBWLayers, 10 | "SD1 MBW Layers Binary": SD1_MBWLayers_Binary, 11 | "SDXL MBW Layers": SDXL_MBWLayers, 12 | "SDXL MBW Layers Binary": SDXL_MBWLayers_Binary, 13 | "MBW Layers String": MBWLayers_String, 14 | 15 | "VAERepeat": VAERepeat, 16 | 17 | "ModelQuant": ModelQuant, 18 | "ClipQuant": ClipQuant, 19 | "VAEQuant": VAEQuant, 20 | } 21 | 22 | NODE_DISPLAY_NAME_MAPPINGS = { 23 | "SDNext Merge": "SDNext Merge", 24 | "VAE Merge": "VAE Merge", 25 | 26 | "SD1 MBW Layers": "SD1 MBW Layers", 27 | "SD1 MBW Layers Binary": "SD1 MBW Layers Binary", 28 | "SDXL MBW Layers": "SDXL MBW Layers", 29 | "SDXL MBW Layers Binary": "SDXL MBW Layers Binary", 30 | "MBW Layers String": "MBW Layers String", 31 | 32 | "VAERepeat": "Repeat VAE", 33 | 34 | "ModelQuant": "ModelQuant", 35 | "ClipQuant": "ClipQuant", 36 | "VAEQuant": "VAEQuant", 37 | } 38 | 39 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 40 | -------------------------------------------------------------------------------- /images/MBW_Layers_showcase.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/MBW_Layers_showcase.webp -------------------------------------------------------------------------------- /images/Quant_Nodes_showcase.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/Quant_Nodes_showcase.webp -------------------------------------------------------------------------------- /images/SDNext_Merge_showcase.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/SDNext_Merge_showcase.webp -------------------------------------------------------------------------------- /images/VAE_Merge_showcase.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/VAE_Merge_showcase.webp -------------------------------------------------------------------------------- /images/VAE_Repeat_showcase.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/VAE_Repeat_showcase.webp -------------------------------------------------------------------------------- /merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | from contextlib import contextmanager 4 | from typing import Dict, Optional, Tuple, Set 5 | import safetensors.torch 6 | import torch 7 | from . import merge_methods 8 | from .merge_utils import WeightClass 9 | from .merge_rebasin import ( 10 | apply_permutation, 11 | update_model_a, 12 | weight_matching, 13 | ) 14 | from .merge_PermSpec import sdunet_permutation_spec 15 | from .merge_PermSpec_SDXL import sdxl_permutation_spec 16 | 17 | from tqdm import tqdm 18 | 19 | import comfy.utils 20 | import comfy.model_management 21 | 22 | MAX_TOKENS = 77 23 | 24 | 25 | KEY_POSITION_IDS = ".".join( 26 | [ 27 | "cond_stage_model", 28 | "transformer", 29 | "text_model", 30 | "embeddings", 31 | "position_ids", 32 | ] 33 | ) 34 | 35 | 36 | def fix_clip(model: Dict) -> Dict: 37 | if KEY_POSITION_IDS in model.keys(): 38 | model[KEY_POSITION_IDS] = torch.tensor( 39 | [list(range(MAX_TOKENS))], 40 | dtype=torch.int64, 41 | device=model[KEY_POSITION_IDS].device, 42 | ) 43 | 44 | return model 45 | 46 | 47 | def prune_sd_model(model: Dict, keyset: Set) -> Dict: 48 | keys = list(model.keys()) 49 | for k in keys: 50 | if ( 51 | not k.startswith("model.diffusion_model.") # UNET 52 | # and not k.startswith("first_stage_model.") # VAE 53 | and not k.startswith("cond_stage_model.") # CLIP 54 | and not k.startswith("conditioner.embedders.") # SDXL CLIP 55 | ) or k not in keyset: 56 | del model[k] 57 | return model 58 | 59 | 60 | def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict: 61 | for k in original_model: 62 | if k not in merged_model: 63 | merged_model[k] = original_model[k] 64 | return merged_model 65 | 66 | def load_thetas( 67 | model_paths: Dict[str, os.PathLike], 68 | should_prune: bool, 69 | target_device: torch.device, 70 | precision: str, 71 | ) -> Dict: 72 | """ 73 | Load and process model parameters from given paths. 74 | 75 | Args: 76 | model_paths: Dictionary of model names and their file paths 77 | should_prune: Flag to determine if models should be pruned 78 | target_device: The device to load the models onto 79 | precision: The precision to use for the model parameters 80 | 81 | Returns: 82 | Dictionary of processed model parameters 83 | """ 84 | # Load model parameters from files 85 | model_params = { 86 | model_name: comfy.utils.load_torch_file(model_path) 87 | for model_name, model_path in model_paths.items() 88 | } 89 | 90 | if should_prune: 91 | # Find common keys across all models 92 | common_keys = set.intersection(*[set(model.keys()) for model in model_params.values() if len(model.keys())]) 93 | # Prune models to keep only common parameters 94 | model_params = { 95 | model_name: prune_sd_model(model, common_keys) 96 | for model_name, model in model_params.items() 97 | } 98 | 99 | # Process each model's parameters 100 | for model_name, model in model_params.items(): 101 | for param_name, param_tensor in model.items(): 102 | if precision == "fp16": 103 | # Convert to half precision and move to target device 104 | model_params[model_name].update({param_name: param_tensor.to(target_device).half()}) 105 | else: 106 | # Move to target device maintaining original precision 107 | model_params[model_name].update({param_name: param_tensor.to(target_device)}) 108 | 109 | print("Models loaded successfully") 110 | return model_params 111 | 112 | def merge_models( 113 | models: Dict[str, os.PathLike], 114 | merge_mode: str, 115 | precision: str = "fp16", 116 | weights_clip: bool = False, 117 | device: torch.device = None, 118 | work_device: torch.device = None, 119 | prune: bool = False, 120 | threads: int = 4, 121 | optional_model_a = None, 122 | optional_clip_a = None, 123 | optional_model_b = None, 124 | optional_clip_b = None, 125 | optional_model_c = None, 126 | optional_clip_c = None, 127 | **kwargs, 128 | ) -> Dict: 129 | print("Alpha:") 130 | print(kwargs["alpha"]) 131 | 132 | if models == { }: 133 | thetas = { } 134 | else: 135 | thetas = load_thetas(models, prune, device, precision) 136 | 137 | if "model_a" not in thetas: 138 | thetas["model_a"] = {} 139 | 140 | if "model_b" not in thetas: 141 | thetas["model_b"] = {} 142 | 143 | if optional_model_a is not None: 144 | key_patches = optional_model_a.get_key_patches() 145 | for key in key_patches: 146 | if "diffusion_model." in key: 147 | thetas["model_a"]["model." + key] = key_patches[key][0] 148 | 149 | if optional_clip_a is not None: 150 | key_patches = optional_clip_a.get_key_patches() 151 | for key in key_patches: 152 | if "transformer." in key and "text_projection" not in key: 153 | thetas["model_a"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0] 154 | 155 | if optional_model_b is not None: 156 | key_patches = optional_model_b.get_key_patches() 157 | for key in key_patches: 158 | if "diffusion_model." in key: 159 | thetas["model_b"]["model." + key] = key_patches[key][0] 160 | 161 | if optional_clip_b is not None: 162 | key_patches = optional_clip_b.get_key_patches() 163 | for key in key_patches: 164 | if "transformer." in key and "text_projection" not in key: 165 | thetas["model_b"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0] 166 | 167 | if optional_model_c is not None: 168 | if "model_c" not in thetas: 169 | thetas["model_c"] = {} 170 | key_patches = optional_model_c.get_key_patches() 171 | for key in key_patches: 172 | if "diffusion_model." in key: 173 | thetas["model_c"]["model." + key] = key_patches[key][0] 174 | 175 | if optional_clip_c is not None: 176 | if "model_c" not in thetas: 177 | thetas["model_c"] = {} 178 | key_patches = optional_clip_c.get_key_patches() 179 | for key in key_patches: 180 | if "transformer." in key and "text_projection" not in key: 181 | thetas["model_c"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0] 182 | 183 | print(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} prune={prune} threads={threads}') 184 | weight_matcher = WeightClass(thetas["model_a"], **kwargs) 185 | if kwargs.get("re_basin", False): 186 | merged = rebasin_merge( 187 | thetas, 188 | weight_matcher, 189 | merge_mode, 190 | precision=precision, 191 | weights_clip=weights_clip, 192 | iterations=kwargs.get("re_basin_iterations", 1), 193 | device=device, 194 | work_device=work_device, 195 | threads=threads, 196 | ) 197 | else: 198 | merged = simple_merge( 199 | thetas, 200 | weight_matcher, 201 | merge_mode, 202 | precision=precision, 203 | weights_clip=weights_clip, 204 | device=device, 205 | work_device=work_device, 206 | threads=threads, 207 | ) 208 | 209 | return fix_clip(merged) 210 | 211 | def simple_merge( 212 | thetas: Dict[str, Dict], 213 | weight_matcher: WeightClass, 214 | merge_mode: str, 215 | precision: str = "fp16", 216 | weights_clip: bool = False, 217 | device: torch.device = None, 218 | work_device: torch.device = None, 219 | threads: int = 4, 220 | ) -> Dict: 221 | futures = [] 222 | with tqdm(thetas["model_a"].keys(), desc="Merge") as progress: 223 | with ThreadPoolExecutor(max_workers=threads) as executor: 224 | for key in thetas["model_a"].keys(): 225 | future = executor.submit( 226 | simple_merge_key, 227 | progress, 228 | key, 229 | thetas, 230 | weight_matcher, 231 | merge_mode, 232 | precision, 233 | weights_clip, 234 | device, 235 | work_device, 236 | ) 237 | futures.append(future) 238 | 239 | for res in futures: 240 | res.result() 241 | 242 | if len(thetas["model_b"]) > 0: 243 | print(f'Merge update thetas: keys={len(thetas["model_b"])}') 244 | for key in thetas["model_b"].keys(): 245 | if KEY_POSITION_IDS in key: 246 | continue 247 | if "model" in key and key not in thetas["model_a"].keys(): 248 | thetas["model_a"].update({key: thetas["model_b"][key]}) 249 | if precision == "fp16": 250 | thetas["model_a"].update({key: thetas["model_a"][key].half()}) 251 | 252 | return fix_clip(thetas["model_a"]) 253 | 254 | 255 | def rebasin_merge( 256 | thetas: Dict[str, os.PathLike], 257 | weight_matcher: WeightClass, 258 | merge_mode: str, 259 | precision: str = "fp16", 260 | weights_clip: bool = False, 261 | iterations: int = 1, 262 | device: torch.device = None, 263 | work_device: torch.device = None, 264 | threads: int = 1, 265 | ): 266 | # not sure how this does when 3 models are involved... 267 | model_a = thetas["model_a"] 268 | if weight_matcher.SDXL: 269 | perm_spec = sdxl_permutation_spec() 270 | else: 271 | perm_spec = sdunet_permutation_spec() 272 | 273 | for it in range(iterations): 274 | print(f"rebasin: iteration={it+1}") 275 | weight_matcher.set_it(it) 276 | 277 | # normal block merge we already know and love 278 | thetas["model_a"] = simple_merge( 279 | thetas, 280 | weight_matcher, 281 | merge_mode, 282 | precision, 283 | False, 284 | device, 285 | work_device, 286 | threads, 287 | ) 288 | 289 | # find permutations 290 | perm_1, y = weight_matching( 291 | perm_spec, 292 | model_a, 293 | thetas["model_a"], 294 | max_iter=it, 295 | init_perm=None, 296 | usefp16=precision == "fp16", 297 | device=device, 298 | ) 299 | thetas["model_a"] = apply_permutation(perm_spec, perm_1, thetas["model_a"]) 300 | 301 | perm_2, z = weight_matching( 302 | perm_spec, 303 | thetas["model_b"], 304 | thetas["model_a"], 305 | max_iter=it, 306 | init_perm=None, 307 | usefp16=precision == "fp16", 308 | device=device, 309 | ) 310 | 311 | new_alpha = torch.nn.functional.normalize( 312 | torch.sigmoid(torch.Tensor([y, z])), p=1, dim=0 313 | ).tolist()[0] 314 | thetas["model_a"] = update_model_a( 315 | perm_spec, perm_2, thetas["model_a"], new_alpha 316 | ) 317 | 318 | if weights_clip: 319 | clip_thetas = thetas.copy() 320 | clip_thetas["model_a"] = model_a 321 | thetas["model_a"] = clip_weights(thetas, thetas["model_a"]) 322 | 323 | return thetas["model_a"] 324 | 325 | 326 | def simple_merge_key(progress, key, thetas, *args, **kwargs): 327 | with merge_key_context(key, thetas, *args, **kwargs) as result: 328 | if result is not None: 329 | thetas["model_a"].update({key: result.detach().clone()}) 330 | progress.update(1) 331 | 332 | 333 | def merge_key( # pylint: disable=inconsistent-return-statements 334 | key: str, 335 | thetas: Dict, 336 | weight_matcher: WeightClass, 337 | merge_mode: str, 338 | precision: str = "fp16", 339 | weights_clip: bool = False, 340 | device: torch.device = None, 341 | work_device: torch.device = None, 342 | ) -> Optional[Tuple[str, Dict]]: 343 | if work_device is None: 344 | work_device = device 345 | 346 | if KEY_POSITION_IDS in key: 347 | return 348 | 349 | for theta in thetas.values(): 350 | if key not in theta.keys(): 351 | return thetas["model_a"][key] 352 | 353 | current_bases = weight_matcher(key) 354 | try: 355 | merge_method = getattr(merge_methods, merge_mode) 356 | except AttributeError as e: 357 | raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e 358 | 359 | merge_args = get_merge_method_args(current_bases, thetas, key, work_device) 360 | 361 | # dealing with pix2pix and inpainting models 362 | if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()): 363 | if a_size[1] > b_size[1]: 364 | merged_key = merge_args["a"] 365 | else: 366 | merged_key = merge_args["b"] 367 | else: 368 | merged_key = merge_method(**merge_args).to(device) 369 | 370 | if weights_clip: 371 | merged_key = clip_weights_key(thetas, merged_key, key) 372 | 373 | if precision == "fp16": 374 | merged_key = merged_key.half() 375 | 376 | return merged_key 377 | 378 | 379 | def clip_weights(thetas, merged): 380 | for k in thetas["model_a"].keys(): 381 | if k in thetas["model_b"].keys(): 382 | merged.update({k: clip_weights_key(thetas, merged[k], k)}) 383 | return merged 384 | 385 | def clip_weights_key(thetas, merged_weights, key): 386 | # Determine the device of the merged_weights 387 | device = merged_weights.device 388 | 389 | # Move all tensors to the same device 390 | t0 = thetas["model_a"][key].to(device) 391 | t1 = thetas["model_b"][key].to(device) 392 | 393 | maximums = torch.maximum(t0, t1) 394 | minimums = torch.minimum(t0, t1) 395 | 396 | return torch.minimum(torch.maximum(merged_weights, minimums), maximums) 397 | 398 | @contextmanager 399 | def merge_key_context(*args, **kwargs): 400 | result = merge_key(*args, **kwargs) 401 | try: 402 | yield result 403 | finally: 404 | if result is not None: 405 | del result 406 | 407 | 408 | def get_merge_method_args( 409 | current_bases: Dict, 410 | thetas: Dict, 411 | key: str, 412 | work_device: torch.device, 413 | ) -> Dict: 414 | merge_method_args = { 415 | "a": thetas["model_a"][key].to(work_device), 416 | "b": thetas["model_b"][key].to(work_device), 417 | **current_bases, 418 | } 419 | 420 | if "model_c" in thetas: 421 | merge_method_args["c"] = thetas["model_c"][key].to(work_device) 422 | 423 | return merge_method_args 424 | -------------------------------------------------------------------------------- /merge_PermSpec.py: -------------------------------------------------------------------------------- 1 | from .merge_rebasin import PermutationSpec, permutation_spec_from_axes_to_perm 2 | def sdunet_permutation_spec() -> PermutationSpec: 3 | conv = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment 4 | f"{name}.weight": ( 5 | p_out, 6 | p_in, 7 | ), 8 | f"{name}.bias": (p_out,), 9 | } 10 | norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)} # pylint: disable=unnecessary-lambda-assignment 11 | dense = ( 12 | lambda name, p_in, p_out, bias=True: { # pylint: disable=unnecessary-lambda-assignment 13 | f"{name}.weight": (p_out, p_in), 14 | f"{name}.bias": (p_out,), 15 | } 16 | if bias 17 | else {f"{name}.weight": (p_out, p_in)} 18 | ) 19 | skip = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment 20 | f"{name}": ( 21 | p_out, 22 | p_in, 23 | None, 24 | None, 25 | ) 26 | } 27 | 28 | # Unet Res blocks 29 | easyblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment 30 | **norm(f"{name}.in_layers.0", p_in), 31 | **conv(f"{name}.in_layers.2", p_in, f"P_{name}_inner"), 32 | **dense( 33 | f"{name}.emb_layers.1", f"P_{name}_inner2", f"P_{name}_inner3", bias=True 34 | ), 35 | **norm(f"{name}.out_layers.0", f"P_{name}_inner4"), 36 | **conv(f"{name}.out_layers.3", f"P_{name}_inner4", p_out), 37 | } 38 | 39 | # VAE blocks - Unused 40 | easyblock2 = lambda name, p: { # pylint: disable=unnecessary-lambda-assignment, unused-variable # noqa: F841 41 | **norm(f"{name}.norm1", p), 42 | **conv(f"{name}.conv1", p, f"P_{name}_inner"), 43 | **norm(f"{name}.norm2", f"P_{name}_inner"), 44 | **conv(f"{name}.conv2", f"P_{name}_inner", p), 45 | } 46 | 47 | # This is for blocks that use a residual connection, but change the number of channels via a Conv. 48 | shortcutblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment, , unused-variable # noqa: F841 49 | **norm(f"{name}.norm1", p_in), 50 | **conv(f"{name}.conv1", p_in, f"P_{name}_inner"), 51 | **norm(f"{name}.norm2", f"P_{name}_inner"), 52 | **conv(f"{name}.conv2", f"P_{name}_inner", p_out), 53 | **conv(f"{name}.nin_shortcut", p_in, p_out), 54 | **norm(f"{name}.nin_shortcut", p_out), 55 | } 56 | 57 | return permutation_spec_from_axes_to_perm( 58 | { 59 | # Skipped Layers 60 | **skip("betas", None, None), 61 | **skip("alphas_cumprod", None, None), 62 | **skip("alphas_cumprod_prev", None, None), 63 | **skip("sqrt_alphas_cumprod", None, None), 64 | **skip("sqrt_one_minus_alphas_cumprod", None, None), 65 | **skip("log_one_minus_alphas_cumprods", None, None), 66 | **skip("sqrt_recip_alphas_cumprod", None, None), 67 | **skip("sqrt_recipm1_alphas_cumprod", None, None), 68 | **skip("posterior_variance", None, None), 69 | **skip("posterior_log_variance_clipped", None, None), 70 | **skip("posterior_mean_coef1", None, None), 71 | **skip("posterior_mean_coef2", None, None), 72 | **skip("log_one_minus_alphas_cumprod", None, None), 73 | **skip("model_ema.decay", None, None), 74 | **skip("model_ema.num_updates", None, None), 75 | # initial 76 | **dense("model.diffusion_model.time_embed.0", None, "P_bg0", bias=True), 77 | **dense("model.diffusion_model.time_embed.2", "P_bg0", "P_bg1", bias=True), 78 | **conv("model.diffusion_model.input_blocks.0.0", "P_bg2", "P_bg3"), 79 | # input blocks 80 | **easyblock("model.diffusion_model.input_blocks.1.0", "P_bg4", "P_bg5"), 81 | **norm("model.diffusion_model.input_blocks.1.1.norm", "P_bg6"), 82 | **conv("model.diffusion_model.input_blocks.1.1.proj_in", "P_bg6", "P_bg7"), 83 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q", "P_bg8", "P_bg9", bias=False), 84 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k", "P_bg8", "P_bg9", bias=False), 85 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v", "P_bg8", "P_bg9", bias=False), 86 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0", "P_bg8", "P_bg9", bias=True), 87 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj", "P_bg10", "P_bg11", bias=True), 88 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2", "P_bg12", "P_bg13", bias=True), 89 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q", "P_bg14", "P_bg15", bias=False), 90 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k", "P_bg16", "P_bg17", bias=False), 91 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v", "P_bg16", "P_bg17", bias=False), 92 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0", "P_bg18", "P_bg19", bias=True), 93 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1", "P_bg19"), 94 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2", "P_bg19"), 95 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3", "P_bg19"), 96 | **conv("model.diffusion_model.input_blocks.1.1.proj_out", "P_bg19", "P_bg20"), 97 | **easyblock("model.diffusion_model.input_blocks.2.0", "P_bg21", "P_bg22"), 98 | **norm("model.diffusion_model.input_blocks.2.1.norm", "P_bg23"), 99 | **conv("model.diffusion_model.input_blocks.2.1.proj_in", "P_bg23", "P_bg24"), 100 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q", "P_bg25", "P_bg26", bias=False), 101 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k", "P_bg25", "P_bg26", bias=False), 102 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v", "P_bg25", "P_bg26", bias=False), 103 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0", "P_bg25", "P_bg26", bias=True), 104 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj", "P_bg27", "P_bg28", bias=True), 105 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2", "P_bg29", "P_bg30", bias=True), 106 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q", "P_bg31", "P_bg32", bias=False), 107 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k", "P_bg33", "P_bg34", bias=False), 108 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v", "P_bg33", "P_bg34", bias=False), 109 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0", "P_bg35", "P_bg36", bias=True), 110 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1", "P_bg36"), 111 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2", "P_bg36"), 112 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3", "P_bg36"), 113 | **conv("model.diffusion_model.input_blocks.2.1.proj_out", "P_bg36", "P_bg37"), 114 | **conv("model.diffusion_model.input_blocks.3.0.op", "P_bg38", "P_bg39"), 115 | **easyblock("model.diffusion_model.input_blocks.4.0", "P_bg40", "P_bg41"), 116 | **conv("model.diffusion_model.input_blocks.4.0.skip_connection", "P_bg42", "P_bg43"), 117 | **norm("model.diffusion_model.input_blocks.4.1.norm", "P_bg44"), 118 | **conv("model.diffusion_model.input_blocks.4.1.proj_in", "P_bg44", "P_bg45"), 119 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q", "P_bg46", "P_bg47", bias=False), 120 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k", "P_bg46", "P_bg47", bias=False), 121 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v", "P_bg46", "P_bg47", bias=False), 122 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0", "P_bg46", "P_bg47", bias=True), 123 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj", "P_bg48", "P_bg49", bias=True), 124 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2", "P_bg50", "P_bg51", bias=True), 125 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q", "P_bg52", "P_bg53", bias=False), 126 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k", "P_bg54", "P_bg55", bias=False), 127 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v", "P_bg54", "P_bg55", bias=False), 128 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0", "P_bg56", "P_bg57", bias=True), 129 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1", "P_bg57"), 130 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2", "P_bg57"), 131 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3", "P_bg57"), 132 | **conv("model.diffusion_model.input_blocks.4.1.proj_out", "P_bg57", "P_bg58"), 133 | **easyblock("model.diffusion_model.input_blocks.5.0", "P_bg59", "P_bg60"), 134 | **norm("model.diffusion_model.input_blocks.5.1.norm", "P_bg61"), 135 | **conv("model.diffusion_model.input_blocks.5.1.proj_in", "P_bg61", "P_bg62"), 136 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q", "P_bg63", "P_bg64", bias=False), 137 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k", "P_bg63", "P_bg64", bias=False), 138 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v", "P_bg63", "P_bg64", bias=False), 139 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0", "P_bg63", "P_bg64", bias=True), 140 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj", "P_bg65", "P_bg66", bias=True), 141 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2", "P_bg67", "P_bg68", bias=True), 142 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q", "P_bg69", "P_bg70", bias=False), 143 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k", "P_bg71", "P_bg72", bias=False), 144 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v", "P_bg71", "P_bg72", bias=False), 145 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0", "P_bg73", "P_bg74", bias=True), 146 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1", "P_bg74"), 147 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2", "P_bg74"), 148 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3", "P_bg74"), 149 | **conv("model.diffusion_model.input_blocks.5.1.proj_out", "P_bg74", "P_bg75"), 150 | **conv("model.diffusion_model.input_blocks.6.0.op", "P_bg76", "P_bg77"), 151 | **easyblock("model.diffusion_model.input_blocks.7.0", "P_bg78", "P_bg79"), 152 | **conv("model.diffusion_model.input_blocks.7.0.skip_connection", "P_bg80", "P_bg81"), 153 | **norm("model.diffusion_model.input_blocks.7.1.norm", "P_bg82"), 154 | **conv("model.diffusion_model.input_blocks.7.1.proj_in", "P_bg82", "P_bg83"), 155 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q", "P_bg84", "P_bg85", bias=False), 156 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k", "P_bg84", "P_bg85", bias=False), 157 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v", "P_bg84", "P_bg85", bias=False), 158 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0", "P_bg84", "P_bg85", bias=True), 159 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj", "P_bg86", "P_bg87", bias=True), 160 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2", "P_bg88", "P_bg89", bias=True), 161 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q", "P_bg90", "P_bg91", bias=False), 162 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k", "P_bg92", "P_bg93", bias=False), 163 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v", "P_bg92", "P_bg93", bias=False), 164 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0", "P_bg94", "P_bg95", bias=True), 165 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1", "P_bg95"), 166 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2", "P_bg95"), 167 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3", "P_bg95"), 168 | **conv("model.diffusion_model.input_blocks.7.1.proj_out", "P_bg95", "P_bg96"), 169 | **easyblock("model.diffusion_model.input_blocks.8.0", "P_bg97", "P_bg98"), 170 | **norm("model.diffusion_model.input_blocks.8.1.norm", "P_bg99"), 171 | **conv("model.diffusion_model.input_blocks.8.1.proj_in", "P_bg99", "P_bg100"), 172 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q", "P_bg101", "P_bg102", bias=False), 173 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k", "P_bg101", "P_bg102", bias=False), 174 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v", "P_bg101", "P_bg102", bias=False), 175 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0", "P_bg101", "P_bg102", bias=True), 176 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj", "P_bg103", "P_bg104", bias=True), 177 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2", "P_bg105", "P_bg106", bias=True), 178 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q", "P_bg107", "P_bg108", bias=False), 179 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k", "P_bg109", "P_bg110", bias=False), 180 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v", "P_bg109", "P_bg110", bias=False), 181 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0", "P_bg111", "P_bg112", bias=True), 182 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1", "P_bg112"), 183 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2", "P_bg112"), 184 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3", "P_bg112"), 185 | **conv("model.diffusion_model.input_blocks.8.1.proj_out", "P_bg112", "P_bg113"), 186 | **conv("model.diffusion_model.input_blocks.9.0.op", "P_bg114", "P_bg115"), 187 | **easyblock("model.diffusion_model.input_blocks.10.0", "P_bg115", "P_bg116"), 188 | **easyblock("model.diffusion_model.input_blocks.11.0", "P_bg116", "P_bg117"), 189 | # middle blocks 190 | **easyblock("model.diffusion_model.middle_block.0", "P_bg117", "P_bg118"), 191 | **norm("model.diffusion_model.middle_block.1.norm", "P_bg119"), 192 | **conv("model.diffusion_model.middle_block.1.proj_in", "P_bg119", "P_bg120"), 193 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q", "P_bg121", "P_bg122", bias=False), 194 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k", "P_bg121", "P_bg122", bias=False), 195 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v", "P_bg121", "P_bg122", bias=False), 196 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0", "P_bg121", "P_bg122", bias=True), 197 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj", "P_bg123", "P_bg124", bias=True), 198 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2", "P_bg125", "P_bg126", bias=True), 199 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q", "P_bg127", "P_bg128", bias=False), 200 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k", "P_bg129", "P_bg130", bias=False), 201 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v", "P_bg129", "P_bg130", bias=False), 202 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0", "P_bg131", "P_bg132", bias=True), 203 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm1", "P_bg132"), 204 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm2", "P_bg132"), 205 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm3", "P_bg132"), 206 | **conv("model.diffusion_model.middle_block.1.proj_out", "P_bg132", "P_bg133"), 207 | **easyblock("model.diffusion_model.middle_block.2", "P_bg134", "P_bg135"), 208 | # output blocks 209 | **easyblock("model.diffusion_model.output_blocks.0.0", "P_bg136", "P_bg137"), 210 | **conv("model.diffusion_model.output_blocks.0.0.skip_connection", "P_bg138", "P_bg139"), 211 | **easyblock("model.diffusion_model.output_blocks.1.0", "P_bg140", "P_bg141"), 212 | **conv("model.diffusion_model.output_blocks.1.0.skip_connection", "P_bg142", "P_bg143"), 213 | **easyblock("model.diffusion_model.output_blocks.2.0", "P_bg144", "P_bg145"), 214 | **conv("model.diffusion_model.output_blocks.2.0.skip_connection", "P_bg146", "P_bg147"), 215 | **conv("model.diffusion_model.output_blocks.2.1.conv", "P_bg148", "P_bg149"), 216 | **easyblock("model.diffusion_model.output_blocks.3.0", "P_bg150", "P_bg151"), 217 | **conv("model.diffusion_model.output_blocks.3.0.skip_connection", "P_bg152", "P_bg153"), 218 | **norm("model.diffusion_model.output_blocks.3.1.norm", "P_bg154"), 219 | **conv("model.diffusion_model.output_blocks.3.1.proj_in", "P_bg154", "P_bg155"), 220 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q", "P_bg156", "P_bg157", bias=False), 221 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k", "P_bg156", "P_bg157", bias=False), 222 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v", "P_bg156", "P_bg157", bias=False), 223 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0", "P_bg156", "P_bg157", bias=True), 224 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj", "P_bg158", "P_bg159", bias=True), 225 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2", "P_bg160", "P_bg161", bias=True), 226 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q", "P_bg162", "P_bg163", bias=False), 227 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k", "P_bg164", "P_bg165", bias=False), 228 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v", "P_bg164", "P_bg165", bias=False), 229 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0", "P_bg166", "P_bg167", bias=True), 230 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1", "P_bg167"), 231 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2", "P_bg167"), 232 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3", "P_bg167"), 233 | **conv("model.diffusion_model.output_blocks.3.1.proj_out", "P_bg167", "P_bg168"), 234 | **easyblock("model.diffusion_model.output_blocks.4.0", "P_bg169", "P_bg170"), 235 | **conv("model.diffusion_model.output_blocks.4.0.skip_connection", "P_bg171", "P_bg172"), 236 | **norm("model.diffusion_model.output_blocks.4.1.norm", "P_bg173"), 237 | **conv("model.diffusion_model.output_blocks.4.1.proj_in", "P_bg173", "P_bg174"), 238 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q", "P_bg175", "P_bg176", bias=False), 239 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k", "P_bg175", "P_bg176", bias=False), 240 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v", "P_bg175", "P_bg176", bias=False), 241 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0", "P_bg175", "P_bg176", bias=True), 242 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj", "P_bg177", "P_bg178", bias=True), 243 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2", "P_bg179", "P_bg180", bias=True), 244 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q", "P_bg181", "P_bg182", bias=False), 245 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k", "P_bg183", "P_bg184", bias=False), 246 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v", "P_bg183", "P_bg184", bias=False), 247 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0", "P_bg185", "P_bg186", bias=True), 248 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1", "P_bg186"), 249 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2", "P_bg186"), 250 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3", "P_bg186"), 251 | **conv("model.diffusion_model.output_blocks.4.1.proj_out", "P_bg186", "P_bg187"), 252 | **easyblock("model.diffusion_model.output_blocks.5.0", "P_bg188", "P_bg189"), 253 | **conv("model.diffusion_model.output_blocks.5.0.skip_connection", "P_bg190", "P_bg191"), 254 | **norm("model.diffusion_model.output_blocks.5.1.norm", "P_bg192"), 255 | **conv("model.diffusion_model.output_blocks.5.1.proj_in", "P_bg192", "P_bg193"), 256 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q", "P_bg194", "P_bg195", bias=False), 257 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k", "P_bg194", "P_bg195", bias=False), 258 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v", "P_bg194", "P_bg195", bias=False), 259 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0", "P_bg194", "P_bg195", bias=True), 260 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj", "P_bg196", "P_bg197", bias=True), 261 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2", "P_bg198", "P_bg199", bias=True), 262 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q", "P_bg200", "P_bg201", bias=False), 263 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k", "P_bg202", "P_bg203", bias=False), 264 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v", "P_bg202", "P_bg203", bias=False), 265 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0", "P_bg204", "P_bg205", bias=True), 266 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1", "P_bg205"), 267 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2", "P_bg205"), 268 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3", "P_bg205"), 269 | **conv("model.diffusion_model.output_blocks.5.1.proj_out", "P_bg205", "P_bg206"), 270 | **conv("model.diffusion_model.output_blocks.5.2.conv", "P_bg206", "P_bg207"), 271 | **easyblock("model.diffusion_model.output_blocks.6.0", "P_bg208", "P_bg209"), 272 | **conv("model.diffusion_model.output_blocks.6.0.skip_connection", "P_bg210", "P_bg211"), 273 | **norm("model.diffusion_model.output_blocks.6.1.norm", "P_bg212"), 274 | **conv("model.diffusion_model.output_blocks.6.1.proj_in", "P_bg212", "P_bg213"), 275 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q", "P_bg214", "P_bg215", bias=False), 276 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k", "P_bg214", "P_bg215", bias=False), 277 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v", "P_bg214", "P_bg215", bias=False), 278 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0", "P_bg214", "P_bg215", bias=True), 279 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj", "P_bg216", "P_bg217", bias=True), 280 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2", "P_bg218", "P_bg219", bias=True), 281 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q", "P_bg220", "P_bg221", bias=False), 282 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k", "P_bg222", "P_bg223", bias=False), 283 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v", "P_bg222", "P_bg223", bias=False), 284 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0", "P_bg224", "P_bg225", bias=True), 285 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1", "P_bg225"), 286 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2", "P_bg225"), 287 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3", "P_bg225"), 288 | **conv("model.diffusion_model.output_blocks.6.1.proj_out", "P_bg225", "P_bg226"), 289 | **easyblock("model.diffusion_model.output_blocks.7.0", "P_bg227", "P_bg228"), 290 | **conv("model.diffusion_model.output_blocks.7.0.skip_connection", "P_bg229", "P_bg230"), 291 | **norm("model.diffusion_model.output_blocks.7.1.norm", "P_bg231"), 292 | **conv("model.diffusion_model.output_blocks.7.1.proj_in", "P_bg231", "P_bg232"), 293 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q", "P_bg233", "P_bg234", bias=False), 294 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k", "P_bg233", "P_bg234", bias=False), 295 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v", "P_bg233", "P_bg234", bias=False), 296 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0", "P_bg233", "P_bg234", bias=True), 297 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj", "P_bg235", "P_bg236", bias=True), 298 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2", "P_bg237", "P_bg238", bias=True), 299 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q", "P_bg239", "P_bg240", bias=False), 300 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k", "P_bg241", "P_bg242", bias=False), 301 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v", "P_bg241", "P_bg242", bias=False), 302 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0", "P_bg243", "P_bg244", bias=True), 303 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1", "P_bg244"), 304 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2", "P_bg244"), 305 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3", "P_bg244"), 306 | **conv("model.diffusion_model.output_blocks.7.1.proj_out", "P_bg244", "P_bg245"), 307 | **easyblock("model.diffusion_model.output_blocks.8.0", "P_bg246", "P_bg247"), 308 | **conv("model.diffusion_model.output_blocks.8.0.skip_connection", "P_bg248", "P_bg249"), 309 | **norm("model.diffusion_model.output_blocks.8.1.norm", "P_bg250"), 310 | **conv("model.diffusion_model.output_blocks.8.1.proj_in", "P_bg250", "P_bg251"), 311 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q", "P_bg252", "P_bg253", bias=False), 312 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k", "P_bg252", "P_bg253", bias=False), 313 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v", "P_bg252", "P_bg253", bias=False), 314 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0", "P_bg252", "P_bg253", bias=True), 315 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj", "P_bg254", "P_bg255", bias=True), 316 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2", "P_bg256", "P_bg257", bias=True), 317 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q", "P_bg258", "P_bg259", bias=False), 318 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k", "P_bg260", "P_bg261", bias=False), 319 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v", "P_bg260", "P_bg261", bias=False), 320 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0", "P_bg262", "P_bg263", bias=True), 321 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1", "P_bg263"), 322 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2", "P_bg263"), 323 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3", "P_bg263"), 324 | **conv("model.diffusion_model.output_blocks.8.1.proj_out", "P_bg263", "P_bg264"), 325 | **conv("model.diffusion_model.output_blocks.8.2.conv", "P_bg265", "P_bg266"), 326 | **easyblock("model.diffusion_model.output_blocks.9.0", "P_bg267", "P_bg268"), 327 | **conv("model.diffusion_model.output_blocks.9.0.skip_connection", "P_bg269", "P_bg270"), 328 | **norm("model.diffusion_model.output_blocks.9.1.norm", "P_bg271"), 329 | **conv("model.diffusion_model.output_blocks.9.1.proj_in", "P_bg271", "P_bg272"), 330 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q", "P_bg273", "P_bg274", bias=False), 331 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k", "P_bg273", "P_bg274", bias=False), 332 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v", "P_bg273", "P_bg274", bias=False), 333 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0", "P_bg273", "P_bg274", bias=True), 334 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj", "P_bg275", "P_bg276", bias=True), 335 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2", "P_bg277", "P_bg278", bias=True), 336 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q", "P_bg279", "P_bg280", bias=False), 337 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k", "P_bg281", "P_bg282", bias=False), 338 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v", "P_bg281", "P_bg282", bias=False), 339 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0", "P_bg283", "P_bg284", bias=True), 340 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1", "P_bg284"), 341 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2", "P_bg284"), 342 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3", "P_bg284"), 343 | **conv("model.diffusion_model.output_blocks.9.1.proj_out", "P_bg284", "P_bg285"), 344 | **easyblock("model.diffusion_model.output_blocks.10.0", "P_bg286", "P_bg287"), 345 | **conv("model.diffusion_model.output_blocks.10.0.skip_connection", "P_bg288", "P_bg289"), 346 | **norm("model.diffusion_model.output_blocks.10.1.norm", "P_bg290"), 347 | **conv("model.diffusion_model.output_blocks.10.1.proj_in", "P_bg290", "P_bg291"), 348 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q", "P_bg292", "P_bg293", bias=False), 349 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k", "P_bg292", "P_bg293", bias=False), 350 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v", "P_bg292", "P_bg293", bias=False), 351 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0", "P_bg292", "P_bg293", bias=True), 352 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj", "P_b294", "P_bg295", bias=True), 353 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2", "P_bg296", "P_bg297", bias=True), 354 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q", "P_bg298", "P_bg299", bias=False), 355 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k", "P_bg300", "P_bg301", bias=False), 356 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v", "P_bg300", "P_bg301", bias=False), 357 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0", "P_bg302", "P_bg303", bias=True), 358 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1", "P_bg303"), 359 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2", "P_bg303"), 360 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3", "P_bg303"), 361 | **conv("model.diffusion_model.output_blocks.10.1.proj_out", "P_bg303", "P_bg304"), 362 | **easyblock("model.diffusion_model.output_blocks.11.0", "P_bg305", "P_bg306"), 363 | **conv("model.diffusion_model.output_blocks.11.0.skip_connection", "P_bg307", "P_bg308"), 364 | **norm("model.diffusion_model.output_blocks.11.1.norm", "P_bg309"), 365 | **conv("model.diffusion_model.output_blocks.11.1.proj_in", "P_bg309", "P_bg310"), 366 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q", "P_bg311", "P_bg312", bias=False), 367 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k", "P_bg311", "P_bg312", bias=False), 368 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v", "P_bg311", "P_bg312", bias=False), 369 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0", "P_bg311", "P_bg312", bias=True), 370 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj", "P_bg313", "P_bg314", bias=True), 371 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2", "P_bg315", "P_bg316", bias=True), 372 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q", "P_bg317", "P_bg318", bias=False), 373 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k", "P_bg319", "P_bg320", bias=False), 374 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v", "P_bg319", "P_bg320", bias=False), 375 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0", "P_bg321", "P_bg322", bias=True), 376 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1", "P_bg322"), 377 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2", "P_bg322"), 378 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3", "P_bg322"), 379 | **conv("model.diffusion_model.output_blocks.11.1.proj_out", "P_bg322", "P_bg323"), 380 | **norm("model.diffusion_model.out.0", "P_bg324"), 381 | **conv("model.diffusion_model.out.2", "P_bg325", "P_bg326"), 382 | **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None), 383 | **dense("cond_stage_model.transformer.text_model.embeddings.token_embedding", "P_bg365", "P_bg366", bias=False), 384 | **dense("cond_stage_model.transformer.text_model.embeddings.token_embedding", None, None), 385 | **dense("cond_stage_model.transformer.text_model.embeddings.position_embedding", "P_bg367", "P_bg368", bias=False), 386 | # cond stage text encoder 387 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj", "P_bg369", "P_bg370", bias=True), 388 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj", "P_bg369", "P_bg370", bias=True), 389 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj", "P_bg369", "P_bg370", bias=True), 390 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj", "P_bg369", "P_bg370", bias=True), 391 | **norm("cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1", "P_bg370"), 392 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1", "P_bg370", "P_bg371", bias=True), 393 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2", "P_bg371", "P_bg372", bias=True), 394 | **norm("cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2", "P_bg372"), 395 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj", "P_bg372", "P_bg373", bias=True), 396 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj", "P_bg372", "P_bg373", bias=True), 397 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj", "P_bg372", "P_bg373", bias=True), 398 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj", "P_bg372", "P_bg373", bias=True), 399 | **norm("cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1", "P_bg373"), 400 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1", "P_bg373", "P_bg374", bias=True), 401 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2", "P_bg374", "P_bg375", bias=True), 402 | **norm("cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2", "P_bg375"), 403 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj", "P_bg375", "P_bg376", bias=True), 404 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj", "P_bg375", "P_bg376", bias=True), 405 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj", "P_bg375", "P_bg376", bias=True), 406 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj", "P_bg375", "P_bg376", bias=True), 407 | **norm("cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1", "P_bg376"), 408 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1", "P_bg376", "P_bg377", bias=True), 409 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2", "P_bg377", "P_bg378", bias=True), 410 | **norm("cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2", "P_bg378"), 411 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj", "P_bg378", "P_bg379", bias=True), 412 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj", "P_bg378", "P_bg379", bias=True), 413 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj", "P_bg378", "P_bg379", bias=True), 414 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj", "P_bg378", "P_bg379", bias=True), 415 | **norm("cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1", "P_bg379"), 416 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1", "P_bg379", "P_bg380", bias=True), 417 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2", "P_bg380", "P_b381", bias=True), 418 | **norm("cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2", "P_bg381"), 419 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj", "P_bg381", "P_bg382", bias=True), 420 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj", "P_bg381", "P_bg382", bias=True), 421 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj", "P_bg381", "P_bg382", bias=True), 422 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj", "P_bg381", "P_bg382", bias=True), 423 | **norm("cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1", "P_bg382"), 424 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1", "P_bg382", "P_bg383", bias=True), 425 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2", "P_bg383", "P_bg384", bias=True), 426 | **norm("cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2", "P_bg384"), 427 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj", "P_bg384", "P_bg385", bias=True), 428 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj", "P_bg384", "P_bg385", bias=True), 429 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj", "P_bg384", "P_bg385", bias=True), 430 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj", "P_bg384", "P_bg385", bias=True), 431 | **norm("cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1", "P_bg385"), 432 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1", "P_bg385", "P_bg386", bias=True), 433 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2", "P_bg386", "P_bg387", bias=True), 434 | **norm("cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2", "P_bg387"), 435 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj", "P_bg387", "P_bg388", bias=True), 436 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj", "P_bg387", "P_bg388", bias=True), 437 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj", "P_bg387", "P_bg388", bias=True), 438 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj", "P_bg387", "P_bg388", bias=True), 439 | **norm("cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1", "P_bg389"), 440 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1", "P_bg389", "P_bg390", bias=True), 441 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2", "P_bg390", "P_bg391", bias=True), 442 | **norm("cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2", "P_bg391"), 443 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj", "P_bg391", "P_bg392", bias=True), 444 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj", "P_bg391", "P_bg392", bias=True), 445 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj", "P_bg391", "P_bg392", bias=True), 446 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj", "P_bg391", "P_bg392", bias=True), 447 | **norm("cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1", "P_bg392"), 448 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1", "P_bg392", "P_bg393", bias=True), 449 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2", "P_bg393", "P_bg394", bias=True), 450 | **norm("cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2", "P_bg394"), 451 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj", "P_bg394", "P_bg395", bias=True), 452 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj", "P_bg394", "P_bg395", bias=True), 453 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj", "P_bg394", "P_bg395", bias=True), 454 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj", "P_bg394", "P_bg395", bias=True), 455 | **norm("cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1", "P_bg395"), 456 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1", "P_bg395", "P_bg396", bias=True), 457 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2", "P_bg396", "P_bg397", bias=True), 458 | **norm("cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2", "P_bg397"), 459 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj", "P_bg397", "P_bg398", bias=True), 460 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj", "P_bg397", "P_bg398", bias=True), 461 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj", "P_bg397", "P_bg398", bias=True), 462 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj", "P_bg397", "P_bg398", bias=True), 463 | **norm("cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1", "P_bg398"), 464 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1", "P_bg398", "P_bg399", bias=True), 465 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2", "P_bg400", "P_bg401", bias=True), 466 | **norm("cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2", "P_bg401"), 467 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj", "P_bg401", "P_bg402", bias=True), 468 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj", "P_bg401", "P_bg402", bias=True), 469 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj", "P_bg401", "P_bg402", bias=True), 470 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj", "P_bg401", "P_bg402", bias=True), 471 | **norm("cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1", "P_bg402"), 472 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1", "P_bg402", "P_bg403", bias=True), 473 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2", "P_bg403", "P_bg404", bias=True), 474 | **norm("cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2", "P_bg404"), 475 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj", "P_bg404", "P_bg405", bias=True), 476 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj", "P_bg404", "P_bg405", bias=True), 477 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj", "P_bg404", "P_bg405", bias=True), 478 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj", "P_bg404", "P_bg405", bias=True), 479 | **norm("cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1", "P_bg405"), 480 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1", "P_bg405", "P_bg406", bias=True), 481 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2", "P_bg406", "P_bg407", bias=True), 482 | **norm("cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2", "P_bg407"), 483 | **norm("cond_stage_model.transformer.text_model.final_layer_norm", "P_bg407"), 484 | } 485 | ) 486 | -------------------------------------------------------------------------------- /merge_methods.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | __all__ = [ 8 | "weighted_sum", 9 | "weighted_subtraction", 10 | "tensor_sum", 11 | "add_difference", 12 | "train_difference", 13 | "sum_twice", 14 | "triple_sum", 15 | "euclidean_add_difference", 16 | "multiply_difference", 17 | "top_k_tensor_sum", 18 | "similarity_add_difference", 19 | "distribution_crossover", 20 | "ties_add_difference", 21 | ] 22 | 23 | 24 | EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero 25 | 26 | 27 | def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 28 | """ 29 | Basic Merge: 30 | alpha 0 returns Primary Model 31 | alpha 1 returns Secondary Model 32 | """ 33 | return (1 - alpha) * a + alpha * b 34 | 35 | 36 | def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 37 | """ 38 | The inverse of a Weighted Sum Merge 39 | Returns Primary Model when alpha*beta = 0 40 | High values of alpha*beta are likely to break the merged model 41 | """ 42 | # Adjust beta if both alpha and beta are 1.0 to avoid division by zero 43 | if alpha == 1.0 and beta == 1.0: 44 | beta -= EPSILON 45 | 46 | return (a - alpha * beta * b) / (1 - alpha * beta) 47 | 48 | 49 | def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 50 | """ 51 | Takes a slice of Secondary Model and pastes it into Primary Model 52 | Alpha sets the width of the slice 53 | Beta sets the start point of the slice 54 | ie Alpha = 0.5 Beta = 0.25 is (ABBA) Alpha = 0.25 Beta = 0 is (BAAA) 55 | """ 56 | if alpha + beta <= 1: 57 | tt = a.clone() 58 | talphas = int(a.shape[0] * beta) 59 | talphae = int(a.shape[0] * (alpha + beta)) 60 | tt[talphas:talphae] = b[talphas:talphae].clone() 61 | else: 62 | talphas = int(a.shape[0] * (alpha + beta - 1)) 63 | talphae = int(a.shape[0] * beta) 64 | tt = b.clone() 65 | tt[talphas:talphae] = a[talphas:talphae].clone() 66 | return tt 67 | 68 | 69 | def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 70 | """ 71 | Classic Add Difference Merge 72 | """ 73 | return a + alpha * (b - c) 74 | 75 | def train_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs): # pylint: disable=unused-argument 76 | # Based on: https://github.com/hako-mikan/sd-webui-supermerger/blob/843ca282948dbd3fac1246fcb1b66544a371778b/scripts/mergers/mergers.py#L673 77 | 78 | # Calculate the difference between b and c 79 | diff_BC = b - c 80 | 81 | # Early exit if there's no difference 82 | if torch.all(diff_BC == 0): 83 | return a 84 | 85 | # Calculate distances 86 | distance_BC = torch.abs(diff_BC) 87 | distance_BA = torch.abs(b - a) 88 | 89 | # Sum of distances 90 | sum_distances = distance_BC + distance_BA 91 | 92 | # Calculate scale, avoiding division by zero 93 | scale = torch.where(sum_distances != 0, distance_BA / sum_distances, torch.tensor(0.)) 94 | 95 | # Adjust scale sign based on the difference between b and c 96 | sign_scale = torch.sign(diff_BC) 97 | scale = sign_scale * torch.abs(scale) 98 | 99 | # Calculate new difference 100 | new_diff = scale * distance_BC 101 | 102 | # Return updated a 103 | return a + (new_diff * (alpha * 1.8)) 104 | 105 | 106 | def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 107 | """ 108 | Stacked Basic Merge: 109 | Equivalent to Merging Primary and Secondary @ alpha 110 | Then merging the result with Tertiary @ beta 111 | """ 112 | return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c 113 | 114 | 115 | def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 116 | """ 117 | Weights Secondary and Tertiary at alpha and beta respectively 118 | Fills in the rest with Primary 119 | Expect odd results if alpha + beta > 1 as Primary will be merged with a negative ratio 120 | """ 121 | return (1 - alpha - beta) * a + alpha * b + beta * c 122 | 123 | 124 | def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 125 | """ 126 | Subtract Primary and Secondary from Tertiary 127 | Compare the remainders via Euclidean distance 128 | Add to Tertiary 129 | Note: Slow 130 | """ 131 | a_diff = a.float() - c.float() 132 | b_diff = b.float() - c.float() 133 | a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff)) 134 | b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff)) 135 | 136 | distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2 137 | distance = torch.sqrt(distance) 138 | sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float() 139 | distance = torch.copysign(distance, sum_diff) 140 | 141 | target_norm = torch.linalg.norm(sum_diff) 142 | return c + distance / torch.linalg.norm(distance) * target_norm 143 | 144 | 145 | def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 146 | """ 147 | Similar to Add Difference but with geometric mean instead of arithmatic mean 148 | """ 149 | diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha)) 150 | diff_b = torch.pow(torch.abs(b.float() - c), alpha) 151 | difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c) 152 | return c + difference.to(c.dtype) 153 | 154 | 155 | def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 156 | """ 157 | Redistributes the largest weights of Secondary Model into Primary Model 158 | """ 159 | a_flat = torch.flatten(a) 160 | a_dist = torch.msort(a_flat) 161 | b_indices = torch.argsort(torch.flatten(b), stable=True) 162 | redist_indices = torch.argsort(b_indices) 163 | 164 | start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a)) 165 | start_top_k = kth_abs_value(a_dist, start_i) 166 | end_top_k = kth_abs_value(a_dist, end_i) 167 | 168 | indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k) 169 | if region_is_inverted: 170 | indices_mask = ~indices_mask 171 | indices_mask = torch.gather(indices_mask.float(), 0, redist_indices) 172 | 173 | a_redist = torch.gather(a_dist, 0, redist_indices) 174 | a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist 175 | return a_redist.reshape_as(a) 176 | 177 | 178 | def kth_abs_value(a: Tensor, k: int) -> Tensor: 179 | if k <= 0: 180 | return torch.tensor(-1, device=a.device) 181 | else: 182 | return torch.kthvalue(torch.abs(a.float()), k)[0] 183 | 184 | 185 | def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]: 186 | if width < 0: 187 | offset += width 188 | width = -width 189 | width = min(width, 1) 190 | 191 | if offset < 0: 192 | offset = 1 + offset - int(offset) 193 | offset = math.fmod(offset, 1.0) 194 | 195 | if width + offset <= 1: 196 | inverted = False 197 | start = offset * n 198 | end = (width + offset) * n 199 | else: 200 | inverted = True 201 | start = (width + offset - 1) * n 202 | end = offset * n 203 | 204 | return round(start), round(end), inverted 205 | 206 | 207 | def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 208 | """ 209 | Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar 210 | """ 211 | threshold = torch.maximum(torch.abs(a), torch.abs(b)) 212 | similarity = ((a * b / threshold**2) + 1) / 2 213 | similarity = torch.nan_to_num(similarity * beta, nan=beta) 214 | 215 | ab_diff = a + alpha * (b - c) 216 | ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b 217 | return (1 - similarity) * ab_diff + similarity * ab_sum 218 | 219 | 220 | def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument 221 | """ 222 | From the creator: 223 | It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of 224 | Primary and Secondary when ordered with respect to Tertiary. Split the frequency domain 225 | using a linear function. Alpha is the split frequency and Beta is the inclination of the line. 226 | add everything under the line as the contribution of Primary and everything over the line as the contribution of Secondary 227 | """ 228 | if a.shape == (): 229 | return alpha * a + (1 - alpha) * b 230 | 231 | c_indices = torch.argsort(torch.flatten(c)) 232 | a_dist = torch.gather(torch.flatten(a), 0, c_indices) 233 | b_dist = torch.gather(torch.flatten(b), 0, c_indices) 234 | 235 | a_dft = torch.fft.rfft(a_dist.float()) 236 | b_dft = torch.fft.rfft(b_dist.float()) 237 | 238 | dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float() 239 | dft_filter /= torch.numel(a_dft) 240 | if beta > EPSILON: 241 | dft_filter = (dft_filter - alpha) / beta + 1 / 2 242 | dft_filter = torch.clamp(dft_filter, 0.0, 1.0) 243 | else: 244 | dft_filter = (dft_filter >= alpha).float() 245 | 246 | x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft 247 | x_dist = torch.fft.irfft(x_dft, a_dist.shape[0]) 248 | x_values = torch.gather(x_dist, 0, torch.argsort(c_indices)) 249 | return x_values.reshape_as(a) 250 | 251 | 252 | def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument 253 | """ 254 | An implementation of arXiv:2306.01708 255 | """ 256 | deltas = [] 257 | signs = [] 258 | for m in [a, b]: 259 | deltas.append(filter_top_k(m - c, beta)) 260 | signs.append(torch.sign(deltas[-1])) 261 | 262 | signs = torch.stack(signs, dim=0) 263 | final_sign = torch.sign(torch.sum(signs, dim=0)) 264 | delta_filters = (signs == final_sign).float() 265 | 266 | res = torch.zeros_like(c, device=c.device) 267 | for delta_filter, delta in zip(delta_filters, deltas): 268 | res += delta_filter * delta 269 | 270 | param_count = torch.sum(delta_filters, dim=0) 271 | return c + alpha * torch.nan_to_num(res / param_count) 272 | 273 | 274 | def filter_top_k(a: Tensor, k: float): 275 | k = max(int((1 - k) * torch.numel(a)), 1) 276 | k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k) 277 | top_k_filter = (torch.abs(a) >= k_value).float() 278 | return a * top_k_filter 279 | -------------------------------------------------------------------------------- /merge_presets.py: -------------------------------------------------------------------------------- 1 | BLOCK_WEIGHTS_PRESETS = { 2 | "GRAD_V": [0, 1, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0], 3 | "GRAD_A": [0, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0], 4 | "FLAT_25": [0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25], 5 | "FLAT_75": [0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 6 | "WRAP08": [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 7 | "WRAP12": [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 8 | "WRAP14": [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 9 | "WRAP16": [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 10 | "MID12_50": [0, 0, 0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0], 11 | "OUT07": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 12 | "OUT12": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 13 | "OUT12_5": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 14 | "RING08_SOFT": [0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0], 15 | "RING08_5": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 16 | "RING10_5": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], 17 | "RING10_3": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], 18 | "SMOOTHSTEP": [0, 0, 0.00506365740740741, 0.0196759259259259, 0.04296875, 0.0740740740740741, 0.112123842592593, 0.15625, 0.205584490740741, 0.259259259259259, 0.31640625, 0.376157407407407, 0.437644675925926, 0.5, 0.562355324074074, 0.623842592592592, 0.68359375, 0.740740740740741, 0.794415509259259, 0.84375, 0.887876157407408, 0.925925925925926, 0.95703125, 0.980324074074074, 0.994936342592593, 1], 19 | "REVERSE_SMOOTHSTEP": [0, 1, 0.994936342592593, 0.980324074074074, 0.95703125, 0.925925925925926, 0.887876157407407, 0.84375, 0.794415509259259, 0.740740740740741, 0.68359375, 0.623842592592593, 0.562355324074074, 0.5, 0.437644675925926, 0.376157407407408, 0.31640625, 0.259259259259259, 0.205584490740741, 0.15625, 0.112123842592592, 0.0740740740740742, 0.0429687499999996, 0.0196759259259258, 0.00506365740740744, 0], 20 | "2SMOOTHSTEP": [0, 0, 0.0101273148148148, 0.0393518518518519, 0.0859375, 0.148148148148148, 0.224247685185185, 0.3125, 0.411168981481482, 0.518518518518519, 0.6328125, 0.752314814814815, 0.875289351851852, 1.0, 0.875289351851852, 0.752314814814815, 0.6328125, 0.518518518518519, 0.411168981481481, 0.3125, 0.224247685185184, 0.148148148148148, 0.0859375, 0.0393518518518512, 0.0101273148148153, 0], 21 | "2R_SMOOTHSTEP": [0, 1, 0.989872685185185, 0.960648148148148, 0.9140625, 0.851851851851852, 0.775752314814815, 0.6875, 0.588831018518519, 0.481481481481481, 0.3671875, 0.247685185185185, 0.124710648148148, 0.0, 0.124710648148148, 0.247685185185185, 0.3671875, 0.481481481481481, 0.588831018518519, 0.6875, 0.775752314814816, 0.851851851851852, 0.9140625, 0.960648148148149, 0.989872685185185, 1], 22 | "3SMOOTHSTEP": [0, 0, 0.0151909722222222, 0.0590277777777778, 0.12890625, 0.222222222222222, 0.336371527777778, 0.46875, 0.616753472222222, 0.777777777777778, 0.94921875, 0.871527777777778, 0.687065972222222, 0.5, 0.312934027777778, 0.128472222222222, 0.0507812500000004, 0.222222222222222, 0.383246527777778, 0.53125, 0.663628472222223, 0.777777777777778, 0.87109375, 0.940972222222222, 0.984809027777777, 1], 23 | "3R_SMOOTHSTEP": [0, 1, 0.984809027777778, 0.940972222222222, 0.87109375, 0.777777777777778, 0.663628472222222, 0.53125, 0.383246527777778, 0.222222222222222, 0.05078125, 0.128472222222222, 0.312934027777778, 0.5, 0.687065972222222, 0.871527777777778, 0.94921875, 0.777777777777778, 0.616753472222222, 0.46875, 0.336371527777777, 0.222222222222222, 0.12890625, 0.0590277777777777, 0.0151909722222232, 0], 24 | "4SMOOTHSTEP": [0, 0, 0.0202546296296296, 0.0787037037037037, 0.171875, 0.296296296296296, 0.44849537037037, 0.625, 0.822337962962963, 0.962962962962963, 0.734375, 0.49537037037037, 0.249421296296296, 0.0, 0.249421296296296, 0.495370370370371, 0.734375000000001, 0.962962962962963, 0.822337962962962, 0.625, 0.448495370370369, 0.296296296296297, 0.171875, 0.0787037037037024, 0.0202546296296307, 0], 25 | "4R_SMOOTHSTEP": [0, 1, 0.97974537037037, 0.921296296296296, 0.828125, 0.703703703703704, 0.55150462962963, 0.375, 0.177662037037037, 0.0370370370370372, 0.265625, 0.50462962962963, 0.750578703703704, 1.0, 0.750578703703704, 0.504629629629629, 0.265624999999999, 0.0370370370370372, 0.177662037037038, 0.375, 0.551504629629631, 0.703703703703703, 0.828125, 0.921296296296298, 0.979745370370369, 1], 26 | "HALF_SMOOTHSTEP": [0, 0, 0.0196759259259259, 0.0740740740740741, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1.0, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740741, 0.0196759259259259, 0], 27 | "HALF_R_SMOOTHSTEP": [0, 1, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740742, 0.0196759259259256, 0.0, 0.0196759259259256, 0.0740740740740742, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1], 28 | "ONE_THIRD_SMOOTHSTEP": [0, 0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1], 29 | "ONE_THIRD_R_SMOOTHSTEP": [0, 1, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0], 30 | "ONE_FOURTH_SMOOTHSTEP": [0, 0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0.0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0], 31 | "ONE_FOURTH_R_SMOOTHSTEP": [0, 1, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1], 32 | "COSINE": [0, 1, 0.995722430686905, 0.982962913144534, 0.961939766255643, 0.933012701892219, 0.896676670145617, 0.853553390593274, 0.80438071450436, 0.75, 0.691341716182545, 0.62940952255126, 0.565263096110026, 0.5, 0.434736903889974, 0.37059047744874, 0.308658283817455, 0.25, 0.195619285495639, 0.146446609406726, 0.103323329854382, 0.0669872981077805, 0.0380602337443566, 0.0170370868554658, 0.00427756931309475, 0], 33 | "REVERSE_COSINE": [0, 0, 0.00427756931309475, 0.0170370868554659, 0.0380602337443566, 0.0669872981077808, 0.103323329854383, 0.146446609406726, 0.19561928549564, 0.25, 0.308658283817455, 0.37059047744874, 0.434736903889974, 0.5, 0.565263096110026, 0.62940952255126, 0.691341716182545, 0.75, 0.804380714504361, 0.853553390593274, 0.896676670145618, 0.933012701892219, 0.961939766255643, 0.982962913144534, 0.995722430686905, 1], 34 | "CUBIC_HERMITE": [0, 0, 0.157576195987654, 0.28491512345679, 0.384765625, 0.459876543209877, 0.512996720679012, 0.546875, 0.564260223765432, 0.567901234567901, 0.560546875, 0.544945987654321, 0.523847415123457, 0.5, 0.476152584876543, 0.455054012345679, 0.439453125, 0.432098765432099, 0.435739776234568, 0.453125, 0.487003279320987, 0.540123456790124, 0.615234375, 0.71508487654321, 0.842423804012347, 1], 35 | "REVERSE_CUBIC_HERMITE": [0, 1, 0.842423804012346, 0.71508487654321, 0.615234375, 0.540123456790123, 0.487003279320988, 0.453125, 0.435739776234568, 0.432098765432099, 0.439453125, 0.455054012345679, 0.476152584876543, 0.5, 0.523847415123457, 0.544945987654321, 0.560546875, 0.567901234567901, 0.564260223765432, 0.546875, 0.512996720679013, 0.459876543209876, 0.384765625, 0.28491512345679, 0.157576195987653, 0], 36 | "FAKE_REVERSE_CUBIC_HERMITE": [0, 1, 0.842423804012346, 0.71508487654321, 0.615234375, 0.540123456790123, 0.487003279320988, 0.453125, 0.435739776234568, 0.432098765432099, 0.439453125, 0.455054012345679, 0.476152584876543, 0.5, 0.523847415123457, 0.544945987654321, 0.560546875, 0.567901234567901, 0.564260223765432, 0.546875, 0.512996720679013, 0.459876543209876, 0.384765625, 0.28491512345679, 0.157576195987653, 0], 37 | "LOW_OFFSET_CUBIC_HERMITE": [0, 0, 0.099515938464506, 0.1628809799382715, 0.2123209635416665, 0.249228395061729, 0.274995780285494, 0.291015625, 0.298680434992284, 0.2993827160493825, 0.294514973958333, 0.285469714506173, 0.273639443479938, 0.261513611593364, 0.24938777970679, 0.245727237654321, 0.23763671875, 0.222901234567901, 0.224305796682099, 0.234635416666667, 0.247675106095678, 0.273209876543211, 0.312024739583333, 0.360904706790124, 0.422634789737655, 0.5], 38 | "ALL_A": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 39 | "ALL_B": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 40 | } 41 | 42 | 43 | SDXL_BLOCK_WEIGHTS_PRESETS = { 44 | "SDXL_GRAD_V": [0, 1.0, 0.888889, 0.777778, 0.666667, 0.555556, 0.444444, 0.333333, 0.222222, 0.111111, 0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0], 45 | "SDXL_GRAD_A": [0, 0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0, 0.888889, 0.777778, 0.666667, 0.555556, 0.444444, 0.333333, 0.222222, 0.111111, 0.0], 46 | "SDXL_FLAT_25": [0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25], 47 | "SDXL_FLAT_75": [0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75], 48 | "SDXL_WRAP08": [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 49 | "SDXL_WRAP12": [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 50 | "SDXL_WRAP14": [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 51 | "SDXL_OUT07": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 52 | "SDXL_SMOOTHSTEP": [0, 0, 0.008916, 0.034294, 0.074074, 0.126200, 0.188615, 0.259259, 0.336077, 0.417010, 0.500000, 0.582990, 0.663923, 0.740741, 0.811385, 0.873800, 0.925926, 0.965706, 0.991084, 1], 53 | "SDXL_REVERSE_SMOOTHSTEP": [0, 1, 0.991084, 0.965706, 0.925926, 0.873800, 0.811385, 0.740741, 0.663923, 0.582990, 0.500000, 0.417010, 0.336077, 0.259259, 0.188615, 0.126200, 0.074074, 0.034294, 0.008916, 0], 54 | "SDXL_HALF_SMOOTHSTEP": [0, 0, 0.034294, 0.126200, 0.259259, 0.417010, 0.582990, 0.740741, 0.873800, 0.965706, 1, 0.965706, 0.873800, 0.740741, 0.582990, 0.417010, 0.259259, 0.126200, 0.034294, 0], 55 | "SDXL_HALF_R_SMOOTHSTEP": [0, 1, 0.965706, 0.873800, 0.740741, 0.582990, 0.417010, 0.259259, 0.126200, 0.034294, 0, 0.034294, 0.126200, 0.259259, 0.417010, 0.582990, 0.740741, 0.873800, 0.965706, 1], 56 | "SDXL_ONE_THIRD_SMOOTHSTEP": [0, 0, 0.074074, 0.259259, 0.500000, 0.740741, 0.925926, 1, 0.907407, 0.592593, 0, 0.592593, 0.907407, 1, 0.925926, 0.740741, 0.500000, 0.259259, 0.074074, 0], 57 | "SDXL_ONE_THIRD_R_SMOOTHSTEP": [0, 1, 0.925926, 0.740741, 0.500000, 0.259259, 0.074074, 0, 0.092593, 0.407407, 1, 0.407407, 0.092593, 0, 0.074074, 0.259259, 0.500000, 0.740741, 0.925926, 1], 58 | "SDXL_COSINE": [0, 1, 0.992404, 0.969846, 0.933013, 0.883022, 0.821394, 0.750000, 0.671010, 0.586824, 0.500000, 0.413176, 0.328990, 0.250000, 0.178606, 0.116978, 0.066987, 0.030154, 0.007596, 0], 59 | "SDXL_REVERSE_COSINE": [0, 0, 0.007596, 0.030154, 0.066987, 0.116978, 0.178606, 0.250000, 0.328990, 0.413176, 0.500000, 0.586824, 0.671010, 0.750000, 0.821394, 0.883022, 0.933013, 0.969846, 0.992404, 1], 60 | "SDXL_CUBIC_HERMITE": [0, 0, 0.268023, 0.461058, 0.588477, 0.659656, 0.683966, 0.670782, 0.629477, 0.569425, 0.500000, 0.430575, 0.370523, 0.329218, 0.316034, 0.340344, 0.411523, 0.538942, 0.731977, 1], 61 | "SDXL_REVERSE_CUBIC_HERMITE": [0, 1, 0.731977, 0.538942, 0.411523, 0.340344, 0.316034, 0.329218, 0.370523, 0.430575, 0.500000, 0.569425, 0.629477, 0.670782, 0.683966, 0.659656, 0.588477, 0.461058, 0.268023, 0], 62 | "SDXL_ALL_A": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 63 | "SDXL_ALL_B": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 64 | } 65 | -------------------------------------------------------------------------------- /merge_rebasin.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion 2 | from collections import defaultdict 3 | from random import shuffle 4 | from typing import NamedTuple 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | 8 | 9 | SPECIAL_KEYS = [ 10 | "first_stage_model.decoder.norm_out.weight", 11 | "first_stage_model.decoder.norm_out.bias", 12 | "first_stage_model.encoder.norm_out.weight", 13 | "first_stage_model.encoder.norm_out.bias", 14 | "model.diffusion_model.out.0.weight", 15 | "model.diffusion_model.out.0.bias", 16 | ] 17 | 18 | 19 | class PermutationSpec(NamedTuple): 20 | perm_to_axes: dict 21 | axes_to_perm: dict 22 | 23 | 24 | def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec: 25 | perm_to_axes = defaultdict(list) 26 | for wk, axis_perms in axes_to_perm.items(): 27 | for axis, perm in enumerate(axis_perms): 28 | if perm is not None: 29 | perm_to_axes[perm].append((wk, axis)) 30 | return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm) 31 | 32 | 33 | def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None): 34 | """Get parameter `k` from `params`, with the permutations applied.""" 35 | try: 36 | w = params[k] 37 | except KeyError: 38 | # If the key is not found in params, return None or handle it as needed 39 | return None 40 | 41 | try: 42 | axes_to_perm = ps.axes_to_perm[k] 43 | except KeyError: 44 | # If the key is not found in axes_to_perm, return the original parameter 45 | return w 46 | 47 | for axis, p in enumerate(axes_to_perm): 48 | # Skip the axis we're trying to permute. 49 | if axis == except_axis: 50 | continue 51 | 52 | # None indicates that there is no permutation relevant to that axis. 53 | if p: 54 | try: 55 | w = torch.index_select(w, axis, perm[p].int()) 56 | except KeyError: 57 | # If the permutation key is not found, continue to the next axis 58 | continue 59 | 60 | return w 61 | 62 | 63 | def apply_permutation(ps: PermutationSpec, perm, params): 64 | """Apply a `perm` to `params`.""" 65 | return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()} 66 | 67 | 68 | def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha): 69 | for k in model_a: 70 | try: 71 | perm_params = get_permuted_param( 72 | ps, perm, k, model_a 73 | ) 74 | model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params 75 | except RuntimeError: # dealing with pix2pix and inpainting models 76 | continue 77 | return model_a 78 | 79 | 80 | def inner_matching( 81 | n, 82 | ps, 83 | p, 84 | params_a, 85 | params_b, 86 | usefp16, 87 | progress, 88 | number, 89 | linear_sum, 90 | perm, 91 | device, 92 | ): 93 | A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n)) 94 | A = A.to(device) 95 | 96 | for wk, axis in ps.perm_to_axes[p]: 97 | w_a = params_a[wk] 98 | w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis) 99 | w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device) 100 | w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device) 101 | 102 | if usefp16: 103 | w_a = w_a.half().to(device) 104 | w_b = w_b.half().to(device) 105 | 106 | try: 107 | A += torch.matmul(w_a, w_b) 108 | except RuntimeError: 109 | A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b)) 110 | 111 | A = A.cpu() 112 | ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True) 113 | A = A.to(device) 114 | 115 | assert (torch.tensor(ri) == torch.arange(len(ri))).all() 116 | 117 | eye_tensor = torch.eye(n).to(device) 118 | 119 | oldL = torch.vdot( 120 | torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()]) 121 | ) 122 | newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :])) 123 | 124 | if usefp16: 125 | oldL = oldL.half() 126 | newL = newL.half() 127 | 128 | if newL - oldL != 0: 129 | linear_sum += abs((newL - oldL).item()) 130 | number += 1 131 | print("Merge Rebasin permutation: {p}={newL-oldL}") 132 | 133 | progress = progress or newL > oldL + 1e-12 134 | 135 | perm[p] = torch.Tensor(ci).to(device) 136 | 137 | return linear_sum, number, perm, progress 138 | 139 | 140 | def weight_matching( 141 | ps: PermutationSpec, 142 | params_a, 143 | params_b, 144 | max_iter=1, 145 | init_perm=None, 146 | usefp16=False, 147 | device="cpu", 148 | ): 149 | perm_sizes = { 150 | p: params_a[axes[0][0]].shape[axes[0][1]] 151 | for p, axes in ps.perm_to_axes.items() 152 | if axes[0][0] in params_a.keys() 153 | } 154 | perm = {} 155 | perm = ( 156 | {p: torch.arange(n).to(device) for p, n in perm_sizes.items()} 157 | if init_perm is None 158 | else init_perm 159 | ) 160 | 161 | linear_sum = 0 162 | number = 0 163 | 164 | special_layers = ["P_bg324"] 165 | for _i in range(max_iter): 166 | progress = False 167 | shuffle(special_layers) 168 | for p in special_layers: 169 | n = perm_sizes[p] 170 | linear_sum, number, perm, progress = inner_matching( 171 | n, 172 | ps, 173 | p, 174 | params_a, 175 | params_b, 176 | usefp16, 177 | progress, 178 | number, 179 | linear_sum, 180 | perm, 181 | device, 182 | ) 183 | progress = True 184 | if not progress: 185 | break 186 | 187 | average = linear_sum / number if number > 0 else 0 188 | return perm, average 189 | -------------------------------------------------------------------------------- /merge_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | from . import merge_methods 4 | from .merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS 5 | 6 | ALL_PRESETS = {} 7 | ALL_PRESETS.update(BLOCK_WEIGHTS_PRESETS) 8 | ALL_PRESETS.update(SDXL_BLOCK_WEIGHTS_PRESETS) 9 | 10 | MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction)) 11 | BETA_METHODS = [ 12 | name 13 | for name, fn in MERGE_METHODS.items() 14 | if "beta" in inspect.getfullargspec(fn)[0] 15 | ] 16 | TRIPLE_METHODS = [ 17 | name 18 | for name, fn in MERGE_METHODS.items() 19 | if "c" in inspect.getfullargspec(fn)[0] 20 | ] 21 | 22 | 23 | def interpolate(values, interp_lambda): 24 | interpolated = [] 25 | for i in range(len(values[0])): 26 | interpolated.append((1 - interp_lambda) * values[0][i] + interp_lambda * values[1][i]) 27 | return interpolated 28 | 29 | 30 | class WeightClass: 31 | def __init__(self, 32 | model_a, 33 | **kwargs, 34 | ): 35 | self.SDXL = "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight" in model_a.keys() 36 | self.NUM_INPUT_BLOCKS = 12 if not self.SDXL else 9 37 | self.NUM_MID_BLOCK = 1 38 | self.NUM_OUTPUT_BLOCKS = 12 if not self.SDXL else 9 39 | self.NUM_TOTAL_BLOCKS = self.NUM_INPUT_BLOCKS + self.NUM_MID_BLOCK + self.NUM_OUTPUT_BLOCKS 40 | self.iterations = kwargs.get("re_basin_iterations", 1) 41 | self.it = 0 42 | self.re_basin = kwargs.get("re_basin", False) 43 | self.ratioDict = {} 44 | for key, value in kwargs.items(): 45 | if isinstance(value, list) or (key.lower() not in ["alpha", "beta"]): 46 | self.ratioDict[key.lower()] = value 47 | else: 48 | self.ratioDict[key.lower()] = [value] 49 | 50 | for key, value in self.ratioDict.items(): 51 | if key in ["alpha", "beta"]: 52 | for i, v in enumerate(value): 53 | if isinstance(v, str) and v.upper() in BLOCK_WEIGHTS_PRESETS.keys(): 54 | value[i] = BLOCK_WEIGHTS_PRESETS[v.upper()] 55 | else: 56 | value[i] = [float(x) for x in v.split(",")] if isinstance(v, str) else v 57 | if not isinstance(value[i], list): 58 | value[i] = [value[i]] * (self.NUM_TOTAL_BLOCKS + 1) 59 | if len(value) > 1 and isinstance(value[0], list): 60 | self.ratioDict[key] = interpolate(value, self.ratioDict.get(key + "_lambda", 0)) 61 | else: 62 | self.ratioDict[key] = self.ratioDict[key][0] 63 | 64 | def __call__(self, key, it=0): 65 | current_bases = {} 66 | if "alpha" in self.ratioDict: 67 | current_bases["alpha"] = self.step_weights_and_bases(self.ratioDict["alpha"]) 68 | if "beta" in self.ratioDict: 69 | current_bases["beta"] = self.step_weights_and_bases(self.ratioDict["beta"]) 70 | 71 | weight_index = 0 72 | if "model" in key: 73 | 74 | if "model.diffusion_model." in key: 75 | weight_index = -1 76 | 77 | re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12 78 | re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1 79 | re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12 80 | 81 | if "time_embed" in key: 82 | weight_index = 0 # before input blocks 83 | elif ".out." in key: 84 | weight_index = self.NUM_TOTAL_BLOCKS - 1 # after output blocks 85 | elif m := re_inp.search(key): 86 | weight_index = int(m.groups()[0]) 87 | elif re_mid.search(key): 88 | weight_index = self.NUM_INPUT_BLOCKS 89 | elif m := re_out.search(key): 90 | weight_index = self.NUM_INPUT_BLOCKS + self.NUM_MID_BLOCK + int(m.groups()[0]) 91 | 92 | if weight_index >= self.NUM_TOTAL_BLOCKS: 93 | raise ValueError(f"illegal block index {key}") 94 | 95 | current_bases = {k: w[weight_index] for k, w in current_bases.items()} 96 | return current_bases 97 | 98 | def step_weights_and_bases(self, ratio): 99 | if not self.re_basin: 100 | return ratio 101 | 102 | new_ratio = [ 103 | 1 - (1 - (1 + self.it) * v / self.iterations) / (1 - self.it * v / self.iterations) 104 | if self.it > 0 105 | else v / self.iterations 106 | for v in ratio 107 | ] 108 | return new_ratio 109 | 110 | def set_it(self, it): 111 | self.it = it 112 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-technodes" 3 | description = "ComfyUI nodes for merging, testing and more. SDNext Merge, VAE Merge, MBW Layers, Repeat VAE, Quantization." 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | 7 | [project.urls] 8 | Repository = "https://github.com/TechnoByteJS/ComfyUI-TechNodes" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "technobyte" 13 | DisplayName = "ComfyUI-TechNodes" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /quant_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | import folder_paths 5 | 6 | import comfy_extras.nodes_model_merging 7 | 8 | def quantize_tensor(tensor, num_bits=8, dtype=torch.float16, dequant=True): 9 | """ 10 | Quantizes a tensor to a specified number of bits. 11 | 12 | Args: 13 | tensor (torch.Tensor): The input tensor to be quantized. 14 | num_bits (int): The number of bits to use for quantization (default: 8). 15 | dtype(torch.dtype): The datatype to use for the output (default: torch.float16). 16 | dequant (bool): Whether to dequantize or not (default: true). 17 | 18 | Returns: 19 | torch.Tensor: The quantized tensor. 20 | """ 21 | # Determine the minimum and maximum values of the tensor 22 | min_val = tensor.min() 23 | max_val = tensor.max() 24 | 25 | # Calculate the scale factor and zero point 26 | qmin = 0 27 | qmax = 2 ** num_bits - 1 28 | scale = (max_val - min_val) / (qmax - qmin) 29 | zero_point = qmin - torch.round(min_val / scale) 30 | 31 | # Quantize the tensor 32 | quantized_tensor = torch.round(tensor / scale + zero_point) 33 | quantized_tensor = torch.clamp(quantized_tensor, qmin, qmax) 34 | 35 | # Convert the quantized tensor to the datatype 36 | dequantized_tensor = quantized_tensor.to(dtype) 37 | 38 | if dequant: 39 | # De-quantize the tensor 40 | dequantized_tensor = (dequantized_tensor - zero_point) * scale 41 | 42 | return dequantized_tensor 43 | 44 | def quantize_model(model, in_bits, mid_bits, out_bits, dtype=torch.float16, dequant=True): 45 | # Clone the base model to create a new one 46 | quantized_model = model.clone() 47 | 48 | # Get the key patches from the model with the prefix "diffusion_model." 49 | key_patches = quantized_model.get_key_patches("diffusion_model.") 50 | 51 | # Iterate over each key patch in the patches 52 | for key in key_patches: 53 | if ".input_" in key: 54 | num_bits = in_bits 55 | elif ".middle_" in key: 56 | num_bits = mid_bits 57 | elif ".output_" in key: 58 | num_bits = out_bits 59 | else: 60 | num_bits = 8 61 | 62 | quantized_tensor = quantize_tensor(key_patches[key][0], num_bits, dtype, dequant) 63 | quantized_model.add_patches({key: (quantized_tensor,)}, 1, 0) 64 | 65 | # Return the quantized model 66 | return quantized_model 67 | 68 | def quantize_clip(clip, bits, dtype=torch.float16, dequant=True): 69 | # Clone the base model to create a new one 70 | quantized_clip = clip.clone() 71 | 72 | # Get the key patches from the model with the prefix "diffusion_model." 73 | key_patches = quantized_clip.get_key_patches() 74 | 75 | # Iterate over each key patch in the patches 76 | for key in key_patches: 77 | quantized_tensor = quantize_tensor(key_patches[key][0], bits, dtype, dequant) 78 | quantized_clip.add_patches({key: (quantized_tensor,)}, 1, 0) 79 | 80 | # Return the quantized model 81 | return quantized_clip 82 | 83 | def quantize_vae(vae, bits, dtype=torch.float16, dequant=True): 84 | # Create a clone of the VAE model 85 | quantized_vae = copy.deepcopy(vae) 86 | 87 | # Get the state dictionary from the clone 88 | state_dict = quantized_vae.first_stage_model.state_dict() 89 | 90 | # Iterate over each key-value pair in the state dictionary 91 | for key, value in state_dict.items(): 92 | state_dict[key] = quantize_tensor(value, bits, dtype, dequant) 93 | 94 | # Load the quantized state dictionary back into the clone 95 | quantized_vae.first_stage_model.load_state_dict(state_dict) 96 | 97 | # Return the quantized clone 98 | return quantized_vae 99 | 100 | class ModelQuant: 101 | @classmethod 102 | def INPUT_TYPES(cls): 103 | return { 104 | "required": { 105 | "model": ["MODEL"], 106 | "in_bits": ("INT", {"default": 8, "min": 1, "max": 8}), 107 | "mid_bits": ("INT", {"default": 8, "min": 1, "max": 8}), 108 | "out_bits": ("INT", {"default": 8, "min": 1, "max": 8}), 109 | } 110 | } 111 | 112 | RETURN_TYPES = ["MODEL"] 113 | FUNCTION = "quant_model" 114 | 115 | CATEGORY = "TechNodes/quantization" 116 | 117 | def quant_model(self, model, in_bits, mid_bits, out_bits): 118 | return [quantize_model(model, in_bits, mid_bits, out_bits)] 119 | 120 | 121 | class ClipQuant: 122 | @classmethod 123 | def INPUT_TYPES(cls): 124 | return { 125 | "required": { 126 | "clip": ["CLIP"], 127 | "bits": ("INT", {"default": 8, "min": 1, "max": 8}), 128 | } 129 | } 130 | 131 | RETURN_TYPES = ["CLIP"] 132 | FUNCTION = "quant_clip" 133 | 134 | CATEGORY = "TechNodes/quantization" 135 | 136 | def quant_clip(self, clip, bits): 137 | return [quantize_clip(clip, bits)] 138 | 139 | 140 | class VAEQuant: 141 | @classmethod 142 | def INPUT_TYPES(cls): 143 | return { 144 | "required": { 145 | "vae": ["VAE"], 146 | "bits": ("INT", {"default": 8, "min": 1, "max": 8}), 147 | } 148 | } 149 | 150 | RETURN_TYPES = ["VAE"] 151 | FUNCTION = "quant_vae" 152 | 153 | CATEGORY = "TechNodes/quantization" 154 | 155 | def quant_vae(self, vae, bits): 156 | return [quantize_vae(vae, bits)] 157 | -------------------------------------------------------------------------------- /sdnextmerge_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import folder_paths 3 | from typing import Dict, Tuple, List 4 | from collections import OrderedDict 5 | import ast 6 | 7 | import comfy.sd 8 | import comfy.utils 9 | import comfy.model_detection 10 | 11 | from .merge import * 12 | 13 | mbw_presets = ([ 14 | "none", 15 | "GRAD_V", 16 | "GRAD_A", 17 | "FLAT_25", 18 | "FLAT_75", 19 | "WRAP08", 20 | "WRAP12", 21 | "WRAP14", 22 | "WRAP16", 23 | "MID12_50", 24 | "OUT07", 25 | "OUT12", 26 | "OUT12_5", 27 | "RING08_SOFT", 28 | "RING08_5", 29 | "RING10_5", 30 | "RING10_3", 31 | "SMOOTHSTEP", 32 | "REVERSE_SMOOTHSTEP", 33 | "2SMOOTHSTEP", 34 | "2R_SMOOTHSTEP", 35 | "3SMOOTHSTEP", 36 | "3R_SMOOTHSTEP", 37 | "4SMOOTHSTEP", 38 | "4R_SMOOTHSTEP", 39 | "HALF_SMOOTHSTEP", 40 | "HALF_R_SMOOTHSTEP", 41 | "ONE_THIRD_SMOOTHSTEP", 42 | "ONE_THIRD_R_SMOOTHSTEP", 43 | "ONE_FOURTH_SMOOTHSTEP", 44 | "ONE_FOURTH_R_SMOOTHSTEP", 45 | "COSINE", 46 | "REVERSE_COSINE", 47 | "CUBIC_HERMITE", 48 | "REVERSE_CUBIC_HERMITE", 49 | "FAKE_REVERSE_CUBIC_HERMITE", 50 | "LOW_OFFSET_CUBIC_HERMITE", 51 | "ALL_A", 52 | "ALL_B", 53 | ], {"default": "none"}) 54 | 55 | class SDNextMerge: 56 | @classmethod 57 | def INPUT_TYPES(cls): 58 | return { 59 | "optional": { 60 | "optional_model_a": ["MODEL"], 61 | "optional_clip_a": ["CLIP"], 62 | 63 | "optional_model_b": ["MODEL"], 64 | "optional_clip_b": ["CLIP"], 65 | 66 | "optional_model_c": ["MODEL"], 67 | "optional_clip_c": ["CLIP"], 68 | 69 | "optional_mbw_layers_alpha": ["MBW_LAYERS"], 70 | }, 71 | "required": { 72 | "model_a": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}), 73 | "model_b": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}), 74 | "model_c": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}), 75 | "merge_mode": ([ 76 | "weighted_sum", 77 | "weighted_subtraction", 78 | "tensor_sum", 79 | "add_difference", 80 | "train_difference", 81 | "sum_twice", 82 | "triple_sum", 83 | "euclidean_add_difference", 84 | "multiply_difference", 85 | "top_k_tensor_sum", 86 | "similarity_add_difference", 87 | "distribution_crossover", 88 | "ties_add_difference", 89 | ],), 90 | "precision": (["fp16", "original"],), 91 | "weights_clip": ("BOOLEAN", {"default": True}), 92 | "mem_device": (["cuda", "cpu"],), 93 | "work_device": (["cuda", "cpu"],), 94 | "threads": ("INT", {"default": 4, "min": 1, "max": 24}), 95 | "mbw_preset_alpha": mbw_presets, 96 | "alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 97 | "beta": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), 98 | "re_basin": ("BOOLEAN", {"default": False}), 99 | "re_basin_iterations": ("INT", {"default": 5, "min": 1, "max": 25}) 100 | } 101 | } 102 | 103 | RETURN_TYPES = ["MODEL", "CLIP"] 104 | FUNCTION = "merge" 105 | 106 | CATEGORY = "TechNodes/merging" 107 | 108 | # The main merge function 109 | def merge(self, model_a, model_b, model_c, merge_mode, precision, weights_clip, mem_device, work_device, threads, mbw_preset_alpha, alpha, beta, re_basin, re_basin_iterations, optional_model_a = None, optional_clip_a = None, optional_model_b = None, optional_clip_b = None, optional_model_c = None, optional_clip_c = None, optional_mbw_layers_alpha = None): 110 | 111 | if model_a == "none" and optional_model_a is None: 112 | raise ValueError("Need either model_a or optional_model_a!") 113 | 114 | if model_b == "none" and optional_model_b is None: 115 | raise ValueError("Need either model_b or optional_model_b!") 116 | 117 | if model_a == "none" and optional_clip_a is None: 118 | raise ValueError("Need either model_a or optional_clip_a!") 119 | 120 | if model_b == "none" and optional_clip_b is None: 121 | raise ValueError("Need either model_b or optional_clip_b!") 122 | 123 | models = { } 124 | 125 | if model_a != "none": 126 | if optional_model_a is None or optional_clip_a is None: 127 | models['model_a'] = folder_paths.get_full_path("checkpoints", model_a) 128 | 129 | if model_b != "none": 130 | if optional_model_b is None or optional_clip_b is None: 131 | models['model_b'] = folder_paths.get_full_path("checkpoints", model_b) 132 | 133 | # Add model C if the merge method needs it 134 | if merge_mode in ["add_difference", "train_difference", "sum_twice", "triple_sum", "euclidean_add_difference", "multiply_difference", "similarity_add_difference", "distribution_crossover", "ties_add_difference"]: 135 | if model_c == "none" and optional_model_c is None: 136 | raise ValueError("Need either model_c or optional_model_c!") 137 | 138 | if model_c == "none" and optional_clip_c is None: 139 | raise ValueError("Need either model_c or optional_clip_c!") 140 | 141 | if model_c != "none": 142 | if optional_model_c is None or optional_clip_c is None: 143 | models['model_c'] = folder_paths.get_full_path("checkpoints", model_c) 144 | 145 | # Devices 146 | device = torch.device(mem_device) 147 | work_device = torch.device(work_device) 148 | 149 | # Merge Arguments 150 | kwargs = { 151 | 'alpha': alpha, 152 | 'beta': beta, 153 | 're_basin': re_basin, 154 | 're_basin_iterations': re_basin_iterations 155 | } 156 | 157 | # If a MBW alpha preset is selected replace the alpha with the preset 158 | if mbw_preset_alpha != "none": 159 | kwargs["alpha"] = [ mbw_preset_alpha ] 160 | 161 | # If a MBW alpha preset is selected replace the alpha with the preset 162 | if optional_mbw_layers_alpha is not None: 163 | kwargs["alpha"] = [ optional_mbw_layers_alpha ] 164 | 165 | # Merge the model 166 | merged_model = merge_models(models, merge_mode, precision, weights_clip, device, work_device, True, threads, optional_model_a, optional_clip_a, optional_model_b, optional_clip_b, optional_model_c, optional_clip_c, **kwargs) 167 | 168 | # Get the config and components from the merged model 169 | model_config = comfy.model_detection.model_config_from_unet(merged_model, "model.diffusion_model.") 170 | 171 | # Create UNet 172 | unet = model_config.get_model(merged_model, "model.diffusion_model.", device=device) 173 | unet.load_model_weights(merged_model, "model.diffusion_model.") 174 | 175 | # Create ModelPatcher 176 | model_patcher = comfy.model_patcher.ModelPatcher( 177 | unet, 178 | load_device=comfy.model_management.get_torch_device(), 179 | offload_device=comfy.model_management.unet_offload_device() 180 | ) 181 | 182 | # Create CLIP 183 | clip_sd = model_config.process_clip_state_dict(merged_model) 184 | clip = comfy.sd.CLIP(model_config.clip_target(), embedding_directory=None) 185 | clip.load_sd(clip_sd, full_model=True) 186 | 187 | return (model_patcher, clip) 188 | 189 | class SD1_MBWLayers: 190 | @classmethod 191 | def INPUT_TYPES(cls) -> Dict[str, tuple]: 192 | arg_dict = { } 193 | 194 | argument = ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}) 195 | 196 | for i in range(12): 197 | arg_dict[f"input_blocks.{i}"] = argument 198 | 199 | arg_dict[f"middle_blocks"] = argument 200 | 201 | for i in range(12): 202 | arg_dict[f"output_blocks.{i}"] = argument 203 | 204 | return {"required": arg_dict} 205 | 206 | RETURN_TYPES = ["MBW_LAYERS"] 207 | FUNCTION = "return_layers" 208 | CATEGORY = "TechNodes/merging" 209 | 210 | def return_layers(self, **inputs) -> Dict[str, float]: 211 | return [ list(inputs.values()) ] 212 | 213 | class SD1_MBWLayers_Binary: 214 | @classmethod 215 | def INPUT_TYPES(cls) -> Dict[str, tuple]: 216 | arg_dict = { } 217 | 218 | argument = ("BOOLEAN", {"default": False}) 219 | 220 | for i in range(12): 221 | arg_dict[f"input_blocks.{i}"] = argument 222 | 223 | arg_dict[f"middle_blocks"] = argument 224 | 225 | for i in range(12): 226 | arg_dict[f"output_blocks.{i}"] = argument 227 | 228 | return {"required": arg_dict} 229 | 230 | RETURN_TYPES = ["MBW_LAYERS"] 231 | FUNCTION = "return_layers" 232 | CATEGORY = "TechNodes/merging" 233 | 234 | def return_layers(self, **inputs) -> Dict[str, List[int]]: 235 | return [list(int(value) for value in inputs.values())] 236 | 237 | class SDXL_MBWLayers: 238 | @classmethod 239 | def INPUT_TYPES(cls) -> Dict[str, tuple]: 240 | arg_dict = { } 241 | 242 | argument = ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}) 243 | 244 | for i in range(9): 245 | arg_dict[f"input_blocks.{i}"] = argument 246 | 247 | arg_dict[f"middle_blocks"] = argument 248 | 249 | for i in range(9): 250 | arg_dict[f"output_blocks.{i}"] = argument 251 | 252 | return {"required": arg_dict} 253 | 254 | RETURN_TYPES = ["MBW_LAYERS"] 255 | FUNCTION = "return_layers" 256 | CATEGORY = "TechNodes/merging" 257 | 258 | def return_layers(self, **inputs) -> Dict[str, float]: 259 | return [ list(inputs.values()) ] 260 | 261 | class SDXL_MBWLayers_Binary: 262 | @classmethod 263 | def INPUT_TYPES(cls) -> Dict[str, tuple]: 264 | arg_dict = { } 265 | 266 | argument = ("BOOLEAN", {"default": False}) 267 | 268 | for i in range(9): 269 | arg_dict[f"input_blocks.{i}"] = argument 270 | 271 | arg_dict[f"middle_blocks"] = argument 272 | 273 | for i in range(9): 274 | arg_dict[f"output_blocks.{i}"] = argument 275 | 276 | return {"required": arg_dict} 277 | 278 | RETURN_TYPES = ["MBW_LAYERS"] 279 | FUNCTION = "return_layers" 280 | CATEGORY = "TechNodes/merging" 281 | 282 | def return_layers(self, **inputs) -> Dict[str, List[int]]: 283 | return [list(int(value) for value in inputs.values())] 284 | 285 | class MBWLayers_String: 286 | @classmethod 287 | def INPUT_TYPES(cls): 288 | return { 289 | "required": { 290 | "mbw_layers": ("STRING", {"multiline": True, "default": "[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]"} ) 291 | } 292 | } 293 | 294 | RETURN_TYPES = ["MBW_LAYERS"] 295 | FUNCTION = "return_layers" 296 | CATEGORY = "TechNodes/merging" 297 | 298 | def return_layers(self, mbw_layers): 299 | return [ ast.literal_eval(mbw_layers) ] 300 | 301 | class VAERepeat: 302 | @classmethod 303 | def INPUT_TYPES(s): 304 | return { 305 | "required": { 306 | "images": ["IMAGE"], 307 | "vae": ["VAE"], 308 | "count": ["INT", {"default": 4, "min": 1, "max": 1000000}], 309 | } 310 | } 311 | RETURN_TYPES = ["IMAGE"] 312 | FUNCTION = "recode" 313 | 314 | CATEGORY = "TechNodes/latent" 315 | 316 | def recode(self, vae, images, count): 317 | for x in range(count): 318 | latent = { "samples": vae.encode(images[:,:,:,:3]) } 319 | images = vae.decode(latent["samples"]) 320 | return [images] -------------------------------------------------------------------------------- /vae_merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | 4 | from tqdm import tqdm 5 | import torch 6 | import safetensors.torch 7 | 8 | import comfy.sd 9 | import comfy.utils 10 | 11 | from . import merge_methods 12 | 13 | from torch import nn 14 | 15 | def merge_state_dict(sd_a, sd_b, sd_c, alpha, beta, weights, mode): 16 | def get_alpha(key): 17 | try: 18 | filtered = sorted( 19 | [x for x in weights.keys() if key.startswith(x)], key=len, reverse=True 20 | ) 21 | if len(filtered) < 1: 22 | return alpha 23 | return weights[filtered[0]] 24 | except: 25 | return alpha 26 | 27 | ckpt_keys = ( 28 | sd_a.keys() & sd_b.keys() 29 | if sd_c is None 30 | else sd_a.keys() & sd_b.keys() & sd_c.keys() 31 | ) 32 | 33 | for key in tqdm(ckpt_keys): 34 | current_alpha = get_alpha(key) if weights is not None else alpha 35 | 36 | if mode == "weighted_sum": 37 | sd_a[key] = merge_methods.weighted_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha) 38 | elif mode == "weighted_subtraction": 39 | sd_a[key] = merge_methods.weighted_subtraction(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta) 40 | elif mode == "tensor_sum": 41 | sd_a[key] = merge_methods.tensor_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta) 42 | elif mode == "add_difference": 43 | assert sd_c is not None, "vae_c is undefined" 44 | sd_a[key] = merge_methods.add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha) 45 | elif mode == "sum_twice": 46 | assert sd_c is not None, "vae_c is undefined" 47 | sd_a[key] = merge_methods.sum_twice(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 48 | elif mode == "triple_sum": 49 | assert sd_c is not None, "vae_c is undefined" 50 | sd_a[key] = merge_methods.triple_sum(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 51 | elif mode == "euclidean_add_difference": 52 | assert sd_c is not None, "vae_c is undefined" 53 | sd_a[key] = merge_methods.euclidean_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha) 54 | elif mode == "multiply_difference": 55 | assert sd_c is not None, "vae_c is undefined" 56 | sd_a[key] = merge_methods.multiply_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 57 | elif mode == "top_k_tensor_sum": 58 | sd_a[key] = merge_methods.top_k_tensor_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta) 59 | elif mode == "similarity_add_difference": 60 | assert sd_c is not None, "vae_c is undefined" 61 | sd_a[key] = merge_methods.similarity_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 62 | elif mode == "distribution_crossover": 63 | assert sd_c is not None, "vae_c is undefined" 64 | sd_a[key] = merge_methods.distribution_crossover(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 65 | elif mode == "ties_add_difference": 66 | assert sd_c is not None, "vae_c is undefined" 67 | sd_a[key] = merge_methods.ties_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta) 68 | 69 | return sd_a 70 | 71 | class VAEMerge: 72 | @classmethod 73 | def INPUT_TYPES(cls): 74 | return { 75 | "required": { 76 | "vae_a": ("VAE",), 77 | "vae_b": ("VAE",), 78 | "merge_mode": ([ 79 | "weighted_sum", 80 | "weighted_subtraction", 81 | "tensor_sum", 82 | "add_difference", 83 | "sum_twice", 84 | "triple_sum", 85 | "euclidean_add_difference", 86 | "multiply_difference", 87 | "top_k_tensor_sum", 88 | "similarity_add_difference", 89 | "distribution_crossover", 90 | "ties_add_difference", 91 | ],), 92 | "alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 93 | "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 94 | "brightness": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}), 95 | "contrast": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}), 96 | "use_blocks": ("BOOLEAN", {"default": False}), 97 | "block_conv_out": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 98 | "block_norm_out": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 99 | "block_0": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 100 | "block_1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 101 | "block_2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 102 | "block_3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 103 | "block_mid": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 104 | "block_conv_in": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 105 | "block_quant_conv": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), 106 | }, 107 | "optional": { 108 | "vae_c": ("VAE",), 109 | } 110 | } 111 | 112 | RETURN_TYPES = ["VAE"] 113 | FUNCTION = "merge_vae" 114 | 115 | CATEGORY = "TechNodes/merging" 116 | 117 | def merge_vae(self, vae_a, vae_b, merge_mode, alpha, beta, brightness, contrast, use_blocks, block_conv_out, block_norm_out, block_0, block_1, block_2, block_3, block_mid, block_conv_in, block_quant_conv, vae_c=None): 118 | vae_a_model = vae_a.first_stage_model.state_dict() 119 | vae_b_model = vae_b.first_stage_model.state_dict() 120 | vae_c_model = None 121 | if merge_mode in ["add_difference", "sum_twice", "triple_sum", "euclidean_add_difference", "multiply_difference", "similarity_add_difference", "distribution_crossover", "ties_add_difference"]: 122 | vae_c_model = vae_c.first_stage_model.state_dict() 123 | 124 | weights = { 125 | 'encoder.conv_out': block_conv_out, 126 | 'encoder.norm_out': block_norm_out, 127 | 'encoder.down.0': block_0, 128 | 'encoder.down.1': block_1, 129 | 'encoder.down.2': block_2, 130 | 'encoder.down.3': block_3, 131 | 'encoder.mid': block_mid, 132 | 'encoder.conv_in': block_conv_in, 133 | 'quant_conv': block_quant_conv, 134 | 'decoder.conv_out': block_conv_out, 135 | 'decoder.norm_out': block_norm_out, 136 | 'decoder.up.0': block_0, 137 | 'decoder.up.1': block_1, 138 | 'decoder.up.2': block_2, 139 | 'decoder.up.3': block_3, 140 | 'decoder.mid': block_mid, 141 | 'decoder.conv_in': block_conv_in, 142 | 'post_quant_conv': block_quant_conv 143 | } 144 | 145 | if(not use_blocks): 146 | weights = {} 147 | 148 | merged_vae = merge_state_dict(vae_a_model, vae_b_model, vae_c_model, alpha, beta, weights, mode=merge_mode) 149 | 150 | merged_vae["decoder.conv_out.bias"] = nn.Parameter(merged_vae["decoder.conv_out.bias"] + brightness) 151 | 152 | merged_vae["decoder.conv_out.weight"] = nn.Parameter(merged_vae["decoder.conv_out.weight"] + contrast / 40) 153 | 154 | comfy_vae = comfy.sd.VAE(merged_vae) 155 | 156 | return (comfy_vae,) --------------------------------------------------------------------------------