├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── javascript
└── bboxHint.js
├── scripts
├── tilediffusion.py
├── tileglobal.py
└── tilevae.py
├── tile_methods
├── abstractdiffusion.py
├── demofusion.py
├── mixtureofdiffusers.py
└── multidiffusion.py
└── tile_utils
├── attn.py
├── typing.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: pkuliyi2015 # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # meta
2 | .vscode/
3 | __pycache__/
4 | .DS_Store
5 |
6 | # settings
7 | region_configs/
8 |
9 | # test images
10 | deflicker/input_frames/*
11 |
12 | # test features
13 | deflicker/*
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the “Licensor.” The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tiled Diffusion & VAE extension for sd-webui
2 |
3 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]
4 |
5 | This extension is licensed under [CC BY-NC-SA](https://creativecommons.org/licenses/by-nc-sa/4.0/), everyone is FREE of charge to access, use, modify and redistribute with the same license.
6 | **You cannot use versions after AOE 2023.3.28 for commercial sales (only refers to code of this repo, the derived artworks are NOT restricted).**
7 |
8 | 由于部分无良商家销售WebUI,捆绑本插件做卖点收取智商税,本仓库的许可证已修改为 [CC BY-NC-SA](https://creativecommons.org/licenses/by-nc-sa/4.0/),任何人都可以自由获取、使用、修改、以相同协议重分发本插件。
9 | **自许可证修改之日(AOE 2023.3.28)起,之后的版本禁止用于商业贩售 (不可贩售本仓库代码,但衍生的艺术创作内容物不受此限制)。**
10 |
11 | If you like the project, please give me a star! ⭐
12 |
13 | [](https://ko-fi.com/pkuliyi2015)
14 |
15 | ****
16 |
17 |
18 | The extension helps you to **generate or upscale large images (≥2K) with limited VRAM (≤6GB)** via the following techniques:
19 |
20 | - Reproduced SOTA Tiled Diffusion methods
21 | - [Mixture of Diffusers](https://github.com/albarji/mixture-of-diffusers)
22 | - [MultiDiffusion](https://multidiffusion.github.io)
23 | - [Demofusion](https://github.com/PRIS-CV/DemoFusion)
24 | - Our original Tiled VAE method
25 | - My original Tiled Noise Inversion method
26 |
27 |
28 | ### Features
29 |
30 | - Core
31 | - [x] [Tiled VAE](#tiled-vae)
32 | - [x] [Tiled Diffusion: txt2img generation for ultra-large image](#tiled-diff-txt2img)
33 | - [x] [Tiled Diffusion: img2img upscaling for image detail enhancement](#tiled-diff-img2img)
34 | - [x] [Regional Prompt Control](#region-prompt-control)
35 | - [x] [Tiled Noise Inversion](#tiled-noise-inversion)
36 | - Advanced
37 | - [x] [ControlNet support]()
38 | - [x] [StableSR support](https://github.com/pkuliyi2015/sd-webui-stablesr)
39 | - [x] [SDXL support](experimental)
40 | - [x] [Demofusion support]()
41 |
42 | 👉 在 [wiki](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/wiki) 页面查看详细的文档和样例,以及由 [@PotatoBananaApple](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/discussions/120) 制作的 [快速入门教程](https://civitai.com/models/34726)
43 | 👉 Find detailed documentation & examples at our [wiki](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/wiki), and quickstart [Tutorial](https://civitai.com/models/34726) by [@PotatoBananaApple](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/discussions/120) 🎉
44 |
45 |
46 | ### Examples
47 |
48 | ⚪ Txt2img: generating ultra-large images
49 |
50 | `prompt: masterpiece, best quality, highres, city skyline, night.`
51 |
52 | 
53 |
54 | ⚪ Img2img: upcaling for detail enhancement
55 |
56 | | original | x4 upscale |
57 | | :-: | :-: |
58 | |  |  |
59 |
60 | ⚪ Regional Prompt Control
61 |
62 | | region setting | output1 | output2 |
63 | | :-: | :-: | :-: |
64 | |  |  |  |
65 | |  |  |  |
66 |
67 | ⚪ ControlNet support
68 |
69 | | original | with canny |
70 | | :-: | :-: |
71 | |  | 
72 |
73 | | | 重绘 “清明上河图” |
74 | | :-: | :-: |
75 | | original |  |
76 | | processed |  |
77 |
78 | ⚪ DemoFusion support
79 |
80 | | original | x3 upscale |
81 | | :-: | :-: |
82 | |  |  |
83 |
84 |
85 | ### License
86 |
87 | Great thanks to all the contributors! 🎉🎉🎉
88 | This work is licensed under [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].
89 |
90 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]
91 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa]
92 |
93 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
94 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
95 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg
96 |
--------------------------------------------------------------------------------
/javascript/bboxHint.js:
--------------------------------------------------------------------------------
1 | const BBOX_MAX_NUM = 16;
2 | const BBOX_WARNING_SIZE = 1280;
3 | const DEFAULT_X = 0.4;
4 | const DEFAULT_Y = 0.4;
5 | const DEFAULT_H = 0.2;
6 | const DEFAULT_W = 0.2;
7 |
8 | // ref: https://html-color.codes/
9 | const COLOR_MAP = [
10 | ['#ff0000', 'rgba(255, 0, 0, 0.3)'], // red
11 | ['#ff9900', 'rgba(255, 153, 0, 0.3)'], // orange
12 | ['#ffff00', 'rgba(255, 255, 0, 0.3)'], // yellow
13 | ['#33cc33', 'rgba(51, 204, 51, 0.3)'], // green
14 | ['#33cccc', 'rgba(51, 204, 204, 0.3)'], // indigo
15 | ['#0066ff', 'rgba(0, 102, 255, 0.3)'], // blue
16 | ['#6600ff', 'rgba(102, 0, 255, 0.3)'], // purple
17 | ['#cc00cc', 'rgba(204, 0, 204, 0.3)'], // dark pink
18 | ['#ff6666', 'rgba(255, 102, 102, 0.3)'], // light red
19 | ['#ffcc66', 'rgba(255, 204, 102, 0.3)'], // light orange
20 | ['#99cc00', 'rgba(153, 204, 0, 0.3)'], // lime green
21 | ['#00cc99', 'rgba(0, 204, 153, 0.3)'], // teal
22 | ['#0099cc', 'rgba(0, 153, 204, 0.3)'], // steel blue
23 | ['#9933cc', 'rgba(153, 51, 204, 0.3)'], // lavender
24 | ['#ff3399', 'rgba(255, 51, 153, 0.3)'], // hot pink
25 | ['#996633', 'rgba(153, 102, 51, 0.3)'], // brown
26 | ];
27 |
28 | const RESIZE_BORDER = 5;
29 | const MOVE_BORDER = 5;
30 |
31 | const t2i_bboxes = new Array(BBOX_MAX_NUM).fill(null);
32 | const i2i_bboxes = new Array(BBOX_MAX_NUM).fill(null);
33 |
34 | // ↓↓↓ called from gradio ↓↓↓
35 |
36 | function onCreateT2IRefClick(overwrite) {
37 | let width, height;
38 | if (overwrite) {
39 | const overwriteInputs = gradioApp().querySelectorAll('#MD-overwrite-width-t2i input, #MD-overwrite-height-t2i input');
40 | width = parseInt(overwriteInputs[0].value);
41 | height = parseInt(overwriteInputs[2].value);
42 | } else {
43 | const sizeInputs = gradioApp().querySelectorAll('#txt2img_width input, #txt2img_height input');
44 | width = parseInt(sizeInputs[0].value);
45 | height = parseInt(sizeInputs[2].value);
46 | }
47 |
48 | if (isNaN(width)) width = 512;
49 | if (isNaN(height)) height = 512;
50 |
51 | // Concat it to string to bypass the gradio bug
52 | // 向黑恶势力低头
53 | return width.toString() + 'x' + height.toString();
54 | }
55 |
56 | function onCreateI2IRefClick() {
57 | const canvas = gradioApp().querySelector('#img2img_image img');
58 | return canvas.src;
59 | }
60 |
61 | function onBoxEnableClick(is_t2i, idx, enable) {
62 | let canvas = null;
63 | let bboxes = null;
64 | let locator = null;
65 | if (is_t2i) {
66 | locator = () => gradioApp().querySelector('#MD-bbox-ref-t2i');
67 | bboxes = t2i_bboxes;
68 | } else {
69 | locator = () => gradioApp().querySelector('#MD-bbox-ref-i2i');
70 | bboxes = i2i_bboxes;
71 | }
72 | ref_div = locator();
73 | canvas = ref_div.querySelector('img');
74 | if (!canvas) { return false; }
75 |
76 | if (enable) {
77 | // Check if the bounding box already exists
78 | if (!bboxes[idx]) {
79 | // Initialize bounding box
80 | const bbox = [DEFAULT_X, DEFAULT_Y, DEFAULT_W, DEFAULT_H];
81 | const colorMap = COLOR_MAP[idx % COLOR_MAP.length];
82 | const div = document.createElement('div');
83 | div.id = 'MD-bbox-' + (is_t2i ? 't2i-' : 'i2i-') + idx;
84 | div.style.left = '0px';
85 | div.style.top = '0px';
86 | div.style.width = '0px';
87 | div.style.height = '0px';
88 | div.style.position = 'absolute';
89 | div.style.border = '2px solid ' + colorMap[0];
90 | div.style.background = colorMap[1];
91 | div.style.zIndex = '900';
92 | div.style.display = 'none';
93 | // A text tip to warn the user if bbox is too large
94 | const tip = document.createElement('span');
95 | tip.id = 'MD-tip-' + (is_t2i ? 't2i-' : 'i2i-') + idx;
96 | tip.style.left = '50%';
97 | tip.style.top = '50%';
98 | tip.style.position = 'absolute';
99 | tip.style.transform = 'translate(-50%, -50%)';
100 | tip.style.fontSize = '12px';
101 | tip.style.fontWeight = 'bold';
102 | tip.style.textAlign = 'center';
103 | tip.style.color = colorMap[0];
104 | tip.style.zIndex = '901';
105 | tip.style.display = 'none';
106 | tip.innerHTML = 'Warning: Region very large!
Take care of VRAM usage!';
107 | div.appendChild(tip);
108 |
109 | div.addEventListener('mousedown', function (e) {
110 | if (e.button === 0) { onBoxMouseDown(e, is_t2i, idx); }
111 | });
112 | div.addEventListener('mousemove', function (e) {
113 | updateCursorStyle(e, is_t2i, idx);
114 | });
115 |
116 | const shower = function() { // insert to DOM if necessary
117 | if (!gradioApp().querySelector('#' + div.id)) {
118 | locator().appendChild(div);
119 | }
120 | }
121 | bboxes[idx] = [div, bbox, shower];
122 | }
123 |
124 | // Show the bounding box
125 | displayBox(canvas, is_t2i, bboxes[idx]);
126 | return true;
127 | } else {
128 | if (!bboxes[idx]) { return false; }
129 | const [div, bbox, shower] = bboxes[idx];
130 | div.style.display = 'none';
131 | }
132 | return false;
133 | }
134 |
135 | function onBoxChange(is_t2i, idx, what, v) {
136 | // This function handles all the changes of the bounding box
137 | // Including the rendering and python slider update
138 | let bboxes = null;
139 | let canvas = null;
140 | if (is_t2i) {
141 | bboxes = t2i_bboxes;
142 | canvas = gradioApp().querySelector('#MD-bbox-ref-t2i img');
143 | } else {
144 | bboxes = i2i_bboxes;
145 | canvas = gradioApp().querySelector('#MD-bbox-ref-i2i img');
146 | }
147 | if (!bboxes[idx] || !canvas) {
148 | switch (what) {
149 | case 'x': return DEFAULT_X;
150 | case 'y': return DEFAULT_Y;
151 | case 'w': return DEFAULT_W;
152 | case 'h': return DEFAULT_H;
153 | }
154 | }
155 | const [div, bbox, shower] = bboxes[idx];
156 | if (div.style.display === 'none') { return v; }
157 |
158 | // parse trigger
159 | switch (what) {
160 | case 'x': bbox[0] = v; break;
161 | case 'y': bbox[1] = v; break;
162 | case 'w': bbox[2] = v; break;
163 | case 'h': bbox[3] = v; break;
164 | }
165 | displayBox(canvas, is_t2i, bboxes[idx]);
166 | return v;
167 | }
168 |
169 | // ↓↓↓ called from js ↓↓↓
170 |
171 | function getSeedInfo(is_t2i, id, current_seed) {
172 | const info_id = is_t2i ? '#html_info_txt2img' : '#html_info_img2img';
173 | const info_div = gradioApp().querySelector(info_id);
174 | try{
175 | current_seed = parseInt(current_seed);
176 | } catch(e) {
177 | current_seed = -1;
178 | }
179 | if (!info_div) return current_seed;
180 | let info = info_div.innerHTML;
181 | if (!info) return current_seed;
182 | // remove all html tags
183 | info = info.replace(/<[^>]*>/g, '');
184 | // Find a json string 'region control:' in the info
185 | // get its index
186 | idx = info.indexOf('Region control');
187 | if (idx == -1) return current_seed;
188 | // get the json string (detect the bracket)
189 | // find the first '{'
190 | let start_idx = info.indexOf('{', idx);
191 | let bracket = 1;
192 | let end_idx = start_idx + 1;
193 | while (bracket > 0 && end_idx < info.length) {
194 | if (info[end_idx] == '{') bracket++;
195 | if (info[end_idx] == '}') bracket--;
196 | end_idx++;
197 | }
198 | if (bracket > 0) {
199 | return current_seed;
200 | }
201 | // get the json string
202 | let json_str = info.substring(start_idx, end_idx);
203 | // replace the single quote to double quote
204 | json_str = json_str.replace(/'/g, '"');
205 | // replace python True to javascript true, False to false
206 | json_str = json_str.replace(/True/g, 'true');
207 | // parse the json string
208 | let json = JSON.parse(json_str);
209 | // get the seed if the region id is in the json
210 | const region_id = 'Region ' + id.toString();
211 | if (!(region_id in json)) return current_seed;
212 | const region = json[region_id];
213 | if (!('seed' in region)) return current_seed;
214 | let seed = region['seed'];
215 | try{
216 | seed = parseInt(seed);
217 | } catch(e) {
218 | return current_seed;
219 | }
220 | return seed;
221 | }
222 |
223 | function displayBox(canvas, is_t2i, bbox_info) {
224 | // check null input
225 | const [div, bbox, shower] = bbox_info;
226 | const [x, y, w, h] = bbox;
227 | if (!canvas || !div || x == null || y == null || w == null || h == null) { return; }
228 |
229 | // client: canvas widget display size
230 | // natural: content image real size
231 | let vpScale = Math.min(canvas.clientWidth / canvas.naturalWidth, canvas.clientHeight / canvas.naturalHeight);
232 | let canvasCenterX = canvas.clientWidth / 2;
233 | let canvasCenterY = canvas.clientHeight / 2;
234 | let scaledX = canvas.naturalWidth * vpScale;
235 | let scaledY = canvas.naturalHeight * vpScale;
236 | let viewRectLeft = canvasCenterX - scaledX / 2;
237 | let viewRectRight = canvasCenterX + scaledX / 2;
238 | let viewRectTop = canvasCenterY - scaledY / 2;
239 | let viewRectDown = canvasCenterY + scaledY / 2;
240 |
241 | let xDiv = viewRectLeft + scaledX * x;
242 | let yDiv = viewRectTop + scaledY * y;
243 | let wDiv = Math.min(scaledX * w, viewRectRight - xDiv);
244 | let hDiv = Math.min(scaledY * h, viewRectDown - yDiv);
245 |
246 | // Calculate warning bbox size
247 | let upscalerFactor = 1.0;
248 | if (!is_t2i) {
249 | const upscalerInput = parseFloat(gradioApp().querySelector('#MD-i2i-upscaler-factor input').value);
250 | if (!isNaN(upscalerInput)) upscalerFactor = upscalerInput;
251 | }
252 | let maxSize = BBOX_WARNING_SIZE / upscalerFactor * vpScale;
253 | let maxW = maxSize / scaledX;
254 | let maxH = maxSize / scaledY;
255 | if (w > maxW || h > maxH) {
256 | div.querySelector('span').style.display = 'block';
257 | } else {
258 | div.querySelector('span').style.display = 'none';
259 | }
260 |
261 | // update
Please test on small images before actual upscale. Default params require denoise <= 0.6
') 143 | with gr.Row(variant='compact'): 144 | noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) 145 | noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) 146 | noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) 147 | 148 | # The control includes txt2img and img2img, we use t2i and i2i to distinguish them 149 | with gr.Group(elem_id=f'MD-bbox-control-{tab}') as tab_bbox: 150 | with gr.Accordion('Region Prompt Control', open=False): 151 | with gr.Row(variant='compact'): 152 | enable_bbox_control = gr.Checkbox(label='Enable Control', value=False, elem_id=uid('enable-bbox-control')) 153 | draw_background = gr.Checkbox(label='Draw full canvas background', value=False, elem_id=uid('draw-background')) 154 | causal_layers = gr.Checkbox(label='Causalize layers', value=False, visible=False, elem_id='MD-causal-layers') # NOTE: currently not used 155 | 156 | with gr.Row(variant='compact'): 157 | create_button = gr.Button(value="Create txt2img canvas" if not is_img2img else "From img2img", elem_id='MD-create-canvas') 158 | 159 | bbox_controls: List[Component] = [] # control set for each bbox 160 | with gr.Row(variant='compact'): 161 | ref_image = gr.Image(label='Ref image (for conviently locate regions)', image_mode=None, elem_id=f'MD-bbox-ref-{tab}', interactive=True) 162 | if not is_img2img: 163 | # gradio has a serious bug: it cannot accept multiple inputs when you use both js and fn. 164 | # to workaround this, we concat the inputs into a single string and parse it in js 165 | def create_t2i_ref(string): 166 | w, h = [int(x) for x in string.split('x')] 167 | w = max(w, opt_f) 168 | h = max(h, opt_f) 169 | return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255 170 | create_button.click( 171 | fn=create_t2i_ref, 172 | inputs=overwrite_size, 173 | outputs=ref_image, 174 | _js='onCreateT2IRefClick', 175 | show_progress=False) 176 | else: 177 | create_button.click(fn=None, outputs=ref_image, _js='onCreateI2IRefClick', show_progress=False) 178 | 179 | with gr.Row(variant='compact'): 180 | cfg_name = gr.Textbox(label='Custom Config File', value='config.json', elem_id=uid('cfg-name')) 181 | cfg_dump = gr.Button(value='💾 Save', variant='tool') 182 | cfg_load = gr.Button(value='⚙️ Load', variant='tool') 183 | 184 | with gr.Row(variant='compact'): 185 | cfg_tip = gr.HTML(value='', visible=False) 186 | 187 | for i in range(BBOX_MAX_NUM): 188 | # Only when displaying & png generate info we use index i+1, in other cases we use i 189 | with gr.Accordion(f'Region {i+1}', open=False, elem_id=f'MD-accordion-{tab}-{i}'): 190 | with gr.Row(variant='compact'): 191 | e = gr.Checkbox(label=f'Enable Region {i+1}', value=False, elem_id=f'MD-bbox-{tab}-{i}-enable') 192 | e.change(fn=None, inputs=e, outputs=e, _js=f'e => onBoxEnableClick({is_t2i}, {i}, e)', show_progress=False) 193 | 194 | blend_mode = gr.Dropdown(label='Type', choices=[e.value for e in BlendMode], value=BlendMode.BACKGROUND.value, elem_id=f'MD-{tab}-{i}-blend-mode') 195 | feather_ratio = gr.Slider(label='Feather', value=0.2, minimum=0, maximum=1, step=0.05, visible=False, elem_id=f'MD-{tab}-{i}-feather') 196 | 197 | blend_mode.change(fn=lambda x: gr_show(x==BlendMode.FOREGROUND.value), inputs=blend_mode, outputs=feather_ratio, show_progress=False) 198 | 199 | with gr.Row(variant='compact'): 200 | x = gr.Slider(label='x', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-x') 201 | y = gr.Slider(label='y', value=0.4, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-y') 202 | 203 | with gr.Row(variant='compact'): 204 | w = gr.Slider(label='w', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-w') 205 | h = gr.Slider(label='h', value=0.2, minimum=0.0, maximum=1.0, step=0.0001, elem_id=f'MD-{tab}-{i}-h') 206 | 207 | x.change(fn=None, inputs=x, outputs=x, _js=f'v => onBoxChange({is_t2i}, {i}, "x", v)', show_progress=False) 208 | y.change(fn=None, inputs=y, outputs=y, _js=f'v => onBoxChange({is_t2i}, {i}, "y", v)', show_progress=False) 209 | w.change(fn=None, inputs=w, outputs=w, _js=f'v => onBoxChange({is_t2i}, {i}, "w", v)', show_progress=False) 210 | h.change(fn=None, inputs=h, outputs=h, _js=f'v => onBoxChange({is_t2i}, {i}, "h", v)', show_progress=False) 211 | 212 | prompt = gr.Text(show_label=False, placeholder=f'Prompt, will append to your {tab} prompt', max_lines=2, elem_id=f'MD-{tab}-{i}-prompt') 213 | neg_prompt = gr.Text(show_label=False, placeholder='Negative Prompt, will also be appended', max_lines=1, elem_id=f'MD-{tab}-{i}-neg-prompt') 214 | with gr.Row(variant='compact'): 215 | seed = gr.Number(label='Seed', value=-1, visible=True, elem_id=f'MD-{tab}-{i}-seed') 216 | random_seed = gr.Button(value='🎲', variant='tool', elem_id=f'MD-{tab}-{i}-random_seed') 217 | reuse_seed = gr.Button(value='♻️', variant='tool', elem_id=f'MD-{tab}-{i}-reuse_seed') 218 | random_seed.click(fn=lambda: -1, outputs=seed, show_progress=False) 219 | reuse_seed.click(fn=None, inputs=seed, outputs=seed, _js=f'e => getSeedInfo({is_t2i}, {i+1}, e)', show_progress=False) 220 | 221 | control = [e, x, y, w, h, prompt, neg_prompt, blend_mode, feather_ratio, seed] 222 | assert len(control) == NUM_BBOX_PARAMS 223 | bbox_controls.extend(control) 224 | 225 | # NOTE: dynamically hard coded!! 226 | load_regions_js = ''' 227 | function onBoxChangeAll(ref_image, cfg_name, ...args) { 228 | const is_t2i = %s; 229 | const n_bbox = %d; 230 | const n_ctrl = %d; 231 | for (let i=0; iPlease test on small images before actual upscale. Default params require denoise <= 0.6
') 108 | with gr.Row(variant='compact'): 109 | noise_inverse_retouch = gr.Slider(minimum=1, maximum=100, step=0.1, label='Retouch', value=1, elem_id=uid('noise-inverse-retouch')) 110 | noise_inverse_renoise_strength = gr.Slider(minimum=0, maximum=2, step=0.01, label='Renoise strength', value=1, elem_id=uid('noise-inverse-renoise-strength')) 111 | noise_inverse_renoise_kernel = gr.Slider(minimum=2, maximum=512, step=1, label='Renoise kernel size', value=64, elem_id=uid('noise-inverse-renoise-kernel')) 112 | 113 | # The control includes txt2img and img2img, we use t2i and i2i to distinguish them 114 | 115 | return [ 116 | enabled, method, 117 | keep_input_size, 118 | window_size, overlap, batch_size, 119 | scale_factor, 120 | noise_inverse, noise_inverse_steps, noise_inverse_retouch, noise_inverse_renoise_strength, noise_inverse_renoise_kernel, 121 | control_tensor_cpu, 122 | random_jitter, 123 | c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode 124 | ] 125 | 126 | 127 | def process(self, p: Processing, 128 | enabled: bool, method: str, 129 | keep_input_size: bool, 130 | window_size:int, overlap: int, tile_batch_size: int, 131 | scale_factor: float, 132 | noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch: float, noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, 133 | control_tensor_cpu: bool, 134 | random_jitter:bool, 135 | c1,c2,c3,gaussian_filter,strength,sigma,batch_size_g,mixture_mode 136 | ): 137 | 138 | # unhijack & unhook, in case it broke at last time 139 | self.reset() 140 | p.mixture = mixture_mode 141 | if not mixture_mode: 142 | sigma = sigma/2 143 | if not enabled: return 144 | 145 | ''' upscale ''' 146 | # store canvas size settings 147 | if hasattr(p, "init_images"): 148 | p.init_images_original_md = [img.copy() for img in p.init_images] 149 | p.width_original_md = p.width 150 | p.height_original_md = p.height 151 | p.current_scale_num = 1 152 | p.gaussian_filter = gaussian_filter 153 | p.scale_factor = int(scale_factor) 154 | 155 | is_img2img = hasattr(p, "init_images") and len(p.init_images) > 0 156 | if is_img2img: 157 | init_img = p.init_images[0] 158 | init_img = images.flatten(init_img, opts.img2img_background_color) 159 | image = init_img 160 | if keep_input_size: 161 | p.width = image.width 162 | p.height = image.height 163 | p.width_original_md = p.width 164 | p.height_original_md = p.height 165 | else: #XXX:To adapt to noise inversion, we do not multiply the scale factor here 166 | p.width = p.width_original_md 167 | p.height = p.height_original_md 168 | else: # txt2img 169 | p.width = p.width_original_md 170 | p.height = p.height_original_md 171 | 172 | if 'png info': 173 | info = {} 174 | p.extra_generation_params["Tiled Diffusion"] = info 175 | 176 | info['Method'] = method 177 | info['Window Size'] = window_size 178 | info['Tile Overlap'] = overlap 179 | info['Tile batch size'] = tile_batch_size 180 | info["Global batch size"] = batch_size_g 181 | 182 | if is_img2img: 183 | info['Upscale factor'] = scale_factor 184 | if keep_input_size: 185 | info['Keep input size'] = keep_input_size 186 | if noise_inverse: 187 | info['NoiseInv'] = noise_inverse 188 | info['NoiseInv Steps'] = noise_inverse_steps 189 | info['NoiseInv Retouch'] = noise_inverse_retouch 190 | info['NoiseInv Renoise strength'] = noise_inverse_renoise_strength 191 | info['NoiseInv Kernel size'] = noise_inverse_renoise_kernel 192 | 193 | ''' ControlNet hackin ''' 194 | try: 195 | from scripts.cldm import ControlNet 196 | 197 | for script in p.scripts.scripts + p.scripts.alwayson_scripts: 198 | if hasattr(script, "latest_network") and script.title().lower() == "controlnet": 199 | self.controlnet_script = script 200 | print("[Demo Fusion] ControlNet found, support is enabled.") 201 | break 202 | except ImportError: 203 | pass 204 | 205 | ''' StableSR hackin ''' 206 | for script in p.scripts.scripts: 207 | if hasattr(script, "stablesr_model") and script.title().lower() == "stablesr": 208 | if script.stablesr_model is not None: 209 | self.stablesr_script = script 210 | print("[Demo Fusion] StableSR found, support is enabled.") 211 | break 212 | 213 | ''' hijack inner APIs, see unhijack in reset() ''' 214 | Script.create_sampler_original_md = sd_samplers.create_sampler 215 | 216 | sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack( 217 | name, model, p, Method_2(method), control_tensor_cpu,window_size, noise_inverse, noise_inverse_steps, noise_inverse_retouch, 218 | noise_inverse_renoise_strength, noise_inverse_renoise_kernel, overlap, tile_batch_size,random_jitter,batch_size_g 219 | ) 220 | 221 | 222 | p.sample = lambda conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts: self.sample_hijack( 223 | conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts,p, is_img2img, 224 | window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g) 225 | 226 | processing.create_infotext_ori = processing.create_infotext 227 | 228 | p.width_list = [p.height] 229 | p.height_list = [p.height] 230 | 231 | processing.create_infotext = create_infotext_hijack 232 | ## end 233 | 234 | 235 | def postprocess_batch(self, p: Processing, enabled, *args, **kwargs): 236 | if not enabled: return 237 | 238 | if self.delegate is not None: self.delegate.reset_controlnet_tensors() 239 | 240 | def postprocess_batch_list(self, p, pp, enabled, *args, **kwargs): 241 | if not enabled: return 242 | for idx,image in enumerate(pp.images): 243 | idx_b = idx//p.batch_size 244 | pp.images[idx] = image[:,:image.shape[1]//(p.scale_factor)*(idx_b+1),:image.shape[2]//(p.scale_factor)*(idx_b+1)] 245 | p.seeds = [item for _ in range(p.scale_factor) for item in p.seeds] 246 | p.prompts = [item for _ in range(p.scale_factor) for item in p.prompts] 247 | p.all_negative_prompts = [item for _ in range(p.scale_factor) for item in p.all_negative_prompts] 248 | p.negative_prompts = [item for _ in range(p.scale_factor) for item in p.negative_prompts] 249 | if p.color_corrections != None: 250 | p.color_corrections = [item for _ in range(p.scale_factor) for item in p.color_corrections] 251 | p.width_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.width for _ in range(p.batch_size)]] 252 | p.height_list = [item*(idx+1) for idx in range(p.scale_factor) for item in [p.height for _ in range(p.batch_size)]] 253 | return 254 | 255 | def postprocess(self, p: Processing, processed, enabled, *args): 256 | if not enabled: return 257 | # unhijack & unhook 258 | self.reset() 259 | 260 | # restore canvas size settings 261 | if hasattr(p, 'init_images') and hasattr(p, 'init_images_original_md'): 262 | p.init_images.clear() # NOTE: do NOT change the list object, compatible with shallow copy of XYZ-plot 263 | p.init_images.extend(p.init_images_original_md) 264 | del p.init_images_original_md 265 | p.width = p.width_original_md ; del p.width_original_md 266 | p.height = p.height_original_md ; del p.height_original_md 267 | 268 | # clean up noise inverse latent for folder-based processing 269 | if hasattr(p, 'noise_inverse_latent'): 270 | del p.noise_inverse_latent 271 | 272 | ''' ↓↓↓ inner API hijack ↓↓↓ ''' 273 | @torch.no_grad() 274 | def sample_hijack(self, conditioning, unconditional_conditioning,seeds, subseeds, subseed_strength, prompts,p,image_ori,window_size, overlap, tile_batch_size,random_jitter,c1,c2,c3,strength,sigma,batch_size_g): 275 | ################################################## Phase Initialization ###################################################### 276 | 277 | if not image_ori: 278 | p.current_step = 0 279 | p.denoising_strength = strength 280 | # p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) #NOTE:Wrong but very useful. If corrected, please replace with the content with the following lines 281 | # latents = p.rng.next() 282 | 283 | p.sampler = Script.create_sampler_original_md(p.sampler_name, p.sd_model) #scale 284 | x = p.rng.next() 285 | print("### Phase 1 Denoising ###") 286 | latents = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.txt2img_image_conditioning(x)) 287 | latents_ = F.pad(latents, (0, latents.shape[3]*(p.scale_factor-1), 0, latents.shape[2]*(p.scale_factor-1))) 288 | res = latents_ 289 | del x 290 | p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) 291 | starting_scale = 2 292 | else: # img2img 293 | print("### Encoding Real Image ###") 294 | latents = p.init_latent 295 | starting_scale = 1 296 | 297 | 298 | anchor_mean = latents.mean() 299 | anchor_std = latents.std() 300 | 301 | devices.torch_gc() 302 | 303 | ####################################################### Phase Upscaling ##################################################### 304 | p.cosine_scale_1 = c1 305 | p.cosine_scale_2 = c2 306 | p.cosine_scale_3 = c3 307 | self.delegate.sig = sigma 308 | p.latents = latents 309 | for current_scale_num in range(starting_scale, p.scale_factor+1): 310 | p.current_scale_num = current_scale_num 311 | print("### Phase {} Denoising ###".format(current_scale_num)) 312 | p.current_height = p.height_original_md * current_scale_num 313 | p.current_width = p.width_original_md * current_scale_num 314 | 315 | 316 | p.latents = F.interpolate(p.latents, size=(int(p.current_height / opt_f), int(p.current_width / opt_f)), mode='bicubic') 317 | p.rng = rng.ImageRNG(p.latents.shape[1:], p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) 318 | 319 | 320 | self.delegate.w = int(p.current_width / opt_f) 321 | self.delegate.h = int(p.current_height / opt_f) 322 | self.delegate.get_views(overlap, tile_batch_size,batch_size_g) 323 | 324 | info = ', '.join([ 325 | # f"{method.value} hooked into {name!r} sampler", 326 | f"Tile size: {self.delegate.window_size}", 327 | f"Tile count: {self.delegate.num_tiles}", 328 | f"Batch size: {self.delegate.tile_bs}", 329 | f"Tile batches: {len(self.delegate.batched_bboxes)}", 330 | f"Global batch size: {self.delegate.global_tile_bs}", 331 | f"Global batches: {len(self.delegate.global_batched_bboxes)}", 332 | ]) 333 | 334 | print(info) 335 | 336 | noise = p.rng.next() 337 | if hasattr(p,'initial_noise_multiplier'): 338 | if p.initial_noise_multiplier != 1.0: 339 | p.extra_generation_params["Noise multiplier"] = p.initial_noise_multiplier 340 | noise *= p.initial_noise_multiplier 341 | else: 342 | p.image_conditioning = p.txt2img_image_conditioning(noise) 343 | 344 | p.noise = noise 345 | p.x = p.latents.clone() 346 | p.current_step=0 347 | 348 | p.latents = p.sampler.sample_img2img(p,p.latents, noise , conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) 349 | if self.flag_noise_inverse: 350 | self.delegate.sampler_raw.sample_img2img = self.delegate.sample_img2img_original 351 | self.flag_noise_inverse = False 352 | 353 | p.latents = (p.latents - p.latents.mean()) / p.latents.std() * anchor_std + anchor_mean 354 | latents_ = F.pad(p.latents, (0, p.latents.shape[3]//current_scale_num*(p.scale_factor-current_scale_num), 0, p.latents.shape[2]//current_scale_num*(p.scale_factor-current_scale_num))) 355 | if current_scale_num==1: 356 | res = latents_ 357 | else: 358 | res = torch.concatenate((res,latents_),axis=0) 359 | 360 | ######################################################################################################################################### 361 | 362 | return res 363 | 364 | @staticmethod 365 | def callback_hijack(self_sampler,d,p): 366 | p.current_step = d['i'] 367 | 368 | if self_sampler.stop_at is not None and p.current_step > self_sampler.stop_at: 369 | raise InterruptedException 370 | 371 | state.sampling_step = p.current_step 372 | shared.total_tqdm.update() 373 | p.current_step += 1 374 | 375 | 376 | def create_sampler_hijack( 377 | self, name: str, model: LatentDiffusion, p: Processing, method: Method_2, control_tensor_cpu:bool,window_size, noise_inverse: bool, noise_inverse_steps: int, noise_inverse_retouch:float, 378 | noise_inverse_renoise_strength: float, noise_inverse_renoise_kernel: int, overlap:int, tile_batch_size:int, random_jitter:bool,batch_size_g:int 379 | ): 380 | if self.delegate is not None: 381 | # samplers are stateless, we reuse it if possible 382 | if self.delegate.sampler_name == name: 383 | # before we reuse the sampler, we refresh the control tensor 384 | # so that we are compatible with ControlNet batch processing 385 | if self.controlnet_script: 386 | self.delegate.prepare_controlnet_tensors(refresh=True) 387 | return self.delegate.sampler_raw 388 | else: 389 | self.reset() 390 | sd_samplers_common.Sampler.callback_ori = sd_samplers_common.Sampler.callback_state 391 | sd_samplers_common.Sampler.callback_state = lambda self_sampler,d:Script.callback_hijack(self_sampler,d,p) 392 | 393 | self.flag_noise_inverse = hasattr(p, "init_images") and len(p.init_images) > 0 and noise_inverse 394 | flag_noise_inverse = self.flag_noise_inverse 395 | if flag_noise_inverse: 396 | print('warn: noise inversion only supports the "Euler" sampler, switch to it sliently...') 397 | name = 'Euler' 398 | p.sampler_name = 'Euler' 399 | if name is None: print('>> name is empty') 400 | if model is None: print('>> model is empty') 401 | sampler = Script.create_sampler_original_md(name, model) 402 | if method ==Method_2.DEMO_FU: delegate_cls = DemoFusion 403 | else: raise NotImplementedError(f"Method {method} not implemented.") 404 | 405 | delegate = delegate_cls(p, sampler) 406 | delegate.window_size = min(min(window_size,p.width//8),p.height//8) 407 | p.random_jitter = random_jitter 408 | 409 | if flag_noise_inverse: 410 | get_cache_callback = self.noise_inverse_get_cache 411 | set_cache_callback = lambda x0, xt, prompts: self.noise_inverse_set_cache(p, x0, xt, prompts, noise_inverse_steps, noise_inverse_retouch) 412 | delegate.init_noise_inverse(noise_inverse_steps, noise_inverse_retouch, get_cache_callback, set_cache_callback, noise_inverse_renoise_strength, noise_inverse_renoise_kernel) 413 | 414 | # delegate.get_views(overlap,tile_batch_size,batch_size_g) 415 | if self.controlnet_script: 416 | delegate.init_controlnet(self.controlnet_script, control_tensor_cpu) 417 | if self.stablesr_script: 418 | delegate.init_stablesr(self.stablesr_script) 419 | 420 | # init everything done, perform sanity check & pre-computations 421 | # hijack the behaviours 422 | delegate.hook() 423 | 424 | self.delegate = delegate 425 | 426 | exts = [ 427 | "ContrlNet" if self.controlnet_script else None, 428 | "StableSR" if self.stablesr_script else None, 429 | ] 430 | ext_info = ', '.join([e for e in exts if e]) 431 | if ext_info: ext_info = f' (ext: {ext_info})' 432 | print(ext_info) 433 | 434 | return delegate.sampler_raw 435 | 436 | def create_random_tensors_hijack( 437 | self, bbox_settings: Dict, region_info: Dict, 438 | shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None, 439 | ): 440 | org_random_tensors = Script.create_random_tensors_original_md(shape, seeds, subseeds, subseed_strength, seed_resize_from_h, seed_resize_from_w, p) 441 | height, width = shape[1], shape[2] 442 | background_noise = torch.zeros_like(org_random_tensors) 443 | background_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 444 | foreground_noise = torch.zeros_like(org_random_tensors) 445 | foreground_noise_count = torch.zeros((1, 1, height, width), device=org_random_tensors.device) 446 | 447 | for i, v in bbox_settings.items(): 448 | seed = get_fixed_seed(v.seed) 449 | x, y, w, h = v.x, v.y, v.w, v.h 450 | # convert to pixel 451 | x = int(x * width) 452 | y = int(y * height) 453 | w = math.ceil(w * width) 454 | h = math.ceil(h * height) 455 | # clamp 456 | x = max(0, x) 457 | y = max(0, y) 458 | w = min(width - x, w) 459 | h = min(height - y, h) 460 | # create random tensor 461 | torch.manual_seed(seed) 462 | rand_tensor = torch.randn((1, org_random_tensors.shape[1], h, w), device=devices.cpu) 463 | if BlendMode(v.blend_mode) == BlendMode.BACKGROUND: 464 | background_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(background_noise.device) 465 | background_noise_count[:, :, y:y+h, x:x+w] += 1 466 | elif BlendMode(v.blend_mode) == BlendMode.FOREGROUND: 467 | foreground_noise [:, :, y:y+h, x:x+w] += rand_tensor.to(foreground_noise.device) 468 | foreground_noise_count[:, :, y:y+h, x:x+w] += 1 469 | else: 470 | raise NotImplementedError 471 | region_info['Region ' + str(i+1)]['seed'] = seed 472 | 473 | # average 474 | background_noise = torch.where(background_noise_count > 1, background_noise / background_noise_count, background_noise) 475 | foreground_noise = torch.where(foreground_noise_count > 1, foreground_noise / foreground_noise_count, foreground_noise) 476 | # paste two layers to original random tensor 477 | org_random_tensors = torch.where(background_noise_count > 0, background_noise, org_random_tensors) 478 | org_random_tensors = torch.where(foreground_noise_count > 0, foreground_noise, org_random_tensors) 479 | return org_random_tensors 480 | 481 | ''' ↓↓↓ helper methods ↓↓↓ ''' 482 | 483 | def dump_regions(self, cfg_name, *bbox_controls): 484 | if not cfg_name: return gr_value(f'Config file name cannot be empty.', visible=True) 485 | 486 | bbox_settings = build_bbox_settings(bbox_controls) 487 | data = {'bbox_controls': [v._asdict() for v in bbox_settings.values()]} 488 | 489 | if not os.path.exists(CFG_PATH): os.makedirs(CFG_PATH) 490 | fp = os.path.join(CFG_PATH, cfg_name) 491 | with open(fp, 'w', encoding='utf-8') as fh: 492 | json.dump(data, fh, indent=2, ensure_ascii=False) 493 | 494 | return gr_value(f'Config saved to {fp}.', visible=True) 495 | 496 | def load_regions(self, ref_image, cfg_name, *bbox_controls): 497 | if ref_image is None: 498 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Please create or upload a ref image first.', visible=True)] 499 | fp = os.path.join(CFG_PATH, cfg_name) 500 | if not os.path.exists(fp): 501 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Config {fp} not found.', visible=True)] 502 | 503 | try: 504 | with open(fp, 'r', encoding='utf-8') as fh: 505 | data = json.load(fh) 506 | except Exception as e: 507 | return [gr_value(v) for v in bbox_controls] + [gr_value(f'Failed to load config {fp}: {e}', visible=True)] 508 | 509 | num_boxes = len(data['bbox_controls']) 510 | data_list = [] 511 | for i in range(BBOX_MAX_NUM): 512 | if i < num_boxes: 513 | for k in BBoxSettings._fields: 514 | if k in data['bbox_controls'][i]: 515 | data_list.append(data['bbox_controls'][i][k]) 516 | else: 517 | data_list.append(None) 518 | else: 519 | data_list.extend(DEFAULT_BBOX_SETTINGS) 520 | 521 | return [gr_value(v) for v in data_list] + [gr_value(f'Config loaded from {fp}.', visible=True)] 522 | 523 | 524 | def noise_inverse_set_cache(self, p: ProcessingImg2Img, x0: Tensor, xt: Tensor, prompts: List[str], steps: int, retouch:float): 525 | self.noise_inverse_cache = NoiseInverseCache(p.sd_model.sd_model_hash, x0, xt, steps, retouch, prompts) 526 | 527 | def noise_inverse_get_cache(self): 528 | return self.noise_inverse_cache 529 | 530 | 531 | def reset(self): 532 | ''' unhijack inner APIs, see hijack in process() ''' 533 | if hasattr(Script, "create_sampler_original_md"): 534 | sd_samplers.create_sampler = Script.create_sampler_original_md 535 | del Script.create_sampler_original_md 536 | if hasattr(Script, "create_random_tensors_original_md"): 537 | processing.create_random_tensors = Script.create_random_tensors_original_md 538 | del Script.create_random_tensors_original_md 539 | if hasattr(sd_samplers_common.Sampler, "callback_ori"): 540 | sd_samplers_common.Sampler.callback_state = sd_samplers_common.Sampler.callback_ori 541 | del sd_samplers_common.Sampler.callback_ori 542 | if hasattr(processing, "create_infotext_ori"): 543 | processing.create_infotext = processing.create_infotext_ori 544 | del processing.create_infotext_ori 545 | DemoFusion.unhook() 546 | self.delegate = None 547 | 548 | def reset_and_gc(self): 549 | self.reset() 550 | self.noise_inverse_cache = None 551 | 552 | import gc; gc.collect() 553 | devices.torch_gc() 554 | 555 | try: 556 | import os 557 | import psutil 558 | mem = psutil.Process(os.getpid()).memory_info() 559 | print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB') 560 | from modules.shared import mem_mon as vram_mon 561 | from modules.memmon import MemUsageMonitor 562 | vram_mon: MemUsageMonitor 563 | free, total = vram_mon.cuda_mem_get_info() 564 | print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB') 565 | except: 566 | pass 567 | -------------------------------------------------------------------------------- /scripts/tilevae.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # ------------------------------------------------------------------------ 3 | # 4 | # Tiled VAE 5 | # 6 | # Introducing a revolutionary new optimization designed to make 7 | # the VAE work with giant images on limited VRAM! 8 | # Say goodbye to the frustration of OOM and hello to seamless output! 9 | # 10 | # ------------------------------------------------------------------------ 11 | # 12 | # This script is a wild hack that splits the image into tiles, 13 | # encodes each tile separately, and merges the result back together. 14 | # 15 | # Advantages: 16 | # - The VAE can now work with giant images on limited VRAM 17 | # (~10 GB for 8K images!) 18 | # - The merged output is completely seamless without any post-processing. 19 | # 20 | # Drawbacks: 21 | # - NaNs always appear in for 8k images when you use fp16 (half) VAE 22 | # You must use --no-half-vae to disable half VAE for that giant image. 23 | # - The gradient calculation is not compatible with this hack. It 24 | # will break any backward() or torch.autograd.grad() that passes VAE. 25 | # (But you can still use the VAE to generate training data.) 26 | # 27 | # How it works: 28 | # 1. The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder. 29 | # 2. When Fast Mode is disabled: 30 | # 1. The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile. 31 | # 2. When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile. 32 | # 3. After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues. 33 | # 4. A zigzag execution order is used to reduce unnecessary data transfer. 34 | # 3. When Fast Mode is enabled: 35 | # 1. The original input is downsampled and passed to a separate task queue. 36 | # 2. Its group norm parameters are recorded and used by all tiles' task queues. 37 | # 3. Each tile is separately processed without any RAM-VRAM data transfer. 38 | # 4. After all tiles are processed, tiles are written to a result buffer and returned. 39 | # Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode. 40 | # 41 | # Enjoy! 42 | # 43 | # @Author: LI YI @ Nanyang Technological University - Singapore 44 | # @Date: 2023-03-02 45 | # @License: CC BY-NC-SA 4.0 46 | # 47 | # Please give me a star if you like this project! 48 | # 49 | # ------------------------------------------------------------------------- 50 | ''' 51 | 52 | import gc 53 | import math 54 | from time import time 55 | from tqdm import tqdm 56 | 57 | import torch 58 | import torch.version 59 | import torch.nn.functional as F 60 | import gradio as gr 61 | 62 | import modules.scripts as scripts 63 | import modules.devices as devices 64 | from modules.shared import state, opts 65 | from modules.ui import gr_show 66 | from modules.processing import opt_f 67 | from modules.sd_vae_approx import cheap_approximation 68 | from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock 69 | 70 | from tile_utils.attn import get_attn_func 71 | from tile_utils.typing import Processing 72 | 73 | if hasattr(opts, 'hypertile_enable_unet'): # webui >= 1.7 74 | from modules.ui_components import InputAccordion 75 | else: 76 | InputAccordion = None 77 | 78 | 79 | def get_rcmd_enc_tsize(): 80 | if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: 81 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 82 | if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 83 | elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 84 | elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 85 | else: ENCODER_TILE_SIZE = 960 86 | else: ENCODER_TILE_SIZE = 512 87 | return ENCODER_TILE_SIZE 88 | 89 | 90 | def get_rcmd_dec_tsize(): 91 | if torch.cuda.is_available() and devices.device not in ['cpu', devices.cpu]: 92 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 93 | if total_memory > 30*1000: DECODER_TILE_SIZE = 256 94 | elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 95 | elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 96 | elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 97 | else: DECODER_TILE_SIZE = 64 98 | else: DECODER_TILE_SIZE = 64 99 | return DECODER_TILE_SIZE 100 | 101 | 102 | def inplace_nonlinearity(x): 103 | # Test: fix for Nans 104 | return F.silu(x, inplace=True) 105 | 106 | 107 | def attn2task(task_queue, net): 108 | attn_forward = get_attn_func() 109 | task_queue.append(('store_res', lambda x: x)) 110 | task_queue.append(('pre_norm', net.norm)) 111 | task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) 112 | task_queue.append(['add_res', None]) 113 | 114 | 115 | def resblock2task(queue, block): 116 | """ 117 | Turn a ResNetBlock into a sequence of tasks and append to the task queue 118 | 119 | @param queue: the target task queue 120 | @param block: ResNetBlock 121 | 122 | """ 123 | if block.in_channels != block.out_channels: 124 | if block.use_conv_shortcut: 125 | queue.append(('store_res', block.conv_shortcut)) 126 | else: 127 | queue.append(('store_res', block.nin_shortcut)) 128 | else: 129 | queue.append(('store_res', lambda x: x)) 130 | queue.append(('pre_norm', block.norm1)) 131 | queue.append(('silu', inplace_nonlinearity)) 132 | queue.append(('conv1', block.conv1)) 133 | queue.append(('pre_norm', block.norm2)) 134 | queue.append(('silu', inplace_nonlinearity)) 135 | queue.append(('conv2', block.conv2)) 136 | queue.append(['add_res', None]) 137 | 138 | 139 | def build_sampling(task_queue, net, is_decoder): 140 | """ 141 | Build the sampling part of a task queue 142 | @param task_queue: the target task queue 143 | @param net: the network 144 | @param is_decoder: currently building decoder or encoder 145 | """ 146 | if is_decoder: 147 | resblock2task(task_queue, net.mid.block_1) 148 | attn2task(task_queue, net.mid.attn_1) 149 | resblock2task(task_queue, net.mid.block_2) 150 | resolution_iter = reversed(range(net.num_resolutions)) 151 | block_ids = net.num_res_blocks + 1 152 | condition = 0 153 | module = net.up 154 | func_name = 'upsample' 155 | else: 156 | resolution_iter = range(net.num_resolutions) 157 | block_ids = net.num_res_blocks 158 | condition = net.num_resolutions - 1 159 | module = net.down 160 | func_name = 'downsample' 161 | 162 | for i_level in resolution_iter: 163 | for i_block in range(block_ids): 164 | resblock2task(task_queue, module[i_level].block[i_block]) 165 | if i_level != condition: 166 | task_queue.append((func_name, getattr(module[i_level], func_name))) 167 | 168 | if not is_decoder: 169 | resblock2task(task_queue, net.mid.block_1) 170 | attn2task(task_queue, net.mid.attn_1) 171 | resblock2task(task_queue, net.mid.block_2) 172 | 173 | 174 | def build_task_queue(net, is_decoder): 175 | """ 176 | Build a single task queue for the encoder or decoder 177 | @param net: the VAE decoder or encoder network 178 | @param is_decoder: currently building decoder or encoder 179 | @return: the task queue 180 | """ 181 | task_queue = [] 182 | task_queue.append(('conv_in', net.conv_in)) 183 | 184 | # construct the sampling part of the task queue 185 | # because encoder and decoder share the same architecture, we extract the sampling part 186 | build_sampling(task_queue, net, is_decoder) 187 | 188 | if not is_decoder or not net.give_pre_end: 189 | task_queue.append(('pre_norm', net.norm_out)) 190 | task_queue.append(('silu', inplace_nonlinearity)) 191 | task_queue.append(('conv_out', net.conv_out)) 192 | if is_decoder and net.tanh_out: 193 | task_queue.append(('tanh', torch.tanh)) 194 | 195 | return task_queue 196 | 197 | 198 | def clone_task_queue(task_queue): 199 | """ 200 | Clone a task queue 201 | @param task_queue: the task queue to be cloned 202 | @return: the cloned task queue 203 | """ 204 | return [[item for item in task] for task in task_queue] 205 | 206 | 207 | def get_var_mean(input, num_groups, eps=1e-6): 208 | """ 209 | Get mean and var for group norm 210 | """ 211 | b, c = input.size(0), input.size(1) 212 | channel_in_group = int(c/num_groups) 213 | input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) 214 | var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) 215 | return var, mean 216 | 217 | 218 | def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): 219 | """ 220 | Custom group norm with fixed mean and var 221 | 222 | @param input: input tensor 223 | @param num_groups: number of groups. by default, num_groups = 32 224 | @param mean: mean, must be pre-calculated by get_var_mean 225 | @param var: var, must be pre-calculated by get_var_mean 226 | @param weight: weight, should be fetched from the original group norm 227 | @param bias: bias, should be fetched from the original group norm 228 | @param eps: epsilon, by default, eps = 1e-6 to match the original group norm 229 | 230 | @return: normalized tensor 231 | """ 232 | b, c = input.size(0), input.size(1) 233 | channel_in_group = int(c/num_groups) 234 | input_reshaped = input.contiguous().view( 235 | 1, int(b * num_groups), channel_in_group, *input.size()[2:]) 236 | 237 | out = F.batch_norm(input_reshaped, mean.to(input), var.to(input), weight=None, bias=None, training=False, momentum=0, eps=eps) 238 | out = out.view(b, c, *input.size()[2:]) 239 | 240 | # post affine transform 241 | if weight is not None: 242 | out *= weight.view(1, -1, 1, 1) 243 | if bias is not None: 244 | out += bias.view(1, -1, 1, 1) 245 | return out 246 | 247 | 248 | def crop_valid_region(x, input_bbox, target_bbox, is_decoder): 249 | """ 250 | Crop the valid region from the tile 251 | @param x: input tile 252 | @param input_bbox: original input bounding box 253 | @param target_bbox: output bounding box 254 | @param scale: scale factor 255 | @return: cropped tile 256 | """ 257 | padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] 258 | margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] 259 | return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] 260 | 261 | 262 | # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ 263 | 264 | def perfcount(fn): 265 | def wrapper(*args, **kwargs): 266 | ts = time() 267 | 268 | if torch.cuda.is_available(): 269 | torch.cuda.reset_peak_memory_stats(devices.device) 270 | devices.torch_gc() 271 | gc.collect() 272 | 273 | ret = fn(*args, **kwargs) 274 | 275 | devices.torch_gc() 276 | gc.collect() 277 | if torch.cuda.is_available(): 278 | vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 279 | print(f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') 280 | else: 281 | print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') 282 | 283 | return ret 284 | return wrapper 285 | 286 | # ↑↑↑ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↑↑↑ 287 | 288 | 289 | class GroupNormParam: 290 | 291 | def __init__(self): 292 | self.var_list = [] 293 | self.mean_list = [] 294 | self.pixel_list = [] 295 | self.weight = None 296 | self.bias = None 297 | 298 | def add_tile(self, tile, layer): 299 | var, mean = get_var_mean(tile, 32) 300 | # For giant images, the variance can be larger than max float16 301 | # In this case we create a copy to float32 302 | if var.dtype == torch.float16 and var.isinf().any(): 303 | fp32_tile = tile.float() 304 | var, mean = get_var_mean(fp32_tile, 32) 305 | # ============= DEBUG: test for infinite ============= 306 | # if torch.isinf(var).any(): 307 | # print('var: ', var) 308 | # ==================================================== 309 | self.var_list.append(var) 310 | self.mean_list.append(mean) 311 | self.pixel_list.append( 312 | tile.shape[2]*tile.shape[3]) 313 | if hasattr(layer, 'weight'): 314 | self.weight = layer.weight 315 | self.bias = layer.bias 316 | else: 317 | self.weight = None 318 | self.bias = None 319 | 320 | def summary(self): 321 | """ 322 | summarize the mean and var and return a function 323 | that apply group norm on each tile 324 | """ 325 | if len(self.var_list) == 0: return None 326 | 327 | var = torch.vstack(self.var_list) 328 | mean = torch.vstack(self.mean_list) 329 | max_value = max(self.pixel_list) 330 | pixels = torch.tensor(self.pixel_list, dtype=torch.float32, device=devices.device) / max_value 331 | sum_pixels = torch.sum(pixels) 332 | pixels = pixels.unsqueeze(1) / sum_pixels 333 | var = torch.sum(var * pixels, dim=0) 334 | mean = torch.sum(mean * pixels, dim=0) 335 | return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) 336 | 337 | @staticmethod 338 | def from_tile(tile, norm): 339 | """ 340 | create a function from a single tile without summary 341 | """ 342 | var, mean = get_var_mean(tile, 32) 343 | if var.dtype == torch.float16 and var.isinf().any(): 344 | fp32_tile = tile.float() 345 | var, mean = get_var_mean(fp32_tile, 32) 346 | # if it is a macbook, we need to convert back to float16 347 | if var.device.type == 'mps': 348 | # clamp to avoid overflow 349 | var = torch.clamp(var, 0, 60000) 350 | var = var.half() 351 | mean = mean.half() 352 | if hasattr(norm, 'weight'): 353 | weight = norm.weight 354 | bias = norm.bias 355 | else: 356 | weight = None 357 | bias = None 358 | 359 | def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): 360 | return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) 361 | return group_norm_func 362 | 363 | 364 | class VAEHook: 365 | 366 | def __init__(self, net, tile_size, is_decoder:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool, to_gpu:bool=False): 367 | self.net = net # encoder | decoder 368 | self.tile_size = tile_size 369 | self.is_decoder = is_decoder 370 | self.fast_mode = (fast_encoder and not is_decoder) or (fast_decoder and is_decoder) 371 | self.color_fix = color_fix and not is_decoder 372 | self.to_gpu = to_gpu 373 | self.pad = 11 if is_decoder else 32 # FIXME: magic number 374 | 375 | def __call__(self, x): 376 | original_device = next(self.net.parameters()).device 377 | try: 378 | if self.to_gpu: 379 | self.net = self.net.to(devices.get_optimal_device()) 380 | 381 | B, C, H, W = x.shape 382 | if max(H, W) <= self.pad * 2 + self.tile_size: 383 | print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") 384 | return self.net.original_forward(x) 385 | else: 386 | return self.vae_tile_forward(x) 387 | finally: 388 | self.net = self.net.to(original_device) 389 | 390 | def get_best_tile_size(self, lowerbound, upperbound): 391 | """ 392 | Get the best tile size for GPU memory 393 | """ 394 | divider = 32 395 | while divider >= 2: 396 | remainer = lowerbound % divider 397 | if remainer == 0: 398 | return lowerbound 399 | candidate = lowerbound - remainer + divider 400 | if candidate <= upperbound: 401 | return candidate 402 | divider //= 2 403 | return lowerbound 404 | 405 | def split_tiles(self, h, w): 406 | """ 407 | Tool function to split the image into tiles 408 | @param h: height of the image 409 | @param w: width of the image 410 | @return: tile_input_bboxes, tile_output_bboxes 411 | """ 412 | tile_input_bboxes, tile_output_bboxes = [], [] 413 | tile_size = self.tile_size 414 | pad = self.pad 415 | num_height_tiles = math.ceil((h - 2 * pad) / tile_size) 416 | num_width_tiles = math.ceil((w - 2 * pad) / tile_size) 417 | # If any of the numbers are 0, we let it be 1 418 | # This is to deal with long and thin images 419 | num_height_tiles = max(num_height_tiles, 1) 420 | num_width_tiles = max(num_width_tiles, 1) 421 | 422 | # Suggestions from https://github.com/Kahsolt: auto shrink the tile size 423 | real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) 424 | real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) 425 | real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) 426 | real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) 427 | 428 | print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + 429 | f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') 430 | 431 | for i in range(num_height_tiles): 432 | for j in range(num_width_tiles): 433 | # bbox: [x1, x2, y1, y2] 434 | # the padding is is unnessary for image borders. So we directly start from (32, 32) 435 | input_bbox = [ 436 | pad + j * real_tile_width, 437 | min(pad + (j + 1) * real_tile_width, w), 438 | pad + i * real_tile_height, 439 | min(pad + (i + 1) * real_tile_height, h), 440 | ] 441 | 442 | # if the output bbox is close to the image boundary, we extend it to the image boundary 443 | output_bbox = [ 444 | input_bbox[0] if input_bbox[0] > pad else 0, 445 | input_bbox[1] if input_bbox[1] < w - pad else w, 446 | input_bbox[2] if input_bbox[2] > pad else 0, 447 | input_bbox[3] if input_bbox[3] < h - pad else h, 448 | ] 449 | 450 | # scale to get the final output bbox 451 | output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] 452 | tile_output_bboxes.append(output_bbox) 453 | 454 | # indistinguishable expand the input bbox by pad pixels 455 | tile_input_bboxes.append([ 456 | max(0, input_bbox[0] - pad), 457 | min(w, input_bbox[1] + pad), 458 | max(0, input_bbox[2] - pad), 459 | min(h, input_bbox[3] + pad), 460 | ]) 461 | 462 | return tile_input_bboxes, tile_output_bboxes 463 | 464 | @torch.no_grad() 465 | def estimate_group_norm(self, z, task_queue, color_fix): 466 | device = z.device 467 | tile = z 468 | last_id = len(task_queue) - 1 469 | while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': 470 | last_id -= 1 471 | if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': 472 | raise ValueError('No group norm found in the task queue') 473 | # estimate until the last group norm 474 | for i in range(last_id + 1): 475 | task = task_queue[i] 476 | if task[0] == 'pre_norm': 477 | group_norm_func = GroupNormParam.from_tile(tile, task[1]) 478 | task_queue[i] = ('apply_norm', group_norm_func) 479 | if i == last_id: 480 | return True 481 | tile = group_norm_func(tile) 482 | elif task[0] == 'store_res': 483 | task_id = i + 1 484 | while task_id < last_id and task_queue[task_id][0] != 'add_res': 485 | task_id += 1 486 | if task_id >= last_id: 487 | continue 488 | task_queue[task_id][1] = task[1](tile) 489 | elif task[0] == 'add_res': 490 | tile += task[1].to(device) 491 | task[1] = None 492 | elif color_fix and task[0] == 'downsample': 493 | for j in range(i, last_id + 1): 494 | if task_queue[j][0] == 'store_res': 495 | task_queue[j] = ('store_res_cpu', task_queue[j][1]) 496 | return True 497 | else: 498 | tile = task[1](tile) 499 | try: 500 | devices.test_for_nans(tile, "vae") 501 | except: 502 | print(f'Nan detected in fast mode estimation. Fast mode disabled.') 503 | return False 504 | 505 | raise IndexError('Should not reach here') 506 | 507 | @perfcount 508 | @torch.no_grad() 509 | def vae_tile_forward(self, z): 510 | """ 511 | Decode a latent vector z into an image in a tiled manner. 512 | @param z: latent vector 513 | @return: image 514 | """ 515 | device = next(self.net.parameters()).device 516 | dtype = next(self.net.parameters()).dtype 517 | net = self.net 518 | tile_size = self.tile_size 519 | is_decoder = self.is_decoder 520 | 521 | z = z.detach() # detach the input to avoid backprop 522 | 523 | N, height, width = z.shape[0], z.shape[2], z.shape[3] 524 | net.last_z_shape = z.shape 525 | 526 | # Split the input into tiles and build a task queue for each tile 527 | print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') 528 | 529 | in_bboxes, out_bboxes = self.split_tiles(height, width) 530 | 531 | # Prepare tiles by split the input latents 532 | tiles = [] 533 | for input_bbox in in_bboxes: 534 | tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() 535 | tiles.append(tile) 536 | 537 | num_tiles = len(tiles) 538 | num_completed = 0 539 | 540 | # Build task queues 541 | single_task_queue = build_task_queue(net, is_decoder) 542 | if self.fast_mode: 543 | # Fast mode: downsample the input image to the tile size, 544 | # then estimate the group norm parameters on the downsampled image 545 | scale_factor = tile_size / max(height, width) 546 | z = z.to(device) 547 | downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') 548 | # use nearest-exact to keep statictics as close as possible 549 | print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') 550 | 551 | # ======= Special thanks to @Kahsolt for distribution shift issue ======= # 552 | # The downsampling will heavily distort its mean and std, so we need to recover it. 553 | std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) 554 | std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) 555 | downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old 556 | del std_old, mean_old, std_new, mean_new 557 | # occasionally the std_new is too small or too large, which exceeds the range of float16 558 | # so we need to clamp it to max z's range. 559 | downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) 560 | estimate_task_queue = clone_task_queue(single_task_queue) 561 | if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): 562 | single_task_queue = estimate_task_queue 563 | del downsampled_z 564 | 565 | task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] 566 | 567 | # Dummy result 568 | result = None 569 | result_approx = None 570 | try: 571 | with devices.autocast(): 572 | result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() 573 | except: pass 574 | # Free memory of input latent tensor 575 | del z 576 | 577 | # Task queue execution 578 | pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") 579 | 580 | # execute the task back and forth when switch tiles so that we always 581 | # keep one tile on the GPU to reduce unnecessary data transfer 582 | forward = True 583 | interrupted = False 584 | #state.interrupted = interrupted 585 | while True: 586 | if state.interrupted: interrupted = True ; break 587 | 588 | group_norm_param = GroupNormParam() 589 | for i in range(num_tiles) if forward else reversed(range(num_tiles)): 590 | if state.interrupted: interrupted = True ; break 591 | 592 | tile = tiles[i].to(device) 593 | input_bbox = in_bboxes[i] 594 | task_queue = task_queues[i] 595 | 596 | interrupted = False 597 | while len(task_queue) > 0: 598 | if state.interrupted: interrupted = True ; break 599 | 600 | # DEBUG: current task 601 | # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) 602 | task = task_queue.pop(0) 603 | if task[0] == 'pre_norm': 604 | group_norm_param.add_tile(tile, task[1]) 605 | break 606 | elif task[0] == 'store_res' or task[0] == 'store_res_cpu': 607 | task_id = 0 608 | res = task[1](tile) 609 | if not self.fast_mode or task[0] == 'store_res_cpu': 610 | res = res.cpu() 611 | while task_queue[task_id][0] != 'add_res': 612 | task_id += 1 613 | task_queue[task_id][1] = res 614 | elif task[0] == 'add_res': 615 | tile += task[1].to(device) 616 | task[1] = None 617 | else: 618 | tile = task[1](tile) 619 | pbar.update(1) 620 | 621 | if interrupted: break 622 | 623 | # check for NaNs in the tile. 624 | # If there are NaNs, we abort the process to save user's time 625 | devices.test_for_nans(tile, "vae") 626 | 627 | if len(task_queue) == 0: 628 | tiles[i] = None 629 | num_completed += 1 630 | if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically 631 | result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) 632 | result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) 633 | del tile 634 | elif i == num_tiles - 1 and forward: 635 | forward = False 636 | tiles[i] = tile 637 | elif i == 0 and not forward: 638 | forward = True 639 | tiles[i] = tile 640 | else: 641 | tiles[i] = tile.cpu() 642 | del tile 643 | 644 | if interrupted: break 645 | if num_completed == num_tiles: break 646 | 647 | # insert the group norm task to the head of each task queue 648 | group_norm_func = group_norm_param.summary() 649 | if group_norm_func is not None: 650 | for i in range(num_tiles): 651 | task_queue = task_queues[i] 652 | task_queue.insert(0, ('apply_norm', group_norm_func)) 653 | 654 | # Done! 655 | pbar.close() 656 | return result.to(dtype) if result is not None else result_approx.to(device, dtype=dtype) 657 | 658 | 659 | class Script(scripts.Script): 660 | 661 | def __init__(self): 662 | self.hooked = False 663 | 664 | def title(self): 665 | return "Tiled VAE" 666 | 667 | def show(self, is_img2img): 668 | return scripts.AlwaysVisible 669 | 670 | def ui(self, is_img2img): 671 | tab = 't2i' if not is_img2img else 'i2i' 672 | uid = lambda name: f'MD-{tab}-{name}' 673 | 674 | with ( 675 | InputAccordion(False, label='Tiled VAE', elem_id=f'MDV-{tab}-enabled') if InputAccordion 676 | else gr.Accordion('Tiled VAE', open=False, elem_id=f'MDV-{tab}') 677 | as enabled 678 | ): 679 | with gr.Row() as tab_enable: 680 | if not InputAccordion: 681 | enabled = gr.Checkbox(label='Enable Tiled VAE', value=False, elem_id=uid('enable')) 682 | vae_to_gpu = gr.Checkbox(label='Move VAE to GPU (if possible)', value=True, elem_id=uid('vae2gpu')) 683 | 684 | gr.HTML('Recommended to set tile sizes as large as possible before got CUDA error: out of memory.
') 685 | with gr.Row() as tab_size: 686 | encoder_tile_size = gr.Slider(label='Encoder Tile Size', minimum=256, maximum=4096, step=16, value=get_rcmd_enc_tsize(), elem_id=uid('enc-size')) 687 | decoder_tile_size = gr.Slider(label='Decoder Tile Size', minimum=48, maximum=512, step=16, value=get_rcmd_dec_tsize(), elem_id=uid('dec-size')) 688 | reset = gr.Button(value='↻ Reset', variant='tool') 689 | reset.click(fn=lambda: [get_rcmd_enc_tsize(), get_rcmd_dec_tsize()], outputs=[encoder_tile_size, decoder_tile_size], show_progress=False) 690 | 691 | with gr.Row() as tab_param: 692 | fast_encoder = gr.Checkbox(label='Fast Encoder', value=True, elem_id=uid('fastenc')) 693 | color_fix = gr.Checkbox(label='Fast Encoder Color Fix', value=False, visible=True, elem_id=uid('fastenc-colorfix')) 694 | fast_decoder = gr.Checkbox(label='Fast Decoder', value=True, elem_id=uid('fastdec')) 695 | 696 | fast_encoder.change(fn=gr_show, inputs=fast_encoder, outputs=color_fix, show_progress=False) 697 | 698 | return [ 699 | enabled, 700 | encoder_tile_size, decoder_tile_size, 701 | vae_to_gpu, fast_decoder, fast_encoder, color_fix, 702 | ] 703 | 704 | def process(self, p:Processing, 705 | enabled:bool, 706 | encoder_tile_size:int, decoder_tile_size:int, 707 | vae_to_gpu:bool, fast_decoder:bool, fast_encoder:bool, color_fix:bool 708 | ): 709 | 710 | # for shorthand 711 | vae = p.sd_model.first_stage_model 712 | encoder = vae.encoder 713 | decoder = vae.decoder 714 | 715 | # undo hijack if disabled (in cases last time crashed) 716 | if not enabled: 717 | if self.hooked: 718 | if isinstance(encoder.forward, VAEHook): 719 | encoder.forward.net = None 720 | encoder.forward = encoder.original_forward 721 | if isinstance(decoder.forward, VAEHook): 722 | decoder.forward.net = None 723 | decoder.forward = decoder.original_forward 724 | self.hooked = False 725 | return 726 | 727 | if devices.get_optimal_device_name().startswith('cuda') and vae.device == devices.cpu and not vae_to_gpu: 728 | print("[Tiled VAE] warn: VAE is not on GPU, check 'Move VAE to GPU' if possible.") 729 | 730 | # do hijack 731 | kwargs = { 732 | 'fast_decoder': fast_decoder, 733 | 'fast_encoder': fast_encoder, 734 | 'color_fix': color_fix, 735 | 'to_gpu': vae_to_gpu, 736 | } 737 | 738 | # save original forward (only once) 739 | if not hasattr(encoder, 'original_forward'): setattr(encoder, 'original_forward', encoder.forward) 740 | if not hasattr(decoder, 'original_forward'): setattr(decoder, 'original_forward', decoder.forward) 741 | 742 | self.hooked = True 743 | 744 | encoder.forward = VAEHook(encoder, encoder_tile_size, is_decoder=False, **kwargs) 745 | decoder.forward = VAEHook(decoder, decoder_tile_size, is_decoder=True, **kwargs) 746 | 747 | def postprocess(self, p:Processing, processed, enabled:bool, *args): 748 | if not enabled: return 749 | 750 | vae = p.sd_model.first_stage_model 751 | encoder = vae.encoder 752 | decoder = vae.decoder 753 | if isinstance(encoder.forward, VAEHook): 754 | encoder.forward.net = None 755 | encoder.forward = encoder.original_forward 756 | if isinstance(decoder.forward, VAEHook): 757 | decoder.forward.net = None 758 | decoder.forward = decoder.original_forward 759 | -------------------------------------------------------------------------------- /tile_methods/demofusion.py: -------------------------------------------------------------------------------- 1 | from tile_methods.abstractdiffusion import AbstractDiffusion 2 | from tile_utils.utils import * 3 | import torch.nn.functional as F 4 | import random 5 | from copy import deepcopy 6 | import inspect 7 | from modules import sd_samplers_common 8 | 9 | 10 | class DemoFusion(AbstractDiffusion): 11 | """ 12 | DemoFusion Implementation 13 | https://arxiv.org/abs/2311.16973 14 | """ 15 | 16 | def __init__(self, p:Processing, *args, **kwargs): 17 | super().__init__(p, *args, **kwargs) 18 | assert p.sampler_name != 'UniPC', 'Demofusion is not compatible with UniPC!' 19 | 20 | 21 | def hook(self): 22 | steps, self.t_enc = sd_samplers_common.setup_img2img_steps(self.p, None) 23 | 24 | self.sampler.model_wrap_cfg.forward_ori = self.sampler.model_wrap_cfg.forward 25 | self.sampler_forward = self.sampler.model_wrap_cfg.inner_model.forward 26 | self.sampler.model_wrap_cfg.forward = self.forward_one_step 27 | if self.is_kdiff: 28 | self.sampler: KDiffusionSampler 29 | self.sampler.model_wrap_cfg: CFGDenoiserKDiffusion 30 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisDenoiser, CompVisVDenoiser] 31 | else: 32 | self.sampler: CompVisSampler 33 | self.sampler.model_wrap_cfg: CFGDenoiserTimesteps 34 | self.sampler.model_wrap_cfg.inner_model: Union[CompVisTimestepsDenoiser, CompVisTimestepsVDenoiser] 35 | self.timesteps = self.sampler.get_timesteps(self.p, steps) 36 | 37 | @staticmethod 38 | def unhook(): 39 | if hasattr(shared.sd_model, 'apply_model_ori'): 40 | shared.sd_model.apply_model = shared.sd_model.apply_model_ori 41 | del shared.sd_model.apply_model_ori 42 | 43 | def reset_buffer(self, x_in:Tensor): 44 | super().reset_buffer(x_in) 45 | 46 | 47 | 48 | def repeat_tensor(self, x:Tensor, n:int) -> Tensor: 49 | ''' repeat the tensor on it's first dim ''' 50 | if n == 1: return x 51 | B = x.shape[0] 52 | r_dims = len(x.shape) - 1 53 | if B == 1: # batch_size = 1 (not `tile_batch_size`) 54 | shape = [n] + [-1] * r_dims # [N, -1, ...] 55 | return x.expand(shape) # `expand` is much lighter than `tile` 56 | else: 57 | shape = [n] + [1] * r_dims # [N, 1, ...] 58 | return x.repeat(shape) 59 | 60 | def repeat_cond_dict(self, cond_in:CondDict, bboxes,mode) -> CondDict: 61 | ''' repeat all tensors in cond_dict on it's first dim (for a batch of tiles), returns a new object ''' 62 | # n_repeat 63 | n_rep = len(bboxes) 64 | # txt cond 65 | tcond = self.get_tcond(cond_in) # [B=1, L, D] => [B*N, L, D] 66 | tcond = self.repeat_tensor(tcond, n_rep) 67 | # img cond 68 | icond = self.get_icond(cond_in) 69 | if icond.shape[2:] == (self.h, self.w): # img2img, [B=1, C, H, W] 70 | if mode == 0: 71 | if self.p.random_jitter: 72 | jitter_range = self.jitter_range 73 | icond = F.pad(icond,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) 74 | icond = torch.cat([icond[bbox.slicer] for bbox in bboxes], dim=0) 75 | else: 76 | icond = torch.cat([icond[:,:,bbox[1]::self.p.current_scale_num,bbox[0]::self.p.current_scale_num] for bbox in bboxes], dim=0) 77 | else: # txt2img, [B=1, C=5, H=1, W=1] 78 | icond = self.repeat_tensor(icond, n_rep) 79 | 80 | # vec cond (SDXL) 81 | vcond = self.get_vcond(cond_in) # [B=1, D] 82 | if vcond is not None: 83 | vcond = self.repeat_tensor(vcond, n_rep) # [B*N, D] 84 | return self.make_cond_dict(cond_in, tcond, icond, vcond) 85 | 86 | 87 | def global_split_bboxes(self): 88 | cols = self.p.current_scale_num 89 | rows = cols 90 | 91 | bbox_list = [] 92 | for row in range(rows): 93 | y = row 94 | for col in range(cols): 95 | x = col 96 | bbox = (x, y) 97 | bbox_list.append(bbox) 98 | 99 | return bbox_list+bbox_list if self.p.mixture else bbox_list 100 | 101 | def split_bboxes_jitter(self,w_l:int, h_l:int, tile_w:int, tile_h:int, overlap:int=16, init_weight:Union[Tensor, float]=1.0) -> Tuple[List[BBox], Tensor]: 102 | cols = math.ceil((w_l - overlap) / (tile_w - overlap)) 103 | rows = math.ceil((h_l - overlap) / (tile_h - overlap)) 104 | if rows==0: 105 | rows=1 106 | if cols == 0: 107 | cols=1 108 | dx = (w_l - tile_w) / (cols - 1) if cols > 1 else 0 109 | dy = (h_l - tile_h) / (rows - 1) if rows > 1 else 0 110 | bbox_list: List[BBox] = [] 111 | self.jitter_range = 0 112 | for row in range(rows): 113 | for col in range(cols): 114 | h = min(int(row * dy), h_l - tile_h) 115 | w = min(int(col * dx), w_l - tile_w) 116 | if self.p.random_jitter: 117 | self.jitter_range = min(max((min(self.w, self.h)-self.stride)//4,0),min(int(self.window_size/2),int(self.overlap/2))) 118 | jitter_range = self.jitter_range 119 | w_jitter = 0 120 | h_jitter = 0 121 | if (w != 0) and (w+tile_w != w_l): 122 | w_jitter = random.randint(-jitter_range, jitter_range) 123 | elif (w == 0) and (w + tile_w != w_l): 124 | w_jitter = random.randint(-jitter_range, 0) 125 | elif (w != 0) and (w + tile_w == w_l): 126 | w_jitter = random.randint(0, jitter_range) 127 | if (h != 0) and (h + tile_h != h_l): 128 | h_jitter = random.randint(-jitter_range, jitter_range) 129 | elif (h == 0) and (h + tile_h != h_l): 130 | h_jitter = random.randint(-jitter_range, 0) 131 | elif (h != 0) and (h + tile_h == h_l): 132 | h_jitter = random.randint(0, jitter_range) 133 | h +=(h_jitter + jitter_range) 134 | w += (w_jitter + jitter_range) 135 | 136 | bbox = BBox(w, h, tile_w, tile_h) 137 | bbox_list.append(bbox) 138 | return bbox_list, None 139 | 140 | @grid_bbox 141 | def get_views(self, overlap:int, tile_bs:int,tile_bs_g:int): 142 | self.enable_grid_bbox = True 143 | self.tile_w = self.window_size 144 | self.tile_h = self.window_size 145 | 146 | self.overlap = max(0, min(overlap, self.window_size - 4)) 147 | 148 | self.stride = max(4,self.window_size - self.overlap) 149 | 150 | # split the latent into overlapped tiles, then batching 151 | # weights basically indicate how many times a pixel is painted 152 | bboxes, _ = self.split_bboxes_jitter(self.w, self.h, self.tile_w, self.tile_h, self.overlap, self.get_tile_weights()) 153 | self.num_tiles = len(bboxes) 154 | self.num_batches = math.ceil(self.num_tiles / tile_bs) 155 | self.tile_bs = math.ceil(len(bboxes) / self.num_batches) # optimal_batch_size 156 | self.batched_bboxes = [bboxes[i*self.tile_bs:(i+1)*self.tile_bs] for i in range(self.num_batches)] 157 | 158 | global_bboxes = self.global_split_bboxes() 159 | self.global_num_tiles = len(global_bboxes) 160 | self.global_num_batches = math.ceil(self.global_num_tiles / tile_bs_g) 161 | self.global_tile_bs = math.ceil(len(global_bboxes) / self.global_num_batches) 162 | self.global_batched_bboxes = [global_bboxes[i*self.global_tile_bs:(i+1)*self.global_tile_bs] for i in range(self.global_num_batches)] 163 | 164 | def gaussian_kernel(self,kernel_size=3, sigma=1.0, channels=3): 165 | x_coord = torch.arange(kernel_size, device=devices.device) 166 | gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2)) 167 | gaussian_1d = gaussian_1d / gaussian_1d.sum() 168 | gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] 169 | kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) 170 | 171 | return kernel 172 | 173 | def gaussian_filter(self,latents, kernel_size=3, sigma=1.0): 174 | channels = latents.shape[1] 175 | kernel = self.gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) 176 | blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels) 177 | 178 | return blurred_latents 179 | 180 | 181 | 182 | ''' ↓↓↓ kernel hijacks ↓↓↓ ''' 183 | @torch.no_grad() 184 | @keep_signature 185 | def forward_one_step(self, x_in, sigma, **kwarg): 186 | if self.is_kdiff: 187 | x_noisy = self.p.x + self.p.noise * sigma[0] 188 | else: 189 | alphas_cumprod = self.p.sd_model.alphas_cumprod 190 | sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) 191 | sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[self.timesteps[self.t_enc-self.p.current_step]]) 192 | x_noisy = self.p.x*sqrt_alpha_cumprod + self.p.noise * sqrt_one_minus_alpha_cumprod 193 | 194 | self.cosine_factor = 0.5 * (1 + torch.cos(torch.pi *torch.tensor(((self.p.current_step + 1) / (self.t_enc+1))))) 195 | 196 | c1 = self.cosine_factor ** self.p.cosine_scale_1 197 | 198 | x_in = x_in*(1 - c1) + x_noisy * c1 199 | 200 | if self.p.random_jitter: 201 | jitter_range = self.jitter_range 202 | else: 203 | jitter_range = 0 204 | x_in_ = F.pad(x_in,(jitter_range, jitter_range, jitter_range, jitter_range),'constant',value=0) 205 | _,_,H,W = x_in.shape 206 | 207 | self.sampler.model_wrap_cfg.inner_model.forward = self.sample_one_step 208 | self.repeat_3 = False 209 | 210 | x_out = self.sampler.model_wrap_cfg.forward_ori(x_in_,sigma, **kwarg) 211 | self.sampler.model_wrap_cfg.inner_model.forward = self.sampler_forward 212 | x_out = x_out[:,:,jitter_range:jitter_range+H,jitter_range:jitter_range+W] 213 | 214 | return x_out 215 | 216 | 217 | @torch.no_grad() 218 | @keep_signature 219 | def sample_one_step(self, x_in, sigma, cond): 220 | assert LatentDiffusion.apply_model 221 | def repeat_func_1(x_tile:Tensor, bboxes,mode=0) -> Tensor: 222 | sigma_tile = self.repeat_tensor(sigma, len(bboxes)) 223 | cond_tile = self.repeat_cond_dict(cond, bboxes,mode) 224 | return self.sampler_forward(x_tile, sigma_tile, cond=cond_tile) 225 | 226 | def repeat_func_2(x_tile:Tensor, bboxes,mode=0) -> Tuple[Tensor, Tensor]: 227 | n_rep = len(bboxes) 228 | ts_tile = self.repeat_tensor(sigma, n_rep) 229 | if isinstance(cond, dict): # FIXME: when will enter this branch? 230 | cond_tile = self.repeat_cond_dict(cond, bboxes,mode) 231 | else: 232 | cond_tile = self.repeat_tensor(cond, n_rep) 233 | return self.sampler_forward(x_tile, ts_tile, cond=cond_tile) 234 | 235 | def repeat_func_3(x_tile:Tensor, bboxes,mode=0): 236 | sigma_in_tile = sigma.repeat(len(bboxes)) 237 | cond_out = self.repeat_cond_dict(cond, bboxes,mode) 238 | x_tile_out = shared.sd_model.apply_model(x_tile, sigma_in_tile, cond=cond_out) 239 | return x_tile_out 240 | 241 | if self.repeat_3: 242 | repeat_func = repeat_func_3 243 | self.repeat_3 = False 244 | elif self.is_kdiff: 245 | repeat_func = repeat_func_1 246 | else: 247 | repeat_func = repeat_func_2 248 | N,_,_,_ = x_in.shape 249 | 250 | 251 | self.x_buffer = torch.zeros_like(x_in) 252 | self.weights = torch.zeros_like(x_in) 253 | 254 | for batch_id, bboxes in enumerate(self.batched_bboxes): 255 | if state.interrupted: return x_in 256 | x_tile = torch.cat([x_in[bbox.slicer] for bbox in bboxes], dim=0) 257 | x_tile_out = repeat_func(x_tile, bboxes) 258 | # de-batching 259 | for i, bbox in enumerate(bboxes): 260 | self.x_buffer[bbox.slicer] += x_tile_out[i*N:(i+1)*N, :, :, :] 261 | self.weights[bbox.slicer] += 1 262 | self.weights = torch.where(self.weights == 0, torch.tensor(1), self.weights) #Prevent NaN from appearing in random_jitter mode 263 | 264 | x_local = self.x_buffer/self.weights 265 | 266 | self.x_buffer = torch.zeros_like(self.x_buffer) 267 | self.weights = torch.zeros_like(self.weights) 268 | 269 | std_, mean_ = x_in.std(), x_in.mean() 270 | c3 = 0.99 * self.cosine_factor ** self.p.cosine_scale_3 + 1e-2 271 | if self.p.gaussian_filter: 272 | x_in_g = self.gaussian_filter(x_in, kernel_size=(2*self.p.current_scale_num-1), sigma=self.sig*c3) 273 | x_in_g = (x_in_g - x_in_g.mean()) / x_in_g.std() * std_ + mean_ 274 | 275 | if not hasattr(self.p.sd_model, 'apply_model_ori'): 276 | self.p.sd_model.apply_model_ori = self.p.sd_model.apply_model 277 | self.p.sd_model.apply_model = self.apply_model_hijack 278 | x_global = torch.zeros_like(x_local) 279 | jitter_range = self.jitter_range 280 | end = x_global.shape[3]-jitter_range 281 | 282 | current_num = 0 283 | if self.p.mixture: 284 | for batch_id, bboxes in enumerate(self.global_batched_bboxes): 285 | current_num += len(bboxes) 286 | if current_num > (self.global_num_tiles//2) and (current_num-self.global_tile_bs) < (self.global_num_tiles//2): 287 | res = len(bboxes) - (current_num - self.global_num_tiles//2) 288 | x_in_i = torch.cat([x_in[:,:,bbox[1]+jitter_range:end:self.p.current_scale_num,bbox[0]+jitter_range:end:self.p.current_scale_num] if idx