├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets ├── adapter.png ├── basic_conversion.png ├── checkpoint_converter.png ├── conversion+lcm_lora.png ├── conversion+lora.png ├── lcm+controlnet.png ├── lcm_converter.png ├── loader+lcm_lora.png ├── lora_loader.png ├── sampler.png ├── sdxl_conversion.png ├── snake.png ├── unet+sampler+checkpoint.png ├── unet+sampler+clip+vae.png ├── unet+sampler+controlnet.png └── unet_loader.png ├── coreml_suite ├── __init__.py ├── config.py ├── controlnet.py ├── converter.py ├── latents.py ├── lcm │ ├── __init__.py │ ├── converter.py │ ├── nodes.py │ ├── unet.py │ └── utils.py ├── logger.py ├── models.py └── nodes.py ├── pyproject.toml ├── requirements.txt └── tests ├── integration ├── __init__.py ├── test_basic_conversion_1_5.py └── workflows │ └── e2e-1.5-basic-conversion.json └── unit ├── __init__.py ├── test_chunks.py └── test_controlnet.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'aszc-dev' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | playground/ 2 | experiments/ 3 | __pycache__/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Core ML Suite for ComfyUI 2 | 3 | ## Overview 4 | 5 | Welcome! In this repository you'll find a set of custom nodes for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) 6 | that allows you to use Core ML models in your ComfyUI workflows. 7 | These models are designed to leverage the Apple Neural Engine (ANE) on Apple Silicon (M1/M2) machines, 8 | thereby enhancing your workflows and improving performance. 9 | 10 | If you're not sure how to obtain these models, you can download them 11 | [here](https://huggingface.co/coreml-community) or convert your own models using 12 | [coremltools](https://github.com/apple/ml-stable-diffusion). 13 | 14 | In simple terms, think of Core ML models as a tool that can help your ComfyUI work faster and more efficiently. 15 | For instance, during my tests on an M2 Pro 32GB machine, 16 | the use of Core ML models sped up the generation of 512x512 images by a factor 17 | of approximately 1.5 to 2 times. 18 | 19 | ## Getting Started 20 | 21 | To start using custom nodes in your ComfyUI, follow these simple steps: 22 | 23 | 1. Clone or download this repository: You can do this directly into the custom_nodes directory of your ComfyUI. 24 | 2. Install the dependencies: You'll need to use a package manager like pip to do this. 25 | 26 | That's it! You're now ready to start enhancing your ComfyUI workflows with Core ML models. 27 | 28 | - Check [Installation](#installation) for more details on installation. 29 | - Check [How to use](#how-to-use) for more details on how to use the custom nodes. 30 | - Check [Example Workflows](#example-workflows) for some example workflows. 31 | 32 | ## Glossary 33 | 34 | - **Core ML**: A machine learning framework developed by Apple. It's used to run machine learning models on Apple 35 | devices. 36 | - **Core ML Model**: A machine learning model that can be run on Apple devices using Core ML. 37 | - **mlmodelc**: A compiled Core ML model. This is the recommended format for Core ML models. 38 | - **mlpackage**: A Core ML model packaged in a directory. This is the default format for Core ML models. 39 | - **ANE**: Apple Neural Engine. A hardware accelerator for machine learning tasks on Apple devices. 40 | - **Compute Unit**: A Core ML option that allows you to specify the hardware on which the model should run. 41 | - **CPU_AND_ANE**: A Core ML compute unit option that allows the model to run on both the CPU and ANE. This is the 42 | default option. 43 | - **CPU_AND_GPU**: A Core ML compute unit option that allows the model to run on both the CPU and GPU. 44 | - **CPU_ONLY**: A Core ML compute unit option that allows the model to run on the CPU only. 45 | - **ALL**: A Core ML compute unit option that allows the model to run on all available hardware. 46 | - **CLIP**: Contrastive Language-Image Pre-training. A model that learns visual concepts from natural language 47 | supervision. It's used as a text encoder in Stable Diffusion. 48 | - **VAE**: Variational Autoencoder. A model that learns a latent representation of images. It's used as a prior in 49 | Stable Diffusion. 50 | - **Checkpoint**: A file that contains the weights of a model. It's used to load models in Stable Diffusion. 51 | - **LCM**: [Latent Consistency Model](https://latent-consistency-models.github.io/). A type of model designed to 52 | generate images with as few steps as possible. 53 | 54 | > [!NOTE] 55 | > Note on Compute Units: 56 | > For the model to run on the ANE, the model must be converted with the `--attention-implementation SPLIT_EINSUM` 57 | > option. 58 | > Models converted with `--attention-implementation ORIGINAL` will run on GPU instead of ANE. 59 | 60 | ## Features 61 | 62 | These custom nodes come with a host of features, including: 63 | 64 | - Loading Core ML Unet models 65 | - Support for ControlNet 66 | - Support for ANE (Apple Neural Engine) 67 | - Support for CPU and GPU 68 | - Support for `mlmodelc` and `mlpackage` files 69 | - Support for SDXL models 70 | - Support for LCM models 71 | - Support for LoRAs 72 | - SD1.5 -> Core ML conversion 73 | - SDXL -> Core ML conversion 74 | - LCM -> Core ML conversion 75 | 76 | > [!NOTE] 77 | > Please note that using Core ML models can take a bit longer to load initially. 78 | > For the best experience, I recommend using the compiled models 79 | > (.mlmodelc files) instead of the .mlpackage files. 80 | 81 | > [!NOTE] 82 | > This repository will continue to be updated with more nodes and features over time. 83 | 84 | ## Installation 85 | 86 | ### Using ComfyUI-Manager 87 | 88 | The easiest way to install the custom nodes is to use the ComfyUI-Manager. You can find the installation instructions 89 | [here](https://github.com/ltdrdata/ComfyUI-Manager#installation). Once you've installed the ComfyUI-Manager, you can 90 | install the custom nodes by following these steps: 91 | 92 | - Open the ComfyUI-Manager by clicking the `Manager` button in the ComfyUI toolbar. 93 | - Click the `Install Custom Nodes` button. 94 | - Search for `Core ML` and click the `Install` button. 95 | - Restart ComfyUI. 96 | 97 | ### Manual Installation 98 | 99 | 1. Clone this repository into the custom_nodes directory of your ComfyUI. If you're not sure how to do this, you can 100 | download the repository as a zip file and extract it into the same directory. 101 | ```bash 102 | cd /path/to/comfyui/custom_nodes 103 | git clone https://github.com/aszc-dev/ComfyUI-CoreMLSuite.git 104 | ``` 105 | 2. Next, install the required dependencies using pip or another package manager: 106 | 107 | ```bash 108 | cd /path/to/comfyui/custom_nodes/ComfyUI-CoreMLSuite 109 | pip install -r requirements.txt 110 | ``` 111 | 112 | ## How to use 113 | 114 | Once you've installed the custom nodes, you can start using them in your ComfyUI workflows. 115 | To do this, you need to add the nodes to your workflow. You can do this by right-clicking on the workflow canvas and 116 | selecting the nodes from the list of available nodes (the nodes are in the `Core ML Suite` category). 117 | You can also double-click the canvas and use the search bar to find the nodes. The list of available nodes is given 118 | below. 119 | 120 | ### Available Nodes 121 | 122 | #### Core ML UNet Loader (`CoreMLUnetLoader`) 123 | 124 | ![CoreMLUnetLoader](./assets/unet_loader.png?raw=true) 125 | 126 | This node allows you to load a Core ML UNet model and use it in your ComfyUI workflow. Place the converted 127 | .mlpackage or .mlmodelc file in ComfyUI's `models/unet` directory and use the node to load the model. The output of the 128 | node is a `coreml_model` object that can be used with the Core ML Sampler. 129 | 130 | - **Inputs**: 131 | - **model_name**: The name of the model to load. This should be the name of the .mlpackage or .mlmodelc file. 132 | - **compute_unit**: The hardware on which the model should run. This can be one of the following: 133 | - `CPU_AND_ANE`: The model will run on both the CPU and ANE. This is the default option. It works best with 134 | models 135 | converted with `--attention-implementation SPLIT_EINSUM` or `--attention-implementation SPLIT_EINSUM_V2`. 136 | - `CPU_AND_GPU`: The model will run on both the CPU and GPU. It works best with models converted with 137 | `--attention-implementation ORIGINAL`. 138 | - `CPU_ONLY`: The model will run on the CPU only. 139 | - `ALL`: The model will run on all available hardware. 140 | - **Outputs**: 141 | - **coreml_model**: A Core ML model that can be used with the Core ML Sampler. 142 | 143 | #### Core ML Sampler (`CoreMLSampler`) 144 | 145 | ![CoreMLSampler](./assets/sampler.png?raw=true) 146 | 147 | This node allows you to generate images using a Core ML model. The node takes a Core ML model as input and outputs a 148 | latent image similar to the latent image output by the KSampler. This means that you can use the 149 | resulting latent as you normally would in your workflow. 150 | 151 | - **Inputs**: 152 | - **coreml_model**: The Core ML model to use for sampling. This should be the output of the Core ML UNet Loader. 153 | - **latent_image** [optional]: The latent image to use for sampling. If provided, should be of the same size as the 154 | input of the Core ML model. If not provided, the node will create a latent suitable for the Core ML model used. 155 | Useful in img2img workflows. 156 | - ... _(the rest of the inputs are the same as the KSampler)_ 157 | - **Outputs**: 158 | - **LATENT**: The latent image output by the Core ML model. This can be decoded using a VAE Decoder or used as input 159 | to the next node in your workflow. 160 | 161 | #### Checkpoint Converter 162 | 163 | ![CoreMLConverter](./assets/checkpoint_converter.png?raw=true) 164 | 165 | You can use this node to convert any **SD1.5** based checkpoint to a Core ML model. The converted model is stored in the 166 | `models/unet` directory and can be used with the `Core ML UNet Loader`. The conversion parameters are encoded in 167 | the node name, so if the model already exists, the node will not convert it again. 168 | 169 | - **Inputs**: 170 | - **ckpt_name**: The name of the checkpoint to convert. This should be the name of the checkpoint file stored in the 171 | `models/checkpoints` directory. 172 | - **model_version**: Whether the model is based on SD1.5 or SDXL. 173 | - **height**: The desired height of the image generated by the model. The default is 512. Must be a multiple of 8. 174 | - **width**: The desired width of the image generated by the model. The default is 512. Must be a multiple of 8. 175 | - **batch_size**: The batch size of generated images. If you're planning to generate batches of images, you can try 176 | increasing this value to speed up the generation process. The default is 1. 177 | - **attention_implementation**: The attention implementation used when converting the model. Choose SPLIT_EINSUM or 178 | SPLIT_EINSUM_V2 for better ANE support. Choose ORIGINAL for better GPU support. 179 | - **compute_unit**: The hardware on which the model should run. This is used only when loading the model and doesn't 180 | affect the conversion process. 181 | - **controlnet_support**: For the model to support ControlNet, it must be converted with this option set to True. 182 | The 183 | default is False. 184 | - **lora_params** [optional]: Optional LoRA names and weights. If provided, the model will be converted with LoRA(s) 185 | baked in. More on loading LoRAs below. 186 | - **Outputs**: 187 | - **coreml_model**: The converted Core ML model that can be used with Core ML Sampler. 188 | 189 | > [!NOTE] 190 | > Some models use a custom config .yaml file. If you're using such a model, you'll need to place the config file in the 191 | > `models/configs` directory. The config file should be named the same as the checkpoint file. For example, if the 192 | > checkpoint file is named `juggernaut_aftermath.safetensors`, the config file should be 193 | > named `juggernaut_aftermath.yaml`. 194 | > The config file will be automatically loaded during conversion. 195 | 196 | > [!NOTE] 197 | > For now, the converter relies heavilty on the model name to determine the conversion parameters. This means that if 198 | > you change the model name, the node will convert the model again. Other than that, if you find the name too long or 199 | > confusing, you can change it to anything you want. 200 | 201 | #### LoRA Loader 202 | 203 | ![LoRALoader](./assets/lora_loader.png?raw=true) 204 | 205 | This node allows you to load LoRAs and bake them into a model. Since this is a workaround (as model weights can't be 206 | modified 207 | after conversion), there are a few caveats to keep in mind: 208 | 209 | - The LoRA weights and _strength_model_ parameter are baked into the model. This means that you can't change them 210 | after conversion. This also means that you need to convert the model again if you want to change the LoRA weights. 211 | - Loading LoRA affects CLIP, which is not a part of Core ML workflow, so you'll need to load CLIP separately, 212 | either using `CLIPLoader` or `CheckpointLoaderSimple`. (See [example workflows](#example-workflows) for more details.) 213 | - After conversion, if you want to load the model using `CoreMLUnetLoader`, you'll need to apply the same LoRAs to 214 | CLIP manually. (See [example workflows](#example-workflows) for more details.) 215 | - The LoRA names are encoded in the model name. This means that if you change the name of the LoRA file, 216 | you'll need to change the model name as well, or the node will convert the model again. (Model strength is not 217 | encoded, so if you want to change it, you'll need to delete the converted model manually) 218 | - _strength_clip_ parameter only affects the CLIP model and is not baked into the converted model. This means that 219 | you can change it after conversion. 220 | 221 | - **Inputs**: 222 | - **lora_name**: The name of the LoRA to load. 223 | - **strength_model**: The strength of the LoRA model. 224 | - **strength_clip**: The strength of the LoRA CLIP. 225 | - **lora_params** [optional]: Optional output from other LoRA Loaders. 226 | - **clip**: The CLIP model to use with the LoRA. This can be either output of the 227 | `CLIPLoader`/`CheckpointLoaderSimple` or other LoRA Loaders. 228 | - **Outputs**: 229 | - **lora_params**: The LoRA parameters that can be passed to the Core ML Converter or other LoRA Loaders. 230 | - **CLIP**: The CLIP model with LoRA applied. 231 | 232 | #### LCM Converter 233 | 234 | ![LCMConverter](./assets/lcm_converter.png?raw=true) 235 | 236 | This node converts [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) model to Core 237 | ML. The converted model is stored in the `models/unet` directory and can be used with the Core ML UNet Loader. The 238 | conversion parameteres are encoded in the node name, so if the model already exists, the node will not convert it again. 239 | 240 | - **Inputs**: 241 | - **height**: The desired height of the image generated by the model. The default is 512. Must be a multiple of 8. 242 | - **width**: The desired width of the image generated by the model. The default is 512. Must be a multiple of 8. 243 | - **batch_size**: The batch size of generated images. If you're planning to generate batches of images, you can try 244 | increasing this value to speed up the generation process. The default is 1. 245 | - **compute_unit**: The hardware on which the model should run. This is used only when loading the model and 246 | doesn't affect the conversion process. 247 | - **controlnet_support**: For the model to support ControlNet, it must be converted with this option set to True. 248 | The default is False. 249 | 250 | > [!NOTE] 251 | > The conversion process can take a while, so please be patient. 252 | 253 | > [!NOTE] 254 | > When using the LCM model with Core ML Sampler, please set _sampler_name_ to `lcm` and _scheduler_ to `sgm_uniform`. 255 | 256 | #### Core ML Adapter (Experimental) (`CoreMLModelAdapter`) 257 | 258 | ![CoreMLModelAdapter](./assets/adapter.png?raw=true) 259 | 260 | This node allows you to use a Core ML as a standard ComfyUI model. This is an experimental node and may not work with 261 | all models and nodes. Please use with caution and pay attention to the expected inputs of the model. 262 | 263 | - **Input**: 264 | - **coreml_model**: The Core ML model to use as a ComfyUI model. 265 | - **Output**: 266 | - **MODEL**: The Core ML model wrapped in a ComfyUI model. 267 | 268 | > [!NOTE] 269 | > While this approach allows you to use Core ML models with many ComfyUI nodes (both standard and custom), the 270 | > expected inputs of the model will not be checked, which may cause errors. Please make sure to use a model compatible 271 | > with the expected parameters. 272 | 273 | ### Example Workflows 274 | 275 | > [!NOTE] 276 | > The models used are just an example. Feel free to experiment with different models and see what works best for you. 277 | 278 | #### Basic txt2img with Core ML UNet loader 279 | 280 | This is a basic txt2img workflow that uses the Core ML UNet loader to load a model. The CLIP and VAE models 281 | are loaded using the standard ComfyUI nodes. In the first example, the text encoder (CLIP) and VAE models are loaded 282 | separately. In the second example, the text encoder and VAE models are loaded from the checkpoint file. Note that you 283 | can use any CLIP or VAE model as long as it's compatible with Stable Diffusion v1.5. 284 | 285 | 1. **Loading text encoder (CLIP) and VAE models separately** 286 | - This workflow uses CLIP and VAE models available 287 | [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/text_encoder/model.safetensors) and 288 | [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/diffusion_pytorch_model.safetensors). 289 | Once downloaded, place the models in the`models/clip` and `models/vae` directories respectively. 290 | - The Core ML UNet model is available 291 | [here](https://huggingface.co/coreml-community/coreml-stable-diffusion-v1-5_cn/blob/main/split_einsum/stable-diffusion-_v1-5_split-einsum_cn.zip). 292 | Once downloaded, place the model in the `models/unet` directory. 293 | ![coreml-unet+clip+vae](./assets/unet+sampler+clip+vae.png?raw=true) 294 | 2. **Loading text encoder (CLIP) and VAE models from checkpoint file** 295 | - This workflow loads the CLIP and VAE models from the checkpoint file available 296 | [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors). 297 | Once downloaded, place the model in the`models/checkpoints` directory. 298 | - The Core ML UNet model is available 299 | [here](https://huggingface.co/coreml-community/coreml-stable-diffusion-v1-5_cn/blob/main/split_einsum/stable-diffusion-_v1-5_split-einsum_cn.zip). 300 | Once downloaded, place the model in the `models/unet` directory. 301 | ![coreml-unet+checkpoint](./assets/unet+sampler+checkpoint.png?raw=true) 302 | 303 | #### ControlNet with Core ML UNet loader 304 | 305 | This workflow uses the Core ML UNet loader to load a Core ML UNet model that supports ControlNet. The ControlNet is 306 | being loaded using the standard ComfyUI nodes. Please refer to 307 | the [basic txt2img workflow](#basic-txt2img-with-core-ml-unet-loader) for more details on how to load the CLIP and VAE 308 | models. 309 | The ControlNet model used in this workflow is available 310 | [here](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble/blob/main/diffusion_pytorch_model.fp16.safetensors). 311 | Once downloaded, place the model in the `models/controlnet` directory. 312 | ![coreml-unet+controlnet](./assets/unet+sampler+controlnet.png?raw=true) 313 | 314 | #### Checkpoint conversion 315 | 316 | This workflow uses the Checkpoint Converter to convert the checkpoint file. See 317 | [Checkpoint Converter](#checkpoint-converter) description for more details. 318 | 319 | ![checkpoint-converter](./assets/basic_conversion.png?raw=true) 320 | 321 | #### Checkpoint conversion with LoRA 322 | 323 | This workflow uses the Checkpoint Converter to convert the checkpoint file with LoRA. See 324 | [LoRA Loader](#lora-loader) description to read more about the caveats of using LoRA. 325 | 326 | ![checkpoint-converter+lora](./assets/conversion+lora.png?raw=true) 327 | 328 | #### LCM LoRA conversion 329 | 330 | Please note that you can use multiple LoRAs with the same model. To do this, you'll need to use multiple LoRA Loaders. 331 | > [!IMPORTANT] 332 | > In this example, the model is passed through the adapter and `ModelSamplingDiscrete` nodes to a standard ComfyUI's 333 | > KSampler (not Core ML Sampler). ModelSamplingDiscrete needs to be used to sample models with LCM LoRAs properly. 334 | 335 | ![multiple-loras](./assets/conversion+lcm_lora.png?raw=true) 336 | 337 | #### Loader with LoRAs 338 | 339 | This workflow uses the Core ML UNet Loader to load a model with LoRAs. The CLIP must be loaded separately and passed 340 | through the same LoRA nodes as during conversion. See [LoRA Loader](#lora-loader) description to read more about the 341 | caveats of using LoRA. Since _lora_name_ and _strength_model_ are baked into the model, it is not necessary to pass 342 | them as inputs to the loader. 343 | > [!IMPORTANT] 344 | > In this example, the model is passed through the adapter and `ModelSamplingDiscrete` nodes to a standard ComfyUI's 345 | > KSampler (not Core ML Sampler). ModelSamplingDiscrete needs to be used to sample models with LCM LoRAs properly. 346 | 347 | ![loader+lora](./assets/loader+lcm_lora.png?raw=true) 348 | 349 | #### LCM conversion with ControlNet 350 | 351 | This workflow uses LCM converter to 352 | convert [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) 353 | model to Core ML. The converted model can then be used with or without ControlNet to generate images. 354 | ![lcm+controlnet](./assets/lcm+controlnet.png?raw=true) 355 | 356 | #### SDXL Base + Refiner conversion 357 | 358 | This is a basic workflow for SDXL. You add LoRAs and ControlNets the same way as in the previous examples. 359 | You can also skip the refiner step. 360 | 361 | The models used in this workflow are available at the following links: 362 | 363 | - [Base model + text_encoder (clip) + text_encoder_2 (clip2)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 364 | - [Refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) 365 | - [VAE](https://huggingface.co/stabilityai/sdxl-vae) 366 | 367 | > [!IMPORTANT] 368 | > SDXL on ANE is not supported. If loading of the model gets stuck, please try using CPU_AND_GPU or CPU_ONLY. 369 | > For best results, use ORIGINAL attention implementation. 370 | 371 | ![sdxl](./assets/sdxl_conversion.png?raw=true) 372 | 373 | ## Limitations 374 | 375 | - Core ML models are fixed in terms of their inputs and outputs. 376 | This means you'll need to use latent images of the same size as the input of the model (512x512 is the default for 377 | SD1.5). 378 | However, you can convert the model to a different input size using tools available 379 | in the [apple/ml-stable-diffusion](https://github.com/apple/ml-stable-diffusion) repository. 380 | - SD2.1 models are not supported. 381 | 382 | [^1]: 383 | Unless [EnumeratedShapes](https://apple.github.io/coremltools/docs-guides/source/flexible-inputs.html#select-from-predetermined-shapes) 384 | is used during conversion. Needs more testing. 385 | 386 | ## Support 387 | 388 | I'm here to help! If you have any questions or suggestions, don't hesitate to open an issue and I'll do my best 389 | to assist you. 390 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.dirname(__file__)) 5 | 6 | from coreml_suite.nodes import ( 7 | CoreMLLoaderUNet, 8 | CoreMLSampler, 9 | CoreMLSamplerAdvanced, 10 | CoreMLModelAdapter, 11 | CoreMLConverter, 12 | COREML_LOAD_LORA, 13 | ) 14 | from coreml_suite.lcm import ( 15 | COREML_CONVERT_LCM, 16 | ) 17 | 18 | NODE_CLASS_MAPPINGS = { 19 | "CoreMLUNetLoader": CoreMLLoaderUNet, 20 | "CoreMLSampler": CoreMLSampler, 21 | "CoreMLSamplerAdvanced": CoreMLSamplerAdvanced, 22 | "CoreMLModelAdapter": CoreMLModelAdapter, 23 | "Core ML LoRA Loader": COREML_LOAD_LORA, 24 | "Core ML Converter": CoreMLConverter, 25 | "Core ML LCM Converter": COREML_CONVERT_LCM, 26 | } 27 | NODE_DISPLAY_NAME_MAPPINGS = { 28 | "CoreMLUNetLoader": "Load Core ML UNet", 29 | "CoreMLSampler": "Core ML Sampler", 30 | "CoreMLSamplerAdvanced": "Core ML Sampler (Advanced)", 31 | "CoreMLModelAdapter": "Core ML Adapter (Experimental)", 32 | "Core ML LoRA Loader": "Load LoRA to use with Core ML", 33 | "Core ML Converter": "Convert Checkpoint to Core ML", 34 | "Core ML LCM Converter": "Convert LCM to Core ML", 35 | } 36 | -------------------------------------------------------------------------------- /assets/adapter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/adapter.png -------------------------------------------------------------------------------- /assets/basic_conversion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/basic_conversion.png -------------------------------------------------------------------------------- /assets/checkpoint_converter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/checkpoint_converter.png -------------------------------------------------------------------------------- /assets/conversion+lcm_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/conversion+lcm_lora.png -------------------------------------------------------------------------------- /assets/conversion+lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/conversion+lora.png -------------------------------------------------------------------------------- /assets/lcm+controlnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/lcm+controlnet.png -------------------------------------------------------------------------------- /assets/lcm_converter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/lcm_converter.png -------------------------------------------------------------------------------- /assets/loader+lcm_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/loader+lcm_lora.png -------------------------------------------------------------------------------- /assets/lora_loader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/lora_loader.png -------------------------------------------------------------------------------- /assets/sampler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/sampler.png -------------------------------------------------------------------------------- /assets/sdxl_conversion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/sdxl_conversion.png -------------------------------------------------------------------------------- /assets/snake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/snake.png -------------------------------------------------------------------------------- /assets/unet+sampler+checkpoint.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/unet+sampler+checkpoint.png -------------------------------------------------------------------------------- /assets/unet+sampler+clip+vae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/unet+sampler+clip+vae.png -------------------------------------------------------------------------------- /assets/unet+sampler+controlnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/unet+sampler+controlnet.png -------------------------------------------------------------------------------- /assets/unet_loader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/assets/unet_loader.png -------------------------------------------------------------------------------- /coreml_suite/__init__.py: -------------------------------------------------------------------------------- 1 | class COREML_NODE: 2 | CATEGORY = "Core ML Suite" 3 | -------------------------------------------------------------------------------- /coreml_suite/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | 5 | from comfy import supported_models_base 6 | from comfy import latent_formats 7 | from comfy.model_detection import convert_config 8 | 9 | 10 | class ModelVersion(Enum): 11 | SD15 = "sd15" 12 | SDXL = "sdxl" 13 | SDXL_REFINER = "sdxl_refiner" 14 | LCM = "lcm" 15 | 16 | 17 | config_map = { 18 | ModelVersion.SD15: { 19 | "use_checkpoint": False, 20 | "image_size": 32, 21 | "out_channels": 4, 22 | "use_spatial_transformer": True, 23 | "legacy": False, 24 | "adm_in_channels": None, 25 | "dtype": torch.float16, 26 | "in_channels": 4, 27 | "model_channels": 320, 28 | "num_res_blocks": 2, 29 | "attention_resolutions": [1, 2, 4], 30 | "transformer_depth": [1, 1, 1, 0], 31 | "channel_mult": [1, 2, 4, 4], 32 | "transformer_depth_middle": 1, 33 | "use_linear_in_transformer": False, 34 | "context_dim": 768, 35 | "num_heads": 8, 36 | "disable_unet_model_creation": True, 37 | }, 38 | ModelVersion.SDXL: { 39 | "use_checkpoint": False, 40 | "image_size": 32, 41 | "out_channels": 4, 42 | "use_spatial_transformer": True, 43 | "legacy": False, 44 | "num_classes": "sequential", 45 | "adm_in_channels": 2816, 46 | "dtype": torch.float16, 47 | "in_channels": 4, 48 | "model_channels": 320, 49 | "num_res_blocks": 2, 50 | "attention_resolutions": [2, 4], 51 | "transformer_depth": [0, 2, 10], 52 | "channel_mult": [1, 2, 4], 53 | "transformer_depth_middle": 10, 54 | "use_linear_in_transformer": True, 55 | "context_dim": 2048, 56 | "num_head_channels": 64, 57 | "disable_unet_model_creation": True, 58 | }, 59 | ModelVersion.SDXL_REFINER: { 60 | "use_checkpoint": False, 61 | "image_size": 32, 62 | "out_channels": 4, 63 | "use_spatial_transformer": True, 64 | "legacy": False, 65 | "num_classes": "sequential", 66 | "adm_in_channels": 2560, 67 | "dtype": torch.float16, 68 | "in_channels": 4, 69 | "model_channels": 384, 70 | "num_res_blocks": 2, 71 | "attention_resolutions": [2, 4], 72 | "transformer_depth": [0, 4, 4, 0], 73 | "channel_mult": [1, 2, 4, 4], 74 | "transformer_depth_middle": 4, 75 | "use_linear_in_transformer": True, 76 | "context_dim": 1280, 77 | "num_head_channels": 64, 78 | "disable_unet_model_creation": True, 79 | }, 80 | } 81 | 82 | latent_format_map = { 83 | ModelVersion.SD15: latent_formats.SD15, 84 | ModelVersion.SDXL: latent_formats.SDXL, 85 | ModelVersion.SDXL_REFINER: latent_formats.SDXL, 86 | } 87 | 88 | 89 | def get_model_config(model_version: ModelVersion): 90 | unet_config = convert_config(config_map[model_version]) 91 | config = supported_models_base.BASE(unet_config) 92 | config.latent_format = latent_format_map[model_version]() 93 | return config 94 | 95 | 96 | def unet_config_from_diffusers_unet(state_dict): 97 | match = {} 98 | attention_resolutions = [] 99 | 100 | attn_res = 1 101 | for i in range(5): 102 | k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format( 103 | i 104 | ) 105 | if k in state_dict: 106 | match["context_dim"] = state_dict[k].shape[1] 107 | attention_resolutions.append(attn_res) 108 | attn_res *= 2 109 | 110 | match["attention_resolutions"] = attention_resolutions 111 | 112 | match["model_channels"] = state_dict["conv_in.weight"].shape[0] 113 | match["in_channels"] = state_dict["conv_in.weight"].shape[1] 114 | match["adm_in_channels"] = None 115 | if "class_embedding.linear_1.weight" in state_dict: 116 | match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[ 117 | 1 118 | ] 119 | elif "add_embedding.linear_1.weight" in state_dict: 120 | match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] 121 | 122 | print(match) 123 | -------------------------------------------------------------------------------- /coreml_suite/controlnet.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from math import ceil 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from coreml_suite.latents import chunk_batch 8 | 9 | 10 | def expand_inputs(inputs): 11 | expanded = inputs.copy() 12 | for k, v in inputs.items(): 13 | if isinstance(v, np.ndarray): 14 | expanded[k] = np.concatenate([v] * 2) if v.shape[0] == 1 else v 15 | elif isinstance(v, torch.Tensor): 16 | expanded[k] = torch.cat([v] * 2) if v.shape[0] == 1 else v 17 | elif isinstance(v, list): 18 | expanded[k] = v * 2 if len(v) == 1 else v 19 | elif isinstance(v, dict): 20 | expand_inputs(v) 21 | return expanded 22 | 23 | 24 | def extract_residual_kwargs(expected_inputs, control): 25 | if "additional_residual_0" not in expected_inputs.keys(): 26 | return {} 27 | if control is None: 28 | return no_control(expected_inputs) 29 | 30 | residual_kwargs = { 31 | "additional_residual_{}".format(i): r.cpu().numpy().astype(np.float16) 32 | for i, r in enumerate(chain(control["output"], control["middle"])) 33 | } 34 | return residual_kwargs 35 | 36 | 37 | def no_control(expected_inputs): 38 | shapes_dict = { 39 | k: v["shape"] for k, v in expected_inputs.items() if k.startswith("additional") 40 | } 41 | residual_kwargs = { 42 | k: torch.zeros(*shape).cpu().numpy().astype(dtype=np.float16) 43 | for k, shape in shapes_dict.items() 44 | } 45 | return residual_kwargs 46 | 47 | 48 | def chunk_control(cn, target_size): 49 | if cn is None: 50 | return [None] * target_size 51 | 52 | num_chunks = ceil(cn["output"][0].shape[0] / target_size) 53 | 54 | out = [{"output": [], "middle": []} for _ in range(num_chunks)] 55 | 56 | for k, v in cn.items(): 57 | for i, x in enumerate(v): 58 | chunks = chunk_batch(x, (target_size, *x.shape[1:])) 59 | for j, chunk in enumerate(chunks): 60 | out[j][k].append(chunk) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /coreml_suite/converter.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import shutil 4 | import time 5 | from typing import Union 6 | 7 | import coremltools as ct 8 | import numpy as np 9 | import python_coreml_stable_diffusion.unet 10 | import torch 11 | from diffusers import ( 12 | StableDiffusionPipeline, 13 | LatentConsistencyModelPipeline, 14 | StableDiffusionXLPipeline, 15 | ) 16 | from python_coreml_stable_diffusion.unet import ( 17 | UNet2DConditionModel, 18 | UNet2DConditionModelXL, 19 | AttentionImplementations, 20 | ) 21 | 22 | from coreml_suite.config import ModelVersion 23 | from coreml_suite.lcm.unet import UNet2DConditionModelLCM 24 | from coreml_suite.logger import logger 25 | from folder_paths import get_folder_paths 26 | 27 | 28 | class StableDiffusionLCMPipeline(LatentConsistencyModelPipeline): 29 | pass 30 | 31 | 32 | MODEL_TYPE_TO_UNET_CLS = { 33 | ModelVersion.SD15: UNet2DConditionModel, 34 | ModelVersion.SDXL: UNet2DConditionModelXL, 35 | ModelVersion.LCM: UNet2DConditionModelLCM, 36 | } 37 | 38 | MODEL_TYPE_TO_PIPE_CLS = { 39 | ModelVersion.SD15: StableDiffusionPipeline, 40 | ModelVersion.SDXL: StableDiffusionXLPipeline, 41 | ModelVersion.LCM: StableDiffusionLCMPipeline, 42 | } 43 | 44 | 45 | def get_unet(model_type: ModelVersion, ref_pipe): 46 | ref_unet = ref_pipe.unet 47 | 48 | unet_cls = MODEL_TYPE_TO_UNET_CLS[model_type] 49 | cml_unet = unet_cls.from_config(ref_unet.config).eval() 50 | cml_unet.load_state_dict(ref_unet.state_dict(), strict=False) 51 | 52 | return cml_unet 53 | 54 | 55 | def get_encoder_hidden_states_shape(ref_pipe, batch_size): 56 | text_encoder = ( 57 | ref_pipe.text_encoder_2 58 | if hasattr(ref_pipe, "text_encoder_2") 59 | else ref_pipe.text_encoder 60 | ) 61 | 62 | text_token_sequence_length = text_encoder.config.max_position_embeddings 63 | hidden_size = (text_encoder.config.hidden_size,) 64 | 65 | encoder_hidden_states_shape = ( 66 | batch_size, 67 | ref_pipe.unet.config.cross_attention_dim or hidden_size, 68 | 1, 69 | text_token_sequence_length, 70 | ) 71 | 72 | return encoder_hidden_states_shape 73 | 74 | 75 | def get_coreml_inputs(sample_inputs): 76 | coreml_sample_unet_inputs = { 77 | k: v.numpy().astype(np.float16) for k, v in sample_inputs.items() 78 | } 79 | return [ 80 | ct.TensorType( 81 | name=k, 82 | shape=v.shape, 83 | dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype, 84 | ) 85 | for k, v in coreml_sample_unet_inputs.items() 86 | ] 87 | 88 | 89 | def load_coreml_model(out_path): 90 | logger.info(f"Loading model from {out_path}") 91 | 92 | start = time.time() 93 | coreml_model = ct.models.MLModel(out_path) 94 | logger.info(f"Loading {out_path} took {time.time() - start:.1f} seconds") 95 | 96 | return coreml_model 97 | 98 | 99 | def convert_to_coreml( 100 | submodule_name, torchscript_module, sample_inputs, output_names, out_path 101 | ): 102 | if os.path.exists(out_path): 103 | logger.info(f"Skipping export because {out_path} already exists") 104 | coreml_model = load_coreml_model(out_path) 105 | else: 106 | logger.info(f"Converting {submodule_name} to CoreML..") 107 | coreml_model = ct.convert( 108 | torchscript_module, 109 | convert_to="mlprogram", 110 | minimum_deployment_target=ct.target.macOS13, 111 | inputs=sample_inputs, 112 | outputs=[ 113 | ct.TensorType(name=name, dtype=np.float32) for name in output_names 114 | ], 115 | skip_model_load=True, 116 | ) 117 | 118 | del torchscript_module 119 | gc.collect() 120 | 121 | return coreml_model 122 | 123 | 124 | def get_out_path(submodule_name, model_name): 125 | fname = f"{model_name}_{submodule_name}.mlpackage" 126 | unet_path = get_folder_paths(submodule_name)[0] 127 | out_path = os.path.join(unet_path, fname) 128 | return out_path 129 | 130 | 131 | def compile_coreml_model(source_model_path, output_dir, final_name): 132 | """Compiles Core ML models using the coremlcompiler utility from Xcode toolchain""" 133 | target_path = os.path.join(output_dir, f"{final_name}.mlmodelc") 134 | if os.path.exists(target_path): 135 | logger.warning(f"Found existing compiled model at {target_path}! Skipping..") 136 | return target_path 137 | 138 | logger.info(f"Compiling {source_model_path}") 139 | source_model_name = os.path.basename(os.path.splitext(source_model_path)[0]) 140 | 141 | os.system(f"xcrun coremlcompiler compile {source_model_path} {output_dir}") 142 | compiled_output = os.path.join(output_dir, f"{source_model_name}.mlmodelc") 143 | shutil.move(compiled_output, target_path) 144 | 145 | return target_path 146 | 147 | 148 | def get_sample_input(batch_size, encoder_hidden_states_shape, sample_shape, scheduler): 149 | sample_unet_inputs = dict( 150 | [ 151 | ("sample", torch.rand(*sample_shape)), 152 | ( 153 | "timestep", 154 | torch.tensor([scheduler.timesteps[0].item()] * batch_size).to( 155 | torch.float32 156 | ), 157 | ), 158 | ("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)), 159 | ] 160 | ) 161 | return sample_unet_inputs 162 | 163 | 164 | def lcm_inputs(sample_unet_inputs): 165 | batch_size = sample_unet_inputs["sample"].shape[0] 166 | return {"timestep_cond": torch.randn(batch_size, 256).to(torch.float32)} 167 | 168 | 169 | def sdxl_inputs(sample_unet_inputs, ref_pipe): 170 | sample_shape = sample_unet_inputs["sample"].shape 171 | batch_size = sample_shape[0] 172 | h = sample_shape[2] * 8 173 | w = sample_shape[3] * 8 174 | original_size = (h, w) 175 | crops_coords_top_left = (0, 0) 176 | 177 | is_refiner = ( 178 | hasattr(ref_pipe.config, "requires_aesthetics_score") 179 | and ref_pipe.config.requires_aesthetics_score 180 | ) 181 | 182 | if is_refiner: 183 | aesthetic_score = (6.0,) 184 | time_ids_list = list(original_size + crops_coords_top_left + aesthetic_score) 185 | else: 186 | target_size = (h, w) 187 | time_ids_list = list(original_size + crops_coords_top_left + target_size) 188 | 189 | time_ids = torch.tensor(time_ids_list).repeat(batch_size, 1).to(torch.int64) 190 | text_embeds_shape = (batch_size, ref_pipe.text_encoder_2.config.hidden_size) 191 | 192 | return { 193 | "time_ids": time_ids, 194 | "text_embeds": torch.randn(*text_embeds_shape).to(torch.float32), 195 | } 196 | 197 | 198 | def get_inputs_spec(inputs): 199 | inputs_spec = {k: (v.shape, v.dtype) for k, v in inputs.items()} 200 | return inputs_spec 201 | 202 | 203 | def add_cnet_support(sample_shape, reference_unet): 204 | from python_coreml_stable_diffusion.unet import calculate_conv2d_output_shape 205 | 206 | additional_residuals_shapes = [] 207 | 208 | batch_size = sample_shape[0] 209 | h, w = sample_shape[2:] 210 | 211 | # conv_in 212 | out_h, out_w = calculate_conv2d_output_shape( 213 | h, 214 | w, 215 | reference_unet.conv_in, 216 | ) 217 | additional_residuals_shapes.append( 218 | (batch_size, reference_unet.conv_in.out_channels, out_h, out_w) 219 | ) 220 | 221 | # down_blocks 222 | for down_block in reference_unet.down_blocks: 223 | additional_residuals_shapes += [ 224 | (batch_size, resnet.out_channels, out_h, out_w) 225 | for resnet in down_block.resnets 226 | ] 227 | if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None: 228 | for downsampler in down_block.downsamplers: 229 | out_h, out_w = calculate_conv2d_output_shape( 230 | out_h, out_w, downsampler.conv 231 | ) 232 | additional_residuals_shapes.append( 233 | ( 234 | batch_size, 235 | down_block.downsamplers[-1].conv.out_channels, 236 | out_h, 237 | out_w, 238 | ) 239 | ) 240 | 241 | # mid_block 242 | additional_residuals_shapes.append( 243 | (batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w) 244 | ) 245 | 246 | additional_inputs = {} 247 | for i, shape in enumerate(additional_residuals_shapes): 248 | sample_residual_input = torch.rand(*shape) 249 | additional_inputs[f"additional_residual_{i}"] = sample_residual_input 250 | 251 | return additional_inputs 252 | 253 | 254 | def convert_unet( 255 | ref_pipe, 256 | model_version: ModelVersion, 257 | unet_out_path: str, 258 | batch_size: int = 1, 259 | sample_size: tuple[int, int] = (64, 64), 260 | controlnet_support: bool = False, 261 | ): 262 | coreml_unet = get_unet(model_version, ref_pipe) 263 | ref_unet = ref_pipe.unet 264 | 265 | sample_shape = ( 266 | batch_size, # B 267 | ref_unet.config.in_channels, # C 268 | sample_size[0], # H 269 | sample_size[1], # W 270 | ) 271 | 272 | encoder_hidden_states_shape = get_encoder_hidden_states_shape(ref_pipe, batch_size) 273 | 274 | scheduler = ref_pipe.scheduler 275 | scheduler.set_timesteps(50) 276 | 277 | sample_inputs = get_sample_input( 278 | batch_size, encoder_hidden_states_shape, sample_shape, scheduler 279 | ) 280 | 281 | if model_version == ModelVersion.LCM: 282 | sample_inputs |= lcm_inputs(sample_inputs) 283 | 284 | if model_version == ModelVersion.SDXL: 285 | sample_inputs |= sdxl_inputs(sample_inputs, ref_pipe) 286 | 287 | if controlnet_support: 288 | sample_inputs |= add_cnet_support(sample_shape, ref_unet) 289 | 290 | sample_inputs_spec = get_inputs_spec(sample_inputs) 291 | 292 | logger.info(f"Sample UNet inputs spec: {sample_inputs_spec}") 293 | logger.info("JIT tracing..") 294 | traced_unet = torch.jit.trace( 295 | coreml_unet, example_inputs=list(sample_inputs.values()) 296 | ) 297 | logger.info("Done.") 298 | 299 | coreml_sample_inputs = get_coreml_inputs(sample_inputs) 300 | 301 | coreml_unet = convert_to_coreml( 302 | "unet", traced_unet, coreml_sample_inputs, ["noise_pred"], unet_out_path 303 | ) 304 | 305 | del traced_unet 306 | gc.collect() 307 | 308 | coreml_unet.save(unet_out_path) 309 | logger.info(f"Saved unet into {unet_out_path}") 310 | 311 | 312 | def convert( 313 | ckpt_path: str, 314 | model_version: ModelVersion, 315 | unet_out_path: str, 316 | batch_size: int = 1, 317 | sample_size: tuple[int, int] = (64, 64), 318 | controlnet_support: bool = False, 319 | lora_weights: list[tuple[Union[str, os.PathLike], float]] = None, 320 | attn_impl: str = AttentionImplementations.SPLIT_EINSUM.name, 321 | config_path: str = None, 322 | ): 323 | if os.path.exists(unet_out_path): 324 | logger.info(f"Found existing model at {unet_out_path}! Skipping..") 325 | return 326 | 327 | python_coreml_stable_diffusion.unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = ( 328 | AttentionImplementations(attn_impl) 329 | ) 330 | 331 | ref_pipe = get_pipeline(ckpt_path, config_path, model_version) 332 | 333 | for i, lora_weight in enumerate(lora_weights or []): 334 | lora_path, strength = lora_weight 335 | adapter_name = f"lora_{i}" 336 | ref_pipe.load_lora_weights(lora_path, adapter_name=adapter_name) 337 | ref_pipe.set_adapters([adapter_name], adapter_weights=[strength]) 338 | ref_pipe.fuse_lora() 339 | 340 | convert_unet( 341 | ref_pipe, 342 | model_version, 343 | unet_out_path, 344 | batch_size, 345 | sample_size, 346 | controlnet_support, 347 | ) 348 | 349 | 350 | def get_pipeline(ckpt_path, config_path, model_version): 351 | pipe_cls = MODEL_TYPE_TO_PIPE_CLS[model_version] 352 | ref_pipe = pipe_cls.from_single_file(ckpt_path, original_config_file=config_path) 353 | return ref_pipe 354 | 355 | 356 | def compile_model(out_path, out_name, submodule_name): 357 | # Compile the model 358 | target_path = compile_coreml_model( 359 | out_path, get_folder_paths(submodule_name)[0], f"{out_name}_{submodule_name}" 360 | ) 361 | logger.info(f"Compiled {out_path} to {target_path}") 362 | return target_path 363 | -------------------------------------------------------------------------------- /coreml_suite/latents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def chunk_batch(input_tensor, target_shape): 5 | if input_tensor.shape == target_shape: 6 | return [input_tensor] 7 | 8 | batch_size = input_tensor.shape[0] 9 | target_batch_size = target_shape[0] 10 | 11 | num_chunks = batch_size // target_batch_size 12 | if num_chunks == 0: 13 | padding = torch.zeros(target_batch_size - batch_size, *target_shape[1:]).to( 14 | input_tensor.device 15 | ) 16 | return [torch.cat((input_tensor, padding), dim=0)] 17 | 18 | mod = batch_size % target_batch_size 19 | if mod != 0: 20 | chunks = list(torch.chunk(input_tensor[:-mod], num_chunks)) 21 | padding = torch.zeros(target_batch_size - mod, *target_shape[1:]).to( 22 | input_tensor.device 23 | ) 24 | padded = torch.cat((input_tensor[-mod:], padding), dim=0) 25 | chunks.append(padded) 26 | return chunks 27 | 28 | chunks = list(torch.chunk(input_tensor, num_chunks)) 29 | return chunks 30 | 31 | 32 | def merge_chunks(chunks, orig_shape): 33 | merged = torch.cat(chunks, dim=0) 34 | if merged.shape == orig_shape: 35 | return merged 36 | return merged[: orig_shape[0]] 37 | -------------------------------------------------------------------------------- /coreml_suite/lcm/__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import COREML_CONVERT_LCM 2 | 3 | __all__ = ["COREML_CONVERT_LCM"] 4 | -------------------------------------------------------------------------------- /coreml_suite/lcm/converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import logging 4 | import time 5 | import gc 6 | 7 | import numpy as np 8 | import torch 9 | from diffusers import UNet2DConditionModel, LCMScheduler 10 | from diffusers.loaders import LoraLoaderMixin 11 | 12 | from comfy.model_management import get_torch_device 13 | from coreml_suite.lcm.unet import UNet2DConditionModelLCM 14 | 15 | from transformers import CLIPTextModel 16 | import coremltools as ct 17 | 18 | from folder_paths import get_folder_paths 19 | 20 | logging.basicConfig() 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.DEBUG) 23 | 24 | MODEL_VERSION = "SimianLuo/LCM_Dreamshaper_v7" 25 | MODEL_NAME = MODEL_VERSION.split("/")[-1] + "_4k" 26 | 27 | import python_coreml_stable_diffusion.unet as unet 28 | 29 | unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations.SPLIT_EINSUM 30 | 31 | 32 | def get_unets(): 33 | ref_unet = UNet2DConditionModel.from_pretrained( 34 | MODEL_VERSION, 35 | subfolder="unet", 36 | device_map=None, 37 | low_cpu_mem_usage=False, 38 | ) 39 | 40 | cml_unet = UNet2DConditionModelLCM.from_config(ref_unet.config).eval() 41 | cml_unet.load_state_dict(ref_unet.state_dict(), strict=False) 42 | 43 | return cml_unet, ref_unet 44 | 45 | 46 | def get_encoder_hidden_states_shape(unet_config, batch_size): 47 | text_encoder = CLIPTextModel.from_pretrained( 48 | MODEL_VERSION, subfolder="text_encoder" 49 | ) 50 | 51 | text_token_sequence_length = text_encoder.config.max_position_embeddings 52 | hidden_size = (text_encoder.config.hidden_size,) 53 | 54 | encoder_hidden_states_shape = ( 55 | batch_size, 56 | unet_config.cross_attention_dim or hidden_size, 57 | 1, 58 | text_token_sequence_length, 59 | ) 60 | 61 | return encoder_hidden_states_shape 62 | 63 | 64 | def get_scheduler(): 65 | scheduler = LCMScheduler.from_pretrained(MODEL_VERSION, subfolder="scheduler") 66 | scheduler.set_timesteps(50, get_torch_device(), 50) 67 | return scheduler 68 | 69 | 70 | def get_coreml_inputs(sample_inputs): 71 | coreml_sample_unet_inputs = { 72 | k: v.numpy().astype(np.float16) for k, v in sample_inputs.items() 73 | } 74 | return [ 75 | ct.TensorType( 76 | name=k, 77 | shape=v.shape, 78 | dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype, 79 | ) 80 | for k, v in coreml_sample_unet_inputs.items() 81 | ] 82 | 83 | 84 | def load_coreml_model(out_path): 85 | logger.info(f"Loading model from {out_path}") 86 | 87 | start = time.time() 88 | coreml_model = ct.models.MLModel(out_path) 89 | logger.info(f"Loading {out_path} took {time.time() - start:.1f} seconds") 90 | 91 | return coreml_model 92 | 93 | 94 | def convert_to_coreml( 95 | submodule_name, torchscript_module, sample_inputs, output_names, out_path 96 | ): 97 | if os.path.exists(out_path): 98 | logger.info(f"Skipping export because {out_path} already exists") 99 | coreml_model = load_coreml_model(out_path) 100 | else: 101 | logger.info(f"Converting {submodule_name} to CoreML..") 102 | coreml_model = ct.convert( 103 | torchscript_module, 104 | convert_to="mlprogram", 105 | minimum_deployment_target=ct.target.macOS13, 106 | inputs=sample_inputs, 107 | outputs=[ 108 | ct.TensorType(name=name, dtype=np.float32) for name in output_names 109 | ], 110 | skip_model_load=True, 111 | ) 112 | 113 | del torchscript_module 114 | gc.collect() 115 | 116 | return coreml_model 117 | 118 | 119 | def get_out_path(submodule_name, model_name): 120 | fname = f"{model_name}_{submodule_name}.mlpackage" 121 | unet_path = get_folder_paths(submodule_name)[0] 122 | out_path = os.path.join(unet_path, fname) 123 | return out_path 124 | 125 | 126 | def compile_coreml_model(source_model_path, output_dir, final_name): 127 | """Compiles Core ML models using the coremlcompiler utility from Xcode toolchain""" 128 | target_path = os.path.join(output_dir, f"{final_name}.mlmodelc") 129 | if os.path.exists(target_path): 130 | logger.warning(f"Found existing compiled model at {target_path}! Skipping..") 131 | return target_path 132 | 133 | logger.info(f"Compiling {source_model_path}") 134 | source_model_name = os.path.basename(os.path.splitext(source_model_path)[0]) 135 | 136 | os.system(f"xcrun coremlcompiler compile {source_model_path} {output_dir}") 137 | compiled_output = os.path.join(output_dir, f"{source_model_name}.mlmodelc") 138 | shutil.move(compiled_output, target_path) 139 | 140 | return target_path 141 | 142 | 143 | def get_sample_input(batch_size, encoder_hidden_states_shape, sample_shape, scheduler): 144 | sample_unet_inputs = dict( 145 | [ 146 | ("sample", torch.rand(*sample_shape)), 147 | ( 148 | "timestep", 149 | torch.tensor([scheduler.timesteps[0].item()] * batch_size).to( 150 | torch.float32 151 | ), 152 | ), 153 | ("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)), 154 | ("timestep_cond", torch.randn(batch_size, 256).to(torch.float32)), 155 | ] 156 | ) 157 | return sample_unet_inputs 158 | 159 | 160 | def get_unet_inputs_spec(sample_unet_inputs): 161 | sample_unet_inputs_spec = { 162 | k: (v.shape, v.dtype) for k, v in sample_unet_inputs.items() 163 | } 164 | return sample_unet_inputs_spec 165 | 166 | 167 | def add_cnet_support(sample_shape, reference_unet): 168 | from python_coreml_stable_diffusion.unet import calculate_conv2d_output_shape 169 | 170 | additional_residuals_shapes = [] 171 | 172 | batch_size = sample_shape[0] 173 | h, w = sample_shape[2:] 174 | 175 | # conv_in 176 | out_h, out_w = calculate_conv2d_output_shape( 177 | h, 178 | w, 179 | reference_unet.conv_in, 180 | ) 181 | additional_residuals_shapes.append( 182 | (batch_size, reference_unet.conv_in.out_channels, out_h, out_w) 183 | ) 184 | 185 | # down_blocks 186 | for down_block in reference_unet.down_blocks: 187 | additional_residuals_shapes += [ 188 | (batch_size, resnet.out_channels, out_h, out_w) 189 | for resnet in down_block.resnets 190 | ] 191 | if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None: 192 | for downsampler in down_block.downsamplers: 193 | out_h, out_w = calculate_conv2d_output_shape( 194 | out_h, out_w, downsampler.conv 195 | ) 196 | additional_residuals_shapes.append( 197 | ( 198 | batch_size, 199 | down_block.downsamplers[-1].conv.out_channels, 200 | out_h, 201 | out_w, 202 | ) 203 | ) 204 | 205 | # mid_block 206 | additional_residuals_shapes.append( 207 | (batch_size, reference_unet.mid_block.resnets[-1].out_channels, out_h, out_w) 208 | ) 209 | 210 | additional_inputs = {} 211 | for i, shape in enumerate(additional_residuals_shapes): 212 | sample_residual_input = torch.rand(*shape) 213 | additional_inputs[f"additional_residual_{i}"] = sample_residual_input 214 | 215 | return additional_inputs 216 | 217 | 218 | def convert( 219 | out_path: str, 220 | batch_size: int = 1, 221 | sample_size: tuple[int, int] = (64, 64), 222 | controlnet_support: bool = False, 223 | lora_paths: list[str] = None, 224 | ): 225 | lora_paths = lora_paths or [] 226 | coreml_unet, ref_unet = get_unets() 227 | 228 | for lora_path in lora_paths: 229 | lora_sd, network_alphas = LoraLoaderMixin.lora_state_dict(lora_path) 230 | LoraLoaderMixin.load_lora_into_unet(lora_sd, network_alphas, ref_unet) 231 | ref_unet.fuse_lora() 232 | 233 | sample_shape = ( 234 | batch_size, # B 235 | ref_unet.config.in_channels, # C 236 | sample_size[0], # H 237 | sample_size[1], # W 238 | ) 239 | 240 | encoder_hidden_states_shape = get_encoder_hidden_states_shape( 241 | ref_unet.config, batch_size 242 | ) 243 | 244 | scheduler = get_scheduler() 245 | 246 | sample_inputs = get_sample_input( 247 | batch_size, encoder_hidden_states_shape, sample_shape, scheduler 248 | ) 249 | 250 | if controlnet_support: 251 | sample_inputs |= add_cnet_support(sample_shape, ref_unet) 252 | 253 | sample_inputs_spec = get_unet_inputs_spec(sample_inputs) 254 | 255 | logger.info(f"Sample UNet inputs spec: {sample_inputs_spec}") 256 | logger.info("JIT tracing..") 257 | traced_unet = torch.jit.trace( 258 | coreml_unet, example_inputs=list(sample_inputs.values()) 259 | ) 260 | logger.info("Done.") 261 | 262 | coreml_sample_inputs = get_coreml_inputs(sample_inputs) 263 | 264 | coreml_unet = convert_to_coreml( 265 | "unet", traced_unet, coreml_sample_inputs, ["noise_pred"], out_path 266 | ) 267 | 268 | del traced_unet 269 | gc.collect() 270 | 271 | coreml_unet.save(out_path) 272 | logger.info(f"Saved unet into {out_path}") 273 | 274 | 275 | def compile_model(out_path, out_name): 276 | # Compile the model 277 | target_path = compile_coreml_model( 278 | out_path, get_folder_paths("unet")[0], f"{out_name}_unet" 279 | ) 280 | logger.info(f"Compiled {out_path} to {target_path}") 281 | return target_path 282 | 283 | 284 | if __name__ == "__main__": 285 | h = 512 286 | w = 512 287 | sample_size = (h // 8, w // 8) 288 | batch_size = 4 289 | 290 | cn_support_str = "_cn" if True else "" 291 | 292 | out_name = f"{MODEL_NAME}_{batch_size}x{w}x{h}{cn_support_str}" 293 | 294 | out_path = get_out_path("unet", f"{out_name}") 295 | if not os.path.exists(out_path): 296 | convert(out_path=out_path, sample_size=sample_size, batch_size=batch_size) 297 | compile_model(out_path=out_path, out_name=out_name) 298 | -------------------------------------------------------------------------------- /coreml_suite/lcm/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from coremltools import ComputeUnit 4 | from python_coreml_stable_diffusion.coreml_model import CoreMLModel 5 | 6 | from coreml_suite import COREML_NODE 7 | from coreml_suite.lcm import converter as lcm_converter 8 | 9 | 10 | class COREML_CONVERT_LCM(COREML_NODE): 11 | """Converts a LCM model to Core ML.""" 12 | 13 | @classmethod 14 | def INPUT_TYPES(cls): 15 | return { 16 | "required": { 17 | "height": ("INT", {"default": 512, "min": 512, "max": 768, "step": 8}), 18 | "width": ("INT", {"default": 512, "min": 512, "max": 768, "step": 8}), 19 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), 20 | "compute_unit": ( 21 | [ 22 | ComputeUnit.CPU_AND_NE.name, 23 | ComputeUnit.CPU_AND_GPU.name, 24 | ComputeUnit.ALL.name, 25 | ComputeUnit.CPU_ONLY.name, 26 | ], 27 | ), 28 | "controlnet_support": ("BOOLEAN", {"default": False}), 29 | } 30 | } 31 | 32 | RETURN_TYPES = ("COREML_UNET",) 33 | RETURN_NAMES = ("coreml_model",) 34 | FUNCTION = "convert" 35 | 36 | def convert(self, height, width, batch_size, compute_unit, controlnet_support): 37 | """Converts a LCM model to Core ML. 38 | 39 | Args: 40 | height (int): Height of the target image. 41 | width (int): Width of the target image. 42 | batch_size (int): Batch size. 43 | compute_unit (str): Compute unit to use when loading the model. 44 | 45 | Returns: 46 | coreml_model: The converted Core ML model. 47 | 48 | The converted model is also saved to "models/unet" directory and 49 | can be loaded with the "LCMCoreMLLoaderUNet" node. 50 | """ 51 | h = height 52 | w = width 53 | sample_size = (h // 8, w // 8) 54 | batch_size = batch_size 55 | cn_support_str = "_cn" if controlnet_support else "" 56 | 57 | out_name = f"{lcm_converter.MODEL_NAME}_{batch_size}x{w}x{h}{cn_support_str}" 58 | 59 | out_path = lcm_converter.get_out_path("unet", f"{out_name}") 60 | 61 | if not os.path.exists(out_path): 62 | lcm_converter.convert( 63 | out_path=out_path, 64 | sample_size=sample_size, 65 | batch_size=batch_size, 66 | controlnet_support=controlnet_support, 67 | ) 68 | target_path = lcm_converter.compile_model(out_path=out_path, out_name=out_name) 69 | 70 | return (CoreMLModel(target_path, compute_unit, "compiled"),) 71 | -------------------------------------------------------------------------------- /coreml_suite/lcm/unet.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from python_coreml_stable_diffusion.unet import UNet2DConditionModel, TimestepEmbedding 3 | 4 | 5 | class UNet2DConditionModelLCM(UNet2DConditionModel): 6 | def __init__( 7 | self, 8 | time_cond_proj_dim=None, 9 | **kwargs, 10 | ): 11 | super().__init__(**kwargs) 12 | timestep_input_dim = self.config.block_out_channels[0] 13 | time_embed_dim = self.config.block_out_channels[0] * 4 14 | 15 | time_embedding = TimestepEmbedding( 16 | timestep_input_dim, time_embed_dim, cond_proj_dim=time_cond_proj_dim 17 | ) 18 | self.time_embedding = time_embedding 19 | 20 | @overrides(check_signature=False) 21 | def forward( 22 | self, 23 | sample, 24 | timestep, 25 | encoder_hidden_states, 26 | timestep_cond, 27 | *additional_residuals, 28 | ): 29 | # 0. Project (or look-up) time embeddings 30 | t_emb = self.time_proj(timestep) 31 | emb = self.time_embedding(t_emb, timestep_cond) 32 | 33 | # 1. center input if necessary 34 | if self.config.center_input_sample: 35 | sample = 2 * sample - 1.0 36 | 37 | # 2. pre-process 38 | sample = self.conv_in(sample) 39 | 40 | # 3. down 41 | down_block_res_samples = (sample,) 42 | for downsample_block in self.down_blocks: 43 | if ( 44 | hasattr(downsample_block, "attentions") 45 | and downsample_block.attentions is not None 46 | ): 47 | sample, res_samples = downsample_block( 48 | hidden_states=sample, 49 | temb=emb, 50 | encoder_hidden_states=encoder_hidden_states, 51 | ) 52 | else: 53 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 54 | 55 | down_block_res_samples += res_samples 56 | 57 | if additional_residuals: 58 | new_down_block_res_samples = () 59 | for i, down_block_res_sample in enumerate(down_block_res_samples): 60 | down_block_res_sample = down_block_res_sample + additional_residuals[i] 61 | new_down_block_res_samples += (down_block_res_sample,) 62 | down_block_res_samples = new_down_block_res_samples 63 | 64 | # 4. mid 65 | sample = self.mid_block( 66 | sample, emb, encoder_hidden_states=encoder_hidden_states 67 | ) 68 | 69 | if additional_residuals: 70 | sample = sample + additional_residuals[-1] 71 | 72 | # 5. up 73 | for upsample_block in self.up_blocks: 74 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 75 | down_block_res_samples = down_block_res_samples[ 76 | : -len(upsample_block.resnets) 77 | ] 78 | 79 | if ( 80 | hasattr(upsample_block, "attentions") 81 | and upsample_block.attentions is not None 82 | ): 83 | sample = upsample_block( 84 | hidden_states=sample, 85 | temb=emb, 86 | res_hidden_states_tuple=res_samples, 87 | encoder_hidden_states=encoder_hidden_states, 88 | ) 89 | else: 90 | sample = upsample_block( 91 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples 92 | ) 93 | 94 | # 6. post-process 95 | sample = self.conv_norm_out(sample) 96 | sample = self.conv_act(sample) 97 | sample = self.conv_out(sample) 98 | 99 | return (sample,) 100 | -------------------------------------------------------------------------------- /coreml_suite/lcm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from comfy.model_management import get_torch_device 4 | from comfy_extras.nodes_model_advanced import ModelSamplingDiscreteDistilled, LCM 5 | 6 | 7 | def is_lcm(coreml_model): 8 | return "timestep_cond" in coreml_model.expected_inputs 9 | 10 | 11 | def get_w_embedding(w, embedding_dim=512, dtype=torch.float32): 12 | assert len(w.shape) == 1 13 | w = w * 1000.0 14 | 15 | half_dim = embedding_dim // 2 16 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) 17 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) 18 | emb = w.to(dtype)[:, None] * emb[None, :] 19 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 20 | if embedding_dim % 2 == 1: # zero pad 21 | emb = torch.nn.functional.pad(emb, (0, 1)) 22 | assert emb.shape == (w.shape[0], embedding_dim) 23 | return emb 24 | 25 | 26 | def model_function_wrapper(w_embedding): 27 | def wrapper(model_function, params): 28 | x = params["input"] 29 | t = params["timestep"] 30 | c = params["c"] 31 | 32 | context = c.get("c_crossattn") 33 | 34 | if context is None: 35 | return torch.zeros_like(x) 36 | 37 | return model_function(x, t, **c, timestep_cond=w_embedding) 38 | 39 | return wrapper 40 | 41 | 42 | def lcm_patch(model): 43 | m = model.clone() 44 | sampling_type = LCM 45 | sampling_base = ModelSamplingDiscreteDistilled 46 | 47 | class ModelSamplingAdvanced(sampling_base, sampling_type): 48 | pass 49 | 50 | model_sampling = ModelSamplingAdvanced() 51 | m.add_object_patch("model_sampling", model_sampling) 52 | 53 | return m 54 | 55 | 56 | def add_lcm_model_options(model_patcher, cfg, latent_image): 57 | mp = model_patcher.clone() 58 | 59 | latent = latent_image["samples"].to(get_torch_device()) 60 | batch_size = latent.shape[0] 61 | dtype = latent.dtype 62 | device = get_torch_device() 63 | 64 | w = torch.tensor(cfg).repeat(batch_size) 65 | w_embedding = get_w_embedding(w, embedding_dim=256).to(device=device, dtype=dtype) 66 | 67 | model_options = { 68 | "model_function_wrapper": model_function_wrapper(w_embedding), 69 | "sampler_cfg_function": lambda x: x["cond"].to(device), 70 | } 71 | mp.model_options |= model_options 72 | 73 | return mp 74 | -------------------------------------------------------------------------------- /coreml_suite/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig() 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | -------------------------------------------------------------------------------- /coreml_suite/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from comfy import model_base 5 | from comfy.model_management import get_torch_device 6 | from comfy.model_patcher import ModelPatcher 7 | from coreml_suite.config import get_model_config, ModelVersion 8 | from coreml_suite.controlnet import extract_residual_kwargs, chunk_control 9 | from coreml_suite.latents import chunk_batch, merge_chunks 10 | from coreml_suite.lcm.utils import is_lcm 11 | from coreml_suite.logger import logger 12 | 13 | 14 | class CoreMLModelWrapper: 15 | def __init__(self, coreml_model): 16 | self.coreml_model = coreml_model 17 | self.dtype = torch.float16 18 | 19 | def __call__(self, x, t, context, control, transformer_options=None, **kwargs): 20 | inputs = CoreMLInputs(x, t, context, control, **kwargs) 21 | input_list = inputs.chunks(self.expected_inputs) 22 | 23 | chunked_out = [ 24 | self.get_torch_outputs( 25 | self.coreml_model(**input_kwargs.coreml_kwargs(self.expected_inputs)), 26 | x.device, 27 | ) 28 | for input_kwargs in input_list 29 | ] 30 | merged_out = merge_chunks(chunked_out, x.shape) 31 | 32 | return merged_out 33 | 34 | @staticmethod 35 | def get_torch_outputs(model_output, device): 36 | return torch.from_numpy(model_output["noise_pred"]).to(device) 37 | 38 | @property 39 | def expected_inputs(self): 40 | return self.coreml_model.expected_inputs 41 | 42 | @property 43 | def is_lcm(self): 44 | return is_lcm(self.coreml_model) 45 | 46 | @property 47 | def is_sdxl_base(self): 48 | return is_sdxl_base(self.coreml_model) 49 | 50 | @property 51 | def is_sdxl_refiner(self): 52 | return is_sdxl_refiner(self.coreml_model) 53 | 54 | @property 55 | def config(self): 56 | if self.is_sdxl_base: 57 | return get_model_config(ModelVersion.SDXL) 58 | 59 | if self.is_sdxl_refiner: 60 | return get_model_config(ModelVersion.SDXL_REFINER) 61 | 62 | return get_model_config(ModelVersion.SD15) 63 | 64 | 65 | class CoreMLModelWrapperLCM(CoreMLModelWrapper): 66 | def __init__(self, coreml_model): 67 | super().__init__(coreml_model) 68 | self.config = None 69 | 70 | 71 | class CoreMLInputs: 72 | def __init__(self, x, t, context, control, **kwargs): 73 | self.x = x 74 | self.t = t 75 | self.context = context 76 | self.control = control 77 | self.time_ids = kwargs.get("time_ids") 78 | self.text_embeds = kwargs.get("text_embeds") 79 | self.ts_cond = kwargs.get("timestep_cond") 80 | 81 | def coreml_kwargs(self, expected_inputs): 82 | sample = self.x.cpu().numpy().astype(np.float16) 83 | 84 | context = self.context.cpu().numpy().astype(np.float16) 85 | context = context.transpose(0, 2, 1)[:, :, None, :] 86 | 87 | t = self.t.cpu().numpy().astype(np.float16) 88 | 89 | model_input_kwargs = { 90 | "sample": sample, 91 | "encoder_hidden_states": context, 92 | "timestep": t, 93 | } 94 | residual_kwargs = extract_residual_kwargs(expected_inputs, self.control) 95 | model_input_kwargs |= residual_kwargs 96 | 97 | # LCM 98 | if self.ts_cond is not None: 99 | model_input_kwargs["timestep_cond"] = ( 100 | self.ts_cond.cpu().numpy().astype(np.float16) 101 | ) 102 | 103 | # SDXL 104 | if "text_embeds" in expected_inputs: 105 | model_input_kwargs["text_embeds"] = ( 106 | self.text_embeds.cpu().numpy().astype(np.float16) 107 | ) 108 | if "time_ids" in expected_inputs: 109 | model_input_kwargs["time_ids"] = ( 110 | self.time_ids.cpu().numpy().astype(np.float16) 111 | ) 112 | 113 | return model_input_kwargs 114 | 115 | def chunks(self, expected_inputs): 116 | sample_shape = expected_inputs["sample"]["shape"] 117 | timestep_shape = expected_inputs["timestep"]["shape"] 118 | hidden_shape = expected_inputs["encoder_hidden_states"]["shape"] 119 | context_shape = (hidden_shape[0], hidden_shape[3], hidden_shape[1]) 120 | 121 | chunked_x = chunk_batch(self.x, sample_shape) 122 | ts = list(torch.full((len(chunked_x), timestep_shape[0]), self.t[0])) 123 | chunked_context = chunk_batch(self.context, context_shape) 124 | 125 | chunked_control = [None] * len(chunked_x) 126 | if self.control is not None: 127 | chunked_control = chunk_control(self.control, sample_shape[0]) 128 | 129 | chunked_ts_cond = [None] * len(chunked_x) 130 | if self.ts_cond is not None: 131 | ts_cond_shape = expected_inputs["timestep_cond"]["shape"] 132 | chunked_ts_cond = chunk_batch(self.ts_cond, ts_cond_shape) 133 | 134 | chunked_time_ids = [None] * len(chunked_x) 135 | if expected_inputs.get("time_ids") is not None: 136 | time_ids_shape = expected_inputs["time_ids"]["shape"] 137 | if self.time_ids is None: 138 | self.time_ids = torch.zeros(len(chunked_x), *time_ids_shape[1:]).to( 139 | self.x.device 140 | ) 141 | chunked_time_ids = chunk_batch(self.time_ids, time_ids_shape) 142 | 143 | chunked_text_embeds = [None] * len(chunked_x) 144 | if expected_inputs.get("text_embeds") is not None: 145 | text_embeds_shape = expected_inputs["text_embeds"]["shape"] 146 | if self.text_embeds is None: 147 | self.text_embeds = torch.zeros( 148 | len(chunked_x), *text_embeds_shape[1:] 149 | ).to(self.x.device) 150 | chunked_text_embeds = chunk_batch(self.text_embeds, text_embeds_shape) 151 | 152 | return [ 153 | CoreMLInputs( 154 | x, 155 | t, 156 | context, 157 | control, 158 | timestep_cond=ts_cond, 159 | time_ids=time_ids, 160 | text_embeds=text_embeds, 161 | ) 162 | for x, t, context, control, ts_cond, time_ids, text_embeds in zip( 163 | chunked_x, 164 | ts, 165 | chunked_context, 166 | chunked_control, 167 | chunked_ts_cond, 168 | chunked_time_ids, 169 | chunked_text_embeds, 170 | ) 171 | ] 172 | 173 | 174 | def is_sdxl(coreml_model): 175 | return ( 176 | "time_ids" in coreml_model.expected_inputs 177 | and "text_embeds" in coreml_model.expected_inputs 178 | ) 179 | 180 | 181 | def is_sdxl_base(coreml_model): 182 | return ( 183 | is_sdxl(coreml_model) 184 | and coreml_model.expected_inputs["time_ids"]["shape"][1] == 6 185 | ) 186 | 187 | 188 | def is_sdxl_refiner(coreml_model): 189 | return ( 190 | is_sdxl(coreml_model) 191 | and coreml_model.expected_inputs["time_ids"]["shape"][1] == 5 192 | ) 193 | 194 | 195 | def sdxl_model_function_wrapper(time_ids, text_embeds, refiner=False): 196 | def wrapper(model_function, params): 197 | x = params["input"] 198 | t = params["timestep"] 199 | c = params["c"] 200 | 201 | context = c.get("c_crossattn") 202 | 203 | if context is None: 204 | return torch.zeros_like(x) 205 | 206 | if refiner and context is not None: 207 | # converted refiner accepts only g clip 208 | c["c_crossattn"] = context[:, :, 768:] 209 | 210 | return model_function(x, t, **c, time_ids=time_ids, text_embeds=text_embeds) 211 | 212 | return wrapper 213 | 214 | 215 | def add_sdxl_model_options(model_patcher, positive, negative): 216 | mp = model_patcher.clone() 217 | 218 | pos_dict = positive[0][1] 219 | neg_dict = negative[0][1] 220 | 221 | pos_pooled = pos_dict["pooled_output"] 222 | neg_pooled = neg_dict["pooled_output"] 223 | 224 | pos_time_ids = [ 225 | pos_dict.get("height", 768), 226 | pos_dict.get("width", 768), 227 | pos_dict.get("crop_h", 0), 228 | pos_dict.get("crop_w", 0), 229 | ] 230 | 231 | neg_time_ids = [ 232 | neg_dict.get("height", 768), 233 | neg_dict.get("width", 768), 234 | neg_dict.get("crop_h", 0), 235 | neg_dict.get("crop_w", 0), 236 | ] 237 | 238 | if model_patcher.model.diffusion_model.is_sdxl_base: 239 | pos_time_ids += [ 240 | pos_dict.get("target_height", 768), 241 | pos_dict.get("target_width", 768), 242 | ] 243 | 244 | neg_time_ids += [ 245 | neg_dict.get("target_height", 768), 246 | neg_dict.get("target_width", 768), 247 | ] 248 | 249 | is_refiner = model_patcher.model.diffusion_model.is_sdxl_refiner 250 | if is_refiner: 251 | pos_time_ids += [ 252 | pos_dict.get("aesthetic_score", 6), 253 | ] 254 | 255 | neg_time_ids += [ 256 | neg_dict.get("aesthetic_score", 2.5), 257 | ] 258 | 259 | time_ids = torch.tensor([pos_time_ids, neg_time_ids]) 260 | text_embeds = torch.cat((pos_pooled, neg_pooled)) 261 | 262 | model_options = { 263 | "model_function_wrapper": sdxl_model_function_wrapper( 264 | time_ids, text_embeds, is_refiner 265 | ), 266 | } 267 | mp.model_options |= model_options 268 | 269 | return mp 270 | 271 | 272 | def get_latent_image(coreml_model, latent_image): 273 | if latent_image is not None: 274 | return latent_image 275 | 276 | logger.warning("No latent image provided, using empty tensor.") 277 | expected = coreml_model.expected_inputs["sample"]["shape"] 278 | batch_size = max(expected[0] // 2, 1) 279 | latent_image = {"samples": torch.zeros(batch_size, *expected[1:])} 280 | return latent_image 281 | 282 | 283 | def get_model_patcher(coreml_model): 284 | wrapped_model = CoreMLModelWrapper(coreml_model) 285 | 286 | if wrapped_model.is_sdxl_base: 287 | model = model_base.SDXL(wrapped_model.config, device=get_torch_device()) 288 | elif wrapped_model.is_sdxl_refiner: 289 | model = model_base.SDXLRefiner(wrapped_model.config, device=get_torch_device()) 290 | else: 291 | model = model_base.BaseModel(wrapped_model.config, device=get_torch_device()) 292 | 293 | model.diffusion_model = wrapped_model 294 | model_patcher = ModelPatcher(model, get_torch_device(), None) 295 | return model_patcher 296 | -------------------------------------------------------------------------------- /coreml_suite/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from coremltools import ComputeUnit 4 | from python_coreml_stable_diffusion.coreml_model import CoreMLModel 5 | from python_coreml_stable_diffusion.unet import AttentionImplementations 6 | 7 | import folder_paths 8 | from coreml_suite import COREML_NODE 9 | from coreml_suite import converter 10 | from coreml_suite.config import ModelVersion 11 | from coreml_suite.lcm.utils import add_lcm_model_options, lcm_patch, is_lcm 12 | from coreml_suite.logger import logger 13 | from nodes import KSampler, LoraLoader, KSamplerAdvanced 14 | 15 | from coreml_suite.models import ( 16 | add_sdxl_model_options, 17 | is_sdxl, 18 | get_model_patcher, 19 | get_latent_image, 20 | ) 21 | 22 | 23 | class CoreMLSampler(COREML_NODE, KSampler): 24 | @classmethod 25 | def INPUT_TYPES(s): 26 | old_required = KSampler.INPUT_TYPES()["required"].copy() 27 | old_required.pop("model") 28 | old_required.pop("negative") 29 | old_required.pop("latent_image") 30 | new_required = {"coreml_model": ("COREML_UNET",)} 31 | return { 32 | "required": new_required | old_required, 33 | "optional": {"negative": ("CONDITIONING",), "latent_image": ("LATENT",)}, 34 | } 35 | 36 | def sample( 37 | self, 38 | coreml_model, 39 | seed, 40 | steps, 41 | cfg, 42 | sampler_name, 43 | scheduler, 44 | positive, 45 | negative=None, 46 | latent_image=None, 47 | denoise=1.0, 48 | ): 49 | model_patcher = get_model_patcher(coreml_model) 50 | latent_image = get_latent_image(coreml_model, latent_image) 51 | 52 | if is_lcm(coreml_model): 53 | negative = [[None, {}]] 54 | positive[0][1]["control_apply_to_uncond"] = False 55 | model_patcher = add_lcm_model_options(model_patcher, cfg, latent_image) 56 | model_patcher = lcm_patch(model_patcher) 57 | else: 58 | assert ( 59 | negative is not None 60 | ), "Negative conditioning is optional only for LCM models." 61 | 62 | if is_sdxl(coreml_model): 63 | model_patcher = add_sdxl_model_options(model_patcher, positive, negative) 64 | 65 | return super().sample( 66 | model_patcher, 67 | seed, 68 | steps, 69 | cfg, 70 | sampler_name, 71 | scheduler, 72 | positive, 73 | negative, 74 | latent_image, 75 | denoise, 76 | ) 77 | 78 | 79 | class CoreMLSamplerAdvanced(COREML_NODE, KSamplerAdvanced): 80 | @classmethod 81 | def INPUT_TYPES(s): 82 | old_required = KSamplerAdvanced.INPUT_TYPES()["required"].copy() 83 | old_required.pop("model") 84 | old_required.pop("negative") 85 | old_required.pop("latent_image") 86 | new_required = {"coreml_model": ("COREML_UNET",)} 87 | return { 88 | "required": new_required | old_required, 89 | "optional": {"negative": ("CONDITIONING",), "latent_image": ("LATENT",)}, 90 | } 91 | 92 | def sample( 93 | self, 94 | coreml_model, 95 | add_noise, 96 | noise_seed, 97 | steps, 98 | cfg, 99 | sampler_name, 100 | scheduler, 101 | positive, 102 | start_at_step, 103 | end_at_step, 104 | return_with_leftover_noise, 105 | negative=None, 106 | latent_image=None, 107 | denoise=1.0, 108 | ): 109 | model_patcher = get_model_patcher(coreml_model) 110 | latent_image = get_latent_image(coreml_model, latent_image) 111 | 112 | if is_lcm(coreml_model): 113 | negative = [[None, {}]] 114 | positive[0][1]["control_apply_to_uncond"] = False 115 | model_patcher = add_lcm_model_options(model_patcher, cfg, latent_image) 116 | model_patcher = lcm_patch(model_patcher) 117 | else: 118 | assert ( 119 | negative is not None 120 | ), "Negative conditioning is optional only for LCM models." 121 | 122 | if is_sdxl(coreml_model): 123 | model_patcher = add_sdxl_model_options(model_patcher, positive, negative) 124 | 125 | return super().sample( 126 | model_patcher, 127 | add_noise, 128 | noise_seed, 129 | steps, 130 | cfg, 131 | sampler_name, 132 | scheduler, 133 | positive, 134 | negative, 135 | latent_image, 136 | start_at_step, 137 | end_at_step, 138 | return_with_leftover_noise, 139 | denoise, 140 | ) 141 | 142 | 143 | class CoreMLLoader(COREML_NODE): 144 | PACKAGE_DIRNAME = "" 145 | 146 | @classmethod 147 | def INPUT_TYPES(s): 148 | return { 149 | "required": { 150 | "coreml_name": (list(s.coreml_filenames().keys()),), 151 | "compute_unit": ( 152 | [ 153 | ComputeUnit.CPU_AND_NE.name, 154 | ComputeUnit.CPU_AND_GPU.name, 155 | ComputeUnit.ALL.name, 156 | ComputeUnit.CPU_ONLY.name, 157 | ], 158 | ), 159 | } 160 | } 161 | 162 | FUNCTION = "load" 163 | 164 | @classmethod 165 | def coreml_filenames(cls): 166 | extensions = (".mlmodelc", ".mlpackage") 167 | all_paths = folder_paths.get_filename_list_(cls.PACKAGE_DIRNAME)[1] 168 | coreml_paths = folder_paths.filter_files_extensions(all_paths, extensions) 169 | 170 | return {os.path.split(p)[-1]: p for p in coreml_paths} 171 | 172 | def load(self, coreml_name, compute_unit): 173 | logger.info(f"Loading {coreml_name} to {compute_unit}") 174 | 175 | coreml_path = self.coreml_filenames()[coreml_name] 176 | 177 | sources = "compiled" if coreml_name.endswith(".mlmodelc") else "packages" 178 | 179 | return (CoreMLModel(coreml_path, compute_unit, sources),) 180 | 181 | 182 | class CoreMLLoaderUNet(CoreMLLoader): 183 | PACKAGE_DIRNAME = "unet" 184 | RETURN_TYPES = ("COREML_UNET",) 185 | RETURN_NAMES = ("coreml_model",) 186 | 187 | 188 | class CoreMLModelAdapter(COREML_NODE): 189 | """ 190 | Adapter Node to use CoreML models as Comfy models. This is an experimental 191 | feature and may not work as expected. 192 | """ 193 | 194 | @classmethod 195 | def INPUT_TYPES(s): 196 | return { 197 | "required": { 198 | "coreml_model": ("COREML_UNET",), 199 | } 200 | } 201 | 202 | RETURN_TYPES = ("MODEL",) 203 | 204 | FUNCTION = "wrap" 205 | CATEGORY = "Core ML Suite" 206 | 207 | def wrap(self, coreml_model): 208 | model_patcher = get_model_patcher(coreml_model) 209 | return (model_patcher,) 210 | 211 | 212 | class CoreMLConverter(COREML_NODE): 213 | """Converts a LCM model to Core ML.""" 214 | 215 | @classmethod 216 | def INPUT_TYPES(cls): 217 | return { 218 | "required": { 219 | "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), 220 | "model_version": ( 221 | [ 222 | ModelVersion.SD15.name, 223 | ModelVersion.SDXL.name, 224 | ], 225 | ), 226 | "height": ("INT", {"default": 512, "min": 256, "max": 2048, "step": 8}), 227 | "width": ("INT", {"default": 512, "min": 256, "max": 2048, "step": 8}), 228 | "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), 229 | "attention_implementation": ( 230 | [ 231 | AttentionImplementations.SPLIT_EINSUM.name, 232 | AttentionImplementations.SPLIT_EINSUM_V2.name, 233 | AttentionImplementations.ORIGINAL.name, 234 | ], 235 | ), 236 | "compute_unit": ( 237 | [ 238 | ComputeUnit.CPU_AND_NE.name, 239 | ComputeUnit.CPU_AND_GPU.name, 240 | ComputeUnit.ALL.name, 241 | ComputeUnit.CPU_ONLY.name, 242 | ], 243 | ), 244 | "controlnet_support": ("BOOLEAN", {"default": False}), 245 | }, 246 | "optional": { 247 | "lora_params": ("LORA_PARAMS",), 248 | }, 249 | } 250 | 251 | RETURN_TYPES = ("COREML_UNET",) 252 | RETURN_NAMES = ("coreml_model",) 253 | FUNCTION = "convert" 254 | 255 | def convert( 256 | self, 257 | ckpt_name, 258 | model_version, 259 | height, 260 | width, 261 | batch_size, 262 | attention_implementation, 263 | compute_unit, 264 | controlnet_support, 265 | lora_params=None, 266 | ): 267 | """Converts a LCM model to Core ML. 268 | 269 | Args: 270 | height (int): Height of the target image. 271 | width (int): Width of the target image. 272 | batch_size (int): Batch size. 273 | compute_unit (str): Compute unit to use when loading the model. 274 | 275 | Returns: 276 | coreml_model: The converted Core ML model. 277 | 278 | The converted model is also saved to "models/unet" directory and 279 | can be loaded with the "LCMCoreMLLoaderUNet" node. 280 | """ 281 | model_version = ModelVersion[model_version] 282 | 283 | lora_params = lora_params or {} 284 | lora_params = [(k, v[0]) for k, v in lora_params.items()] 285 | lora_params = sorted(lora_params, key=lambda lora: lora[0]) 286 | lora_weights = [(self.lora_path(lora[0]), lora[1]) for lora in lora_params] 287 | 288 | h = height 289 | w = width 290 | sample_size = (h // 8, w // 8) 291 | batch_size = batch_size 292 | cn_support_str = "_cn" if controlnet_support else "" 293 | lora_str = ( 294 | "_" + "_".join(lora_param[0].split(".")[0] for lora_param in lora_params) 295 | if lora_params 296 | else "" 297 | ) 298 | 299 | attn_str = ( 300 | "_" 301 | + {"SPLIT_EINSUM": "se", "SPLIT_EINSUM_V2": "se2", "ORIGINAL": "orig"}[ 302 | attention_implementation 303 | ] 304 | ) 305 | 306 | out_name = f"{ckpt_name.split('.')[0]}{lora_str}_{batch_size}x{w}x{h}{cn_support_str}{attn_str}" 307 | out_name = out_name.replace(" ", "_") 308 | 309 | logger.info(f"Converting {ckpt_name} to {out_name}") 310 | logger.info(f"Batch size: {batch_size}") 311 | logger.info(f"Width: {w}, Height: {h}") 312 | logger.info(f"ControlNet support: {controlnet_support}") 313 | logger.info(f"Attention implementation: {attention_implementation}") 314 | 315 | if lora_params: 316 | logger.info(f"LoRAs used:") 317 | for lora_param in lora_params: 318 | logger.info(f" {lora_param[0]} - strength: {lora_param[1]}") 319 | 320 | unet_out_path = converter.get_out_path("unet", f"{out_name}") 321 | ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) 322 | 323 | config_filename = ckpt_name.split(".")[0] + ".yaml" 324 | config_path = folder_paths.get_full_path("configs", config_filename) 325 | if config_path: 326 | logger.info(f"Using config file {config_path}") 327 | 328 | converter.convert( 329 | ckpt_path=ckpt_path, 330 | model_version=model_version, 331 | unet_out_path=unet_out_path, 332 | sample_size=sample_size, 333 | batch_size=batch_size, 334 | controlnet_support=controlnet_support, 335 | lora_weights=lora_weights, 336 | attn_impl=attention_implementation, 337 | config_path=config_path, 338 | ) 339 | unet_target_path = converter.compile_model( 340 | out_path=unet_out_path, out_name=out_name, submodule_name="unet" 341 | ) 342 | 343 | return (CoreMLModel(unet_target_path, compute_unit, "compiled"),) 344 | 345 | @staticmethod 346 | def lora_path(lora_name): 347 | return folder_paths.get_full_path("loras", lora_name) 348 | 349 | 350 | class COREML_LOAD_LORA(COREML_NODE, LoraLoader): 351 | @classmethod 352 | def INPUT_TYPES(s): 353 | required = LoraLoader.INPUT_TYPES()["required"].copy() 354 | required.pop("model") 355 | return { 356 | "required": required, 357 | "optional": {"lora_params": ("LORA_PARAMS",)}, 358 | } 359 | 360 | RETURN_TYPES = ("CLIP", "LORA_PARAMS") 361 | RETURN_NAMES = ("CLIP", "lora_params") 362 | 363 | def load_lora( 364 | self, clip, lora_name, strength_model, strength_clip, lora_params=None 365 | ): 366 | _, lora_clip = super().load_lora( 367 | None, clip, lora_name, strength_model, strength_clip 368 | ) 369 | 370 | lora_params = lora_params or {} 371 | lora_params[lora_name] = (strength_model, strength_clip) 372 | 373 | return lora_clip, lora_params 374 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-coremlsuite" 3 | description = "This extension contains a set of custom nodes for ComfyUI that allow you to use Core ML models in your ComfyUI workflows." 4 | version = "1.0.1" 5 | license = { file = "LICENSE" } 6 | dependencies = ["git+https://github.com/apple/ml-stable-diffusion.git", "coremltools>=7.1", "overrides", "diffusers>=0.22", "peft>=0.6.2", "omegaconf>=2.3"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/aszc-dev/ComfyUI-CoreMLSuite" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "aszc-dev" 14 | DisplayName = "ComfyUI-CoreMLSuite" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/apple/ml-stable-diffusion.git 2 | coremltools>=7.1 3 | overrides 4 | diffusers>=0.22 5 | peft>=0.6.2 6 | omegaconf>=2.3 7 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/test_basic_conversion_1_5.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytest 5 | import requests 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | 10 | from folder_paths import get_save_image_path, get_output_directory 11 | 12 | IMAGE_PREFIX = "E2E-1.5" 13 | IMAGE_PREFIX_CML = f"{IMAGE_PREFIX}-CoreML" 14 | IMAGE_PREFIX_MPS = f"{IMAGE_PREFIX}-MPS" 15 | 16 | 17 | class OutputImageRepository: 18 | def __init__(self, name_prefix): 19 | self.name_prefix = name_prefix 20 | 21 | def list_images(self): 22 | full_output_folder, _, _, _, _ = get_save_image_path( 23 | self.name_prefix, get_output_directory(), 512, 512 24 | ) 25 | return full_output_folder, os.listdir(full_output_folder) 26 | 27 | def delete_images(self): 28 | full_output_folder, images = self.list_images() 29 | for image in images: 30 | os.remove(os.path.join(full_output_folder, image)) 31 | 32 | def get_latest_image(self, prefix): 33 | full_output_folder, images = self.list_images() 34 | for image in sorted(images, reverse=True): 35 | if image.startswith(prefix): 36 | return os.path.join(full_output_folder, image) 37 | return None 38 | 39 | 40 | @pytest.fixture(scope="function") 41 | def output_image_repository(): 42 | repo = OutputImageRepository(IMAGE_PREFIX) 43 | yield repo 44 | repo.delete_images() 45 | 46 | 47 | def test_basic_conversion_1_5(output_image_repository): 48 | with open("tests/integration/workflows/e2e-1.5-basic-conversion.json") as f: 49 | prompt = json.load(f) 50 | prompt = randomize_seed_in_prompt(prompt) 51 | queue_prompt(prompt) 52 | 53 | coreml_img_path = output_image_repository.get_latest_image(IMAGE_PREFIX_CML) 54 | mps_img_path = output_image_repository.get_latest_image(IMAGE_PREFIX_MPS) 55 | 56 | coreml_image = Image.open(coreml_img_path) 57 | mps_image = Image.open(mps_img_path) 58 | 59 | assert psnr(np.array(coreml_image), np.array(mps_image)) > 25 60 | 61 | 62 | def psnr(img1, img2): 63 | mse = np.mean((img1 - img2) ** 2) 64 | if mse == 0: 65 | return 100 66 | PIXEL_MAX = 255.0 67 | return 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) 68 | 69 | 70 | def queue_prompt(prompt: dict): 71 | p = {"prompt": prompt} 72 | data = json.dumps(p).encode("utf-8") 73 | req = requests.post("http://localhost:8188/prompt", data=data) 74 | assert req.status_code == 200 75 | while True: 76 | req = requests.get("http://localhost:8188/prompt") 77 | if req.json()["exec_info"]["queue_remaining"] == 0: 78 | break 79 | 80 | 81 | def randomize_seed_in_prompt(prompt): 82 | seed = torch.random.seed() 83 | prompt["3"]["inputs"]["seed"] = seed 84 | prompt["11"]["inputs"]["seed"] = seed 85 | return prompt 86 | -------------------------------------------------------------------------------- /tests/integration/workflows/e2e-1.5-basic-conversion.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 0, 5 | "steps": 20, 6 | "cfg": 8, 7 | "sampler_name": "dpmpp_2m", 8 | "scheduler": "karras", 9 | "denoise": 1, 10 | "model": [ 11 | "4", 12 | 0 13 | ], 14 | "positive": [ 15 | "6", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "dreamshaper_8.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "5": { 42 | "inputs": { 43 | "width": 512, 44 | "height": 512, 45 | "batch_size": 1 46 | }, 47 | "class_type": "EmptyLatentImage", 48 | "_meta": { 49 | "title": "Empty Latent Image" 50 | } 51 | }, 52 | "6": { 53 | "inputs": { 54 | "text": "beautiful scenery nature glass bottle landscape, purple galaxy bottle", 55 | "clip": [ 56 | "4", 57 | 1 58 | ] 59 | }, 60 | "class_type": "CLIPTextEncode", 61 | "_meta": { 62 | "title": "CLIP Text Encode (Prompt)" 63 | } 64 | }, 65 | "7": { 66 | "inputs": { 67 | "text": "text, watermark", 68 | "clip": [ 69 | "4", 70 | 1 71 | ] 72 | }, 73 | "class_type": "CLIPTextEncode", 74 | "_meta": { 75 | "title": "CLIP Text Encode (Prompt)" 76 | } 77 | }, 78 | "8": { 79 | "inputs": { 80 | "samples": [ 81 | "3", 82 | 0 83 | ], 84 | "vae": [ 85 | "4", 86 | 2 87 | ] 88 | }, 89 | "class_type": "VAEDecode", 90 | "_meta": { 91 | "title": "VAE Decode" 92 | } 93 | }, 94 | "9": { 95 | "inputs": { 96 | "filename_prefix": "E2E-1.5-MPS", 97 | "images": [ 98 | "8", 99 | 0 100 | ] 101 | }, 102 | "class_type": "SaveImage", 103 | "_meta": { 104 | "title": "Save Image" 105 | } 106 | }, 107 | "10": { 108 | "inputs": { 109 | "ckpt_name": "dreamshaper_8.safetensors", 110 | "model_version": "SD15", 111 | "height": 512, 112 | "width": 512, 113 | "batch_size": 1, 114 | "attention_implementation": "SPLIT_EINSUM", 115 | "compute_unit": "CPU_AND_NE", 116 | "controlnet_support": false 117 | }, 118 | "class_type": "Core ML Converter", 119 | "_meta": { 120 | "title": "Convert Checkpoint to Core ML" 121 | } 122 | }, 123 | "11": { 124 | "inputs": { 125 | "seed": 0, 126 | "steps": 20, 127 | "cfg": 8, 128 | "sampler_name": "dpmpp_2m", 129 | "scheduler": "karras", 130 | "denoise": 1, 131 | "coreml_model": [ 132 | "10", 133 | 0 134 | ], 135 | "positive": [ 136 | "6", 137 | 0 138 | ], 139 | "negative": [ 140 | "7", 141 | 0 142 | ], 143 | "latent_image": [ 144 | "5", 145 | 0 146 | ] 147 | }, 148 | "class_type": "CoreMLSampler", 149 | "_meta": { 150 | "title": "Core ML Sampler" 151 | } 152 | }, 153 | "13": { 154 | "inputs": { 155 | "samples": [ 156 | "11", 157 | 0 158 | ], 159 | "vae": [ 160 | "4", 161 | 2 162 | ] 163 | }, 164 | "class_type": "VAEDecode", 165 | "_meta": { 166 | "title": "VAE Decode" 167 | } 168 | }, 169 | "14": { 170 | "inputs": { 171 | "filename_prefix": "E2E-1.5-CoreML", 172 | "images": [ 173 | "13", 174 | 0 175 | ] 176 | }, 177 | "class_type": "SaveImage", 178 | "_meta": { 179 | "title": "Save Image" 180 | } 181 | } 182 | } -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aszc-dev/ComfyUI-CoreMLSuite/7678a07ed551092893a00eb2850dafca1d81566c/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/test_chunks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | 5 | from comfy.model_management import get_torch_device 6 | from coreml_suite.latents import chunk_batch, merge_chunks 7 | from coreml_suite.controlnet import chunk_control 8 | from coreml_suite.models import ( 9 | CoreMLInputs, 10 | ) 11 | from coreml_suite.config import get_model_config 12 | 13 | 14 | @pytest.fixture 15 | def expected_inputs(): 16 | expected = { 17 | "sample": {"shape": (2, 4, 64, 64)}, 18 | "timestep": {"shape": (2,)}, 19 | "timestep_cond": {"shape": (2, 256)}, 20 | "encoder_hidden_states": {"shape": (2, 768, 1, 77)}, 21 | "additional_residual_0": {"shape": (2, 320, 64, 64)}, 22 | "additional_residual_1": {"shape": (2, 640, 32, 32)}, 23 | } 24 | return expected 25 | 26 | 27 | @pytest.fixture 28 | def model_config(): 29 | return get_model_config() 30 | 31 | 32 | @pytest.mark.parametrize("batch_size", [1, 2, 4, 5, 9]) 33 | def test_batch_chunking(batch_size): 34 | latent_image = torch.randn(batch_size, 4, 64, 64).to(get_torch_device()) 35 | target_shape = (4, 4, 64, 64) 36 | 37 | chunked = chunk_batch(latent_image, target_shape) 38 | 39 | for chunk in chunked: 40 | assert chunk.shape == target_shape 41 | 42 | if batch_size % target_shape[0] != 0: 43 | assert chunked[-1][batch_size % target_shape[0] :].sum() == 0 44 | 45 | 46 | @pytest.mark.parametrize("batch_size", [1, 2, 4, 5, 9]) 47 | def test_merge_chunks(batch_size): 48 | input_tensor = torch.randn(batch_size, 4, 64, 64).to(get_torch_device()) 49 | target_shape = (4, 4, 64, 64) 50 | chunked = chunk_batch(input_tensor, target_shape) 51 | 52 | merged = merge_chunks(chunked, input_tensor.shape) 53 | 54 | assert merged.shape == input_tensor.shape 55 | assert torch.equal(input_tensor, merged) 56 | 57 | 58 | @pytest.fixture 59 | def inputs(): 60 | x = torch.randn(1, 4, 64, 64).to(get_torch_device()) 61 | t = torch.randn([1]).to(get_torch_device()) 62 | c_crossattn = torch.randn(1, 77, 768).to(get_torch_device()) 63 | control = { 64 | "output": [ 65 | torch.randn(1, 320, 64, 64).to(get_torch_device()), 66 | torch.randn(1, 640, 32, 32).to(get_torch_device()), 67 | ], 68 | } 69 | timestep_cond = torch.randn(1, 256).to(get_torch_device()) 70 | 71 | return CoreMLInputs(x, t, c_crossattn, control, timestep_cond=timestep_cond) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "b, target_size, num_chunks", 76 | [ 77 | (1, 2, 1), 78 | (1, 1, 1), 79 | (2, 2, 1), 80 | (3, 2, 2), 81 | (4, 2, 2), 82 | (5, 3, 2), 83 | (9, 4, 3), 84 | ], 85 | ) 86 | def test_chunking_controlnet(b, target_size, num_chunks): 87 | cn = { 88 | "output": [ 89 | torch.randn(b, 320, 64, 64).to(get_torch_device()), 90 | torch.randn(b, 640, 32, 32).to(get_torch_device()), 91 | ], 92 | "middle": [ 93 | torch.randn(b, 1280, 8, 8).to(get_torch_device()), 94 | ], 95 | } 96 | 97 | chunked = chunk_control(cn, target_size) 98 | 99 | assert len(chunked) == num_chunks 100 | for chunk in chunked: 101 | assert chunk["output"][0].shape == (target_size, 320, 64, 64) 102 | assert chunk["output"][1].shape == (target_size, 640, 32, 32) 103 | assert chunk["middle"][0].shape == (target_size, 1280, 8, 8) 104 | 105 | 106 | def test_chunking_no_control(): 107 | cn = None 108 | target_size = 2 109 | 110 | chunked = chunk_control(cn, target_size) 111 | 112 | assert chunked == [None, None] 113 | 114 | 115 | def test_chunking_inputs(expected_inputs, inputs): 116 | chunked = inputs.chunks(expected_inputs) 117 | 118 | assert len(chunked) == 1 119 | 120 | assert chunked[0].x.shape == (2, 4, 64, 64) 121 | assert chunked[0].t.shape == (2,) 122 | assert chunked[0].context.shape == (2, 77, 768) 123 | assert chunked[0].control["output"][0].shape == (2, 320, 64, 64) 124 | assert chunked[0].control["output"][1].shape == (2, 640, 32, 32) 125 | assert chunked[0].ts_cond.shape == (2, 256) 126 | -------------------------------------------------------------------------------- /tests/unit/test_controlnet.py: -------------------------------------------------------------------------------- 1 | from coreml_suite.controlnet import no_control 2 | 3 | 4 | def test_no_control(): 5 | expected_inputs = { 6 | "additional_residual_0": {"shape": (2, 2, 2)}, 7 | "additional_residual_1": {"shape": (2, 4, 4)}, 8 | "additional_residual_2": {"shape": (2, 8, 8)}, 9 | } 10 | 11 | residual_kwargs = no_control(expected_inputs) 12 | 13 | assert len(residual_kwargs) == 3 14 | assert residual_kwargs["additional_residual_0"].shape == (2, 2, 2) 15 | assert residual_kwargs["additional_residual_1"].shape == (2, 4, 4) 16 | assert residual_kwargs["additional_residual_2"].shape == (2, 8, 8) 17 | --------------------------------------------------------------------------------