├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── javascript └── bboxHint.js ├── scripts ├── tilediffusion.py ├── tileglobal.py └── tilevae.py ├── tile_methods ├── abstractdiffusion.py ├── demofusion.py ├── mixtureofdiffusers.py └── multidiffusion.py └── tile_utils ├── attn.py ├── typing.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: pkuliyi2015 # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # meta 2 | .vscode/ 3 | __pycache__/ 4 | .DS_Store 5 | 6 | # settings 7 | region_configs/ 8 | 9 | # test images 10 | deflicker/input_frames/* 11 | 12 | # test features 13 | deflicker/* 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tiled Diffusion & VAE extension for sd-webui 2 | 3 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 4 | 5 | This extension is licensed under [CC BY-NC-SA](https://creativecommons.org/licenses/by-nc-sa/4.0/), everyone is FREE of charge to access, use, modify and redistribute with the same license. 6 | **You cannot use versions after AOE 2023.3.28 for commercial sales (only refers to code of this repo, the derived artworks are NOT restricted).** 7 | 8 | 由于部分无良商家销售WebUI,捆绑本插件做卖点收取智商税,本仓库的许可证已修改为 [CC BY-NC-SA](https://creativecommons.org/licenses/by-nc-sa/4.0/),任何人都可以自由获取、使用、修改、以相同协议重分发本插件。 9 | **自许可证修改之日(AOE 2023.3.28)起,之后的版本禁止用于商业贩售 (不可贩售本仓库代码,但衍生的艺术创作内容物不受此限制)。** 10 | 11 | If you like the project, please give me a star! ⭐ 12 | 13 | [![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/pkuliyi2015) 14 | 15 | **** 16 | 17 | 18 | The extension helps you to **generate or upscale large images (≥2K) with limited VRAM (≤6GB)** via the following techniques: 19 | 20 | - Reproduced SOTA Tiled Diffusion methods 21 | - [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers) 22 | - [MultiDiffusion](https://multidiffusion.github.io) 23 | - [Demofusion](https://github.com/PRIS-CV/DemoFusion) 24 | - Our original Tiled VAE method 25 | - My original Tiled Noise Inversion method 26 | 27 | 28 | ### Features 29 | 30 | - Core 31 | - [x] [Tiled VAE](#tiled-vae) 32 | - [x] [Tiled Diffusion: txt2img generation for ultra-large image](#tiled-diff-txt2img) 33 | - [x] [Tiled Diffusion: img2img upscaling for image detail enhancement](#tiled-diff-img2img) 34 | - [x] [Regional Prompt Control](#region-prompt-control) 35 | - [x] [Tiled Noise Inversion](#tiled-noise-inversion) 36 | - Advanced 37 | - [x] [ControlNet support]() 38 | - [x] [StableSR support](https://github.com/pkuliyi2015/sd-webui-stablesr) 39 | - [x] [SDXL support](experimental) 40 | - [x] [Demofusion support]() 41 | 42 | 👉 在 [wiki](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/wiki) 页面查看详细的文档和样例,以及由 [@PotatoBananaApple](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/discussions/120) 制作的 [快速入门教程](https://civitai.com/models/34726) 43 | 👉 Find detailed documentation & examples at our [wiki](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/wiki), and quickstart [Tutorial](https://civitai.com/models/34726) by [@PotatoBananaApple](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/discussions/120) 🎉 44 | 45 | 46 | ### Examples 47 | 48 | ⚪ Txt2img: generating ultra-large images 49 | 50 | `prompt: masterpiece, best quality, highres, city skyline, night.` 51 | 52 | ![panorama](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/city_panorama.jpeg?raw=true) 53 | 54 | ⚪ Img2img: upcaling for detail enhancement 55 | 56 | | original | x4 upscale | 57 | | :-: | :-: | 58 | | ![lowres](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/lowres.jpg?raw=true) | ![highres](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/highres.jpeg?raw=true) | 59 | 60 | ⚪ Regional Prompt Control 61 | 62 | | region setting | output1 | output2 | 63 | | :-: | :-: | :-: | 64 | | ![MultiCharacterRegions](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/multicharacter.png?raw=true) | ![MultiCharacter](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/multicharacter.jpeg?raw=true) | ![MultiCharacter](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/multicharacter2.jpeg?raw=true) | 65 | | ![FullBodyRegions](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/fullbody_regions.png?raw=true) | ![FullBody](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/fullbody.jpeg?raw=true) | ![FullBody2](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/fullbody2.jpeg?raw=true) | 66 | 67 | ⚪ ControlNet support 68 | 69 | | original | with canny | 70 | | :-: | :-: | 71 | | ![Your Name](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/yourname_canny.jpeg?raw=true) | ![Your Name](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/yourname.jpeg?raw=true) 72 | 73 | | | 重绘 “清明上河图” | 74 | | :-: | :-: | 75 | | original | ![ancient city](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/ancient_city_origin_compressed.jpeg?raw=true) | 76 | | processed | ![ancient city](https://github.com/pkuliyi2015/multidiffusion-img-demo/blob/master/ancient_city_compressed.jpeg?raw=true) | 77 | 78 | ⚪ DemoFusion support 79 | 80 | | original | x3 upscale | 81 | | :-: | :-: | 82 | | ![demo-example](https://github.com/Jaylen-Lee/image-demo/blob/main/example.png?raw=true) | ![demo-result](https://github.com/Jaylen-Lee/image-demo/blob/main/3.png?raw=true) | 83 | 84 | 85 | ### License 86 | 87 | Great thanks to all the contributors! 🎉🎉🎉 88 | This work is licensed under [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. 89 | 90 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 91 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 92 | 93 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 94 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 95 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg 96 | -------------------------------------------------------------------------------- /javascript/bboxHint.js: -------------------------------------------------------------------------------- 1 | const BBOX_MAX_NUM = 16; 2 | const BBOX_WARNING_SIZE = 1280; 3 | const DEFAULT_X = 0.4; 4 | const DEFAULT_Y = 0.4; 5 | const DEFAULT_H = 0.2; 6 | const DEFAULT_W = 0.2; 7 | 8 | // ref: https://html-color.codes/ 9 | const COLOR_MAP = [ 10 | ['#ff0000', 'rgba(255, 0, 0, 0.3)'], // red 11 | ['#ff9900', 'rgba(255, 153, 0, 0.3)'], // orange 12 | ['#ffff00', 'rgba(255, 255, 0, 0.3)'], // yellow 13 | ['#33cc33', 'rgba(51, 204, 51, 0.3)'], // green 14 | ['#33cccc', 'rgba(51, 204, 204, 0.3)'], // indigo 15 | ['#0066ff', 'rgba(0, 102, 255, 0.3)'], // blue 16 | ['#6600ff', 'rgba(102, 0, 255, 0.3)'], // purple 17 | ['#cc00cc', 'rgba(204, 0, 204, 0.3)'], // dark pink 18 | ['#ff6666', 'rgba(255, 102, 102, 0.3)'], // light red 19 | ['#ffcc66', 'rgba(255, 204, 102, 0.3)'], // light orange 20 | ['#99cc00', 'rgba(153, 204, 0, 0.3)'], // lime green 21 | ['#00cc99', 'rgba(0, 204, 153, 0.3)'], // teal 22 | ['#0099cc', 'rgba(0, 153, 204, 0.3)'], // steel blue 23 | ['#9933cc', 'rgba(153, 51, 204, 0.3)'], // lavender 24 | ['#ff3399', 'rgba(255, 51, 153, 0.3)'], // hot pink 25 | ['#996633', 'rgba(153, 102, 51, 0.3)'], // brown 26 | ]; 27 | 28 | const RESIZE_BORDER = 5; 29 | const MOVE_BORDER = 5; 30 | 31 | const t2i_bboxes = new Array(BBOX_MAX_NUM).fill(null); 32 | const i2i_bboxes = new Array(BBOX_MAX_NUM).fill(null); 33 | 34 | // ↓↓↓ called from gradio ↓↓↓ 35 | 36 | function onCreateT2IRefClick(overwrite) { 37 | let width, height; 38 | if (overwrite) { 39 | const overwriteInputs = gradioApp().querySelectorAll('#MD-overwrite-width-t2i input, #MD-overwrite-height-t2i input'); 40 | width = parseInt(overwriteInputs[0].value); 41 | height = parseInt(overwriteInputs[2].value); 42 | } else { 43 | const sizeInputs = gradioApp().querySelectorAll('#txt2img_width input, #txt2img_height input'); 44 | width = parseInt(sizeInputs[0].value); 45 | height = parseInt(sizeInputs[2].value); 46 | } 47 | 48 | if (isNaN(width)) width = 512; 49 | if (isNaN(height)) height = 512; 50 | 51 | // Concat it to string to bypass the gradio bug 52 | // 向黑恶势力低头 53 | return width.toString() + 'x' + height.toString(); 54 | } 55 | 56 | function onCreateI2IRefClick() { 57 | const canvas = gradioApp().querySelector('#img2img_image img'); 58 | return canvas.src; 59 | } 60 | 61 | function onBoxEnableClick(is_t2i, idx, enable) { 62 | let canvas = null; 63 | let bboxes = null; 64 | let locator = null; 65 | if (is_t2i) { 66 | locator = () => gradioApp().querySelector('#MD-bbox-ref-t2i'); 67 | bboxes = t2i_bboxes; 68 | } else { 69 | locator = () => gradioApp().querySelector('#MD-bbox-ref-i2i'); 70 | bboxes = i2i_bboxes; 71 | } 72 | ref_div = locator(); 73 | canvas = ref_div.querySelector('img'); 74 | if (!canvas) { return false; } 75 | 76 | if (enable) { 77 | // Check if the bounding box already exists 78 | if (!bboxes[idx]) { 79 | // Initialize bounding box 80 | const bbox = [DEFAULT_X, DEFAULT_Y, DEFAULT_W, DEFAULT_H]; 81 | const colorMap = COLOR_MAP[idx % COLOR_MAP.length]; 82 | const div = document.createElement('div'); 83 | div.id = 'MD-bbox-' + (is_t2i ? 't2i-' : 'i2i-') + idx; 84 | div.style.left = '0px'; 85 | div.style.top = '0px'; 86 | div.style.width = '0px'; 87 | div.style.height = '0px'; 88 | div.style.position = 'absolute'; 89 | div.style.border = '2px solid ' + colorMap[0]; 90 | div.style.background = colorMap[1]; 91 | div.style.zIndex = '900'; 92 | div.style.display = 'none'; 93 | // A text tip to warn the user if bbox is too large 94 | const tip = document.createElement('span'); 95 | tip.id = 'MD-tip-' + (is_t2i ? 't2i-' : 'i2i-') + idx; 96 | tip.style.left = '50%'; 97 | tip.style.top = '50%'; 98 | tip.style.position = 'absolute'; 99 | tip.style.transform = 'translate(-50%, -50%)'; 100 | tip.style.fontSize = '12px'; 101 | tip.style.fontWeight = 'bold'; 102 | tip.style.textAlign = 'center'; 103 | tip.style.color = colorMap[0]; 104 | tip.style.zIndex = '901'; 105 | tip.style.display = 'none'; 106 | tip.innerHTML = 'Warning: Region very large!
Take care of VRAM usage!'; 107 | div.appendChild(tip); 108 | 109 | div.addEventListener('mousedown', function (e) { 110 | if (e.button === 0) { onBoxMouseDown(e, is_t2i, idx); } 111 | }); 112 | div.addEventListener('mousemove', function (e) { 113 | updateCursorStyle(e, is_t2i, idx); 114 | }); 115 | 116 | const shower = function() { // insert to DOM if necessary 117 | if (!gradioApp().querySelector('#' + div.id)) { 118 | locator().appendChild(div); 119 | } 120 | } 121 | bboxes[idx] = [div, bbox, shower]; 122 | } 123 | 124 | // Show the bounding box 125 | displayBox(canvas, is_t2i, bboxes[idx]); 126 | return true; 127 | } else { 128 | if (!bboxes[idx]) { return false; } 129 | const [div, bbox, shower] = bboxes[idx]; 130 | div.style.display = 'none'; 131 | } 132 | return false; 133 | } 134 | 135 | function onBoxChange(is_t2i, idx, what, v) { 136 | // This function handles all the changes of the bounding box 137 | // Including the rendering and python slider update 138 | let bboxes = null; 139 | let canvas = null; 140 | if (is_t2i) { 141 | bboxes = t2i_bboxes; 142 | canvas = gradioApp().querySelector('#MD-bbox-ref-t2i img'); 143 | } else { 144 | bboxes = i2i_bboxes; 145 | canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img'); 146 | } 147 | if (!bboxes[idx] || !canvas) { 148 | switch (what) { 149 | case 'x': return DEFAULT_X; 150 | case 'y': return DEFAULT_Y; 151 | case 'w': return DEFAULT_W; 152 | case 'h': return DEFAULT_H; 153 | } 154 | } 155 | const [div, bbox, shower] = bboxes[idx]; 156 | if (div.style.display === 'none') { return v; } 157 | 158 | // parse trigger 159 | switch (what) { 160 | case 'x': bbox[0] = v; break; 161 | case 'y': bbox[1] = v; break; 162 | case 'w': bbox[2] = v; break; 163 | case 'h': bbox[3] = v; break; 164 | } 165 | displayBox(canvas, is_t2i, bboxes[idx]); 166 | return v; 167 | } 168 | 169 | // ↓↓↓ called from js ↓↓↓ 170 | 171 | function getSeedInfo(is_t2i, id, current_seed) { 172 | const info_id = is_t2i ? '#html_info_txt2img' : '#html_info_img2img'; 173 | const info_div = gradioApp().querySelector(info_id); 174 | try{ 175 | current_seed = parseInt(current_seed); 176 | } catch(e) { 177 | current_seed = -1; 178 | } 179 | if (!info_div) return current_seed; 180 | let info = info_div.innerHTML; 181 | if (!info) return current_seed; 182 | // remove all html tags 183 | info = info.replace(/<[^>]*>/g, ''); 184 | // Find a json string 'region control:' in the info 185 | // get its index 186 | idx = info.indexOf('Region control'); 187 | if (idx == -1) return current_seed; 188 | // get the json string (detect the bracket) 189 | // find the first '{' 190 | let start_idx = info.indexOf('{', idx); 191 | let bracket = 1; 192 | let end_idx = start_idx + 1; 193 | while (bracket > 0 && end_idx < info.length) { 194 | if (info[end_idx] == '{') bracket++; 195 | if (info[end_idx] == '}') bracket--; 196 | end_idx++; 197 | } 198 | if (bracket > 0) { 199 | return current_seed; 200 | } 201 | // get the json string 202 | let json_str = info.substring(start_idx, end_idx); 203 | // replace the single quote to double quote 204 | json_str = json_str.replace(/'/g, '"'); 205 | // replace python True to javascript true, False to false 206 | json_str = json_str.replace(/True/g, 'true'); 207 | // parse the json string 208 | let json = JSON.parse(json_str); 209 | // get the seed if the region id is in the json 210 | const region_id = 'Region ' + id.toString(); 211 | if (!(region_id in json)) return current_seed; 212 | const region = json[region_id]; 213 | if (!('seed' in region)) return current_seed; 214 | let seed = region['seed']; 215 | try{ 216 | seed = parseInt(seed); 217 | } catch(e) { 218 | return current_seed; 219 | } 220 | return seed; 221 | } 222 | 223 | function displayBox(canvas, is_t2i, bbox_info) { 224 | // check null input 225 | const [div, bbox, shower] = bbox_info; 226 | const [x, y, w, h] = bbox; 227 | if (!canvas || !div || x == null || y == null || w == null || h == null) { return; } 228 | 229 | // client: canvas widget display size 230 | // natural: content image real size 231 | let vpScale = Math.min(canvas.clientWidth / canvas.naturalWidth, canvas.clientHeight / canvas.naturalHeight); 232 | let canvasCenterX = canvas.clientWidth / 2; 233 | let canvasCenterY = canvas.clientHeight / 2; 234 | let scaledX = canvas.naturalWidth * vpScale; 235 | let scaledY = canvas.naturalHeight * vpScale; 236 | let viewRectLeft = canvasCenterX - scaledX / 2; 237 | let viewRectRight = canvasCenterX + scaledX / 2; 238 | let viewRectTop = canvasCenterY - scaledY / 2; 239 | let viewRectDown = canvasCenterY + scaledY / 2; 240 | 241 | let xDiv = viewRectLeft + scaledX * x; 242 | let yDiv = viewRectTop + scaledY * y; 243 | let wDiv = Math.min(scaledX * w, viewRectRight - xDiv); 244 | let hDiv = Math.min(scaledY * h, viewRectDown - yDiv); 245 | 246 | // Calculate warning bbox size 247 | let upscalerFactor = 1.0; 248 | if (!is_t2i) { 249 | const upscalerInput = parseFloat(gradioApp().querySelector('#MD-i2i-upscaler-factor input').value); 250 | if (!isNaN(upscalerInput)) upscalerFactor = upscalerInput; 251 | } 252 | let maxSize = BBOX_WARNING_SIZE / upscalerFactor * vpScale; 253 | let maxW = maxSize / scaledX; 254 | let maxH = maxSize / scaledY; 255 | if (w > maxW || h > maxH) { 256 | div.querySelector('span').style.display = 'block'; 257 | } else { 258 | div.querySelector('span').style.display = 'none'; 259 | } 260 | 261 | // update
when not equal 262 | div.style.left = xDiv + 'px'; 263 | div.style.top = yDiv + 'px'; 264 | div.style.width = wDiv + 'px'; 265 | div.style.height = hDiv + 'px'; 266 | div.style.display = 'block'; 267 | 268 | // insert it to DOM if not appear 269 | shower(); 270 | } 271 | 272 | function onBoxMouseDown(e, is_t2i, idx) { 273 | let bboxes = null; 274 | let canvas = null; 275 | if (is_t2i) { 276 | bboxes = t2i_bboxes; 277 | canvas = gradioApp().querySelector('#MD-bbox-ref-t2i img'); 278 | } else { 279 | bboxes = i2i_bboxes; 280 | canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img'); 281 | } 282 | // Get the bounding box 283 | if (!canvas || !bboxes[idx]) { return; } 284 | const [div, bbox, shower] = bboxes[idx]; 285 | 286 | // Check if the click is inside the bounding box 287 | const boxRect = div.getBoundingClientRect(); 288 | let mouseX = e.clientX; 289 | let mouseY = e.clientY; 290 | 291 | const resizeLeft = mouseX >= boxRect.left && mouseX <= boxRect.left + RESIZE_BORDER; 292 | const resizeRight = mouseX >= boxRect.right - RESIZE_BORDER && mouseX <= boxRect.right; 293 | const resizeTop = mouseY >= boxRect.top && mouseY <= boxRect.top + RESIZE_BORDER; 294 | const resizeBottom = mouseY >= boxRect.bottom - RESIZE_BORDER && mouseY <= boxRect.bottom; 295 | 296 | const moveHorizontal = mouseX >= boxRect.left + MOVE_BORDER && mouseX <= boxRect.right - MOVE_BORDER; 297 | const moveVertical = mouseY >= boxRect.top + MOVE_BORDER && mouseY <= boxRect.bottom - MOVE_BORDER; 298 | 299 | if (!resizeLeft && !resizeRight && !resizeTop && !resizeBottom && !moveHorizontal && !moveVertical) { return; } 300 | 301 | const horizontalPivot = resizeLeft ? bbox[0] + bbox[2] : bbox[0]; 302 | const verticalPivot = resizeTop ? bbox[1] + bbox[3] : bbox[1]; 303 | 304 | // Canvas can be regarded as invariant during the drag operation 305 | // Calculate in advance to reduce overhead 306 | 307 | // Calculate viewport scale based on the current canvas size and the natural image size 308 | let vpScale = Math.min(canvas.clientWidth / canvas.naturalWidth, canvas.clientHeight / canvas.naturalHeight); 309 | let vpOffset = canvas.getBoundingClientRect(); 310 | 311 | // Calculate scaled dimensions of the canvas 312 | let scaledX = canvas.naturalWidth * vpScale; 313 | let scaledY = canvas.naturalHeight * vpScale; 314 | 315 | // Calculate the canvas center and view rectangle coordinates 316 | let canvasCenterX = (vpOffset.left + window.scrollX) + canvas.clientWidth / 2; 317 | let canvasCenterY = (vpOffset.top + window.scrollY) + canvas.clientHeight / 2; 318 | let viewRectLeft = canvasCenterX - scaledX / 2 - window.scrollX; 319 | let viewRectRight = canvasCenterX + scaledX / 2 - window.scrollX; 320 | let viewRectTop = canvasCenterY - scaledY / 2 - window.scrollY; 321 | let viewRectDown = canvasCenterY + scaledY / 2 - window.scrollY; 322 | 323 | mouseX = Math.min(Math.max(mouseX, viewRectLeft), viewRectRight); 324 | mouseY = Math.min(Math.max(mouseY, viewRectTop), viewRectDown); 325 | 326 | const accordion = gradioApp().querySelector(`#MD-accordion-${is_t2i ? 't2i' : 'i2i'}-${idx}`); 327 | 328 | // Move or resize the bounding box on mousemove 329 | function onMouseMove(e) { 330 | // Prevent selecting anything irrelevant 331 | e.preventDefault(); 332 | 333 | // Get the new mouse position 334 | let newMouseX = e.clientX; 335 | let newMouseY = e.clientY; 336 | 337 | // clamp the mouse position to the view rectangle 338 | newMouseX = Math.min(Math.max(newMouseX, viewRectLeft), viewRectRight); 339 | newMouseY = Math.min(Math.max(newMouseY, viewRectTop), viewRectDown); 340 | 341 | // Calculate the mouse movement delta 342 | const dx = (newMouseX - mouseX) / scaledX; 343 | const dy = (newMouseY - mouseY) / scaledY; 344 | 345 | // Update the mouse position 346 | mouseX = newMouseX; 347 | mouseY = newMouseY; 348 | 349 | // if no move just return 350 | if (dx === 0 && dy === 0) { return; } 351 | 352 | // Update the mouse position 353 | let [x, y, w, h] = bbox; 354 | if (moveHorizontal && moveVertical) { 355 | // If moving the bounding box 356 | x = Math.min(Math.max(x + dx, 0), 1 - w); 357 | y = Math.min(Math.max(y + dy, 0), 1 - h); 358 | } else { 359 | // If resizing the bounding box 360 | if (resizeLeft || resizeRight) { 361 | if (x < horizontalPivot) { 362 | if (dx <= w) { 363 | // If still within the left side of the pivot 364 | x = x + dx; 365 | w = w - dx; 366 | } else { 367 | // If crossing the pivot 368 | w = dx - w; 369 | x = horizontalPivot; 370 | } 371 | } else { 372 | if (w + dx < 0) { 373 | // If still within the right side of the pivot 374 | x = horizontalPivot + w + dx; 375 | w = - dx - w; 376 | } else { 377 | // If crossing the pivot 378 | x = horizontalPivot; 379 | w = w + dx; 380 | } 381 | } 382 | 383 | // Clamp the bounding box to the image 384 | if (x < 0) { 385 | w = w + x; 386 | x = 0; 387 | } else if (x + w > 1) { 388 | w = 1 - x; 389 | } 390 | } 391 | // Same as above, but for the vertical axis 392 | if (resizeTop || resizeBottom) { 393 | if (y < verticalPivot) { 394 | if (dy <= h) { 395 | y = y + dy; 396 | h = h - dy; 397 | } else { 398 | h = dy - h; 399 | y = verticalPivot; 400 | } 401 | } else { 402 | if (h + dy < 0) { 403 | y = verticalPivot + h + dy; 404 | h = - dy - h; 405 | } else { 406 | y = verticalPivot; 407 | h = h + dy; 408 | } 409 | } 410 | if (y < 0) { 411 | h = h + y; 412 | y = 0; 413 | } else if (y + h > 1) { 414 | h = 1 - y; 415 | } 416 | } 417 | } 418 | const [div, old_bbox, _] = bboxes[idx]; 419 | 420 | // If all the values are the same, just return 421 | if (old_bbox[0] === x && old_bbox[1] === y && old_bbox[2] === w && old_bbox[3] === h) { return; } 422 | // else update the bbox 423 | const event = new Event('input'); 424 | const coords = [x, y, w, h]; 425 | // The querySelector is not very efficient, so we query it once and reuse it 426 | // caching will result gradio bugs that stucks bbox and cannot move & drag 427 | const sliderIds = ['x', 'y', 'w', 'h']; 428 | // We try to select the input sliders 429 | const sliderSelectors = sliderIds.map(id => `#MD-${is_t2i ? 't2i' : 'i2i'}-${idx}-${id} input`).join(', '); 430 | let sliderInputs = accordion.querySelectorAll(sliderSelectors); 431 | if (sliderInputs.length == 0) { 432 | // If we failed, the accordion is probably closed and sliders are removed in the dom, so we open it 433 | accordion.querySelector('.label-wrap').click(); 434 | // and try again 435 | sliderInputs = accordion.querySelectorAll(sliderSelectors); 436 | // If we still failed, we just return 437 | if (sliderInputs.length == 0) { return; } 438 | } 439 | for (let i = 0; i < 4; i++) { 440 | if (old_bbox[i] !== coords[i]) { 441 | sliderInputs[2*i].value = coords[i]; 442 | sliderInputs[2*i].dispatchEvent(event); 443 | } 444 | } 445 | } 446 | 447 | // Remove the mousemove and mouseup event listeners 448 | function onMouseUp() { 449 | document.removeEventListener('mousemove', onMouseMove); 450 | document.removeEventListener('mouseup', onMouseUp); 451 | } 452 | 453 | // Add the event listeners 454 | document.addEventListener('mousemove', onMouseMove); 455 | document.addEventListener('mouseup', onMouseUp); 456 | } 457 | 458 | function updateCursorStyle(e, is_t2i, idx) { 459 | // This function changes the cursor style when hovering over the bounding box 460 | const bboxes = is_t2i ? t2i_bboxes : i2i_bboxes; 461 | if (!bboxes[idx]) return; 462 | 463 | const div = bboxes[idx][0]; 464 | const boxRect = div.getBoundingClientRect(); 465 | const mouseX = e.clientX; 466 | const mouseY = e.clientY; 467 | 468 | const resizeLeft = mouseX >= boxRect.left && mouseX <= boxRect.left + RESIZE_BORDER; 469 | const resizeRight = mouseX >= boxRect.right - RESIZE_BORDER && mouseX <= boxRect.right; 470 | const resizeTop = mouseY >= boxRect.top && mouseY <= boxRect.top + RESIZE_BORDER; 471 | const resizeBottom = mouseY >= boxRect.bottom - RESIZE_BORDER && mouseY <= boxRect.bottom; 472 | 473 | if ((resizeLeft && resizeTop) || (resizeRight && resizeBottom)) { 474 | div.style.cursor = 'nwse-resize'; 475 | } else if ((resizeLeft && resizeBottom) || (resizeRight && resizeTop)) { 476 | div.style.cursor = 'nesw-resize'; 477 | } else if (resizeLeft || resizeRight) { 478 | div.style.cursor = 'ew-resize'; 479 | } else if (resizeTop || resizeBottom) { 480 | div.style.cursor = 'ns-resize'; 481 | } else { 482 | div.style.cursor = 'move'; 483 | } 484 | } 485 | 486 | // ↓↓↓ auto called event listeners ↓↓↓ 487 | 488 | function updateBoxes(is_t2i) { 489 | // This function redraw all bounding boxes 490 | let bboxes = null; 491 | let canvas = null; 492 | if (is_t2i) { 493 | bboxes = t2i_bboxes; 494 | canvas = gradioApp().querySelector('#MD-bbox-ref-t2i img'); 495 | } else { 496 | bboxes = i2i_bboxes; 497 | canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img'); 498 | } 499 | if (!canvas) return; 500 | 501 | for (let idx = 0; idx < bboxes.length; idx++) { 502 | if (!bboxes[idx]) continue; 503 | const [div, bbox, shower] = bboxes[idx]; 504 | if (div.style.display === 'none') { return; } 505 | 506 | displayBox(canvas, is_t2i, bboxes[idx]); 507 | } 508 | } 509 | 510 | window.addEventListener('resize', _ => { 511 | updateBoxes(true); 512 | updateBoxes(false); 513 | }); 514 | 515 | // ======== Gradio Bug Fix ======== 516 | // For Gradio versions > 3.16.0 and < 3.29.0, the accordion DOM will be deleted when it is closed. 517 | // We need to judge the versions and listen to the accordion open event, rerender the bbox at that time. 518 | // This silly bug fix is only for compatibility, we recommend to update the gradio version to 3.29.0 or higher. 519 | try { 520 | const GRADIO_VERSIONS = window.gradio_config["version"].split("."); 521 | const gradio_major_version = parseInt(GRADIO_VERSIONS[0]); 522 | const gradio_minor_version = parseInt(GRADIO_VERSIONS[1]); 523 | if (gradio_major_version == 3 && gradio_minor_version > 16 && gradio_minor_version < 29) { 524 | let listener = e => { 525 | if (!e) { return; } 526 | if (!e.target) { return; } 527 | if (!e.target.classList) { return; } 528 | if (!e.target.classList.contains('label-wrap')) { return; } 529 | for (let tab of ['t2i', 'i2i']) { 530 | const div = gradioApp().querySelector('#MD-bbox-control-' + tab +' div.label-wrap'); 531 | if (!div) { continue; } 532 | updateBoxes(tab === 't2i'); 533 | } 534 | }; 535 | window.addEventListener('DOMNodeInserted', listener); 536 | } 537 | } catch (ignored) { 538 | // If the above code failed, the gradio version shouldn't be in the range of 3.16.0 to 3.29.0, so we just return. 539 | } 540 | // ======== Gradio Bug Fix ======== 541 | -------------------------------------------------------------------------------- /scripts/tilediffusion.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # ------------------------------------------------------------------------ 3 | # 4 | # Tiled Diffusion for Automatic1111 WebUI 5 | # 6 | # Introducing revolutionary large image drawing methods: 7 | # MultiDiffusion and Mixture of Diffusers! 8 | # 9 | # Techniques is not originally proposed by me, please refer to 10 | # 11 | # MultiDiffusion: https://multidiffusion.github.io 12 | # Mixture of Diffusers: https://github.com/albarji/mixture-of-diffusers 13 | # 14 | # The script contains a few optimizations including: 15 | # - symmetric tiling bboxes 16 | # - cached tiling weights 17 | # - batched denoising 18 | # - advanced prompt control for each tile 19 | # 20 | # ------------------------------------------------------------------------ 21 | # 22 | # This script hooks into the original sampler and decomposes the latent 23 | # image, sampled separately and run weighted average to merge them back. 24 | # 25 | # Advantages: 26 | # - Allows for super large resolutions (2k~8k) for both txt2img and img2img. 27 | # - The merged output is completely seamless without any post-processing. 28 | # - Training free. No need to train a new model, and you can control the 29 | # text prompt for specific regions. 30 | # 31 | # Drawbacks: 32 | # - Depending on your parameter settings, the process can be very slow, 33 | # especially when overlap is relatively large. 34 | # - The gradient calculation is not compatible with this hack. It 35 | # will break any backward() or torch.autograd.grad() that passes UNet. 36 | # 37 | # How it works: 38 | # 1. The latent image is split into tiles. 39 | # 2. In MultiDiffusion: 40 | # 1. The UNet predicts the noise of each tile. 41 | # 2. The tiles are denoised by the original sampler for one time step. 42 | # 3. The tiles are added together but divided by how many times each pixel is added. 43 | # 3. In Mixture of Diffusers: 44 | # 1. The UNet predicts the noise of each tile 45 | # 2. All noises are fused with a gaussian weight mask. 46 | # 3. The denoiser denoises the whole image for one time step using fused noises. 47 | # 4. Repeat 2-3 until all timesteps are completed. 48 | # 49 | # Enjoy! 50 | # 51 | # @author: LI YI @ Nanyang Technological University - Singapore 52 | # @date: 2023-03-03 53 | # @license: CC BY-NC-SA 4.0 54 | # 55 | # Please give me a star if you like this project! 56 | # 57 | # ------------------------------------------------------------------------ 58 | ''' 59 | 60 | import os 61 | import json 62 | import torch 63 | import numpy as np 64 | import gradio as gr 65 | 66 | from modules import sd_samplers, images, shared, devices, processing, scripts 67 | from modules.shared import opts 68 | from modules.processing import opt_f, get_fixed_seed 69 | from modules.ui import gr_show 70 | 71 | from tile_methods.abstractdiffusion import AbstractDiffusion 72 | from tile_methods.multidiffusion import MultiDiffusion 73 | from tile_methods.mixtureofdiffusers import MixtureOfDiffusers 74 | from tile_utils.utils import * 75 | if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 76 | from modules.ui_components import InputAccordion 77 | else: 78 | InputAccordion = None 79 | 80 | CFG_PATH = os.path.join(scripts.basedir(), 'region_configs') 81 | BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16) 82 | 83 | 84 | class Script(scripts.Script): 85 | 86 | def __init__(self): 87 | self.controlnet_script: ModuleType = None 88 | self.stablesr_script: ModuleType = None 89 | self.delegate: AbstractDiffusion = None 90 | self.noise_inverse_cache: NoiseInverseCache = None 91 | 92 | def title(self): 93 | return 'Tiled Diffusion' 94 | 95 | def show(self, is_img2img): 96 | return scripts.AlwaysVisible 97 | 98 | def ui(self, is_img2img): 99 | tab = 't2i' if not is_img2img else 'i2i' 100 | is_t2i = 'true' if not is_img2img else 'false' 101 | uid = lambda name: f'MD-{tab}-{name}' 102 | 103 | with ( 104 | InputAccordion(False, label='Tiled Diffusion', elem_id=uid('enabled')) if InputAccordion 105 | else gr.Accordion('Tiled Diffusion', open=False, elem_id=f'MD-{tab}') 106 | as enabled 107 | ): 108 | with gr.Row(variant='compact') as tab_enable: 109 | if not InputAccordion: 110 | enabled = gr.Checkbox(label='Enable Tiled Diffusion', value=False, elem_id=uid('enabled')) 111 | overwrite_size = gr.Checkbox(label='Overwrite image size', value=False, visible=not is_img2img, elem_id=uid('overwrite-image-size')) 112 | keep_input_size = gr.Checkbox(label='Keep input image size', value=True, visible=is_img2img, elem_id=uid('keep-input-size')) 113 | 114 | with gr.Row(variant='compact', visible=False) as tab_size: 115 | image_width = gr.Slider(minimum=256, maximum=16384, step=16, label='Image width', value=1024, elem_id=f'MD-overwrite-width-{tab}') 116 | image_height = gr.Slider(minimum=256, maximum=16384, step=16, label='Image height', value=1024, elem_id=f'MD-overwrite-height-{tab}') 117 | overwrite_size.change(fn=gr_show, inputs=overwrite_size, outputs=tab_size, show_progress=False) 118 | 119 | with gr.Row(variant='compact') as tab_param: 120 | method = gr.Dropdown(label='Method', choices=[e.value for e in Method], value=Method.MULTI_DIFF.value if is_t2i else Method.MIX_DIFF.value, elem_id=uid('method')) 121 | control_tensor_cpu = gr.Checkbox(label='Move ControlNet tensor to CPU (if applicable)', value=False, elem_id=uid('control-tensor-cpu')) 122 | reset_status = gr.Button(value='Free GPU', variant='tool') 123 | reset_status.click(fn=self.reset_and_gc, show_progress=False) 124 | 125 | with gr.Group() as tab_tile: 126 | with gr.Row(variant='compact'): 127 | tile_width = gr.Slider(minimum=16, maximum=256, step=16, label='Latent tile width', value=96, elem_id=uid('latent-tile-width')) 128 | tile_height = gr.Slider(minimum=16, maximum=256, step=16, label='Latent tile height', value=96, elem_id=uid('latent-tile-height')) 129 | 130 | with gr.Row(variant='compact'): 131 | overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent tile overlap', value=48 if is_t2i else 8, elem_id=uid('latent-tile-overlap')) 132 | batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent tile batch size', value=4, elem_id=uid('latent-tile-batch-size')) 133 | 134 | with gr.Row(variant='compact', visible=is_img2img) as tab_upscale: 135 | upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value='None', elem_id=uid('upscaler-index')) 136 | scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label='Scale Factor', value=2.0, elem_id=uid('upscaler-factor')) 137 | 138 | with gr.Accordion('Noise Inversion', open=True, visible=is_img2img) as tab_noise_inv: 139 | with gr.Row(variant='compact'): 140 | noise_inverse = gr.Checkbox(label='Enable Noise Inversion', value=False, elem_id=uid('noise-inverse')) 141 | noise_inverse_steps = gr.Slider(minimum=1, maximum=200, step=1, label='Inversion steps', value=10, elem_id=uid('noise-inverse-steps')) 142 | gr.HTML('

Please test on small images before actual upscale. Default params require denoise <= 0.6

') 143 | with gr.Row(variant='compact'): 144 | noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) 145 | noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) 146 | noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) 147 | 148 | # The control includes txt2img and img2img, we use t2i and i2i to distinguish them 149 | with gr.Group(elem_id=f'MD-bbox-control-{tab}') as tab_bbox: 150 | with gr.Accordion('Region Prompt Control', open=False): 151 | with gr.Row(variant='compact'): 152 | enable_bbox_control = gr.Checkbox(label='Enable Control', value=False, elem_id=uid('enable-bbox-control')) 153 | draw_background = gr.Checkbox(label='Draw full canvas background', value=False, elem_id=uid('draw-background')) 154 | causal_layers = gr.Checkbox(label='Causalize layers', value=False, visible=False, elem_id='MD-causal-layers') # NOTE: currently not used 155 | 156 | with gr.Row(variant='compact'): 157 | create_button = gr.Button(value="Create txt2img canvas" if not is_img2img else "From img2img", elem_id='MD-create-canvas') 158 | 159 | bbox_controls: List[Component] = [] # control set for each bbox 160 | with gr.Row(variant='compact'): 161 | ref_image = gr.Image(label='Ref image (for conviently locate regions)', image_mode=None, elem_id=f'MD-bbox-ref-{tab}', interactive=True) 162 | if not is_img2img: 163 | # gradio has a serious bug: it cannot accept multiple inputs when you use both js and fn. 164 | # to workaround this, we concat the inputs into a single string and parse it in js 165 | def create_t2i_ref(string): 166 | w, h = [int(x) for x in string.split('x')] 167 | w = max(w, opt_f) 168 | h = max(h, opt_f) 169 | return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 170 | create_button.click( 171 | fn=create_t2i_ref, 172 | inputs=overwrite_size, 173 | outputs=ref_image, 174 | _js='onCreateT2IRefClick', 175 | show_progress=False) 176 | else: 177 | create_button.click(fn=None, outputs=ref_image, _js='onCreateI2IRefClick', show_progress=False) 178 | 179 | with gr.Row(variant='compact'): 180 | cfg_name = gr.Textbox(label='Custom Config File', value='config.json', elem_id=uid('cfg-name')) 181 | cfg_dump = gr.Button(value='💾 Save', variant='tool') 182 | cfg_load = gr.Button(value='⚙️ Load', variant='tool') 183 | 184 | with gr.Row(variant='compact'): 185 | cfg_tip = gr.HTML(value='', visible=False) 186 | 187 | for i in range(BBOX_MAX_NUM): 188 | # Only when displaying & png generate info we use index i+1, in other cases we use i 189 | with gr.Accordion(f'Region {i+1}', open=False, elem_id=f'MD-accordion-{tab}-{i}'): 190 | with gr.Row(variant='compact'): 191 | e = gr.Checkbox(label=f'Enable Region {i+1}', value=False, elem_id=f'MD-bbox-{tab}-{i}-enable') 192 | e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', show_progress=False) 193 | 194 | blend_mode = gr.Dropdown(label='Type', choices=[e.value for e in BlendMode], value=BlendMode.BACKGROUND.value, elem_id=f'MD-{tab}-{i}-blend-mode') 195 | feather_ratio = gr.Slider(label='Feather', value=0.2, minimum=0, maximum=1, step=0.05, visible=False, elem_id=f'MD-{tab}-{i}-feather') 196 | 197 | blend_mode.change(fn=lambda x: gr_show(x==BlendMode.FOREGROUND.value), inputs=blend_mode, outputs=feather_ratio, show_progress=False) 198 | 199 | with gr.Row(variant='compact'): 200 | x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-x') 201 | y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-y') 202 | 203 | with gr.Row(variant='compact'): 204 | w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-w') 205 | h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-h') 206 | 207 | x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False) 208 | y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False) 209 | w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False) 210 | h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False) 211 | 212 | prompt = gr.Text(show_label=False, placeholder=f'Prompt, will append to your {tab} prompt', max_lines=2, elem_id=f'MD-{tab}-{i}-prompt') 213 | neg_prompt = gr.Text(show_label=False, placeholder='Negative Prompt, will also be appended', max_lines=1, elem_id=f'MD-{tab}-{i}-neg-prompt') 214 | with gr.Row(variant='compact'): 215 | seed = gr.Number(label='Seed', value=-1, visible=True, elem_id=f'MD-{tab}-{i}-seed') 216 | random_seed = gr.Button(value='🎲', variant='tool', elem_id=f'MD-{tab}-{i}-random_seed') 217 | reuse_seed = gr.Button(value='♻️', variant='tool', elem_id=f'MD-{tab}-{i}-reuse_seed') 218 | random_seed.click(fn=lambda: -1, outputs=seed, show_progress=False) 219 | reuse_seed.click(fn=None, inputs=seed, outputs=seed, _js=f'e => getSeedInfo({is_t2i}, {i+1}, e)', show_progress=False) 220 | 221 | control = [e, x, y, w, h, prompt, neg_prompt, blend_mode, feather_ratio, seed] 222 | assert len(control) == NUM_BBOX_PARAMS 223 | bbox_controls.extend(control) 224 | 225 | # NOTE: dynamically hard coded!! 226 | load_regions_js = ''' 227 | function onBoxChangeAll(ref_image, cfg_name, ...args) { 228 | const is_t2i = %s; 229 | const n_bbox = %d; 230 | const n_ctrl = %d; 231 | for (let i=0; i 0 281 | if is_img2img: # img2img, TODO: replace with `images.resize_image()` 282 | idx = [x.name for x in shared.sd_upscalers].index(upscaler_name) 283 | upscaler = shared.sd_upscalers[idx] 284 | init_img = p.init_images[0] 285 | init_img = images.flatten(init_img, opts.img2img_background_color) 286 | if upscaler.name != "None": 287 | print(f"[Tiled Diffusion] upscaling image with {upscaler.name}...") 288 | image = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path) 289 | p.extra_generation_params["Tiled Diffusion upscaler"] = upscaler.name 290 | p.extra_generation_params["Tiled Diffusion scale factor"] = scale_factor 291 | # For webui folder based batch processing, the length of init_images is not 1 292 | # We need to replace all images with the upsampled one 293 | for i in range(len(p.init_images)): 294 | p.init_images[i] = image 295 | else: 296 | image = init_img 297 | 298 | # decide final canvas size 299 | if keep_input_size: 300 | p.width = image.width 301 | p.height = image.height 302 | elif upscaler.name != "None": 303 | p.width = int(scale_factor * p.width_original_md) 304 | p.height = int(scale_factor * p.height_original_md) 305 | elif overwrite_size: # txt2img 306 | p.width = image_width 307 | p.height = image_height 308 | 309 | ''' sanitiy check ''' 310 | chks = [ 311 | splitable(p.width, p.height, tile_width, tile_height, overlap), 312 | enable_bbox_control, 313 | is_img2img and noise_inverse, 314 | ] 315 | if not any(chks): 316 | print("[Tiled Diffusion] ignore tiling when there's only 1 tile or nothing to do :)") 317 | return 318 | 319 | bbox_settings = build_bbox_settings(bbox_control_states) if enable_bbox_control else {} 320 | 321 | if 'png info': 322 | info = {} 323 | p.extra_generation_params["Tiled Diffusion"] = info 324 | 325 | info['Method'] = method 326 | info['Tile tile width'] = tile_width 327 | info['Tile tile height'] = tile_height 328 | info['Tile Overlap'] = overlap 329 | info['Tile batch size'] = tile_batch_size 330 | 331 | if is_img2img: 332 | if upscaler.name != "None": 333 | info['Upscaler'] = upscaler.name 334 | info['Upscale factor'] = scale_factor 335 | if keep_input_size: 336 | info['Keep input size'] = keep_input_size 337 | if noise_inverse: 338 | info['NoiseInv'] = noise_inverse 339 | info['NoiseInv Steps'] = noise_inverse_steps 340 | info['NoiseInv Retouch'] = noise_inverse_retouch 341 | info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength 342 | info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel 343 | 344 | ''' ControlNet hackin ''' 345 | try: 346 | from scripts.cldm import ControlNet 347 | 348 | for script in p.scripts.scripts + p.scripts.alwayson_scripts: 349 | if hasattr(script, "latest_network") and script.title().lower() == "controlnet": 350 | self.controlnet_script = script 351 | print("[Tiled Diffusion] ControlNet found, support is enabled.") 352 | break 353 | except ImportError: 354 | pass 355 | 356 | ''' StableSR hackin ''' 357 | for script in p.scripts.scripts: 358 | if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": 359 | if script.stablesr_model is not None: 360 | self.stablesr_script = script 361 | print("[Tiled Diffusion] StableSR found, support is enabled.") 362 | break 363 | 364 | ''' hijack inner APIs, see unhijack in reset() ''' 365 | Script.create_sampler_original_md = sd_samplers.create_sampler 366 | sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( 367 | name, model, p, Method(method), 368 | tile_width, tile_height, overlap, tile_batch_size, 369 | noise_inverse, noise_inverse_steps, noise_inverse_retouch, 370 | noise_inverse_renoise_strength, noise_inverse_renoise_kernel, 371 | control_tensor_cpu, 372 | enable_bbox_control, draw_background, causal_layers, 373 | bbox_settings, 374 | ) 375 | 376 | if enable_bbox_control: 377 | region_info = { f'Region {i+1}': v._asdict() for i, v in bbox_settings.items() } 378 | info["Region control"] = region_info 379 | Script.create_random_tensors_original_md = processing.create_random_tensors 380 | processing.create_random_tensors = lambda *args, **kwargs: self.create_random_tensors_hijack( 381 | bbox_settings, region_info, 382 | *args, **kwargs, 383 | ) 384 | 385 | def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): 386 | if not enabled: return 387 | 388 | if self.delegate is not None: self.delegate.reset_controlnet_tensors() 389 | 390 | def postprocess(self, p: Processing, processed, enabled, *args): 391 | if not enabled: return 392 | 393 | # unhijack & unhook 394 | self.reset() 395 | 396 | # restore canvas size settings 397 | if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): 398 | p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot 399 | p.init_images.extend(p.init_images_original_md) 400 | del p.init_images_original_md 401 | p.width = p.width_original_md ; del p.width_original_md 402 | p.height = p.height_original_md ; del p.height_original_md 403 | 404 | # clean up noise inverse latent for folder-based processing 405 | if hasattr(p, 'noise_inverse_latent'): 406 | del p.noise_inverse_latent 407 | 408 | ''' ↓↓↓ inner API hijack ↓↓↓ ''' 409 | 410 | def create_sampler_hijack( 411 | self, name: str, model: LatentDiffusion, p: Processing, method: Method, 412 | tile_width: int, tile_height: int, overlap: int, tile_batch_size: int, 413 | noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, 414 | noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, 415 | control_tensor_cpu: bool, 416 | enable_bbox_control: bool, draw_background: bool, causal_layers: bool, 417 | bbox_settings: Dict[int, BBoxSettings] 418 | ): 419 | 420 | if self.delegate is not None: 421 | # samplers are stateless, we reuse it if possible 422 | if self.delegate.sampler_name == name: 423 | # before we reuse the sampler, we refresh the control tensor 424 | # so that we are compatible with ControlNet batch processing 425 | if self.controlnet_script: 426 | self.delegate.prepare_controlnet_tensors(refresh=True) 427 | return self.delegate.sampler_raw 428 | else: 429 | self.reset() 430 | 431 | flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse 432 | if flag_noise_inverse: 433 | print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') 434 | name = 'Euler' 435 | p.sampler_name = 'Euler' 436 | if name is None: print('>> name is empty') 437 | if model is None: print('>> model is empty') 438 | sampler = Script.create_sampler_original_md(name, model) 439 | if method == Method.MULTI_DIFF: delegate_cls = MultiDiffusion 440 | elif method == Method.MIX_DIFF: delegate_cls = MixtureOfDiffusers 441 | else: raise NotImplementedError(f"Method {method} not implemented.") 442 | 443 | # delegate hacks into the `sampler` with context of `p` 444 | delegate = delegate_cls(p, sampler) 445 | 446 | # setup **optional** supports through `init_*`, make everything relatively pluggable!! 447 | if flag_noise_inverse: 448 | get_cache_callback = self.noise_inverse_get_cache 449 | set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) 450 | delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) 451 | if not enable_bbox_control or draw_background: 452 | delegate.init_grid_bbox(tile_width, tile_height, overlap, tile_batch_size) 453 | if enable_bbox_control: 454 | delegate.init_custom_bbox(bbox_settings, draw_background, causal_layers) 455 | if self.controlnet_script: 456 | delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) 457 | if self.stablesr_script: 458 | delegate.init_stablesr(self.stablesr_script) 459 | 460 | # init everything done, perform sanity check & pre-computations 461 | delegate.init_done() 462 | # hijack the behaviours 463 | delegate.hook() 464 | 465 | self.delegate = delegate 466 | 467 | info = ', '.join([ 468 | f"{method.value} hooked into {name!r} sampler", 469 | f"Tile size: {delegate.tile_h}x{delegate.tile_w}", 470 | f"Tile count: {delegate.num_tiles}", 471 | f"Batch size: {delegate.tile_bs}", 472 | f"Tile batches: {len(delegate.batched_bboxes)}", 473 | ]) 474 | exts = [ 475 | "NoiseInv" if flag_noise_inverse else None, 476 | "RegionCtrl" if enable_bbox_control else None, 477 | "ContrlNet" if self.controlnet_script else None, 478 | "StableSR" if self.stablesr_script else None, 479 | ] 480 | ext_info = ', '.join([e for e in exts if e]) 481 | if ext_info: ext_info = f' (ext: {ext_info})' 482 | print(info + ext_info) 483 | 484 | return delegate.sampler_raw 485 | 486 | def create_random_tensors_hijack( 487 | self, bbox_settings: Dict, region_info: Dict, 488 | shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, 489 | ): 490 | org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) 491 | height, width = shape[1], shape[2] 492 | background_noise = torch.zeros_like(org_random_tensors) 493 | background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 494 | foreground_noise = torch.zeros_like(org_random_tensors) 495 | foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 496 | 497 | for i, v in bbox_settings.items(): 498 | seed = get_fixed_seed(v.seed) 499 | x, y, w, h = v.x, v.y, v.w, v.h 500 | # convert to pixel 501 | x = int(x * width) 502 | y = int(y * height) 503 | w = math.ceil(w * width) 504 | h = math.ceil(h * height) 505 | # clamp 506 | x = max(0, x) 507 | y = max(0, y) 508 | w = min(width - x, w) 509 | h = min(height - y, h) 510 | # create random tensor 511 | torch.manual_seed(seed) 512 | rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) 513 | if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: 514 | background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) 515 | background_noise_count[:, :, y:y+h, x:x+w] += 1 516 | elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: 517 | foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) 518 | foreground_noise_count[:, :, y:y+h, x:x+w] += 1 519 | else: 520 | raise NotImplementedError 521 | region_info['Region ' + str(i+1)]['seed'] = seed 522 | 523 | # average 524 | background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) 525 | foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) 526 | # paste two layers to original random tensor 527 | org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) 528 | org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) 529 | return org_random_tensors 530 | 531 | ''' ↓↓↓ helper methods ↓↓↓ ''' 532 | 533 | def dump_regions(self, cfg_name, *bbox_controls): 534 | if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) 535 | 536 | bbox_settings = build_bbox_settings(bbox_controls) 537 | data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} 538 | 539 | if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) 540 | fp = os.path.join(CFG_PATH, cfg_name) 541 | with open(fp, 'w', encoding='utf-8') as fh: 542 | json.dump(data, fh, indent=2, ensure_ascii=False) 543 | 544 | return gr_value(f'Config saved to {fp}.', visible=True) 545 | 546 | def load_regions(self, ref_image, cfg_name, *bbox_controls): 547 | if ref_image is None: 548 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] 549 | fp = os.path.join(CFG_PATH, cfg_name) 550 | if not os.path.exists(fp): 551 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] 552 | 553 | try: 554 | with open(fp, 'r', encoding='utf-8') as fh: 555 | data = json.load(fh) 556 | except Exception as e: 557 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] 558 | 559 | num_boxes = len(data['bbox_controls']) 560 | data_list = [] 561 | for i in range(BBOX_MAX_NUM): 562 | if i < num_boxes: 563 | for k in BBoxSettings._fields: 564 | if k in data['bbox_controls'][i]: 565 | data_list.append(data['bbox_controls'][i][k]) 566 | else: 567 | data_list.append(None) 568 | else: 569 | data_list.extend(DEFAULT_BBOX_SETTINGS) 570 | 571 | return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] 572 | 573 | def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): 574 | self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) 575 | 576 | def noise_inverse_get_cache(self): 577 | return self.noise_inverse_cache 578 | 579 | def reset(self): 580 | ''' unhijack inner APIs, see hijack in process() ''' 581 | if hasattr(Script, "create_sampler_original_md"): 582 | sd_samplers.create_sampler = Script.create_sampler_original_md 583 | del Script.create_sampler_original_md 584 | if hasattr(Script, "create_random_tensors_original_md"): 585 | processing.create_random_tensors = Script.create_random_tensors_original_md 586 | del Script.create_random_tensors_original_md 587 | MultiDiffusion .unhook() 588 | MixtureOfDiffusers.unhook() 589 | self.delegate = None 590 | 591 | def reset_and_gc(self): 592 | self.reset() 593 | self.noise_inverse_cache = None 594 | 595 | import gc; gc.collect() 596 | devices.torch_gc() 597 | 598 | try: 599 | import os 600 | import psutil 601 | mem = psutil.Process(os.getpid()).memory_info() 602 | print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') 603 | from modules.shared import mem_mon as vram_mon 604 | from modules.memmon import MemUsageMonitor 605 | vram_mon: MemUsageMonitor 606 | free, total = vram_mon.cuda_mem_get_info() 607 | print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') 608 | except: 609 | pass 610 | -------------------------------------------------------------------------------- /scripts/tileglobal.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import gradio as gr 7 | 8 | from modules import sd_samplers, images, shared, devices, processing, scripts, sd_samplers_common, rng 9 | from modules.shared import opts 10 | from modules.processing import opt_f, get_fixed_seed 11 | from modules.ui import gr_show 12 | 13 | from tile_methods.abstractdiffusion import AbstractDiffusion 14 | from tile_methods.demofusion import DemoFusion 15 | from tile_utils.utils import * 16 | from modules.sd_samplers_common import InterruptedException 17 | # import k_diffusion.sampling 18 | if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 19 | from modules.ui_components import InputAccordion 20 | else: 21 | InputAccordion = None 22 | 23 | 24 | CFG_PATH = os.path.join(scripts.basedir(), 'region_configs') 25 | BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16) 26 | 27 | 28 | def create_infotext_hijack(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=-1, all_negative_prompts=None): 29 | idx = index 30 | if index == -1: 31 | idx = None 32 | text = processing.create_infotext_ori(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt, idx, all_negative_prompts) 33 | start_index = text.find("Size") 34 | if start_index != -1: 35 | r_text = f"Size:{p.width_list[index]}x{p.height_list[index]}" 36 | end_index = text.find(",", start_index) 37 | if end_index != -1: 38 | replaced_string = text[:start_index] + r_text + text[end_index:] 39 | return replaced_string 40 | return text 41 | 42 | class Script(scripts.Script): 43 | def __init__(self): 44 | self.controlnet_script: ModuleType = None 45 | self.stablesr_script: ModuleType = None 46 | self.delegate: AbstractDiffusion = None 47 | self.noise_inverse_cache: NoiseInverseCache = None 48 | 49 | def title(self): 50 | return 'demofusion' 51 | 52 | def show(self, is_img2img): 53 | return scripts.AlwaysVisible 54 | 55 | def ui(self, is_img2img): 56 | ext_id = 'demofusion' 57 | tab = f'{ext_id}-t2i' if not is_img2img else f'{ext_id}-i2i' 58 | is_t2i = 'true' if not is_img2img else 'false' 59 | uid = lambda name: f'MD-{tab}-{name}' 60 | 61 | with ( 62 | InputAccordion(False, label='DemoFusion', elem_id=uid('enabled')) if InputAccordion 63 | else gr.Accordion('DemoFusion', open=False, elem_id=f'MD-{tab}') 64 | as enabled 65 | ): 66 | with gr.Row(variant='compact') as tab_enable: 67 | if not InputAccordion: 68 | enabled = gr.Checkbox(label='Enable DemoFusion(Dont open with tilediffusion)', value=False, elem_id=uid('enabled')) 69 | else: 70 | gr.Markdown('(Dont open with tilediffusion)') 71 | random_jitter = gr.Checkbox(label='Random Jitter', value = True, elem_id=uid('random-jitter')) 72 | keep_input_size = gr.Checkbox(label='Keep input-image size', value=False,visible=is_img2img, elem_id=uid('keep-input-size')) 73 | mixture_mode = gr.Checkbox(label='Mixture mode', value=False,elem_id=uid('mixture-mode')) 74 | 75 | gaussian_filter = gr.Checkbox(label='Gaussian Filter', value=True, visible=False, elem_id=uid('gaussian')) 76 | 77 | 78 | with gr.Row(variant='compact') as tab_param: 79 | method = gr.Dropdown(label='Method', choices=[Method_2.DEMO_FU.value], value=Method_2.DEMO_FU.value, visible= False, elem_id=uid('method')) 80 | control_tensor_cpu = gr.Checkbox(label='Move ControlNet tensor to CPU (if applicable)', value=False, elem_id=uid('control-tensor-cpu')) 81 | reset_status = gr.Button(value='Free GPU', variant='tool') 82 | reset_status.click(fn=self.reset_and_gc, show_progress=False) 83 | 84 | with gr.Group() as tab_tile: 85 | with gr.Row(variant='compact'): 86 | window_size = gr.Slider(minimum=16, maximum=256, step=16, label='Latent window size', value=128, elem_id=uid('latent-window-size')) 87 | 88 | with gr.Row(variant='compact'): 89 | overlap = gr.Slider(minimum=0, maximum=256, step=4, label='Latent window overlap', value=64, elem_id=uid('latent-tile-overlap')) 90 | batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Latent window batch size', value=4, elem_id=uid('latent-tile-batch-size')) 91 | batch_size_g = gr.Slider(minimum=1, maximum=8, step=1, label='Global window batch size', value=4, elem_id=uid('Global-tile-batch-size')) 92 | with gr.Row(variant='compact', visible=True) as tab_c: 93 | c1 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 1', value=3, elem_id=f'C1-{tab}') 94 | c2 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 2', value=1, elem_id=f'C2-{tab}') 95 | c3 = gr.Slider(minimum=0, maximum=5, step=0.01, label='Cosine Scale 3', value=1, elem_id=f'C3-{tab}') 96 | sigma = gr.Slider(minimum=0, maximum=2, step=0.01, label='Sigma', value=0.6, elem_id=f'Sigma-{tab}') 97 | with gr.Group() as tab_denoise: 98 | strength = gr.Slider(minimum=0, maximum=1, step=0.01, value = 0.85,label='Denoising Strength for Substage',visible=not is_img2img, elem_id=f'strength-{tab}') 99 | with gr.Row(variant='compact') as tab_upscale: 100 | scale_factor = gr.Slider(minimum=1.0, maximum=8.0, step=1, label='Scale Factor', value=2.0, elem_id=uid('upscaler-factor')) 101 | 102 | 103 | with gr.Accordion('Noise Inversion', open=True, visible=is_img2img) as tab_noise_inv: 104 | with gr.Row(variant='compact'): 105 | noise_inverse = gr.Checkbox(label='Enable Noise Inversion', value=False, elem_id=uid('noise-inverse')) 106 | noise_inverse_steps = gr.Slider(minimum=1, maximum=200, step=1, label='Inversion steps', value=10, elem_id=uid('noise-inverse-steps')) 107 | gr.HTML('

Please test on small images before actual upscale. Default params require denoise <= 0.6

') 108 | with gr.Row(variant='compact'): 109 | noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) 110 | noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) 111 | noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) 112 | 113 | # The control includes txt2img and img2img, we use t2i and i2i to distinguish them 114 | 115 | return [ 116 | enabled, method, 117 | keep_input_size, 118 | window_size, overlap, batch_size, 119 | scale_factor, 120 | noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel, 121 | control_tensor_cpu, 122 | random_jitter, 123 | c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode 124 | ] 125 | 126 | 127 | def process(self, p: Processing, 128 | enabled: bool, method: str, 129 | keep_input_size: bool, 130 | window_size:int, overlap: int, tile_batch_size: int, 131 | scale_factor: float, 132 | noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, 133 | control_tensor_cpu: bool, 134 | random_jitter:bool, 135 | c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode 136 | ): 137 | 138 | # unhijack & unhook, in case it broke at last time 139 | self.reset() 140 | p.mixture = mixture_mode 141 | if not mixture_mode: 142 | sigma = sigma/2 143 | if not enabled: return 144 | 145 | ''' upscale ''' 146 | # store canvas size settings 147 | if hasattr(p, "init_images"): 148 | p.init_images_original_md = [img.copy() for img in p.init_images] 149 | p.width_original_md = p.width 150 | p.height_original_md = p.height 151 | p.current_scale_num = 1 152 | p.gaussian_filter = gaussian_filter 153 | p.scale_factor = int(scale_factor) 154 | 155 | is_img2img = hasattr(p, "init_images") and len(p.init_images) > 0 156 | if is_img2img: 157 | init_img = p.init_images[0] 158 | init_img = images.flatten(init_img, opts.img2img_background_color) 159 | image = init_img 160 | if keep_input_size: 161 | p.width = image.width 162 | p.height = image.height 163 | p.width_original_md = p.width 164 | p.height_original_md = p.height 165 | else: #XXX:To adapt to noise inversion, we do not multiply the scale factor here 166 | p.width = p.width_original_md 167 | p.height = p.height_original_md 168 | else: # txt2img 169 | p.width = p.width_original_md 170 | p.height = p.height_original_md 171 | 172 | if 'png info': 173 | info = {} 174 | p.extra_generation_params["Tiled Diffusion"] = info 175 | 176 | info['Method'] = method 177 | info['Window Size'] = window_size 178 | info['Tile Overlap'] = overlap 179 | info['Tile batch size'] = tile_batch_size 180 | info["Global batch size"] = batch_size_g 181 | 182 | if is_img2img: 183 | info['Upscale factor'] = scale_factor 184 | if keep_input_size: 185 | info['Keep input size'] = keep_input_size 186 | if noise_inverse: 187 | info['NoiseInv'] = noise_inverse 188 | info['NoiseInv Steps'] = noise_inverse_steps 189 | info['NoiseInv Retouch'] = noise_inverse_retouch 190 | info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength 191 | info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel 192 | 193 | ''' ControlNet hackin ''' 194 | try: 195 | from scripts.cldm import ControlNet 196 | 197 | for script in p.scripts.scripts + p.scripts.alwayson_scripts: 198 | if hasattr(script, "latest_network") and script.title().lower() == "controlnet": 199 | self.controlnet_script = script 200 | print("[Demo Fusion] ControlNet found, support is enabled.") 201 | break 202 | except ImportError: 203 | pass 204 | 205 | ''' StableSR hackin ''' 206 | for script in p.scripts.scripts: 207 | if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": 208 | if script.stablesr_model is not None: 209 | self.stablesr_script = script 210 | print("[Demo Fusion] StableSR found, support is enabled.") 211 | break 212 | 213 | ''' hijack inner APIs, see unhijack in reset() ''' 214 | Script.create_sampler_original_md = sd_samplers.create_sampler 215 | 216 | sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( 217 | name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch, 218 | noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter,batch_size_g 219 | ) 220 | 221 | 222 | p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack( 223 | conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img, 224 | window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g) 225 | 226 | processing.create_infotext_ori = processing.create_infotext 227 | 228 | p.width_list = [p.height] 229 | p.height_list = [p.height] 230 | 231 | processing.create_infotext = create_infotext_hijack 232 | ## end 233 | 234 | 235 | def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): 236 | if not enabled: return 237 | 238 | if self.delegate is not None: self.delegate.reset_controlnet_tensors() 239 | 240 | def postprocess_batch_list(self, p, pp, enabled, *args, **kwargs): 241 | if not enabled: return 242 | for idx,image in enumerate(pp.images): 243 | idx_b = idx//p.batch_size 244 | pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)] 245 | p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds] 246 | p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts] 247 | p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts] 248 | p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts] 249 | if p.color_corrections != None: 250 | p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections] 251 | p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]] 252 | p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]] 253 | return 254 | 255 | def postprocess(self, p: Processing, processed, enabled, *args): 256 | if not enabled: return 257 | # unhijack & unhook 258 | self.reset() 259 | 260 | # restore canvas size settings 261 | if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): 262 | p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot 263 | p.init_images.extend(p.init_images_original_md) 264 | del p.init_images_original_md 265 | p.width = p.width_original_md ; del p.width_original_md 266 | p.height = p.height_original_md ; del p.height_original_md 267 | 268 | # clean up noise inverse latent for folder-based processing 269 | if hasattr(p, 'noise_inverse_latent'): 270 | del p.noise_inverse_latent 271 | 272 | ''' ↓↓↓ inner API hijack ↓↓↓ ''' 273 | @torch.no_grad() 274 | def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g): 275 | ################################################## Phase Initialization ###################################################### 276 | 277 | if not image_ori: 278 | p.current_step = 0 279 | p.denoising_strength = strength 280 | # p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) #NOTE:Wrong but very useful. If corrected, please replace with the content with the following lines 281 | # latents = p.rng.next() 282 | 283 | p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale 284 | x = p.rng.next() 285 | print("### Phase 1 Denoising ###") 286 | latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x)) 287 | latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1))) 288 | res = latents_ 289 | del x 290 | p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) 291 | starting_scale = 2 292 | else: # img2img 293 | print("### Encoding Real Image ###") 294 | latents = p.init_latent 295 | starting_scale = 1 296 | 297 | 298 | anchor_mean = latents.mean() 299 | anchor_std = latents.std() 300 | 301 | devices.torch_gc() 302 | 303 | ####################################################### Phase Upscaling ##################################################### 304 | p.cosine_scale_1 = c1 305 | p.cosine_scale_2 = c2 306 | p.cosine_scale_3 = c3 307 | self.delegate.sig = sigma 308 | p.latents = latents 309 | for current_scale_num in range(starting_scale, p.scale_factor+1): 310 | p.current_scale_num = current_scale_num 311 | print("### Phase {} Denoising ###".format(current_scale_num)) 312 | p.current_height = p.height_original_md * current_scale_num 313 | p.current_width = p.width_original_md * current_scale_num 314 | 315 | 316 | p.latents = F.interpolate(p.latents, size=(int(p.current_height / opt_f), int(p.current_width / opt_f)), mode='bicubic') 317 | p.rng = rng.ImageRNG(p.latents.shape[1:], p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) 318 | 319 | 320 | self.delegate.w = int(p.current_width / opt_f) 321 | self.delegate.h = int(p.current_height / opt_f) 322 | self.delegate.get_views(overlap, tile_batch_size,batch_size_g) 323 | 324 | info = ', '.join([ 325 | # f"{method.value} hooked into {name!r} sampler", 326 | f"Tile size: {self.delegate.window_size}", 327 | f"Tile count: {self.delegate.num_tiles}", 328 | f"Batch size: {self.delegate.tile_bs}", 329 | f"Tile batches: {len(self.delegate.batched_bboxes)}", 330 | f"Global batch size: {self.delegate.global_tile_bs}", 331 | f"Global batches: {len(self.delegate.global_batched_bboxes)}", 332 | ]) 333 | 334 | print(info) 335 | 336 | noise = p.rng.next() 337 | if hasattr(p,'initial_noise_multiplier'): 338 | if p.initial_noise_multiplier != 1.0: 339 | p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier 340 | noise *= p.initial_noise_multiplier 341 | else: 342 | p.image_conditioning = p.txt2img_image_conditioning(noise) 343 | 344 | p.noise = noise 345 | p.x = p.latents.clone() 346 | p.current_step=0 347 | 348 | p.latents = p.sampler.sample_img2img(p,p.latents, noise , conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) 349 | if self.flag_noise_inverse: 350 | self.delegate.sampler_raw.sample_img2img = self.delegate.sample_img2img_original 351 | self.flag_noise_inverse = False 352 | 353 | p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean 354 | latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num))) 355 | if current_scale_num==1: 356 | res = latents_ 357 | else: 358 | res = torch.concatenate((res,latents_),axis=0) 359 | 360 | ######################################################################################################################################### 361 | 362 | return res 363 | 364 | @staticmethod 365 | def callback_hijack(self_sampler,d,p): 366 | p.current_step = d['i'] 367 | 368 | if self_sampler.stop_at is not None and p.current_step > self_sampler.stop_at: 369 | raise InterruptedException 370 | 371 | state.sampling_step = p.current_step 372 | shared.total_tqdm.update() 373 | p.current_step += 1 374 | 375 | 376 | def create_sampler_hijack( 377 | self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, 378 | noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool,batch_size_g:int 379 | ): 380 | if self.delegate is not None: 381 | # samplers are stateless, we reuse it if possible 382 | if self.delegate.sampler_name == name: 383 | # before we reuse the sampler, we refresh the control tensor 384 | # so that we are compatible with ControlNet batch processing 385 | if self.controlnet_script: 386 | self.delegate.prepare_controlnet_tensors(refresh=True) 387 | return self.delegate.sampler_raw 388 | else: 389 | self.reset() 390 | sd_samplers_common.Sampler.callback_ori = sd_samplers_common.Sampler.callback_state 391 | sd_samplers_common.Sampler.callback_state = lambda self_sampler,d:Script.callback_hijack(self_sampler,d,p) 392 | 393 | self.flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse 394 | flag_noise_inverse = self.flag_noise_inverse 395 | if flag_noise_inverse: 396 | print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') 397 | name = 'Euler' 398 | p.sampler_name = 'Euler' 399 | if name is None: print('>> name is empty') 400 | if model is None: print('>> model is empty') 401 | sampler = Script.create_sampler_original_md(name, model) 402 | if method ==Method_2.DEMO_FU: delegate_cls = DemoFusion 403 | else: raise NotImplementedError(f"Method {method} not implemented.") 404 | 405 | delegate = delegate_cls(p, sampler) 406 | delegate.window_size = min(min(window_size,p.width//8),p.height//8) 407 | p.random_jitter = random_jitter 408 | 409 | if flag_noise_inverse: 410 | get_cache_callback = self.noise_inverse_get_cache 411 | set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) 412 | delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) 413 | 414 | # delegate.get_views(overlap,tile_batch_size,batch_size_g) 415 | if self.controlnet_script: 416 | delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) 417 | if self.stablesr_script: 418 | delegate.init_stablesr(self.stablesr_script) 419 | 420 | # init everything done, perform sanity check & pre-computations 421 | # hijack the behaviours 422 | delegate.hook() 423 | 424 | self.delegate = delegate 425 | 426 | exts = [ 427 | "ContrlNet" if self.controlnet_script else None, 428 | "StableSR" if self.stablesr_script else None, 429 | ] 430 | ext_info = ', '.join([e for e in exts if e]) 431 | if ext_info: ext_info = f' (ext: {ext_info})' 432 | print(ext_info) 433 | 434 | return delegate.sampler_raw 435 | 436 | def create_random_tensors_hijack( 437 | self, bbox_settings: Dict, region_info: Dict, 438 | shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, 439 | ): 440 | org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) 441 | height, width = shape[1], shape[2] 442 | background_noise = torch.zeros_like(org_random_tensors) 443 | background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 444 | foreground_noise = torch.zeros_like(org_random_tensors) 445 | foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 446 | 447 | for i, v in bbox_settings.items(): 448 | seed = get_fixed_seed(v.seed) 449 | x, y, w, h = v.x, v.y, v.w, v.h 450 | # convert to pixel 451 | x = int(x * width) 452 | y = int(y * height) 453 | w = math.ceil(w * width) 454 | h = math.ceil(h * height) 455 | # clamp 456 | x = max(0, x) 457 | y = max(0, y) 458 | w = min(width - x, w) 459 | h = min(height - y, h) 460 | # create random tensor 461 | torch.manual_seed(seed) 462 | rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) 463 | if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: 464 | background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) 465 | background_noise_count[:, :, y:y+h, x:x+w] += 1 466 | elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: 467 | foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) 468 | foreground_noise_count[:, :, y:y+h, x:x+w] += 1 469 | else: 470 | raise NotImplementedError 471 | region_info['Region ' + str(i+1)]['seed'] = seed 472 | 473 | # average 474 | background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) 475 | foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) 476 | # paste two layers to original random tensor 477 | org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) 478 | org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) 479 | return org_random_tensors 480 | 481 | ''' ↓↓↓ helper methods ↓↓↓ ''' 482 | 483 | def dump_regions(self, cfg_name, *bbox_controls): 484 | if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) 485 | 486 | bbox_settings = build_bbox_settings(bbox_controls) 487 | data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} 488 | 489 | if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) 490 | fp = os.path.join(CFG_PATH, cfg_name) 491 | with open(fp, 'w', encoding='utf-8') as fh: 492 | json.dump(data, fh, indent=2, ensure_ascii=False) 493 | 494 | return gr_value(f'Config saved to {fp}.', visible=True) 495 | 496 | def load_regions(self, ref_image, cfg_name, *bbox_controls): 497 | if ref_image is None: 498 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] 499 | fp = os.path.join(CFG_PATH, cfg_name) 500 | if not os.path.exists(fp): 501 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] 502 | 503 | try: 504 | with open(fp, 'r', encoding='utf-8') as fh: 505 | data = json.load(fh) 506 | except Exception as e: 507 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] 508 | 509 | num_boxes = len(data['bbox_controls']) 510 | data_list = [] 511 | for i in range(BBOX_MAX_NUM): 512 | if i < num_boxes: 513 | for k in BBoxSettings._fields: 514 | if k in data['bbox_controls'][i]: 515 | data_list.append(data['bbox_controls'][i][k]) 516 | else: 517 | data_list.append(None) 518 | else: 519 | data_list.extend(DEFAULT_BBOX_SETTINGS) 520 | 521 | return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] 522 | 523 | 524 | def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): 525 | self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) 526 | 527 | def noise_inverse_get_cache(self): 528 | return self.noise_inverse_cache 529 | 530 | 531 | def reset(self): 532 | ''' unhijack inner APIs, see hijack in process() ''' 533 | if hasattr(Script, "create_sampler_original_md"): 534 | sd_samplers.create_sampler = Script.create_sampler_original_md 535 | del Script.create_sampler_original_md 536 | if hasattr(Script, "create_random_tensors_original_md"): 537 | processing.create_random_tensors = Script.create_random_tensors_original_md 538 | del Script.create_random_tensors_original_md 539 | if hasattr(sd_samplers_common.Sampler, "callback_ori"): 540 | sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori 541 | del sd_samplers_common.Sampler.callback_ori 542 | if hasattr(processing, "create_infotext_ori"): 543 | processing.create_infotext = processing.create_infotext_ori 544 | del processing.create_infotext_ori 545 | DemoFusion.unhook() 546 | self.delegate = None 547 | 548 | def reset_and_gc(self): 549 | self.reset() 550 | self.noise_inverse_cache = None 551 | 552 | import gc; gc.collect() 553 | devices.torch_gc() 554 | 555 | try: 556 | import os 557 | import psutil 558 | mem = psutil.Process(os.getpid()).memory_info() 559 | print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') 560 | from modules.shared import mem_mon as vram_mon 561 | from modules.memmon import MemUsageMonitor 562 | vram_mon: MemUsageMonitor 563 | free, total = vram_mon.cuda_mem_get_info() 564 | print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') 565 | except: 566 | pass 567 | -------------------------------------------------------------------------------- /scripts/tilevae.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # ------------------------------------------------------------------------ 3 | # 4 | # Tiled VAE 5 | # 6 | # Introducing a revolutionary new optimization designed to make 7 | # the VAE work with giant images on limited VRAM! 8 | # Say goodbye to the frustration of OOM and hello to seamless output! 9 | # 10 | # ------------------------------------------------------------------------ 11 | # 12 | # This script is a wild hack that splits the image into tiles, 13 | # encodes each tile separately, and merges the result back together. 14 | # 15 | # Advantages: 16 | # - The VAE can now work with giant images on limited VRAM 17 | # (~10 GB for 8K images!) 18 | # - The merged output is completely seamless without any post-processing. 19 | # 20 | # Drawbacks: 21 | # - NaNs always appear in for 8k images when you use fp16 (half) VAE 22 | # You must use --no-half-vae to disable half VAE for that giant image. 23 | # - The gradient calculation is not compatible with this hack. It 24 | # will break any backward() or torch.autograd.grad() that passes VAE. 25 | # (But you can still use the VAE to generate training data.) 26 | # 27 | # How it works: 28 | # 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. 29 | # 2. When Fast Mode is disabled: 30 | # 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. 31 | # 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. 32 | # 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. 33 | # 4. A zigzag execution order is used to reduce unnecessary data transfer. 34 | # 3. When Fast Mode is enabled: 35 | # 1. The original input is downsampled and passed to a separate task queue. 36 | # 2. Its group norm parameters are recorded and used by all tiles' task queues. 37 | # 3. Each tile is separately processed without any RAM-VRAM data transfer. 38 | # 4. After all tiles are processed, tiles are written to a result buffer and returned. 39 | # Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. 40 | # 41 | # Enjoy! 42 | # 43 | # @Author: LI YI @ Nanyang Technological University - Singapore 44 | # @Date: 2023-03-02 45 | # @License: CC BY-NC-SA 4.0 46 | # 47 | # Please give me a star if you like this project! 48 | # 49 | # ------------------------------------------------------------------------- 50 | ''' 51 | 52 | import gc 53 | import math 54 | from time import time 55 | from tqdm import tqdm 56 | 57 | import torch 58 | import torch.version 59 | import torch.nn.functional as F 60 | import gradio as gr 61 | 62 | import modules.scripts as scripts 63 | import modules.devices as devices 64 | from modules.shared import state, opts 65 | from modules.ui import gr_show 66 | from modules.processing import opt_f 67 | from modules.sd_vae_approx import cheap_approximation 68 | from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock 69 | 70 | from tile_utils.attn import get_attn_func 71 | from tile_utils.typing import Processing 72 | 73 | if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 74 | from modules.ui_components import InputAccordion 75 | else: 76 | InputAccordion = None 77 | 78 | 79 | def get_rcmd_enc_tsize(): 80 | if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: 81 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 82 | if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 83 | elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 84 | elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 85 | else: ENCODER_TILE_SIZE = 960 86 | else: ENCODER_TILE_SIZE = 512 87 | return ENCODER_TILE_SIZE 88 | 89 | 90 | def get_rcmd_dec_tsize(): 91 | if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: 92 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 93 | if total_memory > 30*1000: DECODER_TILE_SIZE = 256 94 | elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 95 | elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 96 | elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 97 | else: DECODER_TILE_SIZE = 64 98 | else: DECODER_TILE_SIZE = 64 99 | return DECODER_TILE_SIZE 100 | 101 | 102 | def inplace_nonlinearity(x): 103 | # Test: fix for Nans 104 | return F.silu(x, inplace=True) 105 | 106 | 107 | def attn2task(task_queue, net): 108 | attn_forward = get_attn_func() 109 | task_queue.append(('store_res', lambda x: x)) 110 | task_queue.append(('pre_norm', net.norm)) 111 | task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) 112 | task_queue.append(['add_res', None]) 113 | 114 | 115 | def resblock2task(queue, block): 116 | """ 117 | Turn a ResNetBlock into a sequence of tasks and append to the task queue 118 | 119 | @param queue: the target task queue 120 | @param block: ResNetBlock 121 | 122 | """ 123 | if block.in_channels != block.out_channels: 124 | if block.use_conv_shortcut: 125 | queue.append(('store_res', block.conv_shortcut)) 126 | else: 127 | queue.append(('store_res', block.nin_shortcut)) 128 | else: 129 | queue.append(('store_res', lambda x: x)) 130 | queue.append(('pre_norm', block.norm1)) 131 | queue.append(('silu', inplace_nonlinearity)) 132 | queue.append(('conv1', block.conv1)) 133 | queue.append(('pre_norm', block.norm2)) 134 | queue.append(('silu', inplace_nonlinearity)) 135 | queue.append(('conv2', block.conv2)) 136 | queue.append(['add_res', None]) 137 | 138 | 139 | def build_sampling(task_queue, net, is_decoder): 140 | """ 141 | Build the sampling part of a task queue 142 | @param task_queue: the target task queue 143 | @param net: the network 144 | @param is_decoder: currently building decoder or encoder 145 | """ 146 | if is_decoder: 147 | resblock2task(task_queue, net.mid.block_1) 148 | attn2task(task_queue, net.mid.attn_1) 149 | resblock2task(task_queue, net.mid.block_2) 150 | resolution_iter = reversed(range(net.num_resolutions)) 151 | block_ids = net.num_res_blocks + 1 152 | condition = 0 153 | module = net.up 154 | func_name = 'upsample' 155 | else: 156 | resolution_iter = range(net.num_resolutions) 157 | block_ids = net.num_res_blocks 158 | condition = net.num_resolutions - 1 159 | module = net.down 160 | func_name = 'downsample' 161 | 162 | for i_level in resolution_iter: 163 | for i_block in range(block_ids): 164 | resblock2task(task_queue, module[i_level].block[i_block]) 165 | if i_level != condition: 166 | task_queue.append((func_name, getattr(module[i_level], func_name))) 167 | 168 | if not is_decoder: 169 | resblock2task(task_queue, net.mid.block_1) 170 | attn2task(task_queue, net.mid.attn_1) 171 | resblock2task(task_queue, net.mid.block_2) 172 | 173 | 174 | def build_task_queue(net, is_decoder): 175 | """ 176 | Build a single task queue for the encoder or decoder 177 | @param net: the VAE decoder or encoder network 178 | @param is_decoder: currently building decoder or encoder 179 | @return: the task queue 180 | """ 181 | task_queue = [] 182 | task_queue.append(('conv_in', net.conv_in)) 183 | 184 | # construct the sampling part of the task queue 185 | # because encoder and decoder share the same architecture, we extract the sampling part 186 | build_sampling(task_queue, net, is_decoder) 187 | 188 | if not is_decoder or not net.give_pre_end: 189 | task_queue.append(('pre_norm', net.norm_out)) 190 | task_queue.append(('silu', inplace_nonlinearity)) 191 | task_queue.append(('conv_out', net.conv_out)) 192 | if is_decoder and net.tanh_out: 193 | task_queue.append(('tanh', torch.tanh)) 194 | 195 | return task_queue 196 | 197 | 198 | def clone_task_queue(task_queue): 199 | """ 200 | Clone a task queue 201 | @param task_queue: the task queue to be cloned 202 | @return: the cloned task queue 203 | """ 204 | return [[item for item in task] for task in task_queue] 205 | 206 | 207 | def get_var_mean(input, num_groups, eps=1e-6): 208 | """ 209 | Get mean and var for group norm 210 | """ 211 | b, c = input.size(0), input.size(1) 212 | channel_in_group = int(c/num_groups) 213 | input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) 214 | var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) 215 | return var, mean 216 | 217 | 218 | def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): 219 | """ 220 | Custom group norm with fixed mean and var 221 | 222 | @param input: input tensor 223 | @param num_groups: number of groups. by default, num_groups = 32 224 | @param mean: mean, must be pre-calculated by get_var_mean 225 | @param var: var, must be pre-calculated by get_var_mean 226 | @param weight: weight, should be fetched from the original group norm 227 | @param bias: bias, should be fetched from the original group norm 228 | @param eps: epsilon, by default, eps = 1e-6 to match the original group norm 229 | 230 | @return: normalized tensor 231 | """ 232 | b, c = input.size(0), input.size(1) 233 | channel_in_group = int(c/num_groups) 234 | input_reshaped = input.contiguous().view( 235 | 1, int(b * num_groups), channel_in_group, *input.size()[2:]) 236 | 237 | out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps) 238 | out = out.view(b, c, *input.size()[2:]) 239 | 240 | # post affine transform 241 | if weight is not None: 242 | out *= weight.view(1, -1, 1, 1) 243 | if bias is not None: 244 | out += bias.view(1, -1, 1, 1) 245 | return out 246 | 247 | 248 | def crop_valid_region(x, input_bbox, target_bbox, is_decoder): 249 | """ 250 | Crop the valid region from the tile 251 | @param x: input tile 252 | @param input_bbox: original input bounding box 253 | @param target_bbox: output bounding box 254 | @param scale: scale factor 255 | @return: cropped tile 256 | """ 257 | padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] 258 | margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] 259 | return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] 260 | 261 | 262 | # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ 263 | 264 | def perfcount(fn): 265 | def wrapper(*args, **kwargs): 266 | ts = time() 267 | 268 | if torch.cuda.is_available(): 269 | torch.cuda.reset_peak_memory_stats(devices.device) 270 | devices.torch_gc() 271 | gc.collect() 272 | 273 | ret = fn(*args, **kwargs) 274 | 275 | devices.torch_gc() 276 | gc.collect() 277 | if torch.cuda.is_available(): 278 | vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 279 | print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') 280 | else: 281 | print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') 282 | 283 | return ret 284 | return wrapper 285 | 286 | # ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑ 287 | 288 | 289 | class GroupNormParam: 290 | 291 | def __init__(self): 292 | self.var_list = [] 293 | self.mean_list = [] 294 | self.pixel_list = [] 295 | self.weight = None 296 | self.bias = None 297 | 298 | def add_tile(self, tile, layer): 299 | var, mean = get_var_mean(tile, 32) 300 | # For giant images, the variance can be larger than max float16 301 | # In this case we create a copy to float32 302 | if var.dtype == torch.float16 and var.isinf().any(): 303 | fp32_tile = tile.float() 304 | var, mean = get_var_mean(fp32_tile, 32) 305 | # ============= DEBUG: test for infinite ============= 306 | # if torch.isinf(var).any(): 307 | # print('var: ', var) 308 | # ==================================================== 309 | self.var_list.append(var) 310 | self.mean_list.append(mean) 311 | self.pixel_list.append( 312 | tile.shape[2]*tile.shape[3]) 313 | if hasattr(layer, 'weight'): 314 | self.weight = layer.weight 315 | self.bias = layer.bias 316 | else: 317 | self.weight = None 318 | self.bias = None 319 | 320 | def summary(self): 321 | """ 322 | summarize the mean and var and return a function 323 | that apply group norm on each tile 324 | """ 325 | if len(self.var_list) == 0: return None 326 | 327 | var = torch.vstack(self.var_list) 328 | mean = torch.vstack(self.mean_list) 329 | max_value = max(self.pixel_list) 330 | pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value 331 | sum_pixels = torch.sum(pixels) 332 | pixels = pixels.unsqueeze(1) / sum_pixels 333 | var = torch.sum(var * pixels, dim=0) 334 | mean = torch.sum(mean * pixels, dim=0) 335 | return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) 336 | 337 | @staticmethod 338 | def from_tile(tile, norm): 339 | """ 340 | create a function from a single tile without summary 341 | """ 342 | var, mean = get_var_mean(tile, 32) 343 | if var.dtype == torch.float16 and var.isinf().any(): 344 | fp32_tile = tile.float() 345 | var, mean = get_var_mean(fp32_tile, 32) 346 | # if it is a macbook, we need to convert back to float16 347 | if var.device.type == 'mps': 348 | # clamp to avoid overflow 349 | var = torch.clamp(var, 0, 60000) 350 | var = var.half() 351 | mean = mean.half() 352 | if hasattr(norm, 'weight'): 353 | weight = norm.weight 354 | bias = norm.bias 355 | else: 356 | weight = None 357 | bias = None 358 | 359 | def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): 360 | return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) 361 | return group_norm_func 362 | 363 | 364 | class VAEHook: 365 | 366 | def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False): 367 | self.net = net # encoder | decoder 368 | self.tile_size = tile_size 369 | self.is_decoder = is_decoder 370 | self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder) 371 | self.color_fix = color_fix and not is_decoder 372 | self.to_gpu = to_gpu 373 | self.pad = 11 if is_decoder else 32 # FIXME: magic number 374 | 375 | def __call__(self, x): 376 | original_device = next(self.net.parameters()).device 377 | try: 378 | if self.to_gpu: 379 | self.net = self.net.to(devices.get_optimal_device()) 380 | 381 | B, C, H, W = x.shape 382 | if max(H, W) <= self.pad * 2 + self.tile_size: 383 | print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") 384 | return self.net.original_forward(x) 385 | else: 386 | return self.vae_tile_forward(x) 387 | finally: 388 | self.net = self.net.to(original_device) 389 | 390 | def get_best_tile_size(self, lowerbound, upperbound): 391 | """ 392 | Get the best tile size for GPU memory 393 | """ 394 | divider = 32 395 | while divider >= 2: 396 | remainer = lowerbound % divider 397 | if remainer == 0: 398 | return lowerbound 399 | candidate = lowerbound - remainer + divider 400 | if candidate <= upperbound: 401 | return candidate 402 | divider //= 2 403 | return lowerbound 404 | 405 | def split_tiles(self, h, w): 406 | """ 407 | Tool function to split the image into tiles 408 | @param h: height of the image 409 | @param w: width of the image 410 | @return: tile_input_bboxes, tile_output_bboxes 411 | """ 412 | tile_input_bboxes, tile_output_bboxes = [], [] 413 | tile_size = self.tile_size 414 | pad = self.pad 415 | num_height_tiles = math.ceil((h - 2 * pad) / tile_size) 416 | num_width_tiles = math.ceil((w - 2 * pad) / tile_size) 417 | # If any of the numbers are 0, we let it be 1 418 | # This is to deal with long and thin images 419 | num_height_tiles = max(num_height_tiles, 1) 420 | num_width_tiles = max(num_width_tiles, 1) 421 | 422 | # Suggestions from https://github.com/Kahsolt: auto shrink the tile size 423 | real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) 424 | real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) 425 | real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) 426 | real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) 427 | 428 | print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + 429 | f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') 430 | 431 | for i in range(num_height_tiles): 432 | for j in range(num_width_tiles): 433 | # bbox: [x1, x2, y1, y2] 434 | # the padding is is unnessary for image borders. So we directly start from (32, 32) 435 | input_bbox = [ 436 | pad + j * real_tile_width, 437 | min(pad + (j + 1) * real_tile_width, w), 438 | pad + i * real_tile_height, 439 | min(pad + (i + 1) * real_tile_height, h), 440 | ] 441 | 442 | # if the output bbox is close to the image boundary, we extend it to the image boundary 443 | output_bbox = [ 444 | input_bbox[0] if input_bbox[0] > pad else 0, 445 | input_bbox[1] if input_bbox[1] < w - pad else w, 446 | input_bbox[2] if input_bbox[2] > pad else 0, 447 | input_bbox[3] if input_bbox[3] < h - pad else h, 448 | ] 449 | 450 | # scale to get the final output bbox 451 | output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] 452 | tile_output_bboxes.append(output_bbox) 453 | 454 | # indistinguishable expand the input bbox by pad pixels 455 | tile_input_bboxes.append([ 456 | max(0, input_bbox[0] - pad), 457 | min(w, input_bbox[1] + pad), 458 | max(0, input_bbox[2] - pad), 459 | min(h, input_bbox[3] + pad), 460 | ]) 461 | 462 | return tile_input_bboxes, tile_output_bboxes 463 | 464 | @torch.no_grad() 465 | def estimate_group_norm(self, z, task_queue, color_fix): 466 | device = z.device 467 | tile = z 468 | last_id = len(task_queue) - 1 469 | while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': 470 | last_id -= 1 471 | if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': 472 | raise ValueError('No group norm found in the task queue') 473 | # estimate until the last group norm 474 | for i in range(last_id + 1): 475 | task = task_queue[i] 476 | if task[0] == 'pre_norm': 477 | group_norm_func = GroupNormParam.from_tile(tile, task[1]) 478 | task_queue[i] = ('apply_norm', group_norm_func) 479 | if i == last_id: 480 | return True 481 | tile = group_norm_func(tile) 482 | elif task[0] == 'store_res': 483 | task_id = i + 1 484 | while task_id < last_id and task_queue[task_id][0] != 'add_res': 485 | task_id += 1 486 | if task_id >= last_id: 487 | continue 488 | task_queue[task_id][1] = task[1](tile) 489 | elif task[0] == 'add_res': 490 | tile += task[1].to(device) 491 | task[1] = None 492 | elif color_fix and task[0] == 'downsample': 493 | for j in range(i, last_id + 1): 494 | if task_queue[j][0] == 'store_res': 495 | task_queue[j] = ('store_res_cpu', task_queue[j][1]) 496 | return True 497 | else: 498 | tile = task[1](tile) 499 | try: 500 | devices.test_for_nans(tile, "vae") 501 | except: 502 | print(f'Nan detected in fast mode estimation. Fast mode disabled.') 503 | return False 504 | 505 | raise IndexError('Should not reach here') 506 | 507 | @perfcount 508 | @torch.no_grad() 509 | def vae_tile_forward(self, z): 510 | """ 511 | Decode a latent vector z into an image in a tiled manner. 512 | @param z: latent vector 513 | @return: image 514 | """ 515 | device = next(self.net.parameters()).device 516 | dtype = next(self.net.parameters()).dtype 517 | net = self.net 518 | tile_size = self.tile_size 519 | is_decoder = self.is_decoder 520 | 521 | z = z.detach() # detach the input to avoid backprop 522 | 523 | N, height, width = z.shape[0], z.shape[2], z.shape[3] 524 | net.last_z_shape = z.shape 525 | 526 | # Split the input into tiles and build a task queue for each tile 527 | print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') 528 | 529 | in_bboxes, out_bboxes = self.split_tiles(height, width) 530 | 531 | # Prepare tiles by split the input latents 532 | tiles = [] 533 | for input_bbox in in_bboxes: 534 | tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() 535 | tiles.append(tile) 536 | 537 | num_tiles = len(tiles) 538 | num_completed = 0 539 | 540 | # Build task queues 541 | single_task_queue = build_task_queue(net, is_decoder) 542 | if self.fast_mode: 543 | # Fast mode: downsample the input image to the tile size, 544 | # then estimate the group norm parameters on the downsampled image 545 | scale_factor = tile_size / max(height, width) 546 | z = z.to(device) 547 | downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') 548 | # use nearest-exact to keep statictics as close as possible 549 | print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') 550 | 551 | # ======= Special thanks to @Kahsolt for distribution shift issue ======= # 552 | # The downsampling will heavily distort its mean and std, so we need to recover it. 553 | std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) 554 | std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) 555 | downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old 556 | del std_old, mean_old, std_new, mean_new 557 | # occasionally the std_new is too small or too large, which exceeds the range of float16 558 | # so we need to clamp it to max z's range. 559 | downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) 560 | estimate_task_queue = clone_task_queue(single_task_queue) 561 | if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): 562 | single_task_queue = estimate_task_queue 563 | del downsampled_z 564 | 565 | task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] 566 | 567 | # Dummy result 568 | result = None 569 | result_approx = None 570 | try: 571 | with devices.autocast(): 572 | result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() 573 | except: pass 574 | # Free memory of input latent tensor 575 | del z 576 | 577 | # Task queue execution 578 | pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") 579 | 580 | # execute the task back and forth when switch tiles so that we always 581 | # keep one tile on the GPU to reduce unnecessary data transfer 582 | forward = True 583 | interrupted = False 584 | #state.interrupted = interrupted 585 | while True: 586 | if state.interrupted: interrupted = True ; break 587 | 588 | group_norm_param = GroupNormParam() 589 | for i in range(num_tiles) if forward else reversed(range(num_tiles)): 590 | if state.interrupted: interrupted = True ; break 591 | 592 | tile = tiles[i].to(device) 593 | input_bbox = in_bboxes[i] 594 | task_queue = task_queues[i] 595 | 596 | interrupted = False 597 | while len(task_queue) > 0: 598 | if state.interrupted: interrupted = True ; break 599 | 600 | # DEBUG: current task 601 | # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) 602 | task = task_queue.pop(0) 603 | if task[0] == 'pre_norm': 604 | group_norm_param.add_tile(tile, task[1]) 605 | break 606 | elif task[0] == 'store_res' or task[0] == 'store_res_cpu': 607 | task_id = 0 608 | res = task[1](tile) 609 | if not self.fast_mode or task[0] == 'store_res_cpu': 610 | res = res.cpu() 611 | while task_queue[task_id][0] != 'add_res': 612 | task_id += 1 613 | task_queue[task_id][1] = res 614 | elif task[0] == 'add_res': 615 | tile += task[1].to(device) 616 | task[1] = None 617 | else: 618 | tile = task[1](tile) 619 | pbar.update(1) 620 | 621 | if interrupted: break 622 | 623 | # check for NaNs in the tile. 624 | # If there are NaNs, we abort the process to save user's time 625 | devices.test_for_nans(tile, "vae") 626 | 627 | if len(task_queue) == 0: 628 | tiles[i] = None 629 | num_completed += 1 630 | if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically 631 | result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) 632 | result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) 633 | del tile 634 | elif i == num_tiles - 1 and forward: 635 | forward = False 636 | tiles[i] = tile 637 | elif i == 0 and not forward: 638 | forward = True 639 | tiles[i] = tile 640 | else: 641 | tiles[i] = tile.cpu() 642 | del tile 643 | 644 | if interrupted: break 645 | if num_completed == num_tiles: break 646 | 647 | # insert the group norm task to the head of each task queue 648 | group_norm_func = group_norm_param.summary() 649 | if group_norm_func is not None: 650 | for i in range(num_tiles): 651 | task_queue = task_queues[i] 652 | task_queue.insert(0, ('apply_norm', group_norm_func)) 653 | 654 | # Done! 655 | pbar.close() 656 | return result.to(dtype) if result is not None else result_approx.to(device, dtype=dtype) 657 | 658 | 659 | class Script(scripts.Script): 660 | 661 | def __init__(self): 662 | self.hooked = False 663 | 664 | def title(self): 665 | return "Tiled VAE" 666 | 667 | def show(self, is_img2img): 668 | return scripts.AlwaysVisible 669 | 670 | def ui(self, is_img2img): 671 | tab = 't2i' if not is_img2img else 'i2i' 672 | uid = lambda name: f'MD-{tab}-{name}' 673 | 674 | with ( 675 | InputAccordion(False, label='Tiled VAE', elem_id=f'MDV-{tab}-enabled') if InputAccordion 676 | else gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}') 677 | as enabled 678 | ): 679 | with gr.Row() as tab_enable: 680 | if not InputAccordion: 681 | enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable')) 682 | vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu')) 683 | 684 | gr.HTML('

Recommended to set tile sizes as large as possible before got CUDA error: out of memory.

') 685 | with gr.Row() as tab_size: 686 | encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size')) 687 | decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size')) 688 | reset = gr.Button(value='↻ Reset', variant='tool') 689 | reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False) 690 | 691 | with gr.Row() as tab_param: 692 | fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc')) 693 | color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix')) 694 | fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec')) 695 | 696 | fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False) 697 | 698 | return [ 699 | enabled, 700 | encoder_tile_size, decoder_tile_size, 701 | vae_to_gpu, fast_decoder, fast_encoder, color_fix, 702 | ] 703 | 704 | def process(self, p:Processing, 705 | enabled:bool, 706 | encoder_tile_size:int, decoder_tile_size:int, 707 | vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool 708 | ): 709 | 710 | # for shorthand 711 | vae = p.sd_model.first_stage_model 712 | encoder = vae.encoder 713 | decoder = vae.decoder 714 | 715 | # undo hijack if disabled (in cases last time crashed) 716 | if not enabled: 717 | if self.hooked: 718 | if isinstance(encoder.forward, VAEHook): 719 | encoder.forward.net = None 720 | encoder.forward = encoder.original_forward 721 | if isinstance(decoder.forward, VAEHook): 722 | decoder.forward.net = None 723 | decoder.forward = decoder.original_forward 724 | self.hooked = False 725 | return 726 | 727 | if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu: 728 | print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.") 729 | 730 | # do hijack 731 | kwargs = { 732 | 'fast_decoder': fast_decoder, 733 | 'fast_encoder': fast_encoder, 734 | 'color_fix': color_fix, 735 | 'to_gpu': vae_to_gpu, 736 | } 737 | 738 | # save original forward (only once) 739 | if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward) 740 | if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward) 741 | 742 | self.hooked = True 743 | 744 | encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs) 745 | decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs) 746 | 747 | def postprocess(self, p:Processing, processed, enabled:bool, *args): 748 | if not enabled: return 749 | 750 | vae = p.sd_model.first_stage_model 751 | encoder = vae.encoder 752 | decoder = vae.decoder 753 | if isinstance(encoder.forward, VAEHook): 754 | encoder.forward.net = None 755 | encoder.forward = encoder.original_forward 756 | if isinstance(decoder.forward, VAEHook): 757 | decoder.forward.net = None 758 | decoder.forward = decoder.original_forward 759 | -------------------------------------------------------------------------------- /tile_methods/demofusion.py: -------------------------------------------------------------------------------- 1 | from tile_methods.abstractdiffusion import AbstractDiffusion 2 | from tile_utils.utils import * 3 | import torch.nn.functional as F 4 | import random 5 | from copy import deepcopy 6 | import inspect 7 | from modules import sd_samplers_common 8 | 9 | 10 | class DemoFusion(AbstractDiffusion): 11 | """ 12 | DemoFusion Implementation 13 | https://arxiv.org/abs/2311.16973 14 | """ 15 | 16 | def __init__(self, p:Processing, *args, **kwargs): 17 | super().__init__(p, *args, **kwargs) 18 | assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' 19 | 20 | 21 | def hook(self): 22 | steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) 23 | 24 | self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward 25 | self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward 26 | self.sampler.model_wrap_cfg.forward = self.forward_one_step 27 | if self.is_kdiff: 28 | self.sampler: KDiffusionSampler 29 | self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion 30 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] 31 | else: 32 | self.sampler: CompVisSampler 33 | self.sampler.model_wrap_cfg: CFGDenoiserTimesteps 34 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] 35 | self.timesteps = self.sampler.get_timesteps(self.p, steps) 36 | 37 | @staticmethod 38 | def unhook(): 39 | if hasattr(shared.sd_model, 'apply_model_ori'): 40 | shared.sd_model.apply_model = shared.sd_model.apply_model_ori 41 | del shared.sd_model.apply_model_ori 42 | 43 | def reset_buffer(self, x_in:Tensor): 44 | super().reset_buffer(x_in) 45 | 46 | 47 | 48 | def repeat_tensor(self, x:Tensor, n:int) -> Tensor: 49 | ''' repeat the tensor on it's first dim ''' 50 | if n == 1: return x 51 | B = x.shape[0] 52 | r_dims = len(x.shape) - 1 53 | if B == 1: # batch_size = 1 (not `tile_batch_size`) 54 | shape = [n] + [-1] * r_dims # [N, -1, ...] 55 | return x.expand(shape) # `expand` is much lighter than `tile` 56 | else: 57 | shape = [n] + [1] * r_dims # [N, 1, ...] 58 | return x.repeat(shape) 59 | 60 | def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: 61 | ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' 62 | # n_repeat 63 | n_rep = len(bboxes) 64 | # txt cond 65 | tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] 66 | tcond = self.repeat_tensor(tcond, n_rep) 67 | # img cond 68 | icond = self.get_icond(cond_in) 69 | if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] 70 | if mode == 0: 71 | if self.p.random_jitter: 72 | jitter_range = self.jitter_range 73 | icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) 74 | icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) 75 | else: 76 | icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) 77 | else: # txt2img, [B=1, C=5, H=1, W=1] 78 | icond = self.repeat_tensor(icond, n_rep) 79 | 80 | # vec cond (SDXL) 81 | vcond = self.get_vcond(cond_in) # [B=1, D] 82 | if vcond is not None: 83 | vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] 84 | return self.make_cond_dict(cond_in, tcond, icond, vcond) 85 | 86 | 87 | def global_split_bboxes(self): 88 | cols = self.p.current_scale_num 89 | rows = cols 90 | 91 | bbox_list = [] 92 | for row in range(rows): 93 | y = row 94 | for col in range(cols): 95 | x = col 96 | bbox = (x, y) 97 | bbox_list.append(bbox) 98 | 99 | return bbox_list+bbox_list if self.p.mixture else bbox_list 100 | 101 | def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: 102 | cols = math.ceil((w_l - overlap) / (tile_w - overlap)) 103 | rows = math.ceil((h_l - overlap) / (tile_h - overlap)) 104 | if rows==0: 105 | rows=1 106 | if cols == 0: 107 | cols=1 108 | dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 109 | dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 110 | bbox_list: List[BBox] = [] 111 | self.jitter_range = 0 112 | for row in range(rows): 113 | for col in range(cols): 114 | h = min(int(row * dy), h_l - tile_h) 115 | w = min(int(col * dx), w_l - tile_w) 116 | if self.p.random_jitter: 117 | self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) 118 | jitter_range = self.jitter_range 119 | w_jitter = 0 120 | h_jitter = 0 121 | if (w != 0) and (w+tile_w != w_l): 122 | w_jitter = random.randint(-jitter_range, jitter_range) 123 | elif (w == 0) and (w + tile_w != w_l): 124 | w_jitter = random.randint(-jitter_range, 0) 125 | elif (w != 0) and (w + tile_w == w_l): 126 | w_jitter = random.randint(0, jitter_range) 127 | if (h != 0) and (h + tile_h != h_l): 128 | h_jitter = random.randint(-jitter_range, jitter_range) 129 | elif (h == 0) and (h + tile_h != h_l): 130 | h_jitter = random.randint(-jitter_range, 0) 131 | elif (h != 0) and (h + tile_h == h_l): 132 | h_jitter = random.randint(0, jitter_range) 133 | h +=(h_jitter + jitter_range) 134 | w += (w_jitter + jitter_range) 135 | 136 | bbox = BBox(w, h, tile_w, tile_h) 137 | bbox_list.append(bbox) 138 | return bbox_list, None 139 | 140 | @grid_bbox 141 | def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): 142 | self.enable_grid_bbox = True 143 | self.tile_w = self.window_size 144 | self.tile_h = self.window_size 145 | 146 | self.overlap = max(0, min(overlap, self.window_size - 4)) 147 | 148 | self.stride = max(4,self.window_size - self.overlap) 149 | 150 | # split the latent into overlapped tiles, then batching 151 | # weights basically indicate how many times a pixel is painted 152 | bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) 153 | self.num_tiles = len(bboxes) 154 | self.num_batches = math.ceil(self.num_tiles / tile_bs) 155 | self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size 156 | self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] 157 | 158 | global_bboxes = self.global_split_bboxes() 159 | self.global_num_tiles = len(global_bboxes) 160 | self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) 161 | self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) 162 | self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] 163 | 164 | def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): 165 | x_coord = torch.arange(kernel_size, device=devices.device) 166 | gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) 167 | gaussian_1d = gaussian_1d / gaussian_1d.sum() 168 | gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] 169 | kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) 170 | 171 | return kernel 172 | 173 | def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): 174 | channels = latents.shape[1] 175 | kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) 176 | blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) 177 | 178 | return blurred_latents 179 | 180 | 181 | 182 | ''' ↓↓↓ kernel hijacks ↓↓↓ ''' 183 | @torch.no_grad() 184 | @keep_signature 185 | def forward_one_step(self, x_in, sigma, **kwarg): 186 | if self.is_kdiff: 187 | x_noisy = self.p.x + self.p.noise * sigma[0] 188 | else: 189 | alphas_cumprod = self.p.sd_model.alphas_cumprod 190 | sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) 191 | sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) 192 | x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod 193 | 194 | self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) 195 | 196 | c1 = self.cosine_factor ** self.p.cosine_scale_1 197 | 198 | x_in = x_in*(1 - c1) + x_noisy * c1 199 | 200 | if self.p.random_jitter: 201 | jitter_range = self.jitter_range 202 | else: 203 | jitter_range = 0 204 | x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) 205 | _,_,H,W = x_in.shape 206 | 207 | self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step 208 | self.repeat_3 = False 209 | 210 | x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) 211 | self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward 212 | x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] 213 | 214 | return x_out 215 | 216 | 217 | @torch.no_grad() 218 | @keep_signature 219 | def sample_one_step(self, x_in, sigma, cond): 220 | assert LatentDiffusion.apply_model 221 | def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: 222 | sigma_tile = self.repeat_tensor(sigma, len(bboxes)) 223 | cond_tile = self.repeat_cond_dict(cond, bboxes,mode) 224 | return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) 225 | 226 | def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: 227 | n_rep = len(bboxes) 228 | ts_tile = self.repeat_tensor(sigma, n_rep) 229 | if isinstance(cond, dict): # FIXME: when will enter this branch? 230 | cond_tile = self.repeat_cond_dict(cond, bboxes,mode) 231 | else: 232 | cond_tile = self.repeat_tensor(cond, n_rep) 233 | return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) 234 | 235 | def repeat_func_3(x_tile:Tensor, bboxes,mode=0): 236 | sigma_in_tile = sigma.repeat(len(bboxes)) 237 | cond_out = self.repeat_cond_dict(cond, bboxes,mode) 238 | x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) 239 | return x_tile_out 240 | 241 | if self.repeat_3: 242 | repeat_func = repeat_func_3 243 | self.repeat_3 = False 244 | elif self.is_kdiff: 245 | repeat_func = repeat_func_1 246 | else: 247 | repeat_func = repeat_func_2 248 | N,_,_,_ = x_in.shape 249 | 250 | 251 | self.x_buffer = torch.zeros_like(x_in) 252 | self.weights = torch.zeros_like(x_in) 253 | 254 | for batch_id, bboxes in enumerate(self.batched_bboxes): 255 | if state.interrupted: return x_in 256 | x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) 257 | x_tile_out = repeat_func(x_tile, bboxes) 258 | # de-batching 259 | for i, bbox in enumerate(bboxes): 260 | self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] 261 | self.weights[bbox.slicer] += 1 262 | self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode 263 | 264 | x_local = self.x_buffer/self.weights 265 | 266 | self.x_buffer = torch.zeros_like(self.x_buffer) 267 | self.weights = torch.zeros_like(self.weights) 268 | 269 | std_, mean_ = x_in.std(), x_in.mean() 270 | c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 271 | if self.p.gaussian_filter: 272 | x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) 273 | x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ 274 | 275 | if not hasattr(self.p.sd_model, 'apply_model_ori'): 276 | self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model 277 | self.p.sd_model.apply_model = self.apply_model_hijack 278 | x_global = torch.zeros_like(x_local) 279 | jitter_range = self.jitter_range 280 | end = x_global.shape[3]-jitter_range 281 | 282 | current_num = 0 283 | if self.p.mixture: 284 | for batch_id, bboxes in enumerate(self.global_batched_bboxes): 285 | current_num += len(bboxes) 286 | if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): 287 | res = len(bboxes) - (current_num - self.global_num_tiles//2) 288 | x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx (self.global_num_tiles//2): 290 | x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) 291 | else: 292 | x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) 293 | 294 | x_global_i = repeat_func(x_in_i,bboxes,mode=1) 295 | 296 | if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): 297 | for idx,bbox in enumerate(bboxes): 298 | x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] 299 | elif current_num > (self.global_num_tiles//2): 300 | for idx,bbox in enumerate(bboxes): 301 | x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] 302 | else: 303 | for idx,bbox in enumerate(bboxes): 304 | x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] 305 | else: 306 | for batch_id, bboxes in enumerate(self.global_batched_bboxes): 307 | x_in_i = torch.cat([x_in_g[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] for bbox in bboxes],dim=0) 308 | x_global_i = repeat_func(x_in_i,bboxes,mode=1) 309 | for idx,bbox in enumerate(bboxes): 310 | x_global[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] += x_global_i[idx*N:(idx+1)*N,:,:,:] 311 | #NOTE According to the original execution process, it would be very strange to use the predicted noise of gaussian latents to predict the denoised data in non Gaussian latents. Why? 312 | if self.p.mixture: 313 | self.x_buffer +=x_global/2 314 | else: 315 | self.x_buffer += x_global 316 | self.weights += 1 317 | 318 | self.p.sd_model.apply_model = self.p.sd_model.apply_model_ori 319 | 320 | x_global = self.x_buffer/self.weights 321 | c2 = self.cosine_factor**self.p.cosine_scale_2 322 | self.x_buffer= x_local*(1-c2)+ x_global*c2 323 | 324 | return self.x_buffer 325 | 326 | 327 | 328 | @torch.no_grad() 329 | @keep_signature 330 | def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict): 331 | assert LatentDiffusion.apply_model 332 | 333 | x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) 334 | return x_tile_out 335 | # NOTE: Using Gaussian Latent to Predict Noise on the Original Latent 336 | # if self.flag == 1: 337 | # x_tile_out = self.p.sd_model.apply_model_ori(x_in,t_in,cond) 338 | # self.x_out_list.append(x_tile_out) 339 | # return x_tile_out 340 | # else: 341 | # self.x_out_idx += 1 342 | # return self.x_out_list[self.x_out_idx] 343 | 344 | 345 | def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: 346 | # NOTE: The following code is analytically wrong but aesthetically beautiful 347 | cond_in_original = cond_in.copy() 348 | self.repeat_3 = True 349 | self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) 350 | jitter_range = self.jitter_range 351 | _,_,H,W = x_in.shape 352 | x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) 353 | return self.sample_one_step(x_in_, sigma_in, cond_in_original)[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] 354 | -------------------------------------------------------------------------------- /tile_methods/mixtureofdiffusers.py: -------------------------------------------------------------------------------- 1 | from tile_methods.abstractdiffusion import AbstractDiffusion 2 | from tile_utils.utils import * 3 | 4 | 5 | class MixtureOfDiffusers(AbstractDiffusion): 6 | """ 7 | Mixture-of-Diffusers Implementation 8 | https://github.com/albarji/mixture-of-diffusers 9 | """ 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | # weights for custom bboxes 15 | self.custom_weights: List[Tensor] = [] 16 | self.get_weight = gaussian_weights 17 | 18 | def hook(self): 19 | if not hasattr(shared.sd_model, 'apply_model_original_md'): 20 | shared.sd_model.apply_model_original_md = shared.sd_model.apply_model 21 | shared.sd_model.apply_model = self.apply_model_hijack 22 | 23 | @staticmethod 24 | def unhook(): 25 | if hasattr(shared.sd_model, 'apply_model_original_md'): 26 | shared.sd_model.apply_model = shared.sd_model.apply_model_original_md 27 | del shared.sd_model.apply_model_original_md 28 | 29 | def init_done(self): 30 | super().init_done() 31 | # The original gaussian weights can be extremely small, so we rescale them for numerical stability 32 | self.rescale_factor = 1 / self.weights 33 | # Meanwhile, we rescale the custom weights in advance to save time of slicing 34 | for bbox_id, bbox in enumerate(self.custom_bboxes): 35 | if bbox.blend_mode == BlendMode.BACKGROUND: 36 | self.custom_weights[bbox_id] *= self.rescale_factor[bbox.slicer] 37 | 38 | @grid_bbox 39 | def get_tile_weights(self) -> Tensor: 40 | # weights for grid bboxes 41 | if not hasattr(self, 'tile_weights'): 42 | self.tile_weights = self.get_weight(self.tile_w, self.tile_h) 43 | return self.tile_weights 44 | 45 | @custom_bbox 46 | def init_custom_bbox(self, *args): 47 | super().init_custom_bbox(*args) 48 | 49 | for bbox in self.custom_bboxes: 50 | if bbox.blend_mode == BlendMode.BACKGROUND: 51 | custom_weights = self.get_weight(bbox.w, bbox.h) 52 | self.weights[bbox.slicer] += custom_weights 53 | self.custom_weights.append(custom_weights.unsqueeze(0).unsqueeze(0)) 54 | else: 55 | self.custom_weights.append(None) 56 | 57 | ''' ↓↓↓ kernel hijacks ↓↓↓ ''' 58 | 59 | @torch.no_grad() 60 | @keep_signature 61 | def apply_model_hijack(self, x_in:Tensor, t_in:Tensor, cond:CondDict, noise_inverse_step:int=-1): 62 | assert LatentDiffusion.apply_model 63 | 64 | # KDiffusion Compatibility for naming 65 | c_in: CondDict = cond 66 | 67 | N, C, H, W = x_in.shape 68 | if (H, W) != (self.h, self.w): 69 | # We don't tile highres, let's just use the original apply_model 70 | self.reset_controlnet_tensors() 71 | return shared.sd_model.apply_model_original_md(x_in, t_in, c_in) 72 | 73 | # clear buffer canvas 74 | self.reset_buffer(x_in) 75 | 76 | # Global sampling 77 | if self.draw_background: 78 | for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size` 79 | if state.interrupted: return x_in 80 | 81 | # batching 82 | x_tile_list = [] 83 | t_tile_list = [] 84 | tcond_tile_list = [] 85 | icond_tile_list = [] 86 | vcond_tile_list = [] 87 | for bbox in bboxes: 88 | x_tile_list.append(x_in[bbox.slicer]) 89 | t_tile_list.append(t_in) 90 | if isinstance(c_in, dict): 91 | # tcond 92 | tcond_tile = self.get_tcond(c_in) # cond, [1, 77, 768] 93 | tcond_tile_list.append(tcond_tile) 94 | # icond: might be dummy for txt2img, latent mask for img2img 95 | icond = self.get_icond(c_in) 96 | if icond.shape[2:] == (self.h, self.w): 97 | icond = icond[bbox.slicer] 98 | icond_tile_list.append(icond) 99 | # vcond: 100 | vcond = self.get_vcond(c_in) 101 | vcond_tile_list.append(vcond) 102 | else: 103 | print('>> [WARN] not supported, make an issue on github!!') 104 | x_tile = torch.cat(x_tile_list, dim=0) # differs each 105 | t_tile = torch.cat(t_tile_list, dim=0) # just repeat 106 | tcond_tile = torch.cat(tcond_tile_list, dim=0) # just repeat 107 | icond_tile = torch.cat(icond_tile_list, dim=0) # differs each 108 | vcond_tile = torch.cat(vcond_tile_list, dim=0) if None not in vcond_tile_list else None # just repeat 109 | 110 | c_tile = self.make_cond_dict(c_in, tcond_tile, icond_tile, vcond_tile) 111 | 112 | # controlnet 113 | self.switch_controlnet_tensors(batch_id, N, len(bboxes), is_denoise=True) 114 | 115 | # stablesr 116 | self.switch_stablesr_tensors(batch_id) 117 | 118 | # denoising: here the x is the noise 119 | x_tile_out = shared.sd_model.apply_model_original_md(x_tile, t_tile, c_tile) 120 | 121 | # de-batching 122 | for i, bbox in enumerate(bboxes): 123 | # This weights can be calcluated in advance, but will cost a lot of vram 124 | # when you have many tiles. So we calculate it here. 125 | w = self.tile_weights * self.rescale_factor[bbox.slicer] 126 | self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] * w 127 | 128 | self.update_pbar() 129 | 130 | # Custom region sampling 131 | x_feather_buffer = None 132 | x_feather_mask = None 133 | x_feather_count = None 134 | if len(self.custom_bboxes) > 0: 135 | for bbox_id, bbox in enumerate(self.custom_bboxes): 136 | if not self.p.disable_extra_networks: 137 | with devices.autocast(): 138 | extra_networks.activate(self.p, bbox.extra_network_data) 139 | 140 | x_tile = x_in[bbox.slicer] 141 | if noise_inverse_step < 0: 142 | x_tile_out = self.custom_apply_model(x_tile, t_in, c_in, bbox_id, bbox) 143 | else: 144 | tcond = Condition.reconstruct_cond(bbox.cond, noise_inverse_step) 145 | icond = self.get_icond(c_in) 146 | if icond.shape[2:] == (self.h, self.w): 147 | icond = icond[bbox.slicer] 148 | vcond = self.get_vcond(c_in) 149 | c_out = self.make_cond_dict(c_in, tcond, icond, vcond) 150 | x_tile_out = shared.sd_model.apply_model(x_tile, t_in, cond=c_out) 151 | 152 | if bbox.blend_mode == BlendMode.BACKGROUND: 153 | self.x_buffer[bbox.slicer] += x_tile_out * self.custom_weights[bbox_id] 154 | elif bbox.blend_mode == BlendMode.FOREGROUND: 155 | if x_feather_buffer is None: 156 | x_feather_buffer = torch.zeros_like(self.x_buffer) 157 | x_feather_mask = torch.zeros((1, 1, H, W), device=self.x_buffer.device) 158 | x_feather_count = torch.zeros((1, 1, H, W), device=self.x_buffer.device) 159 | x_feather_buffer[bbox.slicer] += x_tile_out 160 | x_feather_mask [bbox.slicer] += bbox.feather_mask 161 | x_feather_count [bbox.slicer] += 1 162 | 163 | self.update_pbar() 164 | 165 | if not self.p.disable_extra_networks: 166 | with devices.autocast(): 167 | extra_networks.deactivate(self.p, bbox.extra_network_data) 168 | 169 | x_out = self.x_buffer 170 | if x_feather_buffer is not None: 171 | # Average overlapping feathered regions 172 | x_feather_buffer = torch.where(x_feather_count > 1, x_feather_buffer / x_feather_count, x_feather_buffer) 173 | x_feather_mask = torch.where(x_feather_count > 1, x_feather_mask / x_feather_count, x_feather_mask) 174 | # Weighted average with original x_buffer 175 | x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out) 176 | 177 | # For mixture of diffusers, we cannot fill the not denoised area. 178 | # So we just leave it as it is. 179 | return x_out 180 | 181 | def custom_apply_model(self, x_in, t_in, c_in, bbox_id, bbox) -> Tensor: 182 | if self.is_kdiff: 183 | return self.kdiff_custom_forward(x_in, t_in, c_in, bbox_id, bbox, forward_func=shared.sd_model.apply_model_original_md) 184 | else: 185 | def forward_func(x, c, ts, unconditional_conditioning, *args, **kwargs) -> Tensor: 186 | # copy from p_sample_ddim in ddim.py 187 | c_in: CondDict = dict() 188 | for k in c: 189 | if isinstance(c[k], list): 190 | c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] 191 | else: 192 | c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) 193 | self.set_custom_controlnet_tensors(bbox_id, x.shape[0]) 194 | self.set_custom_stablesr_tensors(bbox_id) 195 | return shared.sd_model.apply_model_original_md(x, ts, c_in) 196 | return self.ddim_custom_forward(x_in, c_in, bbox, ts=t_in, forward_func=forward_func) 197 | 198 | @torch.no_grad() 199 | def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: 200 | return self.apply_model_hijack(x_in, sigma_in, cond=cond_in, noise_inverse_step=step) 201 | -------------------------------------------------------------------------------- /tile_methods/multidiffusion.py: -------------------------------------------------------------------------------- 1 | from tile_methods.abstractdiffusion import AbstractDiffusion 2 | from tile_utils.utils import * 3 | 4 | 5 | class MultiDiffusion(AbstractDiffusion): 6 | """ 7 | Multi-Diffusion Implementation 8 | https://arxiv.org/abs/2302.08113 9 | """ 10 | 11 | def __init__(self, p:Processing, *args, **kwargs): 12 | super().__init__(p, *args, **kwargs) 13 | assert p.sampler_name != 'UniPC', 'MultiDiffusion is not compatible with UniPC!' 14 | 15 | def hook(self): 16 | if self.is_kdiff: 17 | # For K-Diffusion sampler with uniform prompt, we hijack into the inner model for simplicity 18 | # Otherwise, the masked-redraw will break due to the init_latent 19 | self.sampler: KDiffusionSampler 20 | self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion 21 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] 22 | self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward 23 | self.sampler.model_wrap_cfg.inner_model.forward = self.kdiff_forward 24 | else: 25 | self.sampler: CompVisSampler 26 | self.sampler.model_wrap_cfg: CFGDenoiserTimesteps 27 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] 28 | self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward 29 | self.sampler.model_wrap_cfg.inner_model.forward = self.ddim_forward 30 | 31 | @staticmethod 32 | def unhook(): 33 | # no need to unhook MultiDiffusion as it only hook the sampler, 34 | # which will be destroyed after the painting is done 35 | pass 36 | 37 | def reset_buffer(self, x_in:Tensor): 38 | super().reset_buffer(x_in) 39 | 40 | @custom_bbox 41 | def init_custom_bbox(self, *args): 42 | super().init_custom_bbox(*args) 43 | 44 | for bbox in self.custom_bboxes: 45 | if bbox.blend_mode == BlendMode.BACKGROUND: 46 | self.weights[bbox.slicer] += 1.0 47 | 48 | ''' ↓↓↓ kernel hijacks ↓↓↓ ''' 49 | 50 | @torch.no_grad() 51 | @keep_signature 52 | def kdiff_forward(self, x_in:Tensor, sigma_in:Tensor, cond:CondDict) -> Tensor: 53 | assert CompVisDenoiser.forward 54 | assert CompVisVDenoiser.forward 55 | 56 | def org_func(x:Tensor) -> Tensor: 57 | return self.sampler_forward(x, sigma_in, cond=cond) 58 | 59 | def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tensor: 60 | # For kdiff sampler, the dim 0 of input x_in is: 61 | # = batch_size * (num_AND + 1) if not an edit model 62 | # = batch_size * (num_AND + 2) otherwise 63 | sigma_tile = self.repeat_tensor(sigma_in, len(bboxes)) 64 | cond_tile = self.repeat_cond_dict(cond, bboxes) 65 | return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) 66 | 67 | def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: 68 | return self.kdiff_custom_forward(x, sigma_in, cond, bbox_id, bbox, self.sampler_forward) 69 | 70 | return self.sample_one_step(x_in, org_func, repeat_func, custom_func) 71 | 72 | @torch.no_grad() 73 | @keep_signature 74 | def ddim_forward(self, x_in:Tensor, ts_in:Tensor, cond:Union[CondDict, Tensor]) -> Tensor: 75 | assert CompVisTimestepsDenoiser.forward 76 | assert CompVisTimestepsVDenoiser.forward 77 | 78 | def org_func(x:Tensor) -> Tensor: 79 | return self.sampler_forward(x, ts_in, cond=cond) 80 | 81 | def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]) -> Tuple[Tensor, Tensor]: 82 | n_rep = len(bboxes) 83 | ts_tile = self.repeat_tensor(ts_in, n_rep) 84 | if isinstance(cond, dict): # FIXME: when will enter this branch? 85 | cond_tile = self.repeat_cond_dict(cond, bboxes) 86 | else: 87 | cond_tile = self.repeat_tensor(cond, n_rep) 88 | return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) 89 | 90 | def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox) -> Tensor: 91 | # before the final forward, we can set the control tensor 92 | def forward_func(x, *args, **kwargs): 93 | self.set_custom_controlnet_tensors(bbox_id, 2*x.shape[0]) 94 | self.set_custom_stablesr_tensors(bbox_id) 95 | return self.sampler_forward(x, *args, **kwargs) 96 | return self.ddim_custom_forward(x, cond, bbox, ts_in, forward_func) 97 | 98 | return self.sample_one_step(x_in, org_func, repeat_func, custom_func) 99 | 100 | def repeat_tensor(self, x:Tensor, n:int) -> Tensor: 101 | ''' repeat the tensor on it's first dim ''' 102 | if n == 1: return x 103 | B = x.shape[0] 104 | r_dims = len(x.shape) - 1 105 | if B == 1: # batch_size = 1 (not `tile_batch_size`) 106 | shape = [n] + [-1] * r_dims # [N, -1, ...] 107 | return x.expand(shape) # `expand` is much lighter than `tile` 108 | else: 109 | shape = [n] + [1] * r_dims # [N, 1, ...] 110 | return x.repeat(shape) 111 | 112 | def repeat_cond_dict(self, cond_in:CondDict, bboxes:List[CustomBBox]) -> CondDict: 113 | ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' 114 | # n_repeat 115 | n_rep = len(bboxes) 116 | # txt cond 117 | tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] 118 | tcond = self.repeat_tensor(tcond, n_rep) 119 | # img cond 120 | icond = self.get_icond(cond_in) 121 | if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] 122 | icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) 123 | else: # txt2img, [B=1, C=5, H=1, W=1] 124 | icond = self.repeat_tensor(icond, n_rep) 125 | # vec cond (SDXL) 126 | vcond = self.get_vcond(cond_in) # [B=1, D] 127 | if vcond is not None: 128 | vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] 129 | return self.make_cond_dict(cond_in, tcond, icond, vcond) 130 | 131 | def sample_one_step(self, x_in:Tensor, org_func:Callable, repeat_func:Callable, custom_func:Callable) -> Tensor: 132 | ''' 133 | this method splits the whole latent and process in tiles 134 | - x_in: current whole U-Net latent 135 | - org_func: original forward function, when use highres 136 | - repeat_func: one step denoiser for grid tile 137 | - custom_func: one step denoiser for custom tile 138 | ''' 139 | 140 | N, C, H, W = x_in.shape 141 | if (H, W) != (self.h, self.w): 142 | # We don't tile highres, let's just use the original org_func 143 | self.reset_controlnet_tensors() 144 | return org_func(x_in) 145 | 146 | # clear buffer canvas 147 | self.reset_buffer(x_in) 148 | 149 | # Background sampling (grid bbox) 150 | if self.draw_background: 151 | for batch_id, bboxes in enumerate(self.batched_bboxes): 152 | if state.interrupted: return x_in 153 | 154 | # batching 155 | x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) # [TB, C, TH, TW] 156 | 157 | # controlnet tiling 158 | # FIXME: is_denoise is default to False, however it is set to True in case of MixtureOfDiffusers, why? 159 | self.switch_controlnet_tensors(batch_id, N, len(bboxes)) 160 | 161 | # stablesr tiling 162 | self.switch_stablesr_tensors(batch_id) 163 | 164 | # compute tiles 165 | x_tile_out = repeat_func(x_tile, bboxes) 166 | for i, bbox in enumerate(bboxes): 167 | self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] 168 | 169 | # update progress bar 170 | self.update_pbar() 171 | 172 | # Custom region sampling (custom bbox) 173 | x_feather_buffer = None 174 | x_feather_mask = None 175 | x_feather_count = None 176 | if len(self.custom_bboxes) > 0: 177 | for bbox_id, bbox in enumerate(self.custom_bboxes): 178 | if state.interrupted: return x_in 179 | 180 | if not self.p.disable_extra_networks: 181 | with devices.autocast(): 182 | extra_networks.activate(self.p, bbox.extra_network_data) 183 | 184 | x_tile = x_in[bbox.slicer] 185 | 186 | # retrieve original x_in from construncted input 187 | x_tile_out = custom_func(x_tile, bbox_id, bbox) 188 | 189 | if bbox.blend_mode == BlendMode.BACKGROUND: 190 | self.x_buffer[bbox.slicer] += x_tile_out 191 | elif bbox.blend_mode == BlendMode.FOREGROUND: 192 | if x_feather_buffer is None: 193 | x_feather_buffer = torch.zeros_like(self.x_buffer) 194 | x_feather_mask = torch.zeros((1, 1, H, W), device=x_in.device) 195 | x_feather_count = torch.zeros((1, 1, H, W), device=x_in.device) 196 | x_feather_buffer[bbox.slicer] += x_tile_out 197 | x_feather_mask [bbox.slicer] += bbox.feather_mask 198 | x_feather_count [bbox.slicer] += 1 199 | 200 | if not self.p.disable_extra_networks: 201 | with devices.autocast(): 202 | extra_networks.deactivate(self.p, bbox.extra_network_data) 203 | 204 | # update progress bar 205 | self.update_pbar() 206 | 207 | # Averaging background buffer 208 | x_out = torch.where(self.weights > 1, self.x_buffer / self.weights, self.x_buffer) 209 | 210 | # Foreground Feather blending 211 | if x_feather_buffer is not None: 212 | # Average overlapping feathered regions 213 | x_feather_buffer = torch.where(x_feather_count > 1, x_feather_buffer / x_feather_count, x_feather_buffer) 214 | x_feather_mask = torch.where(x_feather_count > 1, x_feather_mask / x_feather_count, x_feather_mask) 215 | # Weighted average with original x_buffer 216 | x_out = torch.where(x_feather_count > 0, x_out * (1 - x_feather_mask) + x_feather_buffer * x_feather_mask, x_out) 217 | 218 | return x_out 219 | 220 | def get_noise(self, x_in:Tensor, sigma_in:Tensor, cond_in:Dict[str, Tensor], step:int) -> Tensor: 221 | # NOTE: The following code is analytically wrong but aesthetically beautiful 222 | cond_in_original = cond_in.copy() 223 | 224 | def org_func(x:Tensor): 225 | return shared.sd_model.apply_model(x, sigma_in, cond=cond_in_original) 226 | 227 | def repeat_func(x_tile:Tensor, bboxes:List[CustomBBox]): 228 | sigma_in_tile = sigma_in.repeat(len(bboxes)) 229 | cond_out = self.repeat_cond_dict(cond_in_original, bboxes) 230 | x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) 231 | return x_tile_out 232 | 233 | def custom_func(x:Tensor, bbox_id:int, bbox:CustomBBox): 234 | # The negative prompt in custom bbox should not be used for noise inversion 235 | # otherwise the result will be astonishingly bad. 236 | tcond = Condition.reconstruct_cond(bbox.cond, step).unsqueeze_(0) 237 | icond = self.get_icond(cond_in_original) 238 | if icond.shape[2:] == (self.h, self.w): 239 | icond = icond[bbox.slicer] 240 | cond_out = self.make_cond_dict(cond_in, tcond, icond) 241 | return shared.sd_model.apply_model(x, sigma_in, cond=cond_out) 242 | 243 | return self.sample_one_step(x_in, org_func, repeat_func, custom_func) 244 | -------------------------------------------------------------------------------- /tile_utils/attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from the sd_hijack_optimizations.py to remove the residual and norm part, 3 | So that the Tiled VAE can support other types of attention. 4 | ''' 5 | import math 6 | import torch 7 | 8 | from modules import shared, sd_hijack 9 | from einops import rearrange 10 | from modules.sd_hijack_optimizations import get_available_vram, get_xformers_flash_attention_op, sub_quad_attention 11 | 12 | try: 13 | import xformers 14 | import xformers.ops 15 | except ImportError: 16 | pass 17 | 18 | 19 | def get_attn_func(): 20 | method = sd_hijack.model_hijack.optimization_method 21 | if method is None: 22 | return attn_forward 23 | method = method.lower() 24 | # The method should be one of the following: 25 | # ['none', 'sdp-no-mem', 'sdp', 'xformers', ''sub-quadratic', 'v1', 'invokeai', 'doggettx'] 26 | if method not in ['none', 'sdp-no-mem', 'sdp', 'xformers', 'sub-quadratic', 'v1', 'invokeai', 'doggettx']: 27 | print(f"[Tiled VAE] Warning: Unknown attention optimization method {method}. Please try to update the extension.") 28 | return attn_forward 29 | 30 | if method == 'none': 31 | return attn_forward 32 | elif method == 'xformers': 33 | return xformers_attnblock_forward 34 | elif method == 'sdp-no-mem': 35 | return sdp_no_mem_attnblock_forward 36 | elif method == 'sdp': 37 | return sdp_attnblock_forward 38 | elif method == 'sub-quadratic': 39 | return sub_quad_attnblock_forward 40 | elif method == 'doggettx': 41 | return cross_attention_attnblock_forward 42 | 43 | return attn_forward 44 | 45 | 46 | # The following functions are all copied from modules.sd_hijack_optimizations 47 | # However, the residual & normalization are removed and computed separately. 48 | 49 | def attn_forward(self, h_): 50 | q = self.q(h_) 51 | k = self.k(h_) 52 | v = self.v(h_) 53 | 54 | # compute attention 55 | b, c, h, w = q.shape 56 | q = q.reshape(b, c, h*w) 57 | q = q.permute(0, 2, 1) # b,hw,c 58 | k = k.reshape(b, c, h*w) # b,c,hw 59 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 60 | w_ = w_ * (int(c)**(-0.5)) 61 | w_ = torch.nn.functional.softmax(w_, dim=2) 62 | 63 | # attend to values 64 | v = v.reshape(b, c, h*w) 65 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 66 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 67 | h_ = torch.bmm(v, w_) 68 | h_ = h_.reshape(b, c, h, w) 69 | 70 | h_ = self.proj_out(h_) 71 | 72 | return h_ 73 | 74 | def xformers_attnblock_forward(self, h_): 75 | try: 76 | q = self.q(h_) 77 | k = self.k(h_) 78 | v = self.v(h_) 79 | b, c, h, w = q.shape 80 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) 81 | dtype = q.dtype 82 | if shared.opts.upcast_attn: 83 | q, k, v = q.float(), k.float(), v.float() 84 | q = q.contiguous() 85 | k = k.contiguous() 86 | v = v.contiguous() 87 | out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) 88 | out = out.to(dtype) 89 | out = rearrange(out, 'b (h w) c -> b c h w', h=h) 90 | out = self.proj_out(out) 91 | return out 92 | except NotImplementedError: 93 | return cross_attention_attnblock_forward(self, h_) 94 | 95 | def cross_attention_attnblock_forward(self, h_): 96 | q1 = self.q(h_) 97 | k1 = self.k(h_) 98 | v = self.v(h_) 99 | 100 | # compute attention 101 | b, c, h, w = q1.shape 102 | 103 | q2 = q1.reshape(b, c, h*w) 104 | del q1 105 | 106 | q = q2.permute(0, 2, 1) # b,hw,c 107 | del q2 108 | 109 | k = k1.reshape(b, c, h*w) # b,c,hw 110 | del k1 111 | 112 | h_ = torch.zeros_like(k, device=q.device) 113 | 114 | mem_free_total = get_available_vram() 115 | 116 | tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() 117 | mem_required = tensor_size * 2.5 118 | steps = 1 119 | 120 | if mem_required > mem_free_total: 121 | steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) 122 | 123 | slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] 124 | for i in range(0, q.shape[1], slice_size): 125 | end = i + slice_size 126 | 127 | w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 128 | w2 = w1 * (int(c)**(-0.5)) 129 | del w1 130 | w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) 131 | del w2 132 | 133 | # attend to values 134 | v1 = v.reshape(b, c, h*w) 135 | w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 136 | del w3 137 | 138 | h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 139 | del v1, w4 140 | 141 | h2 = h_.reshape(b, c, h, w) 142 | del h_ 143 | 144 | h3 = self.proj_out(h2) 145 | del h2 146 | 147 | return h3 148 | 149 | def sdp_no_mem_attnblock_forward(self, x): 150 | with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): 151 | return sdp_attnblock_forward(self, x) 152 | 153 | def sdp_attnblock_forward(self, h_): 154 | q = self.q(h_) 155 | k = self.k(h_) 156 | v = self.v(h_) 157 | b, c, h, w = q.shape 158 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) 159 | dtype = q.dtype 160 | if shared.opts.upcast_attn: 161 | q, k, v = q.float(), k.float(), v.float() 162 | q = q.contiguous() 163 | k = k.contiguous() 164 | v = v.contiguous() 165 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) 166 | out = out.to(dtype) 167 | out = rearrange(out, 'b (h w) c -> b c h w', h=h) 168 | out = self.proj_out(out) 169 | return out 170 | 171 | def sub_quad_attnblock_forward(self, h_): 172 | q = self.q(h_) 173 | k = self.k(h_) 174 | v = self.v(h_) 175 | b, c, h, w = q.shape 176 | q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) 177 | q = q.contiguous() 178 | k = k.contiguous() 179 | v = v.contiguous() 180 | out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) 181 | out = rearrange(out, 'b (h w) c -> b c h w', h=h) 182 | out = self.proj_out(out) 183 | return out 184 | -------------------------------------------------------------------------------- /tile_utils/typing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import * 3 | NoType = Any 4 | 5 | from torch import Tensor 6 | from gradio.components import Component 7 | 8 | from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser 9 | from ldm.models.diffusion.ddpm import LatentDiffusion 10 | 11 | from modules.processing import StableDiffusionProcessing as Processing, StableDiffusionProcessingImg2Img as ProcessingImg2Img, Processed 12 | from modules.prompt_parser import MulticondLearnedConditioning, ScheduledPromptConditioning 13 | from modules.extra_networks import ExtraNetworkParams 14 | from modules.sd_samplers_kdiffusion import KDiffusionSampler, CFGDenoiser 15 | # ↓↓↓ backward compatible for v1.5.2 ↓↓↓ 16 | try: 17 | from modules.shared_state import State 18 | except ImportError: 19 | from modules.shared import State 20 | try: 21 | from modules.sd_samplers_kdiffusion import CFGDenoiserKDiffusion 22 | except ImportError: 23 | CFGDenoiserKDiffusion = NoType 24 | try: 25 | from modules.sd_samplers_timesteps import CompVisSampler, CFGDenoiserTimesteps, CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser 26 | except ImportError: 27 | from modules.sd_samplers_compvis import VanillaStableDiffusionSampler 28 | CompVisSampler = VanillaStableDiffusionSampler 29 | CFGDenoiserTimesteps, CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser = NoType, NoType, NoType 30 | # ↑↑↑ backward compatible for v1.5.2 ↑↑↑ 31 | 32 | ModuleType = type(sys) 33 | 34 | Sampler = Union[KDiffusionSampler, CompVisSampler] 35 | Cond = MulticondLearnedConditioning 36 | Uncond = List[List[ScheduledPromptConditioning]] 37 | ExtraNetworkData = DefaultDict[str, List[ExtraNetworkParams]] 38 | 39 | # 'c_crossattn' List[Tensor[B, L=77, D=768]] prompt cond (tcond) 40 | # 'c_concat' List[Tensor[B, C=5, H, W]] latent mask (icond) 41 | # 'c_adm' Tensor[?] unclip (icond) 42 | # 'crossattn' Tensor[B, L=77, D=2048] sdxl (tcond) 43 | # 'vector' Tensor[B, D] sdxl (tcond) 44 | CondDict = Dict[str, Tensor] 45 | -------------------------------------------------------------------------------- /tile_utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from enum import Enum 3 | from collections import namedtuple 4 | from types import MethodType 5 | 6 | import cv2 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from modules import devices, shared, prompt_parser, extra_networks 12 | from modules import sd_samplers_common 13 | from modules.shared import state 14 | from modules.processing import opt_f 15 | 16 | from tile_utils.typing import * 17 | 18 | state: State 19 | 20 | 21 | class ComparableEnum(Enum): 22 | 23 | def __eq__(self, other: Any) -> bool: 24 | if isinstance(other, str): return self.value == other 25 | elif isinstance(other, ComparableEnum): return self.value == other.value 26 | else: raise TypeError(f'unsupported type: {type(other)}') 27 | 28 | class Method(ComparableEnum): 29 | 30 | MULTI_DIFF = 'MultiDiffusion' 31 | MIX_DIFF = 'Mixture of Diffusers' 32 | 33 | class Method_2(ComparableEnum): 34 | DEMO_FU = "DemoFusion" 35 | 36 | class BlendMode(Enum): # i.e. LayerType 37 | 38 | FOREGROUND = 'Foreground' 39 | BACKGROUND = 'Background' 40 | 41 | BBoxSettings = namedtuple('BBoxSettings', ['enable', 'x', 'y', 'w', 'h', 'prompt', 'neg_prompt', 'blend_mode', 'feather_ratio', 'seed']) 42 | NoiseInverseCache = namedtuple('NoiseInversionCache', ['model_hash', 'x0', 'xt', 'noise_inversion_steps', 'retouch', 'prompts']) 43 | DEFAULT_BBOX_SETTINGS = BBoxSettings(False, 0.4, 0.4, 0.2, 0.2, '', '', BlendMode.BACKGROUND.value, 0.2, -1) 44 | NUM_BBOX_PARAMS = len(BBoxSettings._fields) 45 | 46 | 47 | def build_bbox_settings(bbox_control_states:List[Any]) -> Dict[int, BBoxSettings]: 48 | settings = {} 49 | for index, i in enumerate(range(0, len(bbox_control_states), NUM_BBOX_PARAMS)): 50 | setting = BBoxSettings(*bbox_control_states[i:i+NUM_BBOX_PARAMS]) 51 | # for float x, y, w, h, feather_ratio, keeps 4 digits 52 | setting = setting._replace( 53 | x=round(setting.x, 4), 54 | y=round(setting.y, 4), 55 | w=round(setting.w, 4), 56 | h=round(setting.h, 4), 57 | feather_ratio=round(setting.feather_ratio, 4), 58 | seed=int(setting.seed), 59 | ) 60 | # sanity check 61 | if not setting.enable or setting.x > 1.0 or setting.y > 1.0 or setting.w <= 0.0 or setting.h <= 0.0: continue 62 | settings[index] = setting 63 | return settings 64 | 65 | def gr_value(value=None, visible=None): 66 | return {"value": value, "visible": visible, "__type__": "update"} 67 | 68 | 69 | class BBox: 70 | 71 | ''' grid bbox ''' 72 | 73 | def __init__(self, x:int, y:int, w:int, h:int): 74 | self.x = x 75 | self.y = y 76 | self.w = w 77 | self.h = h 78 | self.box = [x, y, x+w, y+h] 79 | self.slicer = slice(None), slice(None), slice(y, y+h), slice(x, x+w) 80 | 81 | def __getitem__(self, idx:int) -> int: 82 | return self.box[idx] 83 | 84 | class CustomBBox(BBox): 85 | 86 | ''' region control bbox ''' 87 | 88 | def __init__(self, x:int, y:int, w:int, h:int, prompt:str, neg_prompt:str, blend_mode:str, feather_radio:float, seed:int): 89 | super().__init__(x, y, w, h) 90 | self.prompt = prompt 91 | self.neg_prompt = neg_prompt 92 | self.blend_mode = BlendMode(blend_mode) 93 | self.feather_ratio = max(min(feather_radio, 1.0), 0.0) 94 | self.seed = seed 95 | # initialize necessary fields 96 | self.feather_mask = feather_mask(self.w, self.h, self.feather_ratio) if self.blend_mode == BlendMode.FOREGROUND else None 97 | self.cond: MulticondLearnedConditioning = None 98 | self.extra_network_data: DefaultDict[List[ExtraNetworkParams]] = None 99 | self.uncond: List[List[ScheduledPromptConditioning]] = None 100 | 101 | 102 | class Prompt: 103 | 104 | ''' prompts helper ''' 105 | 106 | @staticmethod 107 | def apply_styles(prompts:List[str], styles=None) -> List[str]: 108 | if not styles: return prompts 109 | return [shared.prompt_styles.apply_styles_to_prompt(p, styles) for p in prompts] 110 | 111 | @staticmethod 112 | def append_prompt(prompts:List[str], prompt:str='') -> List[str]: 113 | if not prompt: return prompts 114 | return [f'{p}, {prompt}' for p in prompts] 115 | 116 | class Condition: 117 | 118 | ''' CLIP cond helper ''' 119 | 120 | @staticmethod 121 | def get_custom_cond(prompts:List[str], prompt, steps:int, styles=None) -> Tuple[Cond, ExtraNetworkData]: 122 | prompt = Prompt.apply_styles([prompt], styles)[0] 123 | _, extra_network_data = extra_networks.parse_prompts([prompt]) 124 | prompts = Prompt.append_prompt(prompts, prompt) 125 | prompts = Prompt.apply_styles(prompts, styles) 126 | cond = Condition.get_cond(prompts, steps) 127 | return cond, extra_network_data 128 | 129 | @staticmethod 130 | def get_cond(prompts, steps:int): 131 | prompts, _ = extra_networks.parse_prompts(prompts) 132 | cond = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, steps) 133 | return cond 134 | 135 | @staticmethod 136 | def get_uncond(neg_prompts:List[str], steps:int, styles=None) -> Uncond: 137 | neg_prompts = Prompt.apply_styles(neg_prompts, styles) 138 | uncond = prompt_parser.get_learned_conditioning(shared.sd_model, neg_prompts, steps) 139 | return uncond 140 | 141 | @staticmethod 142 | def reconstruct_cond(cond:Cond, step:int) -> Tensor: 143 | list_of_what, tensor = prompt_parser.reconstruct_multicond_batch(cond, step) 144 | return tensor 145 | 146 | def reconstruct_uncond(uncond:Uncond, step:int) -> Tensor: 147 | tensor = prompt_parser.reconstruct_cond_batch(uncond, step) 148 | return tensor 149 | 150 | 151 | def splitable(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16) -> bool: 152 | w, h = w // opt_f, h // opt_f 153 | min_tile_size = min(tile_w, tile_h) 154 | if overlap >= min_tile_size: 155 | overlap = min_tile_size - 4 156 | cols = math.ceil((w - overlap) / (tile_w - overlap)) 157 | rows = math.ceil((h - overlap) / (tile_h - overlap)) 158 | return cols > 1 or rows > 1 159 | 160 | def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: 161 | cols = math.ceil((w - overlap) / (tile_w - overlap)) 162 | rows = math.ceil((h - overlap) / (tile_h - overlap)) 163 | dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 164 | dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 165 | 166 | bbox_list: List[BBox] = [] 167 | weight = torch.zeros((1, 1, h, w), device=devices.device, dtype=torch.float32) 168 | for row in range(rows): 169 | y = min(int(row * dy), h - tile_h) 170 | for col in range(cols): 171 | x = min(int(col * dx), w - tile_w) 172 | 173 | bbox = BBox(x, y, tile_w, tile_h) 174 | bbox_list.append(bbox) 175 | weight[bbox.slicer] += init_weight 176 | 177 | return bbox_list, weight 178 | 179 | 180 | def gaussian_weights(tile_w:int, tile_h:int) -> Tensor: 181 | ''' 182 | Copy from the original implementation of Mixture of Diffusers 183 | https://github.com/albarji/mixture-of-diffusers/blob/master/mixdiff/tiling.py 184 | This generates gaussian weights to smooth the noise of each tile. 185 | This is critical for this method to work. 186 | ''' 187 | from numpy import pi, exp, sqrt 188 | 189 | f = lambda x, midpoint, var=0.01: exp(-(x-midpoint)*(x-midpoint) / (tile_w*tile_w) / (2*var)) / sqrt(2*pi*var) 190 | x_probs = [f(x, (tile_w - 1) / 2) for x in range(tile_w)] # -1 because index goes from 0 to latent_width - 1 191 | y_probs = [f(y, tile_h / 2) for y in range(tile_h)] 192 | 193 | w = np.outer(y_probs, x_probs) 194 | return torch.from_numpy(w).to(devices.device, dtype=torch.float32) 195 | 196 | def feather_mask(w:int, h:int, ratio:float) -> Tensor: 197 | '''Generate a feather mask for the bbox''' 198 | 199 | mask = np.ones((h, w), dtype=np.float32) 200 | feather_radius = int(min(w//2, h//2) * ratio) 201 | # Generate the mask via gaussian weights 202 | # adjust the weight near the edge. the closer to the edge, the lower the weight 203 | # weight = ( dist / feather_radius) ** 2 204 | for i in range(h//2): 205 | for j in range(w//2): 206 | dist = min(i, j) 207 | if dist >= feather_radius: continue 208 | weight = (dist / feather_radius) ** 2 209 | mask[i, j] = weight 210 | mask[i, w-j-1] = weight 211 | mask[h-i-1, j] = weight 212 | mask[h-i-1, w-j-1] = weight 213 | 214 | return torch.from_numpy(mask).to(devices.device, dtype=torch.float32) 215 | 216 | def get_retouch_mask(img_input: np.ndarray, kernel_size: int) -> np.ndarray: 217 | ''' 218 | Return the area where the image is retouched. 219 | Copy from Zhihu.com 220 | ''' 221 | step = 1 222 | kernel = (kernel_size, kernel_size) 223 | 224 | img = img_input.astype(np.float32)/255.0 225 | sz = img.shape[:2] 226 | sz1 = (int(round(sz[1] * step)), int(round(sz[0] * step))) 227 | sz2 = (int(round(kernel[0] * step)), int(round(kernel[0] * step))) 228 | sI = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR) 229 | sp = cv2.resize(img, sz1, interpolation=cv2.INTER_LINEAR) 230 | msI = cv2.blur(sI, sz2) 231 | msp = cv2.blur(sp, sz2) 232 | msII = cv2.blur(sI*sI, sz2) 233 | msIp = cv2.blur(sI*sp, sz2) 234 | vsI = msII - msI*msI 235 | csIp = msIp - msI*msp 236 | recA = csIp/(vsI+0.01) 237 | recB = msp - recA*msI 238 | mA = cv2.resize(recA, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR) 239 | mB = cv2.resize(recB, (sz[1],sz[0]), interpolation=cv2.INTER_LINEAR) 240 | 241 | gf = mA * img + mB 242 | gf -= img 243 | gf *= 255 244 | gf = gf.astype(np.uint8) 245 | gf = gf.clip(0, 255) 246 | gf = gf.astype(np.float32)/255.0 247 | return gf 248 | 249 | 250 | def null_decorator(fn): 251 | def wrapper(*args, **kwargs): 252 | return fn(*args, **kwargs) 253 | return wrapper 254 | 255 | keep_signature = null_decorator 256 | controlnet = null_decorator 257 | stablesr = null_decorator 258 | grid_bbox = null_decorator 259 | custom_bbox = null_decorator 260 | noise_inverse = null_decorator 261 | --------------------------------------------------------------------------------