├── .github
└── workflows
│ └── publish.yml
├── LICENSE
├── README.md
├── __init__.py
├── images
├── MBW_Layers_showcase.webp
├── Quant_Nodes_showcase.webp
├── SDNext_Merge_showcase.webp
├── VAE_Merge_showcase.webp
└── VAE_Repeat_showcase.webp
├── merge.py
├── merge_PermSpec.py
├── merge_PermSpec_SDXL.py
├── merge_methods.py
├── merge_presets.py
├── merge_rebasin.py
├── merge_utils.py
├── pyproject.toml
├── quant_nodes.py
├── sdnextmerge_nodes.py
└── vae_merge.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | jobs:
12 | publish-node:
13 | name: Publish Custom Node to registry
14 | runs-on: ubuntu-latest
15 | # if this is a forked repository. Skipping the workflow.
16 | if: github.event.repository.fork == false
17 | steps:
18 | - name: Check out code
19 | uses: actions/checkout@v4
20 | - name: Publish Custom Node
21 | uses: Comfy-Org/publish-node-action@main
22 | with:
23 | ## Add your own personal access token to your Github Repository secrets and reference it here.
24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TechNodes
2 | ComfyUI nodes for merging, testing and more.
3 |
4 |
5 | ## Installation
6 | Inside the `ComfyUI/custom_nodes` directory, run:
7 |
8 | ```
9 | git clone https://github.com/TechnoByteJS/ComfyUI-TechNodes --depth 1
10 | ```
11 |
12 | ## SDNext Merge
13 | The merger from [SD.Next](https://github.com/vladmandic/automatic) (based on [meh](https://github.com/s1dlx/meh)) ported to ComfyUI, with [Re-Basin](https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion) built-in.
14 |
15 | 
16 |
17 | ## VAE Merge
18 | A node that lets you merge VAEs using multiple methods (and support for individual blocks), and adjust the brightness or contrast.
19 |
20 | 
21 |
22 | ## MBW Layers
23 | Allows for advanced merging by adjusting the alpha of each U-Net block individually, with binary versions that make it easy to extract specific layers.
24 |
25 | 
26 |
27 | ## Repeat VAE
28 | A node that encodes and decodes an image with a VAE a specified amount of times, useful for testing and comparing the performance of different VAEs.
29 |
30 | 
31 |
32 | ## Quantization
33 | Quantize the U-Net, CLIP, or VAE to the specified amount of bits
34 | > Note: This is purely experimental, there is no speed or storage benefits from this.
35 |
36 | 
37 |
38 | ### Credits
39 | To create these nodes, I used code from:
40 | - [SD.Next](https://github.com/vladmandic/automatic)
41 | - [meh](https://github.com/s1dlx/meh)
42 | - [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
43 | - [VAE-BlessUp](https://github.com/sALTaccount/VAE-BlessUp)
44 |
45 | Thank you [Kybalico](https://github.com/kybalico/) and [NovaZone](https://civitai.com/user/nova1337) for helping me test, and providing suggestions! ✨
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from .sdnextmerge_nodes import *
2 | from .vae_merge import *
3 | from .quant_nodes import *
4 |
5 | NODE_CLASS_MAPPINGS = {
6 | "SDNext Merge": SDNextMerge,
7 | "VAE Merge": VAEMerge,
8 |
9 | "SD1 MBW Layers": SD1_MBWLayers,
10 | "SD1 MBW Layers Binary": SD1_MBWLayers_Binary,
11 | "SDXL MBW Layers": SDXL_MBWLayers,
12 | "SDXL MBW Layers Binary": SDXL_MBWLayers_Binary,
13 | "MBW Layers String": MBWLayers_String,
14 |
15 | "VAERepeat": VAERepeat,
16 |
17 | "ModelQuant": ModelQuant,
18 | "ClipQuant": ClipQuant,
19 | "VAEQuant": VAEQuant,
20 | }
21 |
22 | NODE_DISPLAY_NAME_MAPPINGS = {
23 | "SDNext Merge": "SDNext Merge",
24 | "VAE Merge": "VAE Merge",
25 |
26 | "SD1 MBW Layers": "SD1 MBW Layers",
27 | "SD1 MBW Layers Binary": "SD1 MBW Layers Binary",
28 | "SDXL MBW Layers": "SDXL MBW Layers",
29 | "SDXL MBW Layers Binary": "SDXL MBW Layers Binary",
30 | "MBW Layers String": "MBW Layers String",
31 |
32 | "VAERepeat": "Repeat VAE",
33 |
34 | "ModelQuant": "ModelQuant",
35 | "ClipQuant": "ClipQuant",
36 | "VAEQuant": "VAEQuant",
37 | }
38 |
39 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
40 |
--------------------------------------------------------------------------------
/images/MBW_Layers_showcase.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/MBW_Layers_showcase.webp
--------------------------------------------------------------------------------
/images/Quant_Nodes_showcase.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/Quant_Nodes_showcase.webp
--------------------------------------------------------------------------------
/images/SDNext_Merge_showcase.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/SDNext_Merge_showcase.webp
--------------------------------------------------------------------------------
/images/VAE_Merge_showcase.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/VAE_Merge_showcase.webp
--------------------------------------------------------------------------------
/images/VAE_Repeat_showcase.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TechnoByteJS/ComfyUI-TechNodes/038d32cd28751618ae41c1c8233f7aec026e1288/images/VAE_Repeat_showcase.webp
--------------------------------------------------------------------------------
/merge.py:
--------------------------------------------------------------------------------
1 | import os
2 | from concurrent.futures import ThreadPoolExecutor
3 | from contextlib import contextmanager
4 | from typing import Dict, Optional, Tuple, Set
5 | import safetensors.torch
6 | import torch
7 | from . import merge_methods
8 | from .merge_utils import WeightClass
9 | from .merge_rebasin import (
10 | apply_permutation,
11 | update_model_a,
12 | weight_matching,
13 | )
14 | from .merge_PermSpec import sdunet_permutation_spec
15 | from .merge_PermSpec_SDXL import sdxl_permutation_spec
16 |
17 | from tqdm import tqdm
18 |
19 | import comfy.utils
20 | import comfy.model_management
21 |
22 | MAX_TOKENS = 77
23 |
24 |
25 | KEY_POSITION_IDS = ".".join(
26 | [
27 | "cond_stage_model",
28 | "transformer",
29 | "text_model",
30 | "embeddings",
31 | "position_ids",
32 | ]
33 | )
34 |
35 |
36 | def fix_clip(model: Dict) -> Dict:
37 | if KEY_POSITION_IDS in model.keys():
38 | model[KEY_POSITION_IDS] = torch.tensor(
39 | [list(range(MAX_TOKENS))],
40 | dtype=torch.int64,
41 | device=model[KEY_POSITION_IDS].device,
42 | )
43 |
44 | return model
45 |
46 |
47 | def prune_sd_model(model: Dict, keyset: Set) -> Dict:
48 | keys = list(model.keys())
49 | for k in keys:
50 | if (
51 | not k.startswith("model.diffusion_model.") # UNET
52 | # and not k.startswith("first_stage_model.") # VAE
53 | and not k.startswith("cond_stage_model.") # CLIP
54 | and not k.startswith("conditioner.embedders.") # SDXL CLIP
55 | ) or k not in keyset:
56 | del model[k]
57 | return model
58 |
59 |
60 | def restore_sd_model(original_model: Dict, merged_model: Dict) -> Dict:
61 | for k in original_model:
62 | if k not in merged_model:
63 | merged_model[k] = original_model[k]
64 | return merged_model
65 |
66 | def load_thetas(
67 | model_paths: Dict[str, os.PathLike],
68 | should_prune: bool,
69 | target_device: torch.device,
70 | precision: str,
71 | ) -> Dict:
72 | """
73 | Load and process model parameters from given paths.
74 |
75 | Args:
76 | model_paths: Dictionary of model names and their file paths
77 | should_prune: Flag to determine if models should be pruned
78 | target_device: The device to load the models onto
79 | precision: The precision to use for the model parameters
80 |
81 | Returns:
82 | Dictionary of processed model parameters
83 | """
84 | # Load model parameters from files
85 | model_params = {
86 | model_name: comfy.utils.load_torch_file(model_path)
87 | for model_name, model_path in model_paths.items()
88 | }
89 |
90 | if should_prune:
91 | # Find common keys across all models
92 | common_keys = set.intersection(*[set(model.keys()) for model in model_params.values() if len(model.keys())])
93 | # Prune models to keep only common parameters
94 | model_params = {
95 | model_name: prune_sd_model(model, common_keys)
96 | for model_name, model in model_params.items()
97 | }
98 |
99 | # Process each model's parameters
100 | for model_name, model in model_params.items():
101 | for param_name, param_tensor in model.items():
102 | if precision == "fp16":
103 | # Convert to half precision and move to target device
104 | model_params[model_name].update({param_name: param_tensor.to(target_device).half()})
105 | else:
106 | # Move to target device maintaining original precision
107 | model_params[model_name].update({param_name: param_tensor.to(target_device)})
108 |
109 | print("Models loaded successfully")
110 | return model_params
111 |
112 | def merge_models(
113 | models: Dict[str, os.PathLike],
114 | merge_mode: str,
115 | precision: str = "fp16",
116 | weights_clip: bool = False,
117 | device: torch.device = None,
118 | work_device: torch.device = None,
119 | prune: bool = False,
120 | threads: int = 4,
121 | optional_model_a = None,
122 | optional_clip_a = None,
123 | optional_model_b = None,
124 | optional_clip_b = None,
125 | optional_model_c = None,
126 | optional_clip_c = None,
127 | **kwargs,
128 | ) -> Dict:
129 | print("Alpha:")
130 | print(kwargs["alpha"])
131 |
132 | if models == { }:
133 | thetas = { }
134 | else:
135 | thetas = load_thetas(models, prune, device, precision)
136 |
137 | if "model_a" not in thetas:
138 | thetas["model_a"] = {}
139 |
140 | if "model_b" not in thetas:
141 | thetas["model_b"] = {}
142 |
143 | if optional_model_a is not None:
144 | key_patches = optional_model_a.get_key_patches()
145 | for key in key_patches:
146 | if "diffusion_model." in key:
147 | thetas["model_a"]["model." + key] = key_patches[key][0]
148 |
149 | if optional_clip_a is not None:
150 | key_patches = optional_clip_a.get_key_patches()
151 | for key in key_patches:
152 | if "transformer." in key and "text_projection" not in key:
153 | thetas["model_a"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0]
154 |
155 | if optional_model_b is not None:
156 | key_patches = optional_model_b.get_key_patches()
157 | for key in key_patches:
158 | if "diffusion_model." in key:
159 | thetas["model_b"]["model." + key] = key_patches[key][0]
160 |
161 | if optional_clip_b is not None:
162 | key_patches = optional_clip_b.get_key_patches()
163 | for key in key_patches:
164 | if "transformer." in key and "text_projection" not in key:
165 | thetas["model_b"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0]
166 |
167 | if optional_model_c is not None:
168 | if "model_c" not in thetas:
169 | thetas["model_c"] = {}
170 | key_patches = optional_model_c.get_key_patches()
171 | for key in key_patches:
172 | if "diffusion_model." in key:
173 | thetas["model_c"]["model." + key] = key_patches[key][0]
174 |
175 | if optional_clip_c is not None:
176 | if "model_c" not in thetas:
177 | thetas["model_c"] = {}
178 | key_patches = optional_clip_c.get_key_patches()
179 | for key in key_patches:
180 | if "transformer." in key and "text_projection" not in key:
181 | thetas["model_c"][key.replace("clip_l", "cond_stage_model")] = key_patches[key][0]
182 |
183 | print(f'Merge start: models={models.values()} precision={precision} clip={weights_clip} prune={prune} threads={threads}')
184 | weight_matcher = WeightClass(thetas["model_a"], **kwargs)
185 | if kwargs.get("re_basin", False):
186 | merged = rebasin_merge(
187 | thetas,
188 | weight_matcher,
189 | merge_mode,
190 | precision=precision,
191 | weights_clip=weights_clip,
192 | iterations=kwargs.get("re_basin_iterations", 1),
193 | device=device,
194 | work_device=work_device,
195 | threads=threads,
196 | )
197 | else:
198 | merged = simple_merge(
199 | thetas,
200 | weight_matcher,
201 | merge_mode,
202 | precision=precision,
203 | weights_clip=weights_clip,
204 | device=device,
205 | work_device=work_device,
206 | threads=threads,
207 | )
208 |
209 | return fix_clip(merged)
210 |
211 | def simple_merge(
212 | thetas: Dict[str, Dict],
213 | weight_matcher: WeightClass,
214 | merge_mode: str,
215 | precision: str = "fp16",
216 | weights_clip: bool = False,
217 | device: torch.device = None,
218 | work_device: torch.device = None,
219 | threads: int = 4,
220 | ) -> Dict:
221 | futures = []
222 | with tqdm(thetas["model_a"].keys(), desc="Merge") as progress:
223 | with ThreadPoolExecutor(max_workers=threads) as executor:
224 | for key in thetas["model_a"].keys():
225 | future = executor.submit(
226 | simple_merge_key,
227 | progress,
228 | key,
229 | thetas,
230 | weight_matcher,
231 | merge_mode,
232 | precision,
233 | weights_clip,
234 | device,
235 | work_device,
236 | )
237 | futures.append(future)
238 |
239 | for res in futures:
240 | res.result()
241 |
242 | if len(thetas["model_b"]) > 0:
243 | print(f'Merge update thetas: keys={len(thetas["model_b"])}')
244 | for key in thetas["model_b"].keys():
245 | if KEY_POSITION_IDS in key:
246 | continue
247 | if "model" in key and key not in thetas["model_a"].keys():
248 | thetas["model_a"].update({key: thetas["model_b"][key]})
249 | if precision == "fp16":
250 | thetas["model_a"].update({key: thetas["model_a"][key].half()})
251 |
252 | return fix_clip(thetas["model_a"])
253 |
254 |
255 | def rebasin_merge(
256 | thetas: Dict[str, os.PathLike],
257 | weight_matcher: WeightClass,
258 | merge_mode: str,
259 | precision: str = "fp16",
260 | weights_clip: bool = False,
261 | iterations: int = 1,
262 | device: torch.device = None,
263 | work_device: torch.device = None,
264 | threads: int = 1,
265 | ):
266 | # not sure how this does when 3 models are involved...
267 | model_a = thetas["model_a"]
268 | if weight_matcher.SDXL:
269 | perm_spec = sdxl_permutation_spec()
270 | else:
271 | perm_spec = sdunet_permutation_spec()
272 |
273 | for it in range(iterations):
274 | print(f"rebasin: iteration={it+1}")
275 | weight_matcher.set_it(it)
276 |
277 | # normal block merge we already know and love
278 | thetas["model_a"] = simple_merge(
279 | thetas,
280 | weight_matcher,
281 | merge_mode,
282 | precision,
283 | False,
284 | device,
285 | work_device,
286 | threads,
287 | )
288 |
289 | # find permutations
290 | perm_1, y = weight_matching(
291 | perm_spec,
292 | model_a,
293 | thetas["model_a"],
294 | max_iter=it,
295 | init_perm=None,
296 | usefp16=precision == "fp16",
297 | device=device,
298 | )
299 | thetas["model_a"] = apply_permutation(perm_spec, perm_1, thetas["model_a"])
300 |
301 | perm_2, z = weight_matching(
302 | perm_spec,
303 | thetas["model_b"],
304 | thetas["model_a"],
305 | max_iter=it,
306 | init_perm=None,
307 | usefp16=precision == "fp16",
308 | device=device,
309 | )
310 |
311 | new_alpha = torch.nn.functional.normalize(
312 | torch.sigmoid(torch.Tensor([y, z])), p=1, dim=0
313 | ).tolist()[0]
314 | thetas["model_a"] = update_model_a(
315 | perm_spec, perm_2, thetas["model_a"], new_alpha
316 | )
317 |
318 | if weights_clip:
319 | clip_thetas = thetas.copy()
320 | clip_thetas["model_a"] = model_a
321 | thetas["model_a"] = clip_weights(thetas, thetas["model_a"])
322 |
323 | return thetas["model_a"]
324 |
325 |
326 | def simple_merge_key(progress, key, thetas, *args, **kwargs):
327 | with merge_key_context(key, thetas, *args, **kwargs) as result:
328 | if result is not None:
329 | thetas["model_a"].update({key: result.detach().clone()})
330 | progress.update(1)
331 |
332 |
333 | def merge_key( # pylint: disable=inconsistent-return-statements
334 | key: str,
335 | thetas: Dict,
336 | weight_matcher: WeightClass,
337 | merge_mode: str,
338 | precision: str = "fp16",
339 | weights_clip: bool = False,
340 | device: torch.device = None,
341 | work_device: torch.device = None,
342 | ) -> Optional[Tuple[str, Dict]]:
343 | if work_device is None:
344 | work_device = device
345 |
346 | if KEY_POSITION_IDS in key:
347 | return
348 |
349 | for theta in thetas.values():
350 | if key not in theta.keys():
351 | return thetas["model_a"][key]
352 |
353 | current_bases = weight_matcher(key)
354 | try:
355 | merge_method = getattr(merge_methods, merge_mode)
356 | except AttributeError as e:
357 | raise ValueError(f"{merge_mode} not implemented, aborting merge!") from e
358 |
359 | merge_args = get_merge_method_args(current_bases, thetas, key, work_device)
360 |
361 | # dealing with pix2pix and inpainting models
362 | if (a_size := merge_args["a"].size()) != (b_size := merge_args["b"].size()):
363 | if a_size[1] > b_size[1]:
364 | merged_key = merge_args["a"]
365 | else:
366 | merged_key = merge_args["b"]
367 | else:
368 | merged_key = merge_method(**merge_args).to(device)
369 |
370 | if weights_clip:
371 | merged_key = clip_weights_key(thetas, merged_key, key)
372 |
373 | if precision == "fp16":
374 | merged_key = merged_key.half()
375 |
376 | return merged_key
377 |
378 |
379 | def clip_weights(thetas, merged):
380 | for k in thetas["model_a"].keys():
381 | if k in thetas["model_b"].keys():
382 | merged.update({k: clip_weights_key(thetas, merged[k], k)})
383 | return merged
384 |
385 | def clip_weights_key(thetas, merged_weights, key):
386 | # Determine the device of the merged_weights
387 | device = merged_weights.device
388 |
389 | # Move all tensors to the same device
390 | t0 = thetas["model_a"][key].to(device)
391 | t1 = thetas["model_b"][key].to(device)
392 |
393 | maximums = torch.maximum(t0, t1)
394 | minimums = torch.minimum(t0, t1)
395 |
396 | return torch.minimum(torch.maximum(merged_weights, minimums), maximums)
397 |
398 | @contextmanager
399 | def merge_key_context(*args, **kwargs):
400 | result = merge_key(*args, **kwargs)
401 | try:
402 | yield result
403 | finally:
404 | if result is not None:
405 | del result
406 |
407 |
408 | def get_merge_method_args(
409 | current_bases: Dict,
410 | thetas: Dict,
411 | key: str,
412 | work_device: torch.device,
413 | ) -> Dict:
414 | merge_method_args = {
415 | "a": thetas["model_a"][key].to(work_device),
416 | "b": thetas["model_b"][key].to(work_device),
417 | **current_bases,
418 | }
419 |
420 | if "model_c" in thetas:
421 | merge_method_args["c"] = thetas["model_c"][key].to(work_device)
422 |
423 | return merge_method_args
424 |
--------------------------------------------------------------------------------
/merge_PermSpec.py:
--------------------------------------------------------------------------------
1 | from .merge_rebasin import PermutationSpec, permutation_spec_from_axes_to_perm
2 | def sdunet_permutation_spec() -> PermutationSpec:
3 | conv = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
4 | f"{name}.weight": (
5 | p_out,
6 | p_in,
7 | ),
8 | f"{name}.bias": (p_out,),
9 | }
10 | norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)} # pylint: disable=unnecessary-lambda-assignment
11 | dense = (
12 | lambda name, p_in, p_out, bias=True: { # pylint: disable=unnecessary-lambda-assignment
13 | f"{name}.weight": (p_out, p_in),
14 | f"{name}.bias": (p_out,),
15 | }
16 | if bias
17 | else {f"{name}.weight": (p_out, p_in)}
18 | )
19 | skip = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
20 | f"{name}": (
21 | p_out,
22 | p_in,
23 | None,
24 | None,
25 | )
26 | }
27 |
28 | # Unet Res blocks
29 | easyblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
30 | **norm(f"{name}.in_layers.0", p_in),
31 | **conv(f"{name}.in_layers.2", p_in, f"P_{name}_inner"),
32 | **dense(
33 | f"{name}.emb_layers.1", f"P_{name}_inner2", f"P_{name}_inner3", bias=True
34 | ),
35 | **norm(f"{name}.out_layers.0", f"P_{name}_inner4"),
36 | **conv(f"{name}.out_layers.3", f"P_{name}_inner4", p_out),
37 | }
38 |
39 | # VAE blocks - Unused
40 | easyblock2 = lambda name, p: { # pylint: disable=unnecessary-lambda-assignment, unused-variable # noqa: F841
41 | **norm(f"{name}.norm1", p),
42 | **conv(f"{name}.conv1", p, f"P_{name}_inner"),
43 | **norm(f"{name}.norm2", f"P_{name}_inner"),
44 | **conv(f"{name}.conv2", f"P_{name}_inner", p),
45 | }
46 |
47 | # This is for blocks that use a residual connection, but change the number of channels via a Conv.
48 | shortcutblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment, , unused-variable # noqa: F841
49 | **norm(f"{name}.norm1", p_in),
50 | **conv(f"{name}.conv1", p_in, f"P_{name}_inner"),
51 | **norm(f"{name}.norm2", f"P_{name}_inner"),
52 | **conv(f"{name}.conv2", f"P_{name}_inner", p_out),
53 | **conv(f"{name}.nin_shortcut", p_in, p_out),
54 | **norm(f"{name}.nin_shortcut", p_out),
55 | }
56 |
57 | return permutation_spec_from_axes_to_perm(
58 | {
59 | # Skipped Layers
60 | **skip("betas", None, None),
61 | **skip("alphas_cumprod", None, None),
62 | **skip("alphas_cumprod_prev", None, None),
63 | **skip("sqrt_alphas_cumprod", None, None),
64 | **skip("sqrt_one_minus_alphas_cumprod", None, None),
65 | **skip("log_one_minus_alphas_cumprods", None, None),
66 | **skip("sqrt_recip_alphas_cumprod", None, None),
67 | **skip("sqrt_recipm1_alphas_cumprod", None, None),
68 | **skip("posterior_variance", None, None),
69 | **skip("posterior_log_variance_clipped", None, None),
70 | **skip("posterior_mean_coef1", None, None),
71 | **skip("posterior_mean_coef2", None, None),
72 | **skip("log_one_minus_alphas_cumprod", None, None),
73 | **skip("model_ema.decay", None, None),
74 | **skip("model_ema.num_updates", None, None),
75 | # initial
76 | **dense("model.diffusion_model.time_embed.0", None, "P_bg0", bias=True),
77 | **dense("model.diffusion_model.time_embed.2", "P_bg0", "P_bg1", bias=True),
78 | **conv("model.diffusion_model.input_blocks.0.0", "P_bg2", "P_bg3"),
79 | # input blocks
80 | **easyblock("model.diffusion_model.input_blocks.1.0", "P_bg4", "P_bg5"),
81 | **norm("model.diffusion_model.input_blocks.1.1.norm", "P_bg6"),
82 | **conv("model.diffusion_model.input_blocks.1.1.proj_in", "P_bg6", "P_bg7"),
83 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q", "P_bg8", "P_bg9", bias=False),
84 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k", "P_bg8", "P_bg9", bias=False),
85 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v", "P_bg8", "P_bg9", bias=False),
86 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0", "P_bg8", "P_bg9", bias=True),
87 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj", "P_bg10", "P_bg11", bias=True),
88 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff.net.2", "P_bg12", "P_bg13", bias=True),
89 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q", "P_bg14", "P_bg15", bias=False),
90 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k", "P_bg16", "P_bg17", bias=False),
91 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v", "P_bg16", "P_bg17", bias=False),
92 | **dense("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0", "P_bg18", "P_bg19", bias=True),
93 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm1", "P_bg19"),
94 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm2", "P_bg19"),
95 | **norm("model.diffusion_model.input_blocks.1.1.transformer_blocks.0.norm3", "P_bg19"),
96 | **conv("model.diffusion_model.input_blocks.1.1.proj_out", "P_bg19", "P_bg20"),
97 | **easyblock("model.diffusion_model.input_blocks.2.0", "P_bg21", "P_bg22"),
98 | **norm("model.diffusion_model.input_blocks.2.1.norm", "P_bg23"),
99 | **conv("model.diffusion_model.input_blocks.2.1.proj_in", "P_bg23", "P_bg24"),
100 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q", "P_bg25", "P_bg26", bias=False),
101 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k", "P_bg25", "P_bg26", bias=False),
102 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v", "P_bg25", "P_bg26", bias=False),
103 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0", "P_bg25", "P_bg26", bias=True),
104 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj", "P_bg27", "P_bg28", bias=True),
105 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff.net.2", "P_bg29", "P_bg30", bias=True),
106 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q", "P_bg31", "P_bg32", bias=False),
107 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k", "P_bg33", "P_bg34", bias=False),
108 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v", "P_bg33", "P_bg34", bias=False),
109 | **dense("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0", "P_bg35", "P_bg36", bias=True),
110 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm1", "P_bg36"),
111 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm2", "P_bg36"),
112 | **norm("model.diffusion_model.input_blocks.2.1.transformer_blocks.0.norm3", "P_bg36"),
113 | **conv("model.diffusion_model.input_blocks.2.1.proj_out", "P_bg36", "P_bg37"),
114 | **conv("model.diffusion_model.input_blocks.3.0.op", "P_bg38", "P_bg39"),
115 | **easyblock("model.diffusion_model.input_blocks.4.0", "P_bg40", "P_bg41"),
116 | **conv("model.diffusion_model.input_blocks.4.0.skip_connection", "P_bg42", "P_bg43"),
117 | **norm("model.diffusion_model.input_blocks.4.1.norm", "P_bg44"),
118 | **conv("model.diffusion_model.input_blocks.4.1.proj_in", "P_bg44", "P_bg45"),
119 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q", "P_bg46", "P_bg47", bias=False),
120 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k", "P_bg46", "P_bg47", bias=False),
121 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v", "P_bg46", "P_bg47", bias=False),
122 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0", "P_bg46", "P_bg47", bias=True),
123 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj", "P_bg48", "P_bg49", bias=True),
124 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff.net.2", "P_bg50", "P_bg51", bias=True),
125 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q", "P_bg52", "P_bg53", bias=False),
126 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k", "P_bg54", "P_bg55", bias=False),
127 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v", "P_bg54", "P_bg55", bias=False),
128 | **dense("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0", "P_bg56", "P_bg57", bias=True),
129 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm1", "P_bg57"),
130 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm2", "P_bg57"),
131 | **norm("model.diffusion_model.input_blocks.4.1.transformer_blocks.0.norm3", "P_bg57"),
132 | **conv("model.diffusion_model.input_blocks.4.1.proj_out", "P_bg57", "P_bg58"),
133 | **easyblock("model.diffusion_model.input_blocks.5.0", "P_bg59", "P_bg60"),
134 | **norm("model.diffusion_model.input_blocks.5.1.norm", "P_bg61"),
135 | **conv("model.diffusion_model.input_blocks.5.1.proj_in", "P_bg61", "P_bg62"),
136 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q", "P_bg63", "P_bg64", bias=False),
137 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k", "P_bg63", "P_bg64", bias=False),
138 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v", "P_bg63", "P_bg64", bias=False),
139 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0", "P_bg63", "P_bg64", bias=True),
140 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj", "P_bg65", "P_bg66", bias=True),
141 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff.net.2", "P_bg67", "P_bg68", bias=True),
142 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q", "P_bg69", "P_bg70", bias=False),
143 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k", "P_bg71", "P_bg72", bias=False),
144 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v", "P_bg71", "P_bg72", bias=False),
145 | **dense("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0", "P_bg73", "P_bg74", bias=True),
146 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm1", "P_bg74"),
147 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm2", "P_bg74"),
148 | **norm("model.diffusion_model.input_blocks.5.1.transformer_blocks.0.norm3", "P_bg74"),
149 | **conv("model.diffusion_model.input_blocks.5.1.proj_out", "P_bg74", "P_bg75"),
150 | **conv("model.diffusion_model.input_blocks.6.0.op", "P_bg76", "P_bg77"),
151 | **easyblock("model.diffusion_model.input_blocks.7.0", "P_bg78", "P_bg79"),
152 | **conv("model.diffusion_model.input_blocks.7.0.skip_connection", "P_bg80", "P_bg81"),
153 | **norm("model.diffusion_model.input_blocks.7.1.norm", "P_bg82"),
154 | **conv("model.diffusion_model.input_blocks.7.1.proj_in", "P_bg82", "P_bg83"),
155 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q", "P_bg84", "P_bg85", bias=False),
156 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k", "P_bg84", "P_bg85", bias=False),
157 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v", "P_bg84", "P_bg85", bias=False),
158 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0", "P_bg84", "P_bg85", bias=True),
159 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj", "P_bg86", "P_bg87", bias=True),
160 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff.net.2", "P_bg88", "P_bg89", bias=True),
161 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q", "P_bg90", "P_bg91", bias=False),
162 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k", "P_bg92", "P_bg93", bias=False),
163 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v", "P_bg92", "P_bg93", bias=False),
164 | **dense("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0", "P_bg94", "P_bg95", bias=True),
165 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm1", "P_bg95"),
166 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm2", "P_bg95"),
167 | **norm("model.diffusion_model.input_blocks.7.1.transformer_blocks.0.norm3", "P_bg95"),
168 | **conv("model.diffusion_model.input_blocks.7.1.proj_out", "P_bg95", "P_bg96"),
169 | **easyblock("model.diffusion_model.input_blocks.8.0", "P_bg97", "P_bg98"),
170 | **norm("model.diffusion_model.input_blocks.8.1.norm", "P_bg99"),
171 | **conv("model.diffusion_model.input_blocks.8.1.proj_in", "P_bg99", "P_bg100"),
172 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q", "P_bg101", "P_bg102", bias=False),
173 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k", "P_bg101", "P_bg102", bias=False),
174 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v", "P_bg101", "P_bg102", bias=False),
175 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0", "P_bg101", "P_bg102", bias=True),
176 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj", "P_bg103", "P_bg104", bias=True),
177 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff.net.2", "P_bg105", "P_bg106", bias=True),
178 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q", "P_bg107", "P_bg108", bias=False),
179 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k", "P_bg109", "P_bg110", bias=False),
180 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v", "P_bg109", "P_bg110", bias=False),
181 | **dense("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0", "P_bg111", "P_bg112", bias=True),
182 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm1", "P_bg112"),
183 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm2", "P_bg112"),
184 | **norm("model.diffusion_model.input_blocks.8.1.transformer_blocks.0.norm3", "P_bg112"),
185 | **conv("model.diffusion_model.input_blocks.8.1.proj_out", "P_bg112", "P_bg113"),
186 | **conv("model.diffusion_model.input_blocks.9.0.op", "P_bg114", "P_bg115"),
187 | **easyblock("model.diffusion_model.input_blocks.10.0", "P_bg115", "P_bg116"),
188 | **easyblock("model.diffusion_model.input_blocks.11.0", "P_bg116", "P_bg117"),
189 | # middle blocks
190 | **easyblock("model.diffusion_model.middle_block.0", "P_bg117", "P_bg118"),
191 | **norm("model.diffusion_model.middle_block.1.norm", "P_bg119"),
192 | **conv("model.diffusion_model.middle_block.1.proj_in", "P_bg119", "P_bg120"),
193 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q", "P_bg121", "P_bg122", bias=False),
194 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_k", "P_bg121", "P_bg122", bias=False),
195 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_v", "P_bg121", "P_bg122", bias=False),
196 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_out.0", "P_bg121", "P_bg122", bias=True),
197 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.0.proj", "P_bg123", "P_bg124", bias=True),
198 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.ff.net.2", "P_bg125", "P_bg126", bias=True),
199 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_q", "P_bg127", "P_bg128", bias=False),
200 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k", "P_bg129", "P_bg130", bias=False),
201 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v", "P_bg129", "P_bg130", bias=False),
202 | **dense("model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_out.0", "P_bg131", "P_bg132", bias=True),
203 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm1", "P_bg132"),
204 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm2", "P_bg132"),
205 | **norm("model.diffusion_model.middle_block.1.transformer_blocks.0.norm3", "P_bg132"),
206 | **conv("model.diffusion_model.middle_block.1.proj_out", "P_bg132", "P_bg133"),
207 | **easyblock("model.diffusion_model.middle_block.2", "P_bg134", "P_bg135"),
208 | # output blocks
209 | **easyblock("model.diffusion_model.output_blocks.0.0", "P_bg136", "P_bg137"),
210 | **conv("model.diffusion_model.output_blocks.0.0.skip_connection", "P_bg138", "P_bg139"),
211 | **easyblock("model.diffusion_model.output_blocks.1.0", "P_bg140", "P_bg141"),
212 | **conv("model.diffusion_model.output_blocks.1.0.skip_connection", "P_bg142", "P_bg143"),
213 | **easyblock("model.diffusion_model.output_blocks.2.0", "P_bg144", "P_bg145"),
214 | **conv("model.diffusion_model.output_blocks.2.0.skip_connection", "P_bg146", "P_bg147"),
215 | **conv("model.diffusion_model.output_blocks.2.1.conv", "P_bg148", "P_bg149"),
216 | **easyblock("model.diffusion_model.output_blocks.3.0", "P_bg150", "P_bg151"),
217 | **conv("model.diffusion_model.output_blocks.3.0.skip_connection", "P_bg152", "P_bg153"),
218 | **norm("model.diffusion_model.output_blocks.3.1.norm", "P_bg154"),
219 | **conv("model.diffusion_model.output_blocks.3.1.proj_in", "P_bg154", "P_bg155"),
220 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_q", "P_bg156", "P_bg157", bias=False),
221 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_k", "P_bg156", "P_bg157", bias=False),
222 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_v", "P_bg156", "P_bg157", bias=False),
223 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn1.to_out.0", "P_bg156", "P_bg157", bias=True),
224 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.0.proj", "P_bg158", "P_bg159", bias=True),
225 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff.net.2", "P_bg160", "P_bg161", bias=True),
226 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_q", "P_bg162", "P_bg163", bias=False),
227 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k", "P_bg164", "P_bg165", bias=False),
228 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v", "P_bg164", "P_bg165", bias=False),
229 | **dense("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_out.0", "P_bg166", "P_bg167", bias=True),
230 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm1", "P_bg167"),
231 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm2", "P_bg167"),
232 | **norm("model.diffusion_model.output_blocks.3.1.transformer_blocks.0.norm3", "P_bg167"),
233 | **conv("model.diffusion_model.output_blocks.3.1.proj_out", "P_bg167", "P_bg168"),
234 | **easyblock("model.diffusion_model.output_blocks.4.0", "P_bg169", "P_bg170"),
235 | **conv("model.diffusion_model.output_blocks.4.0.skip_connection", "P_bg171", "P_bg172"),
236 | **norm("model.diffusion_model.output_blocks.4.1.norm", "P_bg173"),
237 | **conv("model.diffusion_model.output_blocks.4.1.proj_in", "P_bg173", "P_bg174"),
238 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_q", "P_bg175", "P_bg176", bias=False),
239 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_k", "P_bg175", "P_bg176", bias=False),
240 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_v", "P_bg175", "P_bg176", bias=False),
241 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn1.to_out.0", "P_bg175", "P_bg176", bias=True),
242 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.0.proj", "P_bg177", "P_bg178", bias=True),
243 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff.net.2", "P_bg179", "P_bg180", bias=True),
244 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_q", "P_bg181", "P_bg182", bias=False),
245 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_k", "P_bg183", "P_bg184", bias=False),
246 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_v", "P_bg183", "P_bg184", bias=False),
247 | **dense("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.attn2.to_out.0", "P_bg185", "P_bg186", bias=True),
248 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm1", "P_bg186"),
249 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm2", "P_bg186"),
250 | **norm("model.diffusion_model.output_blocks.4.1.transformer_blocks.0.norm3", "P_bg186"),
251 | **conv("model.diffusion_model.output_blocks.4.1.proj_out", "P_bg186", "P_bg187"),
252 | **easyblock("model.diffusion_model.output_blocks.5.0", "P_bg188", "P_bg189"),
253 | **conv("model.diffusion_model.output_blocks.5.0.skip_connection", "P_bg190", "P_bg191"),
254 | **norm("model.diffusion_model.output_blocks.5.1.norm", "P_bg192"),
255 | **conv("model.diffusion_model.output_blocks.5.1.proj_in", "P_bg192", "P_bg193"),
256 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_q", "P_bg194", "P_bg195", bias=False),
257 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_k", "P_bg194", "P_bg195", bias=False),
258 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_v", "P_bg194", "P_bg195", bias=False),
259 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn1.to_out.0", "P_bg194", "P_bg195", bias=True),
260 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.0.proj", "P_bg196", "P_bg197", bias=True),
261 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff.net.2", "P_bg198", "P_bg199", bias=True),
262 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_q", "P_bg200", "P_bg201", bias=False),
263 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_k", "P_bg202", "P_bg203", bias=False),
264 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_v", "P_bg202", "P_bg203", bias=False),
265 | **dense("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.attn2.to_out.0", "P_bg204", "P_bg205", bias=True),
266 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm1", "P_bg205"),
267 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm2", "P_bg205"),
268 | **norm("model.diffusion_model.output_blocks.5.1.transformer_blocks.0.norm3", "P_bg205"),
269 | **conv("model.diffusion_model.output_blocks.5.1.proj_out", "P_bg205", "P_bg206"),
270 | **conv("model.diffusion_model.output_blocks.5.2.conv", "P_bg206", "P_bg207"),
271 | **easyblock("model.diffusion_model.output_blocks.6.0", "P_bg208", "P_bg209"),
272 | **conv("model.diffusion_model.output_blocks.6.0.skip_connection", "P_bg210", "P_bg211"),
273 | **norm("model.diffusion_model.output_blocks.6.1.norm", "P_bg212"),
274 | **conv("model.diffusion_model.output_blocks.6.1.proj_in", "P_bg212", "P_bg213"),
275 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_q", "P_bg214", "P_bg215", bias=False),
276 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_k", "P_bg214", "P_bg215", bias=False),
277 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_v", "P_bg214", "P_bg215", bias=False),
278 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn1.to_out.0", "P_bg214", "P_bg215", bias=True),
279 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.0.proj", "P_bg216", "P_bg217", bias=True),
280 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff.net.2", "P_bg218", "P_bg219", bias=True),
281 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_q", "P_bg220", "P_bg221", bias=False),
282 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_k", "P_bg222", "P_bg223", bias=False),
283 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_v", "P_bg222", "P_bg223", bias=False),
284 | **dense("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.attn2.to_out.0", "P_bg224", "P_bg225", bias=True),
285 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm1", "P_bg225"),
286 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm2", "P_bg225"),
287 | **norm("model.diffusion_model.output_blocks.6.1.transformer_blocks.0.norm3", "P_bg225"),
288 | **conv("model.diffusion_model.output_blocks.6.1.proj_out", "P_bg225", "P_bg226"),
289 | **easyblock("model.diffusion_model.output_blocks.7.0", "P_bg227", "P_bg228"),
290 | **conv("model.diffusion_model.output_blocks.7.0.skip_connection", "P_bg229", "P_bg230"),
291 | **norm("model.diffusion_model.output_blocks.7.1.norm", "P_bg231"),
292 | **conv("model.diffusion_model.output_blocks.7.1.proj_in", "P_bg231", "P_bg232"),
293 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_q", "P_bg233", "P_bg234", bias=False),
294 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k", "P_bg233", "P_bg234", bias=False),
295 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_v", "P_bg233", "P_bg234", bias=False),
296 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_out.0", "P_bg233", "P_bg234", bias=True),
297 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.0.proj", "P_bg235", "P_bg236", bias=True),
298 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff.net.2", "P_bg237", "P_bg238", bias=True),
299 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_q", "P_bg239", "P_bg240", bias=False),
300 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_k", "P_bg241", "P_bg242", bias=False),
301 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_v", "P_bg241", "P_bg242", bias=False),
302 | **dense("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn2.to_out.0", "P_bg243", "P_bg244", bias=True),
303 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm1", "P_bg244"),
304 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm2", "P_bg244"),
305 | **norm("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.norm3", "P_bg244"),
306 | **conv("model.diffusion_model.output_blocks.7.1.proj_out", "P_bg244", "P_bg245"),
307 | **easyblock("model.diffusion_model.output_blocks.8.0", "P_bg246", "P_bg247"),
308 | **conv("model.diffusion_model.output_blocks.8.0.skip_connection", "P_bg248", "P_bg249"),
309 | **norm("model.diffusion_model.output_blocks.8.1.norm", "P_bg250"),
310 | **conv("model.diffusion_model.output_blocks.8.1.proj_in", "P_bg250", "P_bg251"),
311 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_q", "P_bg252", "P_bg253", bias=False),
312 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_k", "P_bg252", "P_bg253", bias=False),
313 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_v", "P_bg252", "P_bg253", bias=False),
314 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn1.to_out.0", "P_bg252", "P_bg253", bias=True),
315 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.0.proj", "P_bg254", "P_bg255", bias=True),
316 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff.net.2", "P_bg256", "P_bg257", bias=True),
317 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_q", "P_bg258", "P_bg259", bias=False),
318 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_k", "P_bg260", "P_bg261", bias=False),
319 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_v", "P_bg260", "P_bg261", bias=False),
320 | **dense("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.attn2.to_out.0", "P_bg262", "P_bg263", bias=True),
321 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm1", "P_bg263"),
322 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm2", "P_bg263"),
323 | **norm("model.diffusion_model.output_blocks.8.1.transformer_blocks.0.norm3", "P_bg263"),
324 | **conv("model.diffusion_model.output_blocks.8.1.proj_out", "P_bg263", "P_bg264"),
325 | **conv("model.diffusion_model.output_blocks.8.2.conv", "P_bg265", "P_bg266"),
326 | **easyblock("model.diffusion_model.output_blocks.9.0", "P_bg267", "P_bg268"),
327 | **conv("model.diffusion_model.output_blocks.9.0.skip_connection", "P_bg269", "P_bg270"),
328 | **norm("model.diffusion_model.output_blocks.9.1.norm", "P_bg271"),
329 | **conv("model.diffusion_model.output_blocks.9.1.proj_in", "P_bg271", "P_bg272"),
330 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_q", "P_bg273", "P_bg274", bias=False),
331 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_k", "P_bg273", "P_bg274", bias=False),
332 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_v", "P_bg273", "P_bg274", bias=False),
333 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn1.to_out.0", "P_bg273", "P_bg274", bias=True),
334 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.0.proj", "P_bg275", "P_bg276", bias=True),
335 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff.net.2", "P_bg277", "P_bg278", bias=True),
336 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_q", "P_bg279", "P_bg280", bias=False),
337 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k", "P_bg281", "P_bg282", bias=False),
338 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v", "P_bg281", "P_bg282", bias=False),
339 | **dense("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_out.0", "P_bg283", "P_bg284", bias=True),
340 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm1", "P_bg284"),
341 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm2", "P_bg284"),
342 | **norm("model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3", "P_bg284"),
343 | **conv("model.diffusion_model.output_blocks.9.1.proj_out", "P_bg284", "P_bg285"),
344 | **easyblock("model.diffusion_model.output_blocks.10.0", "P_bg286", "P_bg287"),
345 | **conv("model.diffusion_model.output_blocks.10.0.skip_connection", "P_bg288", "P_bg289"),
346 | **norm("model.diffusion_model.output_blocks.10.1.norm", "P_bg290"),
347 | **conv("model.diffusion_model.output_blocks.10.1.proj_in", "P_bg290", "P_bg291"),
348 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_q", "P_bg292", "P_bg293", bias=False),
349 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_k", "P_bg292", "P_bg293", bias=False),
350 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_v", "P_bg292", "P_bg293", bias=False),
351 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn1.to_out.0", "P_bg292", "P_bg293", bias=True),
352 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.0.proj", "P_b294", "P_bg295", bias=True),
353 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff.net.2", "P_bg296", "P_bg297", bias=True),
354 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_q", "P_bg298", "P_bg299", bias=False),
355 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_k", "P_bg300", "P_bg301", bias=False),
356 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_v", "P_bg300", "P_bg301", bias=False),
357 | **dense("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.attn2.to_out.0", "P_bg302", "P_bg303", bias=True),
358 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm1", "P_bg303"),
359 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm2", "P_bg303"),
360 | **norm("model.diffusion_model.output_blocks.10.1.transformer_blocks.0.norm3", "P_bg303"),
361 | **conv("model.diffusion_model.output_blocks.10.1.proj_out", "P_bg303", "P_bg304"),
362 | **easyblock("model.diffusion_model.output_blocks.11.0", "P_bg305", "P_bg306"),
363 | **conv("model.diffusion_model.output_blocks.11.0.skip_connection", "P_bg307", "P_bg308"),
364 | **norm("model.diffusion_model.output_blocks.11.1.norm", "P_bg309"),
365 | **conv("model.diffusion_model.output_blocks.11.1.proj_in", "P_bg309", "P_bg310"),
366 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_q", "P_bg311", "P_bg312", bias=False),
367 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_k", "P_bg311", "P_bg312", bias=False),
368 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_v", "P_bg311", "P_bg312", bias=False),
369 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn1.to_out.0", "P_bg311", "P_bg312", bias=True),
370 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.0.proj", "P_bg313", "P_bg314", bias=True),
371 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff.net.2", "P_bg315", "P_bg316", bias=True),
372 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_q", "P_bg317", "P_bg318", bias=False),
373 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_k", "P_bg319", "P_bg320", bias=False),
374 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_v", "P_bg319", "P_bg320", bias=False),
375 | **dense("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.attn2.to_out.0", "P_bg321", "P_bg322", bias=True),
376 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1", "P_bg322"),
377 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm2", "P_bg322"),
378 | **norm("model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm3", "P_bg322"),
379 | **conv("model.diffusion_model.output_blocks.11.1.proj_out", "P_bg322", "P_bg323"),
380 | **norm("model.diffusion_model.out.0", "P_bg324"),
381 | **conv("model.diffusion_model.out.2", "P_bg325", "P_bg326"),
382 | **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None),
383 | **dense("cond_stage_model.transformer.text_model.embeddings.token_embedding", "P_bg365", "P_bg366", bias=False),
384 | **dense("cond_stage_model.transformer.text_model.embeddings.token_embedding", None, None),
385 | **dense("cond_stage_model.transformer.text_model.embeddings.position_embedding", "P_bg367", "P_bg368", bias=False),
386 | # cond stage text encoder
387 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj", "P_bg369", "P_bg370", bias=True),
388 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj", "P_bg369", "P_bg370", bias=True),
389 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj", "P_bg369", "P_bg370", bias=True),
390 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj", "P_bg369", "P_bg370", bias=True),
391 | **norm("cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1", "P_bg370"),
392 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1", "P_bg370", "P_bg371", bias=True),
393 | **dense("cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2", "P_bg371", "P_bg372", bias=True),
394 | **norm("cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2", "P_bg372"),
395 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj", "P_bg372", "P_bg373", bias=True),
396 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj", "P_bg372", "P_bg373", bias=True),
397 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj", "P_bg372", "P_bg373", bias=True),
398 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj", "P_bg372", "P_bg373", bias=True),
399 | **norm("cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1", "P_bg373"),
400 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1", "P_bg373", "P_bg374", bias=True),
401 | **dense("cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2", "P_bg374", "P_bg375", bias=True),
402 | **norm("cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2", "P_bg375"),
403 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj", "P_bg375", "P_bg376", bias=True),
404 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj", "P_bg375", "P_bg376", bias=True),
405 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj", "P_bg375", "P_bg376", bias=True),
406 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj", "P_bg375", "P_bg376", bias=True),
407 | **norm("cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1", "P_bg376"),
408 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1", "P_bg376", "P_bg377", bias=True),
409 | **dense("cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2", "P_bg377", "P_bg378", bias=True),
410 | **norm("cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2", "P_bg378"),
411 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj", "P_bg378", "P_bg379", bias=True),
412 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj", "P_bg378", "P_bg379", bias=True),
413 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj", "P_bg378", "P_bg379", bias=True),
414 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj", "P_bg378", "P_bg379", bias=True),
415 | **norm("cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1", "P_bg379"),
416 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1", "P_bg379", "P_bg380", bias=True),
417 | **dense("cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2", "P_bg380", "P_b381", bias=True),
418 | **norm("cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2", "P_bg381"),
419 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj", "P_bg381", "P_bg382", bias=True),
420 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj", "P_bg381", "P_bg382", bias=True),
421 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj", "P_bg381", "P_bg382", bias=True),
422 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj", "P_bg381", "P_bg382", bias=True),
423 | **norm("cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1", "P_bg382"),
424 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1", "P_bg382", "P_bg383", bias=True),
425 | **dense("cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2", "P_bg383", "P_bg384", bias=True),
426 | **norm("cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2", "P_bg384"),
427 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj", "P_bg384", "P_bg385", bias=True),
428 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj", "P_bg384", "P_bg385", bias=True),
429 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj", "P_bg384", "P_bg385", bias=True),
430 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj", "P_bg384", "P_bg385", bias=True),
431 | **norm("cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1", "P_bg385"),
432 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1", "P_bg385", "P_bg386", bias=True),
433 | **dense("cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2", "P_bg386", "P_bg387", bias=True),
434 | **norm("cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2", "P_bg387"),
435 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj", "P_bg387", "P_bg388", bias=True),
436 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj", "P_bg387", "P_bg388", bias=True),
437 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj", "P_bg387", "P_bg388", bias=True),
438 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj", "P_bg387", "P_bg388", bias=True),
439 | **norm("cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1", "P_bg389"),
440 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1", "P_bg389", "P_bg390", bias=True),
441 | **dense("cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2", "P_bg390", "P_bg391", bias=True),
442 | **norm("cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2", "P_bg391"),
443 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj", "P_bg391", "P_bg392", bias=True),
444 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj", "P_bg391", "P_bg392", bias=True),
445 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj", "P_bg391", "P_bg392", bias=True),
446 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj", "P_bg391", "P_bg392", bias=True),
447 | **norm("cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1", "P_bg392"),
448 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1", "P_bg392", "P_bg393", bias=True),
449 | **dense("cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2", "P_bg393", "P_bg394", bias=True),
450 | **norm("cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2", "P_bg394"),
451 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj", "P_bg394", "P_bg395", bias=True),
452 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj", "P_bg394", "P_bg395", bias=True),
453 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj", "P_bg394", "P_bg395", bias=True),
454 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj", "P_bg394", "P_bg395", bias=True),
455 | **norm("cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1", "P_bg395"),
456 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1", "P_bg395", "P_bg396", bias=True),
457 | **dense("cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2", "P_bg396", "P_bg397", bias=True),
458 | **norm("cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2", "P_bg397"),
459 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj", "P_bg397", "P_bg398", bias=True),
460 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj", "P_bg397", "P_bg398", bias=True),
461 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj", "P_bg397", "P_bg398", bias=True),
462 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj", "P_bg397", "P_bg398", bias=True),
463 | **norm("cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1", "P_bg398"),
464 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1", "P_bg398", "P_bg399", bias=True),
465 | **dense("cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2", "P_bg400", "P_bg401", bias=True),
466 | **norm("cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2", "P_bg401"),
467 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj", "P_bg401", "P_bg402", bias=True),
468 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj", "P_bg401", "P_bg402", bias=True),
469 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj", "P_bg401", "P_bg402", bias=True),
470 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj", "P_bg401", "P_bg402", bias=True),
471 | **norm("cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1", "P_bg402"),
472 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1", "P_bg402", "P_bg403", bias=True),
473 | **dense("cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2", "P_bg403", "P_bg404", bias=True),
474 | **norm("cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2", "P_bg404"),
475 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj", "P_bg404", "P_bg405", bias=True),
476 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj", "P_bg404", "P_bg405", bias=True),
477 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj", "P_bg404", "P_bg405", bias=True),
478 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj", "P_bg404", "P_bg405", bias=True),
479 | **norm("cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1", "P_bg405"),
480 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1", "P_bg405", "P_bg406", bias=True),
481 | **dense("cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2", "P_bg406", "P_bg407", bias=True),
482 | **norm("cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2", "P_bg407"),
483 | **norm("cond_stage_model.transformer.text_model.final_layer_norm", "P_bg407"),
484 | }
485 | )
486 |
--------------------------------------------------------------------------------
/merge_methods.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Tuple
3 |
4 | import torch
5 | from torch import Tensor
6 |
7 | __all__ = [
8 | "weighted_sum",
9 | "weighted_subtraction",
10 | "tensor_sum",
11 | "add_difference",
12 | "train_difference",
13 | "sum_twice",
14 | "triple_sum",
15 | "euclidean_add_difference",
16 | "multiply_difference",
17 | "top_k_tensor_sum",
18 | "similarity_add_difference",
19 | "distribution_crossover",
20 | "ties_add_difference",
21 | ]
22 |
23 |
24 | EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero
25 |
26 |
27 | def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
28 | """
29 | Basic Merge:
30 | alpha 0 returns Primary Model
31 | alpha 1 returns Secondary Model
32 | """
33 | return (1 - alpha) * a + alpha * b
34 |
35 |
36 | def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
37 | """
38 | The inverse of a Weighted Sum Merge
39 | Returns Primary Model when alpha*beta = 0
40 | High values of alpha*beta are likely to break the merged model
41 | """
42 | # Adjust beta if both alpha and beta are 1.0 to avoid division by zero
43 | if alpha == 1.0 and beta == 1.0:
44 | beta -= EPSILON
45 |
46 | return (a - alpha * beta * b) / (1 - alpha * beta)
47 |
48 |
49 | def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
50 | """
51 | Takes a slice of Secondary Model and pastes it into Primary Model
52 | Alpha sets the width of the slice
53 | Beta sets the start point of the slice
54 | ie Alpha = 0.5 Beta = 0.25 is (ABBA) Alpha = 0.25 Beta = 0 is (BAAA)
55 | """
56 | if alpha + beta <= 1:
57 | tt = a.clone()
58 | talphas = int(a.shape[0] * beta)
59 | talphae = int(a.shape[0] * (alpha + beta))
60 | tt[talphas:talphae] = b[talphas:talphae].clone()
61 | else:
62 | talphas = int(a.shape[0] * (alpha + beta - 1))
63 | talphae = int(a.shape[0] * beta)
64 | tt = b.clone()
65 | tt[talphas:talphae] = a[talphas:talphae].clone()
66 | return tt
67 |
68 |
69 | def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
70 | """
71 | Classic Add Difference Merge
72 | """
73 | return a + alpha * (b - c)
74 |
75 | def train_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs): # pylint: disable=unused-argument
76 | # Based on: https://github.com/hako-mikan/sd-webui-supermerger/blob/843ca282948dbd3fac1246fcb1b66544a371778b/scripts/mergers/mergers.py#L673
77 |
78 | # Calculate the difference between b and c
79 | diff_BC = b - c
80 |
81 | # Early exit if there's no difference
82 | if torch.all(diff_BC == 0):
83 | return a
84 |
85 | # Calculate distances
86 | distance_BC = torch.abs(diff_BC)
87 | distance_BA = torch.abs(b - a)
88 |
89 | # Sum of distances
90 | sum_distances = distance_BC + distance_BA
91 |
92 | # Calculate scale, avoiding division by zero
93 | scale = torch.where(sum_distances != 0, distance_BA / sum_distances, torch.tensor(0.))
94 |
95 | # Adjust scale sign based on the difference between b and c
96 | sign_scale = torch.sign(diff_BC)
97 | scale = sign_scale * torch.abs(scale)
98 |
99 | # Calculate new difference
100 | new_diff = scale * distance_BC
101 |
102 | # Return updated a
103 | return a + (new_diff * (alpha * 1.8))
104 |
105 |
106 | def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
107 | """
108 | Stacked Basic Merge:
109 | Equivalent to Merging Primary and Secondary @ alpha
110 | Then merging the result with Tertiary @ beta
111 | """
112 | return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c
113 |
114 |
115 | def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
116 | """
117 | Weights Secondary and Tertiary at alpha and beta respectively
118 | Fills in the rest with Primary
119 | Expect odd results if alpha + beta > 1 as Primary will be merged with a negative ratio
120 | """
121 | return (1 - alpha - beta) * a + alpha * b + beta * c
122 |
123 |
124 | def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
125 | """
126 | Subtract Primary and Secondary from Tertiary
127 | Compare the remainders via Euclidean distance
128 | Add to Tertiary
129 | Note: Slow
130 | """
131 | a_diff = a.float() - c.float()
132 | b_diff = b.float() - c.float()
133 | a_diff = torch.nan_to_num(a_diff / torch.linalg.norm(a_diff))
134 | b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff))
135 |
136 | distance = (1 - alpha) * a_diff**2 + alpha * b_diff**2
137 | distance = torch.sqrt(distance)
138 | sum_diff = weighted_sum(a.float(), b.float(), alpha) - c.float()
139 | distance = torch.copysign(distance, sum_diff)
140 |
141 | target_norm = torch.linalg.norm(sum_diff)
142 | return c + distance / torch.linalg.norm(distance) * target_norm
143 |
144 |
145 | def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
146 | """
147 | Similar to Add Difference but with geometric mean instead of arithmatic mean
148 | """
149 | diff_a = torch.pow(torch.abs(a.float() - c), (1 - alpha))
150 | diff_b = torch.pow(torch.abs(b.float() - c), alpha)
151 | difference = torch.copysign(diff_a * diff_b, weighted_sum(a, b, beta) - c)
152 | return c + difference.to(c.dtype)
153 |
154 |
155 | def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
156 | """
157 | Redistributes the largest weights of Secondary Model into Primary Model
158 | """
159 | a_flat = torch.flatten(a)
160 | a_dist = torch.msort(a_flat)
161 | b_indices = torch.argsort(torch.flatten(b), stable=True)
162 | redist_indices = torch.argsort(b_indices)
163 |
164 | start_i, end_i, region_is_inverted = ratio_to_region(alpha, beta, torch.numel(a))
165 | start_top_k = kth_abs_value(a_dist, start_i)
166 | end_top_k = kth_abs_value(a_dist, end_i)
167 |
168 | indices_mask = (start_top_k < torch.abs(a_dist)) & (torch.abs(a_dist) <= end_top_k)
169 | if region_is_inverted:
170 | indices_mask = ~indices_mask
171 | indices_mask = torch.gather(indices_mask.float(), 0, redist_indices)
172 |
173 | a_redist = torch.gather(a_dist, 0, redist_indices)
174 | a_redist = (1 - indices_mask) * a_flat + indices_mask * a_redist
175 | return a_redist.reshape_as(a)
176 |
177 |
178 | def kth_abs_value(a: Tensor, k: int) -> Tensor:
179 | if k <= 0:
180 | return torch.tensor(-1, device=a.device)
181 | else:
182 | return torch.kthvalue(torch.abs(a.float()), k)[0]
183 |
184 |
185 | def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool]:
186 | if width < 0:
187 | offset += width
188 | width = -width
189 | width = min(width, 1)
190 |
191 | if offset < 0:
192 | offset = 1 + offset - int(offset)
193 | offset = math.fmod(offset, 1.0)
194 |
195 | if width + offset <= 1:
196 | inverted = False
197 | start = offset * n
198 | end = (width + offset) * n
199 | else:
200 | inverted = True
201 | start = (width + offset - 1) * n
202 | end = offset * n
203 |
204 | return round(start), round(end), inverted
205 |
206 |
207 | def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
208 | """
209 | Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar
210 | """
211 | threshold = torch.maximum(torch.abs(a), torch.abs(b))
212 | similarity = ((a * b / threshold**2) + 1) / 2
213 | similarity = torch.nan_to_num(similarity * beta, nan=beta)
214 |
215 | ab_diff = a + alpha * (b - c)
216 | ab_sum = (1 - alpha / 2) * a + (alpha / 2) * b
217 | return (1 - similarity) * ab_diff + similarity * ab_sum
218 |
219 |
220 | def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument
221 | """
222 | From the creator:
223 | It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of
224 | Primary and Secondary when ordered with respect to Tertiary. Split the frequency domain
225 | using a linear function. Alpha is the split frequency and Beta is the inclination of the line.
226 | add everything under the line as the contribution of Primary and everything over the line as the contribution of Secondary
227 | """
228 | if a.shape == ():
229 | return alpha * a + (1 - alpha) * b
230 |
231 | c_indices = torch.argsort(torch.flatten(c))
232 | a_dist = torch.gather(torch.flatten(a), 0, c_indices)
233 | b_dist = torch.gather(torch.flatten(b), 0, c_indices)
234 |
235 | a_dft = torch.fft.rfft(a_dist.float())
236 | b_dft = torch.fft.rfft(b_dist.float())
237 |
238 | dft_filter = torch.arange(0, torch.numel(a_dft), device=a_dft.device).float()
239 | dft_filter /= torch.numel(a_dft)
240 | if beta > EPSILON:
241 | dft_filter = (dft_filter - alpha) / beta + 1 / 2
242 | dft_filter = torch.clamp(dft_filter, 0.0, 1.0)
243 | else:
244 | dft_filter = (dft_filter >= alpha).float()
245 |
246 | x_dft = (1 - dft_filter) * a_dft + dft_filter * b_dft
247 | x_dist = torch.fft.irfft(x_dft, a_dist.shape[0])
248 | x_values = torch.gather(x_dist, 0, torch.argsort(c_indices))
249 | return x_values.reshape_as(a)
250 |
251 |
252 | def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
253 | """
254 | An implementation of arXiv:2306.01708
255 | """
256 | deltas = []
257 | signs = []
258 | for m in [a, b]:
259 | deltas.append(filter_top_k(m - c, beta))
260 | signs.append(torch.sign(deltas[-1]))
261 |
262 | signs = torch.stack(signs, dim=0)
263 | final_sign = torch.sign(torch.sum(signs, dim=0))
264 | delta_filters = (signs == final_sign).float()
265 |
266 | res = torch.zeros_like(c, device=c.device)
267 | for delta_filter, delta in zip(delta_filters, deltas):
268 | res += delta_filter * delta
269 |
270 | param_count = torch.sum(delta_filters, dim=0)
271 | return c + alpha * torch.nan_to_num(res / param_count)
272 |
273 |
274 | def filter_top_k(a: Tensor, k: float):
275 | k = max(int((1 - k) * torch.numel(a)), 1)
276 | k_value, _ = torch.kthvalue(torch.abs(a.flatten()).float(), k)
277 | top_k_filter = (torch.abs(a) >= k_value).float()
278 | return a * top_k_filter
279 |
--------------------------------------------------------------------------------
/merge_presets.py:
--------------------------------------------------------------------------------
1 | BLOCK_WEIGHTS_PRESETS = {
2 | "GRAD_V": [0, 1, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0],
3 | "GRAD_A": [0, 0, 0.0833333333, 0.1666666667, 0.25, 0.3333333333, 0.4166666667, 0.5, 0.5833333333, 0.6666666667, 0.75, 0.8333333333, 0.9166666667, 1.0, 0.9166666667, 0.8333333333, 0.75, 0.6666666667, 0.5833333333, 0.5, 0.4166666667, 0.3333333333, 0.25, 0.1666666667, 0.0833333333, 0],
4 | "FLAT_25": [0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25],
5 | "FLAT_75": [0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75],
6 | "WRAP08": [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
7 | "WRAP12": [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
8 | "WRAP14": [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
9 | "WRAP16": [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
10 | "MID12_50": [0, 0, 0, 0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0],
11 | "OUT07": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
12 | "OUT12": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
13 | "OUT12_5": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
14 | "RING08_SOFT": [0, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0, 0.5, 1, 1, 1, 0.5, 0, 0, 0, 0, 0],
15 | "RING08_5": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
16 | "RING10_5": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
17 | "RING10_3": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
18 | "SMOOTHSTEP": [0, 0, 0.00506365740740741, 0.0196759259259259, 0.04296875, 0.0740740740740741, 0.112123842592593, 0.15625, 0.205584490740741, 0.259259259259259, 0.31640625, 0.376157407407407, 0.437644675925926, 0.5, 0.562355324074074, 0.623842592592592, 0.68359375, 0.740740740740741, 0.794415509259259, 0.84375, 0.887876157407408, 0.925925925925926, 0.95703125, 0.980324074074074, 0.994936342592593, 1],
19 | "REVERSE_SMOOTHSTEP": [0, 1, 0.994936342592593, 0.980324074074074, 0.95703125, 0.925925925925926, 0.887876157407407, 0.84375, 0.794415509259259, 0.740740740740741, 0.68359375, 0.623842592592593, 0.562355324074074, 0.5, 0.437644675925926, 0.376157407407408, 0.31640625, 0.259259259259259, 0.205584490740741, 0.15625, 0.112123842592592, 0.0740740740740742, 0.0429687499999996, 0.0196759259259258, 0.00506365740740744, 0],
20 | "2SMOOTHSTEP": [0, 0, 0.0101273148148148, 0.0393518518518519, 0.0859375, 0.148148148148148, 0.224247685185185, 0.3125, 0.411168981481482, 0.518518518518519, 0.6328125, 0.752314814814815, 0.875289351851852, 1.0, 0.875289351851852, 0.752314814814815, 0.6328125, 0.518518518518519, 0.411168981481481, 0.3125, 0.224247685185184, 0.148148148148148, 0.0859375, 0.0393518518518512, 0.0101273148148153, 0],
21 | "2R_SMOOTHSTEP": [0, 1, 0.989872685185185, 0.960648148148148, 0.9140625, 0.851851851851852, 0.775752314814815, 0.6875, 0.588831018518519, 0.481481481481481, 0.3671875, 0.247685185185185, 0.124710648148148, 0.0, 0.124710648148148, 0.247685185185185, 0.3671875, 0.481481481481481, 0.588831018518519, 0.6875, 0.775752314814816, 0.851851851851852, 0.9140625, 0.960648148148149, 0.989872685185185, 1],
22 | "3SMOOTHSTEP": [0, 0, 0.0151909722222222, 0.0590277777777778, 0.12890625, 0.222222222222222, 0.336371527777778, 0.46875, 0.616753472222222, 0.777777777777778, 0.94921875, 0.871527777777778, 0.687065972222222, 0.5, 0.312934027777778, 0.128472222222222, 0.0507812500000004, 0.222222222222222, 0.383246527777778, 0.53125, 0.663628472222223, 0.777777777777778, 0.87109375, 0.940972222222222, 0.984809027777777, 1],
23 | "3R_SMOOTHSTEP": [0, 1, 0.984809027777778, 0.940972222222222, 0.87109375, 0.777777777777778, 0.663628472222222, 0.53125, 0.383246527777778, 0.222222222222222, 0.05078125, 0.128472222222222, 0.312934027777778, 0.5, 0.687065972222222, 0.871527777777778, 0.94921875, 0.777777777777778, 0.616753472222222, 0.46875, 0.336371527777777, 0.222222222222222, 0.12890625, 0.0590277777777777, 0.0151909722222232, 0],
24 | "4SMOOTHSTEP": [0, 0, 0.0202546296296296, 0.0787037037037037, 0.171875, 0.296296296296296, 0.44849537037037, 0.625, 0.822337962962963, 0.962962962962963, 0.734375, 0.49537037037037, 0.249421296296296, 0.0, 0.249421296296296, 0.495370370370371, 0.734375000000001, 0.962962962962963, 0.822337962962962, 0.625, 0.448495370370369, 0.296296296296297, 0.171875, 0.0787037037037024, 0.0202546296296307, 0],
25 | "4R_SMOOTHSTEP": [0, 1, 0.97974537037037, 0.921296296296296, 0.828125, 0.703703703703704, 0.55150462962963, 0.375, 0.177662037037037, 0.0370370370370372, 0.265625, 0.50462962962963, 0.750578703703704, 1.0, 0.750578703703704, 0.504629629629629, 0.265624999999999, 0.0370370370370372, 0.177662037037038, 0.375, 0.551504629629631, 0.703703703703703, 0.828125, 0.921296296296298, 0.979745370370369, 1],
26 | "HALF_SMOOTHSTEP": [0, 0, 0.0196759259259259, 0.0740740740740741, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1.0, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740741, 0.0196759259259259, 0],
27 | "HALF_R_SMOOTHSTEP": [0, 1, 0.980324074074074, 0.925925925925926, 0.84375, 0.740740740740741, 0.623842592592593, 0.5, 0.376157407407407, 0.259259259259259, 0.15625, 0.0740740740740742, 0.0196759259259256, 0.0, 0.0196759259259256, 0.0740740740740742, 0.15625, 0.259259259259259, 0.376157407407407, 0.5, 0.623842592592593, 0.740740740740741, 0.84375, 0.925925925925926, 0.980324074074074, 1],
28 | "ONE_THIRD_SMOOTHSTEP": [0, 0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1],
29 | "ONE_THIRD_R_SMOOTHSTEP": [0, 1, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0.0, 0.04296875, 0.15625, 0.31640625, 0.5, 0.68359375, 0.84375, 0.95703125, 1.0, 0.95703125, 0.84375, 0.68359375, 0.5, 0.31640625, 0.15625, 0.04296875, 0],
30 | "ONE_FOURTH_SMOOTHSTEP": [0, 0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0.0, 0.0740740740740741, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740741, 0],
31 | "ONE_FOURTH_R_SMOOTHSTEP": [0, 1, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1.0, 0.925925925925926, 0.740740740740741, 0.5, 0.259259259259259, 0.0740740740740742, 0.0, 0.0740740740740742, 0.259259259259259, 0.5, 0.740740740740741, 0.925925925925926, 1],
32 | "COSINE": [0, 1, 0.995722430686905, 0.982962913144534, 0.961939766255643, 0.933012701892219, 0.896676670145617, 0.853553390593274, 0.80438071450436, 0.75, 0.691341716182545, 0.62940952255126, 0.565263096110026, 0.5, 0.434736903889974, 0.37059047744874, 0.308658283817455, 0.25, 0.195619285495639, 0.146446609406726, 0.103323329854382, 0.0669872981077805, 0.0380602337443566, 0.0170370868554658, 0.00427756931309475, 0],
33 | "REVERSE_COSINE": [0, 0, 0.00427756931309475, 0.0170370868554659, 0.0380602337443566, 0.0669872981077808, 0.103323329854383, 0.146446609406726, 0.19561928549564, 0.25, 0.308658283817455, 0.37059047744874, 0.434736903889974, 0.5, 0.565263096110026, 0.62940952255126, 0.691341716182545, 0.75, 0.804380714504361, 0.853553390593274, 0.896676670145618, 0.933012701892219, 0.961939766255643, 0.982962913144534, 0.995722430686905, 1],
34 | "CUBIC_HERMITE": [0, 0, 0.157576195987654, 0.28491512345679, 0.384765625, 0.459876543209877, 0.512996720679012, 0.546875, 0.564260223765432, 0.567901234567901, 0.560546875, 0.544945987654321, 0.523847415123457, 0.5, 0.476152584876543, 0.455054012345679, 0.439453125, 0.432098765432099, 0.435739776234568, 0.453125, 0.487003279320987, 0.540123456790124, 0.615234375, 0.71508487654321, 0.842423804012347, 1],
35 | "REVERSE_CUBIC_HERMITE": [0, 1, 0.842423804012346, 0.71508487654321, 0.615234375, 0.540123456790123, 0.487003279320988, 0.453125, 0.435739776234568, 0.432098765432099, 0.439453125, 0.455054012345679, 0.476152584876543, 0.5, 0.523847415123457, 0.544945987654321, 0.560546875, 0.567901234567901, 0.564260223765432, 0.546875, 0.512996720679013, 0.459876543209876, 0.384765625, 0.28491512345679, 0.157576195987653, 0],
36 | "FAKE_REVERSE_CUBIC_HERMITE": [0, 1, 0.842423804012346, 0.71508487654321, 0.615234375, 0.540123456790123, 0.487003279320988, 0.453125, 0.435739776234568, 0.432098765432099, 0.439453125, 0.455054012345679, 0.476152584876543, 0.5, 0.523847415123457, 0.544945987654321, 0.560546875, 0.567901234567901, 0.564260223765432, 0.546875, 0.512996720679013, 0.459876543209876, 0.384765625, 0.28491512345679, 0.157576195987653, 0],
37 | "LOW_OFFSET_CUBIC_HERMITE": [0, 0, 0.099515938464506, 0.1628809799382715, 0.2123209635416665, 0.249228395061729, 0.274995780285494, 0.291015625, 0.298680434992284, 0.2993827160493825, 0.294514973958333, 0.285469714506173, 0.273639443479938, 0.261513611593364, 0.24938777970679, 0.245727237654321, 0.23763671875, 0.222901234567901, 0.224305796682099, 0.234635416666667, 0.247675106095678, 0.273209876543211, 0.312024739583333, 0.360904706790124, 0.422634789737655, 0.5],
38 | "ALL_A": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
39 | "ALL_B": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
40 | }
41 |
42 |
43 | SDXL_BLOCK_WEIGHTS_PRESETS = {
44 | "SDXL_GRAD_V": [0, 1.0, 0.888889, 0.777778, 0.666667, 0.555556, 0.444444, 0.333333, 0.222222, 0.111111, 0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0],
45 | "SDXL_GRAD_A": [0, 0.0, 0.111111, 0.222222, 0.333333, 0.444444, 0.555556, 0.666667, 0.777778, 0.888889, 1.0, 0.888889, 0.777778, 0.666667, 0.555556, 0.444444, 0.333333, 0.222222, 0.111111, 0.0],
46 | "SDXL_FLAT_25": [0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25],
47 | "SDXL_FLAT_75": [0, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75],
48 | "SDXL_WRAP08": [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
49 | "SDXL_WRAP12": [0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
50 | "SDXL_WRAP14": [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
51 | "SDXL_OUT07": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
52 | "SDXL_SMOOTHSTEP": [0, 0, 0.008916, 0.034294, 0.074074, 0.126200, 0.188615, 0.259259, 0.336077, 0.417010, 0.500000, 0.582990, 0.663923, 0.740741, 0.811385, 0.873800, 0.925926, 0.965706, 0.991084, 1],
53 | "SDXL_REVERSE_SMOOTHSTEP": [0, 1, 0.991084, 0.965706, 0.925926, 0.873800, 0.811385, 0.740741, 0.663923, 0.582990, 0.500000, 0.417010, 0.336077, 0.259259, 0.188615, 0.126200, 0.074074, 0.034294, 0.008916, 0],
54 | "SDXL_HALF_SMOOTHSTEP": [0, 0, 0.034294, 0.126200, 0.259259, 0.417010, 0.582990, 0.740741, 0.873800, 0.965706, 1, 0.965706, 0.873800, 0.740741, 0.582990, 0.417010, 0.259259, 0.126200, 0.034294, 0],
55 | "SDXL_HALF_R_SMOOTHSTEP": [0, 1, 0.965706, 0.873800, 0.740741, 0.582990, 0.417010, 0.259259, 0.126200, 0.034294, 0, 0.034294, 0.126200, 0.259259, 0.417010, 0.582990, 0.740741, 0.873800, 0.965706, 1],
56 | "SDXL_ONE_THIRD_SMOOTHSTEP": [0, 0, 0.074074, 0.259259, 0.500000, 0.740741, 0.925926, 1, 0.907407, 0.592593, 0, 0.592593, 0.907407, 1, 0.925926, 0.740741, 0.500000, 0.259259, 0.074074, 0],
57 | "SDXL_ONE_THIRD_R_SMOOTHSTEP": [0, 1, 0.925926, 0.740741, 0.500000, 0.259259, 0.074074, 0, 0.092593, 0.407407, 1, 0.407407, 0.092593, 0, 0.074074, 0.259259, 0.500000, 0.740741, 0.925926, 1],
58 | "SDXL_COSINE": [0, 1, 0.992404, 0.969846, 0.933013, 0.883022, 0.821394, 0.750000, 0.671010, 0.586824, 0.500000, 0.413176, 0.328990, 0.250000, 0.178606, 0.116978, 0.066987, 0.030154, 0.007596, 0],
59 | "SDXL_REVERSE_COSINE": [0, 0, 0.007596, 0.030154, 0.066987, 0.116978, 0.178606, 0.250000, 0.328990, 0.413176, 0.500000, 0.586824, 0.671010, 0.750000, 0.821394, 0.883022, 0.933013, 0.969846, 0.992404, 1],
60 | "SDXL_CUBIC_HERMITE": [0, 0, 0.268023, 0.461058, 0.588477, 0.659656, 0.683966, 0.670782, 0.629477, 0.569425, 0.500000, 0.430575, 0.370523, 0.329218, 0.316034, 0.340344, 0.411523, 0.538942, 0.731977, 1],
61 | "SDXL_REVERSE_CUBIC_HERMITE": [0, 1, 0.731977, 0.538942, 0.411523, 0.340344, 0.316034, 0.329218, 0.370523, 0.430575, 0.500000, 0.569425, 0.629477, 0.670782, 0.683966, 0.659656, 0.588477, 0.461058, 0.268023, 0],
62 | "SDXL_ALL_A": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
63 | "SDXL_ALL_B": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
64 | }
65 |
--------------------------------------------------------------------------------
/merge_rebasin.py:
--------------------------------------------------------------------------------
1 | # https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
2 | from collections import defaultdict
3 | from random import shuffle
4 | from typing import NamedTuple
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 |
8 |
9 | SPECIAL_KEYS = [
10 | "first_stage_model.decoder.norm_out.weight",
11 | "first_stage_model.decoder.norm_out.bias",
12 | "first_stage_model.encoder.norm_out.weight",
13 | "first_stage_model.encoder.norm_out.bias",
14 | "model.diffusion_model.out.0.weight",
15 | "model.diffusion_model.out.0.bias",
16 | ]
17 |
18 |
19 | class PermutationSpec(NamedTuple):
20 | perm_to_axes: dict
21 | axes_to_perm: dict
22 |
23 |
24 | def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
25 | perm_to_axes = defaultdict(list)
26 | for wk, axis_perms in axes_to_perm.items():
27 | for axis, perm in enumerate(axis_perms):
28 | if perm is not None:
29 | perm_to_axes[perm].append((wk, axis))
30 | return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)
31 |
32 |
33 | def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
34 | """Get parameter `k` from `params`, with the permutations applied."""
35 | try:
36 | w = params[k]
37 | except KeyError:
38 | # If the key is not found in params, return None or handle it as needed
39 | return None
40 |
41 | try:
42 | axes_to_perm = ps.axes_to_perm[k]
43 | except KeyError:
44 | # If the key is not found in axes_to_perm, return the original parameter
45 | return w
46 |
47 | for axis, p in enumerate(axes_to_perm):
48 | # Skip the axis we're trying to permute.
49 | if axis == except_axis:
50 | continue
51 |
52 | # None indicates that there is no permutation relevant to that axis.
53 | if p:
54 | try:
55 | w = torch.index_select(w, axis, perm[p].int())
56 | except KeyError:
57 | # If the permutation key is not found, continue to the next axis
58 | continue
59 |
60 | return w
61 |
62 |
63 | def apply_permutation(ps: PermutationSpec, perm, params):
64 | """Apply a `perm` to `params`."""
65 | return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}
66 |
67 |
68 | def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
69 | for k in model_a:
70 | try:
71 | perm_params = get_permuted_param(
72 | ps, perm, k, model_a
73 | )
74 | model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
75 | except RuntimeError: # dealing with pix2pix and inpainting models
76 | continue
77 | return model_a
78 |
79 |
80 | def inner_matching(
81 | n,
82 | ps,
83 | p,
84 | params_a,
85 | params_b,
86 | usefp16,
87 | progress,
88 | number,
89 | linear_sum,
90 | perm,
91 | device,
92 | ):
93 | A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n))
94 | A = A.to(device)
95 |
96 | for wk, axis in ps.perm_to_axes[p]:
97 | w_a = params_a[wk]
98 | w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
99 | w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
100 | w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)
101 |
102 | if usefp16:
103 | w_a = w_a.half().to(device)
104 | w_b = w_b.half().to(device)
105 |
106 | try:
107 | A += torch.matmul(w_a, w_b)
108 | except RuntimeError:
109 | A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b))
110 |
111 | A = A.cpu()
112 | ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
113 | A = A.to(device)
114 |
115 | assert (torch.tensor(ri) == torch.arange(len(ri))).all()
116 |
117 | eye_tensor = torch.eye(n).to(device)
118 |
119 | oldL = torch.vdot(
120 | torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()])
121 | )
122 | newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :]))
123 |
124 | if usefp16:
125 | oldL = oldL.half()
126 | newL = newL.half()
127 |
128 | if newL - oldL != 0:
129 | linear_sum += abs((newL - oldL).item())
130 | number += 1
131 | print("Merge Rebasin permutation: {p}={newL-oldL}")
132 |
133 | progress = progress or newL > oldL + 1e-12
134 |
135 | perm[p] = torch.Tensor(ci).to(device)
136 |
137 | return linear_sum, number, perm, progress
138 |
139 |
140 | def weight_matching(
141 | ps: PermutationSpec,
142 | params_a,
143 | params_b,
144 | max_iter=1,
145 | init_perm=None,
146 | usefp16=False,
147 | device="cpu",
148 | ):
149 | perm_sizes = {
150 | p: params_a[axes[0][0]].shape[axes[0][1]]
151 | for p, axes in ps.perm_to_axes.items()
152 | if axes[0][0] in params_a.keys()
153 | }
154 | perm = {}
155 | perm = (
156 | {p: torch.arange(n).to(device) for p, n in perm_sizes.items()}
157 | if init_perm is None
158 | else init_perm
159 | )
160 |
161 | linear_sum = 0
162 | number = 0
163 |
164 | special_layers = ["P_bg324"]
165 | for _i in range(max_iter):
166 | progress = False
167 | shuffle(special_layers)
168 | for p in special_layers:
169 | n = perm_sizes[p]
170 | linear_sum, number, perm, progress = inner_matching(
171 | n,
172 | ps,
173 | p,
174 | params_a,
175 | params_b,
176 | usefp16,
177 | progress,
178 | number,
179 | linear_sum,
180 | perm,
181 | device,
182 | )
183 | progress = True
184 | if not progress:
185 | break
186 |
187 | average = linear_sum / number if number > 0 else 0
188 | return perm, average
189 |
--------------------------------------------------------------------------------
/merge_utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import re
3 | from . import merge_methods
4 | from .merge_presets import BLOCK_WEIGHTS_PRESETS, SDXL_BLOCK_WEIGHTS_PRESETS
5 |
6 | ALL_PRESETS = {}
7 | ALL_PRESETS.update(BLOCK_WEIGHTS_PRESETS)
8 | ALL_PRESETS.update(SDXL_BLOCK_WEIGHTS_PRESETS)
9 |
10 | MERGE_METHODS = dict(inspect.getmembers(merge_methods, inspect.isfunction))
11 | BETA_METHODS = [
12 | name
13 | for name, fn in MERGE_METHODS.items()
14 | if "beta" in inspect.getfullargspec(fn)[0]
15 | ]
16 | TRIPLE_METHODS = [
17 | name
18 | for name, fn in MERGE_METHODS.items()
19 | if "c" in inspect.getfullargspec(fn)[0]
20 | ]
21 |
22 |
23 | def interpolate(values, interp_lambda):
24 | interpolated = []
25 | for i in range(len(values[0])):
26 | interpolated.append((1 - interp_lambda) * values[0][i] + interp_lambda * values[1][i])
27 | return interpolated
28 |
29 |
30 | class WeightClass:
31 | def __init__(self,
32 | model_a,
33 | **kwargs,
34 | ):
35 | self.SDXL = "model.diffusion_model.middle_block.1.transformer_blocks.9.norm3.weight" in model_a.keys()
36 | self.NUM_INPUT_BLOCKS = 12 if not self.SDXL else 9
37 | self.NUM_MID_BLOCK = 1
38 | self.NUM_OUTPUT_BLOCKS = 12 if not self.SDXL else 9
39 | self.NUM_TOTAL_BLOCKS = self.NUM_INPUT_BLOCKS + self.NUM_MID_BLOCK + self.NUM_OUTPUT_BLOCKS
40 | self.iterations = kwargs.get("re_basin_iterations", 1)
41 | self.it = 0
42 | self.re_basin = kwargs.get("re_basin", False)
43 | self.ratioDict = {}
44 | for key, value in kwargs.items():
45 | if isinstance(value, list) or (key.lower() not in ["alpha", "beta"]):
46 | self.ratioDict[key.lower()] = value
47 | else:
48 | self.ratioDict[key.lower()] = [value]
49 |
50 | for key, value in self.ratioDict.items():
51 | if key in ["alpha", "beta"]:
52 | for i, v in enumerate(value):
53 | if isinstance(v, str) and v.upper() in BLOCK_WEIGHTS_PRESETS.keys():
54 | value[i] = BLOCK_WEIGHTS_PRESETS[v.upper()]
55 | else:
56 | value[i] = [float(x) for x in v.split(",")] if isinstance(v, str) else v
57 | if not isinstance(value[i], list):
58 | value[i] = [value[i]] * (self.NUM_TOTAL_BLOCKS + 1)
59 | if len(value) > 1 and isinstance(value[0], list):
60 | self.ratioDict[key] = interpolate(value, self.ratioDict.get(key + "_lambda", 0))
61 | else:
62 | self.ratioDict[key] = self.ratioDict[key][0]
63 |
64 | def __call__(self, key, it=0):
65 | current_bases = {}
66 | if "alpha" in self.ratioDict:
67 | current_bases["alpha"] = self.step_weights_and_bases(self.ratioDict["alpha"])
68 | if "beta" in self.ratioDict:
69 | current_bases["beta"] = self.step_weights_and_bases(self.ratioDict["beta"])
70 |
71 | weight_index = 0
72 | if "model" in key:
73 |
74 | if "model.diffusion_model." in key:
75 | weight_index = -1
76 |
77 | re_inp = re.compile(r"\.input_blocks\.(\d+)\.") # 12
78 | re_mid = re.compile(r"\.middle_block\.(\d+)\.") # 1
79 | re_out = re.compile(r"\.output_blocks\.(\d+)\.") # 12
80 |
81 | if "time_embed" in key:
82 | weight_index = 0 # before input blocks
83 | elif ".out." in key:
84 | weight_index = self.NUM_TOTAL_BLOCKS - 1 # after output blocks
85 | elif m := re_inp.search(key):
86 | weight_index = int(m.groups()[0])
87 | elif re_mid.search(key):
88 | weight_index = self.NUM_INPUT_BLOCKS
89 | elif m := re_out.search(key):
90 | weight_index = self.NUM_INPUT_BLOCKS + self.NUM_MID_BLOCK + int(m.groups()[0])
91 |
92 | if weight_index >= self.NUM_TOTAL_BLOCKS:
93 | raise ValueError(f"illegal block index {key}")
94 |
95 | current_bases = {k: w[weight_index] for k, w in current_bases.items()}
96 | return current_bases
97 |
98 | def step_weights_and_bases(self, ratio):
99 | if not self.re_basin:
100 | return ratio
101 |
102 | new_ratio = [
103 | 1 - (1 - (1 + self.it) * v / self.iterations) / (1 - self.it * v / self.iterations)
104 | if self.it > 0
105 | else v / self.iterations
106 | for v in ratio
107 | ]
108 | return new_ratio
109 |
110 | def set_it(self, it):
111 | self.it = it
112 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui-technodes"
3 | description = "ComfyUI nodes for merging, testing and more. SDNext Merge, VAE Merge, MBW Layers, Repeat VAE, Quantization."
4 | version = "1.0.0"
5 | license = {file = "LICENSE"}
6 |
7 | [project.urls]
8 | Repository = "https://github.com/TechnoByteJS/ComfyUI-TechNodes"
9 | # Used by Comfy Registry https://comfyregistry.org
10 |
11 | [tool.comfy]
12 | PublisherId = "technobyte"
13 | DisplayName = "ComfyUI-TechNodes"
14 | Icon = ""
15 |
--------------------------------------------------------------------------------
/quant_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 |
4 | import folder_paths
5 |
6 | import comfy_extras.nodes_model_merging
7 |
8 | def quantize_tensor(tensor, num_bits=8, dtype=torch.float16, dequant=True):
9 | """
10 | Quantizes a tensor to a specified number of bits.
11 |
12 | Args:
13 | tensor (torch.Tensor): The input tensor to be quantized.
14 | num_bits (int): The number of bits to use for quantization (default: 8).
15 | dtype(torch.dtype): The datatype to use for the output (default: torch.float16).
16 | dequant (bool): Whether to dequantize or not (default: true).
17 |
18 | Returns:
19 | torch.Tensor: The quantized tensor.
20 | """
21 | # Determine the minimum and maximum values of the tensor
22 | min_val = tensor.min()
23 | max_val = tensor.max()
24 |
25 | # Calculate the scale factor and zero point
26 | qmin = 0
27 | qmax = 2 ** num_bits - 1
28 | scale = (max_val - min_val) / (qmax - qmin)
29 | zero_point = qmin - torch.round(min_val / scale)
30 |
31 | # Quantize the tensor
32 | quantized_tensor = torch.round(tensor / scale + zero_point)
33 | quantized_tensor = torch.clamp(quantized_tensor, qmin, qmax)
34 |
35 | # Convert the quantized tensor to the datatype
36 | dequantized_tensor = quantized_tensor.to(dtype)
37 |
38 | if dequant:
39 | # De-quantize the tensor
40 | dequantized_tensor = (dequantized_tensor - zero_point) * scale
41 |
42 | return dequantized_tensor
43 |
44 | def quantize_model(model, in_bits, mid_bits, out_bits, dtype=torch.float16, dequant=True):
45 | # Clone the base model to create a new one
46 | quantized_model = model.clone()
47 |
48 | # Get the key patches from the model with the prefix "diffusion_model."
49 | key_patches = quantized_model.get_key_patches("diffusion_model.")
50 |
51 | # Iterate over each key patch in the patches
52 | for key in key_patches:
53 | if ".input_" in key:
54 | num_bits = in_bits
55 | elif ".middle_" in key:
56 | num_bits = mid_bits
57 | elif ".output_" in key:
58 | num_bits = out_bits
59 | else:
60 | num_bits = 8
61 |
62 | quantized_tensor = quantize_tensor(key_patches[key][0], num_bits, dtype, dequant)
63 | quantized_model.add_patches({key: (quantized_tensor,)}, 1, 0)
64 |
65 | # Return the quantized model
66 | return quantized_model
67 |
68 | def quantize_clip(clip, bits, dtype=torch.float16, dequant=True):
69 | # Clone the base model to create a new one
70 | quantized_clip = clip.clone()
71 |
72 | # Get the key patches from the model with the prefix "diffusion_model."
73 | key_patches = quantized_clip.get_key_patches()
74 |
75 | # Iterate over each key patch in the patches
76 | for key in key_patches:
77 | quantized_tensor = quantize_tensor(key_patches[key][0], bits, dtype, dequant)
78 | quantized_clip.add_patches({key: (quantized_tensor,)}, 1, 0)
79 |
80 | # Return the quantized model
81 | return quantized_clip
82 |
83 | def quantize_vae(vae, bits, dtype=torch.float16, dequant=True):
84 | # Create a clone of the VAE model
85 | quantized_vae = copy.deepcopy(vae)
86 |
87 | # Get the state dictionary from the clone
88 | state_dict = quantized_vae.first_stage_model.state_dict()
89 |
90 | # Iterate over each key-value pair in the state dictionary
91 | for key, value in state_dict.items():
92 | state_dict[key] = quantize_tensor(value, bits, dtype, dequant)
93 |
94 | # Load the quantized state dictionary back into the clone
95 | quantized_vae.first_stage_model.load_state_dict(state_dict)
96 |
97 | # Return the quantized clone
98 | return quantized_vae
99 |
100 | class ModelQuant:
101 | @classmethod
102 | def INPUT_TYPES(cls):
103 | return {
104 | "required": {
105 | "model": ["MODEL"],
106 | "in_bits": ("INT", {"default": 8, "min": 1, "max": 8}),
107 | "mid_bits": ("INT", {"default": 8, "min": 1, "max": 8}),
108 | "out_bits": ("INT", {"default": 8, "min": 1, "max": 8}),
109 | }
110 | }
111 |
112 | RETURN_TYPES = ["MODEL"]
113 | FUNCTION = "quant_model"
114 |
115 | CATEGORY = "TechNodes/quantization"
116 |
117 | def quant_model(self, model, in_bits, mid_bits, out_bits):
118 | return [quantize_model(model, in_bits, mid_bits, out_bits)]
119 |
120 |
121 | class ClipQuant:
122 | @classmethod
123 | def INPUT_TYPES(cls):
124 | return {
125 | "required": {
126 | "clip": ["CLIP"],
127 | "bits": ("INT", {"default": 8, "min": 1, "max": 8}),
128 | }
129 | }
130 |
131 | RETURN_TYPES = ["CLIP"]
132 | FUNCTION = "quant_clip"
133 |
134 | CATEGORY = "TechNodes/quantization"
135 |
136 | def quant_clip(self, clip, bits):
137 | return [quantize_clip(clip, bits)]
138 |
139 |
140 | class VAEQuant:
141 | @classmethod
142 | def INPUT_TYPES(cls):
143 | return {
144 | "required": {
145 | "vae": ["VAE"],
146 | "bits": ("INT", {"default": 8, "min": 1, "max": 8}),
147 | }
148 | }
149 |
150 | RETURN_TYPES = ["VAE"]
151 | FUNCTION = "quant_vae"
152 |
153 | CATEGORY = "TechNodes/quantization"
154 |
155 | def quant_vae(self, vae, bits):
156 | return [quantize_vae(vae, bits)]
157 |
--------------------------------------------------------------------------------
/sdnextmerge_nodes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import folder_paths
3 | from typing import Dict, Tuple, List
4 | from collections import OrderedDict
5 | import ast
6 |
7 | import comfy.sd
8 | import comfy.utils
9 | import comfy.model_detection
10 |
11 | from .merge import *
12 |
13 | mbw_presets = ([
14 | "none",
15 | "GRAD_V",
16 | "GRAD_A",
17 | "FLAT_25",
18 | "FLAT_75",
19 | "WRAP08",
20 | "WRAP12",
21 | "WRAP14",
22 | "WRAP16",
23 | "MID12_50",
24 | "OUT07",
25 | "OUT12",
26 | "OUT12_5",
27 | "RING08_SOFT",
28 | "RING08_5",
29 | "RING10_5",
30 | "RING10_3",
31 | "SMOOTHSTEP",
32 | "REVERSE_SMOOTHSTEP",
33 | "2SMOOTHSTEP",
34 | "2R_SMOOTHSTEP",
35 | "3SMOOTHSTEP",
36 | "3R_SMOOTHSTEP",
37 | "4SMOOTHSTEP",
38 | "4R_SMOOTHSTEP",
39 | "HALF_SMOOTHSTEP",
40 | "HALF_R_SMOOTHSTEP",
41 | "ONE_THIRD_SMOOTHSTEP",
42 | "ONE_THIRD_R_SMOOTHSTEP",
43 | "ONE_FOURTH_SMOOTHSTEP",
44 | "ONE_FOURTH_R_SMOOTHSTEP",
45 | "COSINE",
46 | "REVERSE_COSINE",
47 | "CUBIC_HERMITE",
48 | "REVERSE_CUBIC_HERMITE",
49 | "FAKE_REVERSE_CUBIC_HERMITE",
50 | "LOW_OFFSET_CUBIC_HERMITE",
51 | "ALL_A",
52 | "ALL_B",
53 | ], {"default": "none"})
54 |
55 | class SDNextMerge:
56 | @classmethod
57 | def INPUT_TYPES(cls):
58 | return {
59 | "optional": {
60 | "optional_model_a": ["MODEL"],
61 | "optional_clip_a": ["CLIP"],
62 |
63 | "optional_model_b": ["MODEL"],
64 | "optional_clip_b": ["CLIP"],
65 |
66 | "optional_model_c": ["MODEL"],
67 | "optional_clip_c": ["CLIP"],
68 |
69 | "optional_mbw_layers_alpha": ["MBW_LAYERS"],
70 | },
71 | "required": {
72 | "model_a": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}),
73 | "model_b": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}),
74 | "model_c": (["none"] + folder_paths.get_filename_list("checkpoints"), {"multiline": False}),
75 | "merge_mode": ([
76 | "weighted_sum",
77 | "weighted_subtraction",
78 | "tensor_sum",
79 | "add_difference",
80 | "train_difference",
81 | "sum_twice",
82 | "triple_sum",
83 | "euclidean_add_difference",
84 | "multiply_difference",
85 | "top_k_tensor_sum",
86 | "similarity_add_difference",
87 | "distribution_crossover",
88 | "ties_add_difference",
89 | ],),
90 | "precision": (["fp16", "original"],),
91 | "weights_clip": ("BOOLEAN", {"default": True}),
92 | "mem_device": (["cuda", "cpu"],),
93 | "work_device": (["cuda", "cpu"],),
94 | "threads": ("INT", {"default": 4, "min": 1, "max": 24}),
95 | "mbw_preset_alpha": mbw_presets,
96 | "alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
97 | "beta": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}),
98 | "re_basin": ("BOOLEAN", {"default": False}),
99 | "re_basin_iterations": ("INT", {"default": 5, "min": 1, "max": 25})
100 | }
101 | }
102 |
103 | RETURN_TYPES = ["MODEL", "CLIP"]
104 | FUNCTION = "merge"
105 |
106 | CATEGORY = "TechNodes/merging"
107 |
108 | # The main merge function
109 | def merge(self, model_a, model_b, model_c, merge_mode, precision, weights_clip, mem_device, work_device, threads, mbw_preset_alpha, alpha, beta, re_basin, re_basin_iterations, optional_model_a = None, optional_clip_a = None, optional_model_b = None, optional_clip_b = None, optional_model_c = None, optional_clip_c = None, optional_mbw_layers_alpha = None):
110 |
111 | if model_a == "none" and optional_model_a is None:
112 | raise ValueError("Need either model_a or optional_model_a!")
113 |
114 | if model_b == "none" and optional_model_b is None:
115 | raise ValueError("Need either model_b or optional_model_b!")
116 |
117 | if model_a == "none" and optional_clip_a is None:
118 | raise ValueError("Need either model_a or optional_clip_a!")
119 |
120 | if model_b == "none" and optional_clip_b is None:
121 | raise ValueError("Need either model_b or optional_clip_b!")
122 |
123 | models = { }
124 |
125 | if model_a != "none":
126 | if optional_model_a is None or optional_clip_a is None:
127 | models['model_a'] = folder_paths.get_full_path("checkpoints", model_a)
128 |
129 | if model_b != "none":
130 | if optional_model_b is None or optional_clip_b is None:
131 | models['model_b'] = folder_paths.get_full_path("checkpoints", model_b)
132 |
133 | # Add model C if the merge method needs it
134 | if merge_mode in ["add_difference", "train_difference", "sum_twice", "triple_sum", "euclidean_add_difference", "multiply_difference", "similarity_add_difference", "distribution_crossover", "ties_add_difference"]:
135 | if model_c == "none" and optional_model_c is None:
136 | raise ValueError("Need either model_c or optional_model_c!")
137 |
138 | if model_c == "none" and optional_clip_c is None:
139 | raise ValueError("Need either model_c or optional_clip_c!")
140 |
141 | if model_c != "none":
142 | if optional_model_c is None or optional_clip_c is None:
143 | models['model_c'] = folder_paths.get_full_path("checkpoints", model_c)
144 |
145 | # Devices
146 | device = torch.device(mem_device)
147 | work_device = torch.device(work_device)
148 |
149 | # Merge Arguments
150 | kwargs = {
151 | 'alpha': alpha,
152 | 'beta': beta,
153 | 're_basin': re_basin,
154 | 're_basin_iterations': re_basin_iterations
155 | }
156 |
157 | # If a MBW alpha preset is selected replace the alpha with the preset
158 | if mbw_preset_alpha != "none":
159 | kwargs["alpha"] = [ mbw_preset_alpha ]
160 |
161 | # If a MBW alpha preset is selected replace the alpha with the preset
162 | if optional_mbw_layers_alpha is not None:
163 | kwargs["alpha"] = [ optional_mbw_layers_alpha ]
164 |
165 | # Merge the model
166 | merged_model = merge_models(models, merge_mode, precision, weights_clip, device, work_device, True, threads, optional_model_a, optional_clip_a, optional_model_b, optional_clip_b, optional_model_c, optional_clip_c, **kwargs)
167 |
168 | # Get the config and components from the merged model
169 | model_config = comfy.model_detection.model_config_from_unet(merged_model, "model.diffusion_model.")
170 |
171 | # Create UNet
172 | unet = model_config.get_model(merged_model, "model.diffusion_model.", device=device)
173 | unet.load_model_weights(merged_model, "model.diffusion_model.")
174 |
175 | # Create ModelPatcher
176 | model_patcher = comfy.model_patcher.ModelPatcher(
177 | unet,
178 | load_device=comfy.model_management.get_torch_device(),
179 | offload_device=comfy.model_management.unet_offload_device()
180 | )
181 |
182 | # Create CLIP
183 | clip_sd = model_config.process_clip_state_dict(merged_model)
184 | clip = comfy.sd.CLIP(model_config.clip_target(), embedding_directory=None)
185 | clip.load_sd(clip_sd, full_model=True)
186 |
187 | return (model_patcher, clip)
188 |
189 | class SD1_MBWLayers:
190 | @classmethod
191 | def INPUT_TYPES(cls) -> Dict[str, tuple]:
192 | arg_dict = { }
193 |
194 | argument = ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
195 |
196 | for i in range(12):
197 | arg_dict[f"input_blocks.{i}"] = argument
198 |
199 | arg_dict[f"middle_blocks"] = argument
200 |
201 | for i in range(12):
202 | arg_dict[f"output_blocks.{i}"] = argument
203 |
204 | return {"required": arg_dict}
205 |
206 | RETURN_TYPES = ["MBW_LAYERS"]
207 | FUNCTION = "return_layers"
208 | CATEGORY = "TechNodes/merging"
209 |
210 | def return_layers(self, **inputs) -> Dict[str, float]:
211 | return [ list(inputs.values()) ]
212 |
213 | class SD1_MBWLayers_Binary:
214 | @classmethod
215 | def INPUT_TYPES(cls) -> Dict[str, tuple]:
216 | arg_dict = { }
217 |
218 | argument = ("BOOLEAN", {"default": False})
219 |
220 | for i in range(12):
221 | arg_dict[f"input_blocks.{i}"] = argument
222 |
223 | arg_dict[f"middle_blocks"] = argument
224 |
225 | for i in range(12):
226 | arg_dict[f"output_blocks.{i}"] = argument
227 |
228 | return {"required": arg_dict}
229 |
230 | RETURN_TYPES = ["MBW_LAYERS"]
231 | FUNCTION = "return_layers"
232 | CATEGORY = "TechNodes/merging"
233 |
234 | def return_layers(self, **inputs) -> Dict[str, List[int]]:
235 | return [list(int(value) for value in inputs.values())]
236 |
237 | class SDXL_MBWLayers:
238 | @classmethod
239 | def INPUT_TYPES(cls) -> Dict[str, tuple]:
240 | arg_dict = { }
241 |
242 | argument = ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01})
243 |
244 | for i in range(9):
245 | arg_dict[f"input_blocks.{i}"] = argument
246 |
247 | arg_dict[f"middle_blocks"] = argument
248 |
249 | for i in range(9):
250 | arg_dict[f"output_blocks.{i}"] = argument
251 |
252 | return {"required": arg_dict}
253 |
254 | RETURN_TYPES = ["MBW_LAYERS"]
255 | FUNCTION = "return_layers"
256 | CATEGORY = "TechNodes/merging"
257 |
258 | def return_layers(self, **inputs) -> Dict[str, float]:
259 | return [ list(inputs.values()) ]
260 |
261 | class SDXL_MBWLayers_Binary:
262 | @classmethod
263 | def INPUT_TYPES(cls) -> Dict[str, tuple]:
264 | arg_dict = { }
265 |
266 | argument = ("BOOLEAN", {"default": False})
267 |
268 | for i in range(9):
269 | arg_dict[f"input_blocks.{i}"] = argument
270 |
271 | arg_dict[f"middle_blocks"] = argument
272 |
273 | for i in range(9):
274 | arg_dict[f"output_blocks.{i}"] = argument
275 |
276 | return {"required": arg_dict}
277 |
278 | RETURN_TYPES = ["MBW_LAYERS"]
279 | FUNCTION = "return_layers"
280 | CATEGORY = "TechNodes/merging"
281 |
282 | def return_layers(self, **inputs) -> Dict[str, List[int]]:
283 | return [list(int(value) for value in inputs.values())]
284 |
285 | class MBWLayers_String:
286 | @classmethod
287 | def INPUT_TYPES(cls):
288 | return {
289 | "required": {
290 | "mbw_layers": ("STRING", {"multiline": True, "default": "[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]"} )
291 | }
292 | }
293 |
294 | RETURN_TYPES = ["MBW_LAYERS"]
295 | FUNCTION = "return_layers"
296 | CATEGORY = "TechNodes/merging"
297 |
298 | def return_layers(self, mbw_layers):
299 | return [ ast.literal_eval(mbw_layers) ]
300 |
301 | class VAERepeat:
302 | @classmethod
303 | def INPUT_TYPES(s):
304 | return {
305 | "required": {
306 | "images": ["IMAGE"],
307 | "vae": ["VAE"],
308 | "count": ["INT", {"default": 4, "min": 1, "max": 1000000}],
309 | }
310 | }
311 | RETURN_TYPES = ["IMAGE"]
312 | FUNCTION = "recode"
313 |
314 | CATEGORY = "TechNodes/latent"
315 |
316 | def recode(self, vae, images, count):
317 | for x in range(count):
318 | latent = { "samples": vae.encode(images[:,:,:,:3]) }
319 | images = vae.decode(latent["samples"])
320 | return [images]
--------------------------------------------------------------------------------
/vae_merge.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths
3 |
4 | from tqdm import tqdm
5 | import torch
6 | import safetensors.torch
7 |
8 | import comfy.sd
9 | import comfy.utils
10 |
11 | from . import merge_methods
12 |
13 | from torch import nn
14 |
15 | def merge_state_dict(sd_a, sd_b, sd_c, alpha, beta, weights, mode):
16 | def get_alpha(key):
17 | try:
18 | filtered = sorted(
19 | [x for x in weights.keys() if key.startswith(x)], key=len, reverse=True
20 | )
21 | if len(filtered) < 1:
22 | return alpha
23 | return weights[filtered[0]]
24 | except:
25 | return alpha
26 |
27 | ckpt_keys = (
28 | sd_a.keys() & sd_b.keys()
29 | if sd_c is None
30 | else sd_a.keys() & sd_b.keys() & sd_c.keys()
31 | )
32 |
33 | for key in tqdm(ckpt_keys):
34 | current_alpha = get_alpha(key) if weights is not None else alpha
35 |
36 | if mode == "weighted_sum":
37 | sd_a[key] = merge_methods.weighted_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha)
38 | elif mode == "weighted_subtraction":
39 | sd_a[key] = merge_methods.weighted_subtraction(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta)
40 | elif mode == "tensor_sum":
41 | sd_a[key] = merge_methods.tensor_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta)
42 | elif mode == "add_difference":
43 | assert sd_c is not None, "vae_c is undefined"
44 | sd_a[key] = merge_methods.add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha)
45 | elif mode == "sum_twice":
46 | assert sd_c is not None, "vae_c is undefined"
47 | sd_a[key] = merge_methods.sum_twice(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
48 | elif mode == "triple_sum":
49 | assert sd_c is not None, "vae_c is undefined"
50 | sd_a[key] = merge_methods.triple_sum(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
51 | elif mode == "euclidean_add_difference":
52 | assert sd_c is not None, "vae_c is undefined"
53 | sd_a[key] = merge_methods.euclidean_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha)
54 | elif mode == "multiply_difference":
55 | assert sd_c is not None, "vae_c is undefined"
56 | sd_a[key] = merge_methods.multiply_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
57 | elif mode == "top_k_tensor_sum":
58 | sd_a[key] = merge_methods.top_k_tensor_sum(a = sd_a[key], b = sd_b[key], alpha = current_alpha, beta=beta)
59 | elif mode == "similarity_add_difference":
60 | assert sd_c is not None, "vae_c is undefined"
61 | sd_a[key] = merge_methods.similarity_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
62 | elif mode == "distribution_crossover":
63 | assert sd_c is not None, "vae_c is undefined"
64 | sd_a[key] = merge_methods.distribution_crossover(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
65 | elif mode == "ties_add_difference":
66 | assert sd_c is not None, "vae_c is undefined"
67 | sd_a[key] = merge_methods.ties_add_difference(a = sd_a[key], b = sd_b[key], c = sd_c[key], alpha = current_alpha, beta = beta)
68 |
69 | return sd_a
70 |
71 | class VAEMerge:
72 | @classmethod
73 | def INPUT_TYPES(cls):
74 | return {
75 | "required": {
76 | "vae_a": ("VAE",),
77 | "vae_b": ("VAE",),
78 | "merge_mode": ([
79 | "weighted_sum",
80 | "weighted_subtraction",
81 | "tensor_sum",
82 | "add_difference",
83 | "sum_twice",
84 | "triple_sum",
85 | "euclidean_add_difference",
86 | "multiply_difference",
87 | "top_k_tensor_sum",
88 | "similarity_add_difference",
89 | "distribution_crossover",
90 | "ties_add_difference",
91 | ],),
92 | "alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
93 | "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
94 | "brightness": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
95 | "contrast": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.01}),
96 | "use_blocks": ("BOOLEAN", {"default": False}),
97 | "block_conv_out": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
98 | "block_norm_out": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
99 | "block_0": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
100 | "block_1": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
101 | "block_2": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
102 | "block_3": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
103 | "block_mid": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
104 | "block_conv_in": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
105 | "block_quant_conv": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
106 | },
107 | "optional": {
108 | "vae_c": ("VAE",),
109 | }
110 | }
111 |
112 | RETURN_TYPES = ["VAE"]
113 | FUNCTION = "merge_vae"
114 |
115 | CATEGORY = "TechNodes/merging"
116 |
117 | def merge_vae(self, vae_a, vae_b, merge_mode, alpha, beta, brightness, contrast, use_blocks, block_conv_out, block_norm_out, block_0, block_1, block_2, block_3, block_mid, block_conv_in, block_quant_conv, vae_c=None):
118 | vae_a_model = vae_a.first_stage_model.state_dict()
119 | vae_b_model = vae_b.first_stage_model.state_dict()
120 | vae_c_model = None
121 | if merge_mode in ["add_difference", "sum_twice", "triple_sum", "euclidean_add_difference", "multiply_difference", "similarity_add_difference", "distribution_crossover", "ties_add_difference"]:
122 | vae_c_model = vae_c.first_stage_model.state_dict()
123 |
124 | weights = {
125 | 'encoder.conv_out': block_conv_out,
126 | 'encoder.norm_out': block_norm_out,
127 | 'encoder.down.0': block_0,
128 | 'encoder.down.1': block_1,
129 | 'encoder.down.2': block_2,
130 | 'encoder.down.3': block_3,
131 | 'encoder.mid': block_mid,
132 | 'encoder.conv_in': block_conv_in,
133 | 'quant_conv': block_quant_conv,
134 | 'decoder.conv_out': block_conv_out,
135 | 'decoder.norm_out': block_norm_out,
136 | 'decoder.up.0': block_0,
137 | 'decoder.up.1': block_1,
138 | 'decoder.up.2': block_2,
139 | 'decoder.up.3': block_3,
140 | 'decoder.mid': block_mid,
141 | 'decoder.conv_in': block_conv_in,
142 | 'post_quant_conv': block_quant_conv
143 | }
144 |
145 | if(not use_blocks):
146 | weights = {}
147 |
148 | merged_vae = merge_state_dict(vae_a_model, vae_b_model, vae_c_model, alpha, beta, weights, mode=merge_mode)
149 |
150 | merged_vae["decoder.conv_out.bias"] = nn.Parameter(merged_vae["decoder.conv_out.bias"] + brightness)
151 |
152 | merged_vae["decoder.conv_out.weight"] = nn.Parameter(merged_vae["decoder.conv_out.weight"] + contrast / 40)
153 |
154 | comfy_vae = comfy.sd.VAE(merged_vae)
155 |
156 | return (comfy_vae,)
--------------------------------------------------------------------------------