├── .dockerignore ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .gitmodules ├── LICENSE ├── app.py ├── canvas.py ├── config.yaml ├── convert_checkpoint.py ├── css └── w2ui.min.css ├── docker-compose.yml ├── docker ├── Dockerfile ├── docker-run.sh ├── entrypoint.sh ├── opencv.pc └── run-shell.sh ├── docs ├── run_with_docker.md ├── setup_guide.md └── usage.md ├── environment.yml ├── index.html ├── interrogate.py ├── js ├── fabric.min.js ├── keyboard.js ├── mode.js ├── outpaint.js ├── proceed.js ├── setup.js ├── toolbar.js ├── upload.js ├── w2ui.min.js └── xss.js ├── models ├── v1-inference.yaml └── v1-inpainting-inference.yaml ├── perlin2d.py ├── postprocess.py ├── process.py ├── readme.md ├── stablediffusion_infinity_colab.ipynb └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .github/ 3 | .git/ 4 | docs/ 5 | .dockerignore 6 | readme.md 7 | LICENSE -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | For setup problems or dependencies problems, please post in Q&A in Discussions 13 | 14 | **To Reproduce** 15 | Steps to reproduce the behavior: 16 | 1. Go to '...' 17 | 2. Click on '....' 18 | 3. Scroll down to '....' 19 | 4. See error 20 | 21 | **Expected behavior** 22 | A clear and concise description of what you expected to happen. 23 | 24 | **Screenshots** 25 | If applicable, add screenshots to help explain your problem. 26 | 27 | **Desktop (please complete the following information):** 28 | - OS: [e.g. windows] 29 | - Browser [e.g. chrome] 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature Request]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | Makefile 3 | .ipynb_checkpoints/ 4 | build/ 5 | csrc/ 6 | .idea/ 7 | travis.sh 8 | *.iml 9 | .token 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "glid_3_xl_stable"] 2 | path = glid_3_xl_stable 3 | url = https://github.com/lkwq007/glid_3_xl_stable.git 4 | [submodule "PyPatchMatch"] 5 | path = PyPatchMatch 6 | url = https://github.com/lkwq007/PyPatchMatch.git 7 | [submodule "sd_grpcserver"] 8 | path = sd_grpcserver 9 | url = https://github.com/lkwq007/sd_grpcserver.git 10 | [submodule "blip_model"] 11 | path = blip_model 12 | url = https://github.com/lkwq007/blip_model 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /canvas.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import io 4 | import numpy as np 5 | from PIL import Image 6 | from pyodide import to_js, create_proxy 7 | import gc 8 | from js import ( 9 | console, 10 | document, 11 | devicePixelRatio, 12 | ImageData, 13 | Uint8ClampedArray, 14 | CanvasRenderingContext2D as Context2d, 15 | requestAnimationFrame, 16 | update_overlay, 17 | setup_overlay, 18 | window 19 | ) 20 | 21 | PAINT_SELECTION = "selection" 22 | IMAGE_SELECTION = "canvas" 23 | BRUSH_SELECTION = "eraser" 24 | NOP_MODE = 0 25 | PAINT_MODE = 1 26 | IMAGE_MODE = 2 27 | BRUSH_MODE = 3 28 | 29 | 30 | def hold_canvas(): 31 | pass 32 | 33 | 34 | def prepare_canvas(width, height, canvas) -> Context2d: 35 | ctx = canvas.getContext("2d") 36 | 37 | canvas.style.width = f"{width}px" 38 | canvas.style.height = f"{height}px" 39 | 40 | canvas.width = width 41 | canvas.height = height 42 | 43 | ctx.clearRect(0, 0, width, height) 44 | 45 | return ctx 46 | 47 | 48 | # class MultiCanvas: 49 | # def __init__(self,layer,width=800, height=600) -> None: 50 | # pass 51 | def multi_canvas(layer, width=800, height=600): 52 | lst = [ 53 | CanvasProxy(document.querySelector(f"#canvas{i}"), width, height) 54 | for i in range(layer) 55 | ] 56 | return lst 57 | 58 | 59 | class CanvasProxy: 60 | def __init__(self, canvas, width=800, height=600) -> None: 61 | self.canvas = canvas 62 | self.ctx = prepare_canvas(width, height, canvas) 63 | self.width = width 64 | self.height = height 65 | 66 | def clear_rect(self, x, y, w, h): 67 | self.ctx.clearRect(x, y, w, h) 68 | 69 | def clear(self,): 70 | self.clear_rect(0, 0, self.canvas.width, self.canvas.height) 71 | 72 | def stroke_rect(self, x, y, w, h): 73 | self.ctx.strokeRect(x, y, w, h) 74 | 75 | def fill_rect(self, x, y, w, h): 76 | self.ctx.fillRect(x, y, w, h) 77 | 78 | def put_image_data(self, image, x, y): 79 | data = Uint8ClampedArray.new(to_js(image.tobytes())) 80 | height, width, _ = image.shape 81 | image_data = ImageData.new(data, width, height) 82 | self.ctx.putImageData(image_data, x, y) 83 | del image_data 84 | 85 | # def draw_image(self,canvas, x, y, w, h): 86 | # self.ctx.drawImage(canvas,x,y,w,h) 87 | def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight): 88 | self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight) 89 | 90 | @property 91 | def stroke_style(self): 92 | return self.ctx.strokeStyle 93 | 94 | @stroke_style.setter 95 | def stroke_style(self, value): 96 | self.ctx.strokeStyle = value 97 | 98 | @property 99 | def fill_style(self): 100 | return self.ctx.strokeStyle 101 | 102 | @fill_style.setter 103 | def fill_style(self, value): 104 | self.ctx.fillStyle = value 105 | 106 | 107 | # RGBA for masking 108 | class InfCanvas: 109 | def __init__( 110 | self, 111 | width, 112 | height, 113 | selection_size=256, 114 | grid_size=64, 115 | patch_size=4096, 116 | test_mode=False, 117 | ) -> None: 118 | assert selection_size < min(height, width) 119 | self.width = width 120 | self.height = height 121 | self.display_width = width 122 | self.display_height = height 123 | self.canvas = multi_canvas(5, width=width, height=height) 124 | setup_overlay(width,height) 125 | # place at center 126 | self.view_pos = [patch_size//2-width//2, patch_size//2-height//2] 127 | self.cursor = [ 128 | width // 2 - selection_size // 2, 129 | height // 2 - selection_size // 2, 130 | ] 131 | self.data = {} 132 | self.grid_size = grid_size 133 | self.selection_size_w = selection_size 134 | self.selection_size_h = selection_size 135 | self.patch_size = patch_size 136 | # note that for image data, the height comes before width 137 | self.buffer = np.zeros((height, width, 4), dtype=np.uint8) 138 | self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8) 139 | self.sel_buffer_bak = np.zeros( 140 | (selection_size, selection_size, 4), dtype=np.uint8 141 | ) 142 | self.sel_dirty = False 143 | self.buffer_dirty = False 144 | self.mouse_pos = [-1, -1] 145 | self.mouse_state = 0 146 | # self.output = widgets.Output() 147 | self.test_mode = test_mode 148 | self.buffer_updated = False 149 | self.image_move_freq = 1 150 | self.show_brush = False 151 | self.scale=1.0 152 | self.eraser_size=32 153 | 154 | def reset_large_buffer(self): 155 | self.canvas[2].canvas.width=self.width 156 | self.canvas[2].canvas.height=self.height 157 | # self.canvas[2].canvas.style.width=f"{self.display_width}px" 158 | # self.canvas[2].canvas.style.height=f"{self.display_height}px" 159 | self.canvas[2].canvas.style.display="block" 160 | self.canvas[2].clear() 161 | 162 | def draw_eraser(self, x, y): 163 | self.canvas[-2].clear() 164 | self.canvas[-2].fill_style = "#ffffff" 165 | self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size) 166 | self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size) 167 | 168 | def use_eraser(self,x,y): 169 | if self.sel_dirty: 170 | self.write_selection_to_buffer() 171 | self.draw_buffer() 172 | self.canvas[2].clear() 173 | self.buffer_dirty=True 174 | bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2 175 | bx1,by1=bx0+self.eraser_size,by0+self.eraser_size 176 | bx0,by0=max(0,bx0),max(0,by0) 177 | bx1,by1=min(self.width,bx1),min(self.height,by1) 178 | self.buffer[by0:by1,bx0:bx1,:]*=0 179 | self.draw_buffer() 180 | self.draw_selection_box() 181 | 182 | def setup_mouse(self): 183 | self.image_move_cnt = 0 184 | 185 | def get_mouse_mode(): 186 | mode = document.querySelector("#mode").value 187 | if mode == PAINT_SELECTION: 188 | return PAINT_MODE 189 | elif mode == IMAGE_SELECTION: 190 | return IMAGE_MODE 191 | return BRUSH_MODE 192 | 193 | def get_event_pos(event): 194 | canvas = self.canvas[-1].canvas 195 | rect = canvas.getBoundingClientRect() 196 | x = (canvas.width * (event.clientX - rect.left)) / rect.width 197 | y = (canvas.height * (event.clientY - rect.top)) / rect.height 198 | return x, y 199 | 200 | def handle_mouse_down(event): 201 | self.mouse_state = get_mouse_mode() 202 | if self.mouse_state==BRUSH_MODE: 203 | x,y=get_event_pos(event) 204 | self.use_eraser(x,y) 205 | 206 | def handle_mouse_out(event): 207 | last_state = self.mouse_state 208 | self.mouse_state = NOP_MODE 209 | self.image_move_cnt = 0 210 | if last_state == IMAGE_MODE: 211 | self.update_view_pos(0, 0) 212 | if True: 213 | self.clear_background() 214 | self.draw_buffer() 215 | self.reset_large_buffer() 216 | self.draw_selection_box() 217 | gc.collect() 218 | if self.show_brush: 219 | self.canvas[-2].clear() 220 | self.show_brush = False 221 | 222 | def handle_mouse_up(event): 223 | last_state = self.mouse_state 224 | self.mouse_state = NOP_MODE 225 | self.image_move_cnt = 0 226 | if last_state == IMAGE_MODE: 227 | self.update_view_pos(0, 0) 228 | if True: 229 | self.clear_background() 230 | self.draw_buffer() 231 | self.reset_large_buffer() 232 | self.draw_selection_box() 233 | gc.collect() 234 | 235 | async def handle_mouse_move(event): 236 | x, y = get_event_pos(event) 237 | x0, y0 = self.mouse_pos 238 | xo = x - x0 239 | yo = y - y0 240 | if self.mouse_state == PAINT_MODE: 241 | self.update_cursor(int(xo), int(yo)) 242 | if True: 243 | # self.clear_background() 244 | # console.log(self.buffer_updated) 245 | if self.buffer_updated: 246 | self.draw_buffer() 247 | self.buffer_updated = False 248 | self.draw_selection_box() 249 | elif self.mouse_state == IMAGE_MODE: 250 | self.image_move_cnt += 1 251 | if self.image_move_cnt == self.image_move_freq: 252 | self.draw_buffer() 253 | self.canvas[2].clear() 254 | self.draw_selection_box() 255 | self.update_view_pos(int(xo), int(yo)) 256 | self.cached_view_pos=tuple(self.view_pos) 257 | self.canvas[2].canvas.style.display="none" 258 | large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size)) 259 | self.canvas[2].canvas.width=large_buffer.shape[1] 260 | self.canvas[2].canvas.height=large_buffer.shape[0] 261 | # self.canvas[2].canvas.style.width="" 262 | # self.canvas[2].canvas.style.height="" 263 | self.canvas[2].put_image_data(large_buffer,0,0) 264 | else: 265 | self.update_view_pos(int(xo), int(yo), False) 266 | self.canvas[1].clear() 267 | self.canvas[1].draw_image(self.canvas[2].canvas, 268 | self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]), 269 | self.width,self.height, 270 | 0,0,self.width,self.height 271 | ) 272 | self.clear_background() 273 | # self.image_move_cnt = 0 274 | elif self.mouse_state == BRUSH_MODE: 275 | self.use_eraser(x,y) 276 | 277 | mode = document.querySelector("#mode").value 278 | if mode == BRUSH_SELECTION: 279 | self.draw_eraser(x,y) 280 | self.show_brush = True 281 | elif self.show_brush: 282 | self.canvas[-2].clear() 283 | self.show_brush = False 284 | self.mouse_pos[0] = x 285 | self.mouse_pos[1] = y 286 | 287 | self.canvas[-1].canvas.addEventListener( 288 | "mousedown", create_proxy(handle_mouse_down) 289 | ) 290 | self.canvas[-1].canvas.addEventListener( 291 | "mousemove", create_proxy(handle_mouse_move) 292 | ) 293 | self.canvas[-1].canvas.addEventListener( 294 | "mouseup", create_proxy(handle_mouse_up) 295 | ) 296 | self.canvas[-1].canvas.addEventListener( 297 | "mouseout", create_proxy(handle_mouse_out) 298 | ) 299 | async def handle_mouse_wheel(event): 300 | x, y = get_event_pos(event) 301 | self.mouse_pos[0] = x 302 | self.mouse_pos[1] = y 303 | console.log(to_js(self.mouse_pos)) 304 | if event.deltaY>10: 305 | window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*") 306 | elif event.deltaY<-10: 307 | window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*") 308 | return False 309 | self.canvas[-1].canvas.addEventListener( 310 | "wheel", create_proxy(handle_mouse_wheel), False 311 | ) 312 | def clear_background(self): 313 | # fake transparent background 314 | h, w, step = self.height, self.width, self.grid_size 315 | stride = step * 2 316 | x0, y0 = self.view_pos 317 | x0 = (-x0) % stride 318 | y0 = (-y0) % stride 319 | if y0>=step: 320 | val0,val1=stride,step 321 | else: 322 | val0,val1=step,stride 323 | # self.canvas.clear() 324 | self.canvas[0].fill_style = "#ffffff" 325 | self.canvas[0].fill_rect(0, 0, w, h) 326 | self.canvas[0].fill_style = "#aaaaaa" 327 | for y in range(y0-stride, h + step, step): 328 | start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1) 329 | for x in range(start, w + step, stride): 330 | self.canvas[0].fill_rect(x, y, step, step) 331 | self.canvas[0].stroke_rect(0, 0, w, h) 332 | 333 | def refine_selection(self): 334 | h,w=self.selection_size_h,self.selection_size_w 335 | h=min(h,self.height) 336 | w=min(w,self.width) 337 | self.selection_size_h=h*8//8 338 | self.selection_size_w=w*8//8 339 | self.update_cursor(1,0) 340 | 341 | 342 | def update_scale(self, scale, mx=-1, my=-1): 343 | self.sync_to_data() 344 | scaled_width=int(self.display_width*scale) 345 | scaled_height=int(self.display_height*scale) 346 | if max(scaled_height,scaled_width)>=self.patch_size*2-128: 347 | return 348 | if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w: 349 | return 350 | if mx>=0 and my>=0: 351 | scaled_mx=mx/self.scale*scale 352 | scaled_my=my/self.scale*scale 353 | self.view_pos[0]+=int(mx-scaled_mx) 354 | self.view_pos[1]+=int(my-scaled_my) 355 | self.scale=scale 356 | for item in self.canvas: 357 | item.canvas.width=scaled_width 358 | item.canvas.height=scaled_height 359 | item.clear() 360 | update_overlay(scaled_width,scaled_height) 361 | self.width=scaled_width 362 | self.height=scaled_height 363 | self.data2buffer() 364 | self.clear_background() 365 | self.draw_buffer() 366 | self.update_cursor(1,0) 367 | self.draw_selection_box() 368 | 369 | def update_view_pos(self, xo, yo, update=True): 370 | # if abs(xo) + abs(yo) == 0: 371 | # return 372 | if self.sel_dirty: 373 | self.write_selection_to_buffer() 374 | if self.buffer_dirty: 375 | self.buffer2data() 376 | self.view_pos[0] -= xo 377 | self.view_pos[1] -= yo 378 | if update: 379 | self.data2buffer() 380 | # self.read_selection_from_buffer() 381 | 382 | def update_cursor(self, xo, yo): 383 | if abs(xo) + abs(yo) == 0: 384 | return 385 | if self.sel_dirty: 386 | self.write_selection_to_buffer() 387 | self.cursor[0] += xo 388 | self.cursor[1] += yo 389 | self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0) 390 | self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0) 391 | # self.read_selection_from_buffer() 392 | 393 | def data2buffer(self): 394 | x, y = self.view_pos 395 | h, w = self.height, self.width 396 | if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]: 397 | self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8) 398 | # fill four parts 399 | for i in range(4): 400 | pos_src, pos_dst, data = self.select(x, y, i) 401 | xs0, xs1 = pos_src[0] 402 | ys0, ys1 = pos_src[1] 403 | xd0, xd1 = pos_dst[0] 404 | yd0, yd1 = pos_dst[1] 405 | self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :] 406 | 407 | def data2array(self, x, y, w, h): 408 | # x, y = self.view_pos 409 | # h, w = self.height, self.width 410 | ret=np.zeros((h, w, 4), dtype=np.uint8) 411 | # fill four parts 412 | for i in range(4): 413 | pos_src, pos_dst, data = self.select(x, y, i, w, h) 414 | xs0, xs1 = pos_src[0] 415 | ys0, ys1 = pos_src[1] 416 | xd0, xd1 = pos_dst[0] 417 | yd0, yd1 = pos_dst[1] 418 | ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :] 419 | return ret 420 | 421 | def buffer2data(self): 422 | x, y = self.view_pos 423 | h, w = self.height, self.width 424 | # fill four parts 425 | for i in range(4): 426 | pos_src, pos_dst, data = self.select(x, y, i) 427 | xs0, xs1 = pos_src[0] 428 | ys0, ys1 = pos_src[1] 429 | xd0, xd1 = pos_dst[0] 430 | yd0, yd1 = pos_dst[1] 431 | data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :] 432 | self.buffer_dirty = False 433 | 434 | def select(self, x, y, idx, width=0, height=0): 435 | if width==0: 436 | w, h = self.width, self.height 437 | else: 438 | w, h = width, height 439 | lst = [(0, 0), (0, h), (w, 0), (w, h)] 440 | if idx == 0: 441 | x0, y0 = x % self.patch_size, y % self.patch_size 442 | x1 = min(x0 + w, self.patch_size) 443 | y1 = min(y0 + h, self.patch_size) 444 | elif idx == 1: 445 | y += h 446 | x0, y0 = x % self.patch_size, y % self.patch_size 447 | x1 = min(x0 + w, self.patch_size) 448 | y1 = max(y0 - h, 0) 449 | elif idx == 2: 450 | x += w 451 | x0, y0 = x % self.patch_size, y % self.patch_size 452 | x1 = max(x0 - w, 0) 453 | y1 = min(y0 + h, self.patch_size) 454 | else: 455 | x += w 456 | y += h 457 | x0, y0 = x % self.patch_size, y % self.patch_size 458 | x1 = max(x0 - w, 0) 459 | y1 = max(y0 - h, 0) 460 | xi, yi = x // self.patch_size, y // self.patch_size 461 | cur = self.data.setdefault( 462 | (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8) 463 | ) 464 | x0_img, y0_img = lst[idx] 465 | x1_img = x0_img + x1 - x0 466 | y1_img = y0_img + y1 - y0 467 | sort = lambda a, b: ((a, b) if a < b else (b, a)) 468 | return ( 469 | (sort(x0, x1), sort(y0, y1)), 470 | (sort(x0_img, x1_img), sort(y0_img, y1_img)), 471 | cur, 472 | ) 473 | 474 | def draw_buffer(self): 475 | self.canvas[1].clear() 476 | self.canvas[1].put_image_data(self.buffer, 0, 0) 477 | 478 | def fill_selection(self, img): 479 | self.sel_buffer = img 480 | self.sel_dirty = True 481 | 482 | def draw_selection_box(self): 483 | x0, y0 = self.cursor 484 | w, h = self.selection_size_w, self.selection_size_h 485 | if self.sel_dirty: 486 | self.canvas[2].clear() 487 | self.canvas[2].put_image_data(self.sel_buffer, x0, y0) 488 | self.canvas[-1].clear() 489 | self.canvas[-1].stroke_style = "#0a0a0a" 490 | self.canvas[-1].stroke_rect(x0, y0, w, h) 491 | self.canvas[-1].stroke_style = "#ffffff" 492 | offset=round(self.scale) if self.scale>1.0 else 1 493 | self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2) 494 | self.canvas[-1].stroke_style = "#000000" 495 | self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4) 496 | 497 | def write_selection_to_buffer(self): 498 | x0, y0 = self.cursor 499 | x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h 500 | self.buffer[y0:y1, x0:x1] = self.sel_buffer 501 | self.sel_dirty = False 502 | self.sel_buffer = np.zeros( 503 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 504 | ) 505 | self.buffer_dirty = True 506 | self.buffer_updated = True 507 | # self.canvas[2].clear() 508 | 509 | def read_selection_from_buffer(self): 510 | x0, y0 = self.cursor 511 | x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h 512 | self.sel_buffer = self.buffer[y0:y1, x0:x1] 513 | self.sel_dirty = False 514 | 515 | def base64_to_numpy(self, base64_str): 516 | try: 517 | data = base64.b64decode(str(base64_str)) 518 | pil = Image.open(io.BytesIO(data)) 519 | arr = np.array(pil) 520 | ret = arr 521 | except: 522 | ret = np.tile( 523 | np.array([255, 0, 0, 255], dtype=np.uint8), 524 | (self.selection_size_h, self.selection_size_w, 1), 525 | ) 526 | return ret 527 | 528 | def numpy_to_base64(self, arr): 529 | out_pil = Image.fromarray(arr) 530 | out_buffer = io.BytesIO() 531 | out_pil.save(out_buffer, format="PNG") 532 | out_buffer.seek(0) 533 | base64_bytes = base64.b64encode(out_buffer.read()) 534 | base64_str = base64_bytes.decode("ascii") 535 | return base64_str 536 | 537 | def sync_to_data(self): 538 | if self.sel_dirty: 539 | self.write_selection_to_buffer() 540 | self.canvas[2].clear() 541 | self.draw_buffer() 542 | if self.buffer_dirty: 543 | self.buffer2data() 544 | 545 | def sync_to_buffer(self): 546 | if self.sel_dirty: 547 | self.canvas[2].clear() 548 | self.write_selection_to_buffer() 549 | self.draw_buffer() 550 | 551 | def resize(self,width,height,scale=None,**kwargs): 552 | self.display_width=width 553 | self.display_height=height 554 | for canvas in self.canvas: 555 | prepare_canvas(width=width,height=height,canvas=canvas.canvas) 556 | setup_overlay(width,height) 557 | if scale is None: 558 | scale=1 559 | self.update_scale(scale) 560 | 561 | 562 | def save(self): 563 | self.sync_to_data() 564 | state={} 565 | state["width"]=self.display_width 566 | state["height"]=self.display_height 567 | state["selection_width"]=self.selection_size_w 568 | state["selection_height"]=self.selection_size_h 569 | state["view_pos"]=self.view_pos[:] 570 | state["cursor"]=self.cursor[:] 571 | state["scale"]=self.scale 572 | keys=list(self.data.keys()) 573 | data={} 574 | for key in keys: 575 | if self.data[key].sum()>0: 576 | data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key]) 577 | state["data"]=data 578 | return json.dumps(state) 579 | 580 | def load(self, state_json): 581 | self.reset() 582 | state=json.loads(state_json) 583 | self.display_width=state["width"] 584 | self.display_height=state["height"] 585 | self.selection_size_w=state["selection_width"] 586 | self.selection_size_h=state["selection_height"] 587 | self.view_pos=state["view_pos"][:] 588 | self.cursor=state["cursor"][:] 589 | self.scale=state["scale"] 590 | self.resize(state["width"],state["height"],scale=state["scale"]) 591 | for k,v in state["data"].items(): 592 | key=tuple(map(int,k.split(","))) 593 | self.data[key]=self.base64_to_numpy(v) 594 | self.data2buffer() 595 | self.display() 596 | 597 | def display(self): 598 | self.clear_background() 599 | self.draw_buffer() 600 | self.draw_selection_box() 601 | 602 | def reset(self): 603 | self.data.clear() 604 | self.buffer*=0 605 | self.buffer_dirty=False 606 | self.buffer_updated=False 607 | self.sel_buffer*=0 608 | self.sel_dirty=False 609 | self.view_pos = [0, 0] 610 | self.clear_background() 611 | for i in range(1,len(self.canvas)-1): 612 | self.canvas[i].clear() 613 | 614 | def export(self): 615 | self.sync_to_data() 616 | xmin, xmax, ymin, ymax = 0, 0, 0, 0 617 | if len(self.data.keys()) == 0: 618 | return np.zeros( 619 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 620 | ) 621 | for xi, yi in self.data.keys(): 622 | buf = self.data[(xi, yi)] 623 | if buf.sum() > 0: 624 | xmin = min(xi, xmin) 625 | xmax = max(xi, xmax) 626 | ymin = min(yi, ymin) 627 | ymax = max(yi, ymax) 628 | yn = ymax - ymin + 1 629 | xn = xmax - xmin + 1 630 | image = np.zeros( 631 | (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8 632 | ) 633 | for xi, yi in self.data.keys(): 634 | buf = self.data[(xi, yi)] 635 | if buf.sum() > 0: 636 | y0 = (yi - ymin) * self.patch_size 637 | x0 = (xi - xmin) * self.patch_size 638 | image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf 639 | ylst, xlst = image[:, :, -1].nonzero() 640 | if len(ylst) > 0: 641 | yt, xt = ylst.min(), xlst.min() 642 | yb, xb = ylst.max(), xlst.max() 643 | image = image[yt : yb + 1, xt : xb + 1] 644 | return image 645 | else: 646 | return np.zeros( 647 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8 648 | ) 649 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | shortcut: 2 | clear: Escape 3 | load: Ctrl+o 4 | save: Ctrl+s 5 | export: Ctrl+e 6 | upload: Ctrl+u 7 | selection: 1 8 | canvas: 2 9 | eraser: 3 10 | outpaint: d 11 | accept: a 12 | cancel: c 13 | retry: r 14 | prev: q 15 | next: e 16 | zoom_in: z 17 | zoom_out: x 18 | random_seed: s -------------------------------------------------------------------------------- /convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py 16 | """ Conversion script for the LDM checkpoints. """ 17 | 18 | import argparse 19 | import os 20 | 21 | import torch 22 | 23 | 24 | try: 25 | from omegaconf import OmegaConf 26 | except ImportError: 27 | raise ImportError( 28 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." 29 | ) 30 | 31 | from diffusers import ( 32 | AutoencoderKL, 33 | DDIMScheduler, 34 | LDMTextToImagePipeline, 35 | LMSDiscreteScheduler, 36 | PNDMScheduler, 37 | StableDiffusionPipeline, 38 | UNet2DConditionModel, 39 | ) 40 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel 41 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 42 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer 43 | 44 | 45 | def shave_segments(path, n_shave_prefix_segments=1): 46 | """ 47 | Removes segments. Positive values shave the first segments, negative shave the last segments. 48 | """ 49 | if n_shave_prefix_segments >= 0: 50 | return ".".join(path.split(".")[n_shave_prefix_segments:]) 51 | else: 52 | return ".".join(path.split(".")[:n_shave_prefix_segments]) 53 | 54 | 55 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0): 56 | """ 57 | Updates paths inside resnets to the new naming scheme (local renaming) 58 | """ 59 | mapping = [] 60 | for old_item in old_list: 61 | new_item = old_item.replace("in_layers.0", "norm1") 62 | new_item = new_item.replace("in_layers.2", "conv1") 63 | 64 | new_item = new_item.replace("out_layers.0", "norm2") 65 | new_item = new_item.replace("out_layers.3", "conv2") 66 | 67 | new_item = new_item.replace("emb_layers.1", "time_emb_proj") 68 | new_item = new_item.replace("skip_connection", "conv_shortcut") 69 | 70 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 71 | 72 | mapping.append({"old": old_item, "new": new_item}) 73 | 74 | return mapping 75 | 76 | 77 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): 78 | """ 79 | Updates paths inside resnets to the new naming scheme (local renaming) 80 | """ 81 | mapping = [] 82 | for old_item in old_list: 83 | new_item = old_item 84 | 85 | new_item = new_item.replace("nin_shortcut", "conv_shortcut") 86 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 87 | 88 | mapping.append({"old": old_item, "new": new_item}) 89 | 90 | return mapping 91 | 92 | 93 | def renew_attention_paths(old_list, n_shave_prefix_segments=0): 94 | """ 95 | Updates paths inside attentions to the new naming scheme (local renaming) 96 | """ 97 | mapping = [] 98 | for old_item in old_list: 99 | new_item = old_item 100 | 101 | # new_item = new_item.replace('norm.weight', 'group_norm.weight') 102 | # new_item = new_item.replace('norm.bias', 'group_norm.bias') 103 | 104 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') 105 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') 106 | 107 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 108 | 109 | mapping.append({"old": old_item, "new": new_item}) 110 | 111 | return mapping 112 | 113 | 114 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): 115 | """ 116 | Updates paths inside attentions to the new naming scheme (local renaming) 117 | """ 118 | mapping = [] 119 | for old_item in old_list: 120 | new_item = old_item 121 | 122 | new_item = new_item.replace("norm.weight", "group_norm.weight") 123 | new_item = new_item.replace("norm.bias", "group_norm.bias") 124 | 125 | new_item = new_item.replace("q.weight", "query.weight") 126 | new_item = new_item.replace("q.bias", "query.bias") 127 | 128 | new_item = new_item.replace("k.weight", "key.weight") 129 | new_item = new_item.replace("k.bias", "key.bias") 130 | 131 | new_item = new_item.replace("v.weight", "value.weight") 132 | new_item = new_item.replace("v.bias", "value.bias") 133 | 134 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight") 135 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias") 136 | 137 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) 138 | 139 | mapping.append({"old": old_item, "new": new_item}) 140 | 141 | return mapping 142 | 143 | 144 | def assign_to_checkpoint( 145 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None 146 | ): 147 | """ 148 | This does the final conversion step: take locally converted weights and apply a global renaming 149 | to them. It splits attention layers, and takes into account additional replacements 150 | that may arise. 151 | 152 | Assigns the weights to the new checkpoint. 153 | """ 154 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." 155 | 156 | # Splits the attention layers into three variables. 157 | if attention_paths_to_split is not None: 158 | for path, path_map in attention_paths_to_split.items(): 159 | old_tensor = old_checkpoint[path] 160 | channels = old_tensor.shape[0] // 3 161 | 162 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) 163 | 164 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 165 | 166 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) 167 | query, key, value = old_tensor.split(channels // num_heads, dim=1) 168 | 169 | checkpoint[path_map["query"]] = query.reshape(target_shape) 170 | checkpoint[path_map["key"]] = key.reshape(target_shape) 171 | checkpoint[path_map["value"]] = value.reshape(target_shape) 172 | 173 | for path in paths: 174 | new_path = path["new"] 175 | 176 | # These have already been assigned 177 | if attention_paths_to_split is not None and new_path in attention_paths_to_split: 178 | continue 179 | 180 | # Global renaming happens here 181 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") 182 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") 183 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") 184 | 185 | if additional_replacements is not None: 186 | for replacement in additional_replacements: 187 | new_path = new_path.replace(replacement["old"], replacement["new"]) 188 | 189 | # proj_attn.weight has to be converted from conv 1D to linear 190 | if "proj_attn.weight" in new_path: 191 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] 192 | else: 193 | checkpoint[new_path] = old_checkpoint[path["old"]] 194 | 195 | 196 | def conv_attn_to_linear(checkpoint): 197 | keys = list(checkpoint.keys()) 198 | attn_keys = ["query.weight", "key.weight", "value.weight"] 199 | for key in keys: 200 | if ".".join(key.split(".")[-2:]) in attn_keys: 201 | if checkpoint[key].ndim > 2: 202 | checkpoint[key] = checkpoint[key][:, :, 0, 0] 203 | elif "proj_attn.weight" in key: 204 | if checkpoint[key].ndim > 2: 205 | checkpoint[key] = checkpoint[key][:, :, 0] 206 | 207 | 208 | def create_unet_diffusers_config(original_config): 209 | """ 210 | Creates a config for the diffusers based on the config of the LDM model. 211 | """ 212 | unet_params = original_config.model.params.unet_config.params 213 | 214 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] 215 | 216 | down_block_types = [] 217 | resolution = 1 218 | for i in range(len(block_out_channels)): 219 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" 220 | down_block_types.append(block_type) 221 | if i != len(block_out_channels) - 1: 222 | resolution *= 2 223 | 224 | up_block_types = [] 225 | for i in range(len(block_out_channels)): 226 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" 227 | up_block_types.append(block_type) 228 | resolution //= 2 229 | 230 | config = dict( 231 | sample_size=unet_params.image_size, 232 | in_channels=unet_params.in_channels, 233 | out_channels=unet_params.out_channels, 234 | down_block_types=tuple(down_block_types), 235 | up_block_types=tuple(up_block_types), 236 | block_out_channels=tuple(block_out_channels), 237 | layers_per_block=unet_params.num_res_blocks, 238 | cross_attention_dim=unet_params.context_dim, 239 | attention_head_dim=unet_params.num_heads, 240 | ) 241 | 242 | return config 243 | 244 | 245 | def create_vae_diffusers_config(original_config): 246 | """ 247 | Creates a config for the diffusers based on the config of the LDM model. 248 | """ 249 | vae_params = original_config.model.params.first_stage_config.params.ddconfig 250 | _ = original_config.model.params.first_stage_config.params.embed_dim 251 | 252 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] 253 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) 254 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) 255 | 256 | config = dict( 257 | sample_size=vae_params.resolution, 258 | in_channels=vae_params.in_channels, 259 | out_channels=vae_params.out_ch, 260 | down_block_types=tuple(down_block_types), 261 | up_block_types=tuple(up_block_types), 262 | block_out_channels=tuple(block_out_channels), 263 | latent_channels=vae_params.z_channels, 264 | layers_per_block=vae_params.num_res_blocks, 265 | ) 266 | return config 267 | 268 | 269 | def create_diffusers_schedular(original_config): 270 | schedular = DDIMScheduler( 271 | num_train_timesteps=original_config.model.params.timesteps, 272 | beta_start=original_config.model.params.linear_start, 273 | beta_end=original_config.model.params.linear_end, 274 | beta_schedule="scaled_linear", 275 | ) 276 | return schedular 277 | 278 | 279 | def create_ldm_bert_config(original_config): 280 | bert_params = original_config.model.parms.cond_stage_config.params 281 | config = LDMBertConfig( 282 | d_model=bert_params.n_embed, 283 | encoder_layers=bert_params.n_layer, 284 | encoder_ffn_dim=bert_params.n_embed * 4, 285 | ) 286 | return config 287 | 288 | 289 | def convert_ldm_unet_checkpoint(checkpoint, config): 290 | """ 291 | Takes a state dict and a config, and returns a converted checkpoint. 292 | """ 293 | 294 | # extract state_dict for UNet 295 | unet_state_dict = {} 296 | unet_key = "model.diffusion_model." 297 | keys = list(checkpoint.keys()) 298 | for key in keys: 299 | if key.startswith(unet_key): 300 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) 301 | 302 | new_checkpoint = {} 303 | 304 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] 305 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] 306 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] 307 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] 308 | 309 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] 310 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] 311 | 312 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] 313 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] 314 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] 315 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] 316 | 317 | # Retrieves the keys for the input blocks only 318 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) 319 | input_blocks = { 320 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] 321 | for layer_id in range(num_input_blocks) 322 | } 323 | 324 | # Retrieves the keys for the middle blocks only 325 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) 326 | middle_blocks = { 327 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] 328 | for layer_id in range(num_middle_blocks) 329 | } 330 | 331 | # Retrieves the keys for the output blocks only 332 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) 333 | output_blocks = { 334 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] 335 | for layer_id in range(num_output_blocks) 336 | } 337 | 338 | for i in range(1, num_input_blocks): 339 | block_id = (i - 1) // (config["layers_per_block"] + 1) 340 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) 341 | 342 | resnets = [ 343 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key 344 | ] 345 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] 346 | 347 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict: 348 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( 349 | f"input_blocks.{i}.0.op.weight" 350 | ) 351 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( 352 | f"input_blocks.{i}.0.op.bias" 353 | ) 354 | 355 | paths = renew_resnet_paths(resnets) 356 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} 357 | assign_to_checkpoint( 358 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 359 | ) 360 | 361 | if len(attentions): 362 | paths = renew_attention_paths(attentions) 363 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} 364 | assign_to_checkpoint( 365 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 366 | ) 367 | 368 | resnet_0 = middle_blocks[0] 369 | attentions = middle_blocks[1] 370 | resnet_1 = middle_blocks[2] 371 | 372 | resnet_0_paths = renew_resnet_paths(resnet_0) 373 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) 374 | 375 | resnet_1_paths = renew_resnet_paths(resnet_1) 376 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) 377 | 378 | attentions_paths = renew_attention_paths(attentions) 379 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} 380 | assign_to_checkpoint( 381 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 382 | ) 383 | 384 | for i in range(num_output_blocks): 385 | block_id = i // (config["layers_per_block"] + 1) 386 | layer_in_block_id = i % (config["layers_per_block"] + 1) 387 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] 388 | output_block_list = {} 389 | 390 | for layer in output_block_layers: 391 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) 392 | if layer_id in output_block_list: 393 | output_block_list[layer_id].append(layer_name) 394 | else: 395 | output_block_list[layer_id] = [layer_name] 396 | 397 | if len(output_block_list) > 1: 398 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] 399 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] 400 | 401 | resnet_0_paths = renew_resnet_paths(resnets) 402 | paths = renew_resnet_paths(resnets) 403 | 404 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} 405 | assign_to_checkpoint( 406 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 407 | ) 408 | 409 | if ["conv.weight", "conv.bias"] in output_block_list.values(): 410 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) 411 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ 412 | f"output_blocks.{i}.{index}.conv.weight" 413 | ] 414 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ 415 | f"output_blocks.{i}.{index}.conv.bias" 416 | ] 417 | 418 | # Clear attentions as they have been attributed above. 419 | if len(attentions) == 2: 420 | attentions = [] 421 | 422 | if len(attentions): 423 | paths = renew_attention_paths(attentions) 424 | meta_path = { 425 | "old": f"output_blocks.{i}.1", 426 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", 427 | } 428 | assign_to_checkpoint( 429 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config 430 | ) 431 | else: 432 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) 433 | for path in resnet_0_paths: 434 | old_path = ".".join(["output_blocks", str(i), path["old"]]) 435 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) 436 | 437 | new_checkpoint[new_path] = unet_state_dict[old_path] 438 | 439 | return new_checkpoint 440 | 441 | 442 | def convert_ldm_vae_checkpoint(checkpoint, config): 443 | # extract state dict for VAE 444 | vae_state_dict = {} 445 | vae_key = "first_stage_model." 446 | keys = list(checkpoint.keys()) 447 | for key in keys: 448 | if key.startswith(vae_key): 449 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) 450 | 451 | new_checkpoint = {} 452 | 453 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] 454 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] 455 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] 456 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] 457 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] 458 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] 459 | 460 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] 461 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] 462 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] 463 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] 464 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] 465 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] 466 | 467 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] 468 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] 469 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] 470 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] 471 | 472 | # Retrieves the keys for the encoder down blocks only 473 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) 474 | down_blocks = { 475 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) 476 | } 477 | 478 | # Retrieves the keys for the decoder up blocks only 479 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) 480 | up_blocks = { 481 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) 482 | } 483 | 484 | for i in range(num_down_blocks): 485 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] 486 | 487 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: 488 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( 489 | f"encoder.down.{i}.downsample.conv.weight" 490 | ) 491 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( 492 | f"encoder.down.{i}.downsample.conv.bias" 493 | ) 494 | 495 | paths = renew_vae_resnet_paths(resnets) 496 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} 497 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 498 | 499 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] 500 | num_mid_res_blocks = 2 501 | for i in range(1, num_mid_res_blocks + 1): 502 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] 503 | 504 | paths = renew_vae_resnet_paths(resnets) 505 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 506 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 507 | 508 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] 509 | paths = renew_vae_attention_paths(mid_attentions) 510 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 511 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 512 | conv_attn_to_linear(new_checkpoint) 513 | 514 | for i in range(num_up_blocks): 515 | block_id = num_up_blocks - 1 - i 516 | resnets = [ 517 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key 518 | ] 519 | 520 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: 521 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ 522 | f"decoder.up.{block_id}.upsample.conv.weight" 523 | ] 524 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ 525 | f"decoder.up.{block_id}.upsample.conv.bias" 526 | ] 527 | 528 | paths = renew_vae_resnet_paths(resnets) 529 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} 530 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 531 | 532 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] 533 | num_mid_res_blocks = 2 534 | for i in range(1, num_mid_res_blocks + 1): 535 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] 536 | 537 | paths = renew_vae_resnet_paths(resnets) 538 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} 539 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 540 | 541 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] 542 | paths = renew_vae_attention_paths(mid_attentions) 543 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} 544 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) 545 | conv_attn_to_linear(new_checkpoint) 546 | return new_checkpoint 547 | 548 | 549 | def convert_ldm_bert_checkpoint(checkpoint, config): 550 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer): 551 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight 552 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight 553 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight 554 | 555 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight 556 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias 557 | 558 | def _copy_linear(hf_linear, pt_linear): 559 | hf_linear.weight = pt_linear.weight 560 | hf_linear.bias = pt_linear.bias 561 | 562 | def _copy_layer(hf_layer, pt_layer): 563 | # copy layer norms 564 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) 565 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) 566 | 567 | # copy attn 568 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) 569 | 570 | # copy MLP 571 | pt_mlp = pt_layer[1][1] 572 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) 573 | _copy_linear(hf_layer.fc2, pt_mlp.net[2]) 574 | 575 | def _copy_layers(hf_layers, pt_layers): 576 | for i, hf_layer in enumerate(hf_layers): 577 | if i != 0: 578 | i += i 579 | pt_layer = pt_layers[i : i + 2] 580 | _copy_layer(hf_layer, pt_layer) 581 | 582 | hf_model = LDMBertModel(config).eval() 583 | 584 | # copy embeds 585 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight 586 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight 587 | 588 | # copy layer norm 589 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) 590 | 591 | # copy hidden layers 592 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) 593 | 594 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) 595 | 596 | return hf_model 597 | 598 | 599 | def convert_ldm_clip_checkpoint(checkpoint): 600 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") 601 | 602 | keys = list(checkpoint.keys()) 603 | 604 | text_model_dict = {} 605 | 606 | for key in keys: 607 | if key.startswith("cond_stage_model.transformer"): 608 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] 609 | 610 | text_model.load_state_dict(text_model_dict) 611 | 612 | return text_model 613 | 614 | import os 615 | def convert_checkpoint(checkpoint_path, inpainting=False): 616 | parser = argparse.ArgumentParser() 617 | 618 | parser.add_argument( 619 | "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert." 620 | ) 621 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 622 | parser.add_argument( 623 | "--original_config_file", 624 | default=None, 625 | type=str, 626 | help="The YAML config file corresponding to the original architecture.", 627 | ) 628 | parser.add_argument( 629 | "--scheduler_type", 630 | default="pndm", 631 | type=str, 632 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", 633 | ) 634 | parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.") 635 | 636 | args = parser.parse_args([]) 637 | if args.original_config_file is None: 638 | if inpainting: 639 | args.original_config_file = "./models/v1-inpainting-inference.yaml" 640 | else: 641 | args.original_config_file = "./models/v1-inference.yaml" 642 | 643 | original_config = OmegaConf.load(args.original_config_file) 644 | checkpoint = torch.load(args.checkpoint_path)["state_dict"] 645 | 646 | num_train_timesteps = original_config.model.params.timesteps 647 | beta_start = original_config.model.params.linear_start 648 | beta_end = original_config.model.params.linear_end 649 | if args.scheduler_type == "pndm": 650 | scheduler = PNDMScheduler( 651 | beta_end=beta_end, 652 | beta_schedule="scaled_linear", 653 | beta_start=beta_start, 654 | num_train_timesteps=num_train_timesteps, 655 | skip_prk_steps=True, 656 | ) 657 | elif args.scheduler_type == "lms": 658 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") 659 | elif args.scheduler_type == "ddim": 660 | scheduler = DDIMScheduler( 661 | beta_start=beta_start, 662 | beta_end=beta_end, 663 | beta_schedule="scaled_linear", 664 | clip_sample=False, 665 | set_alpha_to_one=False, 666 | ) 667 | else: 668 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") 669 | 670 | # Convert the UNet2DConditionModel model. 671 | unet_config = create_unet_diffusers_config(original_config) 672 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) 673 | 674 | unet = UNet2DConditionModel(**unet_config) 675 | unet.load_state_dict(converted_unet_checkpoint) 676 | 677 | # Convert the VAE model. 678 | vae_config = create_vae_diffusers_config(original_config) 679 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) 680 | 681 | vae = AutoencoderKL(**vae_config) 682 | vae.load_state_dict(converted_vae_checkpoint) 683 | 684 | # Convert the text model. 685 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] 686 | if text_model_type == "FrozenCLIPEmbedder": 687 | text_model = convert_ldm_clip_checkpoint(checkpoint) 688 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 689 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") 690 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") 691 | pipe = StableDiffusionPipeline( 692 | vae=vae, 693 | text_encoder=text_model, 694 | tokenizer=tokenizer, 695 | unet=unet, 696 | scheduler=scheduler, 697 | safety_checker=safety_checker, 698 | feature_extractor=feature_extractor, 699 | ) 700 | else: 701 | text_config = create_ldm_bert_config(original_config) 702 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) 703 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 704 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) 705 | 706 | return pipe 707 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | sd-infinity: 3 | build: 4 | context: . 5 | dockerfile: ./docker/Dockerfile 6 | #shm_size: '2gb' # Enable if more shared memory is needed 7 | ports: 8 | - "8888:8888" 9 | volumes: 10 | - user_home:/home/user 11 | - cond_env:/opt/conda/envs 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | device_ids: ['0'] 18 | capabilities: [gpu] 19 | 20 | volumes: 21 | user_home: {} 22 | cond_env: {} -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:4.12.0 2 | 3 | RUN apt-get update && \ 4 | apt install -y \ 5 | fonts-dejavu-core \ 6 | build-essential \ 7 | libopencv-dev \ 8 | cmake \ 9 | vim \ 10 | && apt-get clean 11 | 12 | COPY docker/opencv.pc /usr/lib/pkgconfig/opencv.pc 13 | 14 | RUN useradd -ms /bin/bash user 15 | USER user 16 | 17 | RUN mkdir ~/.huggingface && conda init bash 18 | 19 | COPY --chown=user:user . /app 20 | WORKDIR /app 21 | 22 | EXPOSE 8888 23 | CMD ["/app/docker/entrypoint.sh"] -------------------------------------------------------------------------------- /docker/docker-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd .. 4 | echo Current dir: "$(pwd)" 5 | 6 | if ! docker version | grep 'linux/amd64' ; then 7 | echo "Could not find docker." 8 | exit 1 9 | fi 10 | 11 | if ! docker-compose version | grep v2 ; then 12 | echo "docker-compose v2.x is not installed" 13 | exit 1 14 | fi 15 | 16 | 17 | if ! docker run -it --gpus=all --rm nvidia/cuda:11.4.2-base-ubuntu20.04 nvidia-smi | grep -e 'NVIDIA.*On' ; then 18 | echo "Docker could not find your NVIDIA gpu" 19 | exit 1 20 | fi 21 | 22 | if ! docker compose build ; then 23 | echo "Error while building" 24 | exit 1 25 | fi 26 | docker compose up -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd /app 4 | 5 | set -euxo pipefail 6 | 7 | set -x 8 | 9 | if ! conda env list | grep sd-inf ; then 10 | echo "Creating environment, it may appear to freeze for a few minutes..." 11 | conda env create -f environment.yml 12 | echo "Finished installing." 13 | echo "conda activate sd-inf" >> ~/.bashrc 14 | shasum environment.yml > ~/.environment.sha 15 | fi 16 | 17 | . "/opt/conda/etc/profile.d/conda.sh" 18 | conda activate sd-inf 19 | 20 | if shasum -c ~/.environment.sha > /dev/null 2>&1 ; then 21 | echo "environment.yml is unchanged." 22 | else 23 | echo "environment.yml was changed, please wait a minute until it says 'Done updating'..." 24 | conda env update --file environment.yml 25 | shasum environment.yml > ~/.environment.sha 26 | echo "Done updating." 27 | fi 28 | 29 | python app.py --port=8888 --host=0.0.0.0 30 | -------------------------------------------------------------------------------- /docker/opencv.pc: -------------------------------------------------------------------------------- 1 | prefix=/usr 2 | exec_prefix=${prefix} 3 | includedir=${prefix}/include 4 | libdir=${exec_prefix}/lib 5 | 6 | Name: opencv 7 | Description: The opencv library 8 | Version: 2.x.x 9 | Cflags: -I${includedir}/opencv4 10 | #Cflags: -I${includedir}/opencv -I${includedir}/opencv2 11 | Libs: -L${libdir} -lopencv_calib3d -lopencv_imgproc -lopencv_xobjdetect -lopencv_hdf -lopencv_flann -lopencv_core -lopencv_dpm -lopencv_videoio -lopencv_reg -lopencv_quality -lopencv_tracking -lopencv_dnn_superres -lopencv_objdetect -lopencv_stitching -lopencv_saliency -lopencv_intensity_transform -lopencv_rapid -lopencv_dnn -lopencv_features2d -lopencv_text -lopencv_calib3d -lopencv_line_descriptor -lopencv_superres -lopencv_ml -lopencv_alphamat -lopencv_viz -lopencv_optflow -lopencv_videostab -lopencv_bioinspired -lopencv_highgui -lopencv_img_hash -lopencv_freetype -lopencv_imgcodecs -lopencv_mcc -lopencv_video -lopencv_photo -lopencv_surface_matching -lopencv_rgbd -lopencv_datasets -lopencv_ximgproc -lopencv_plot -lopencv_face -lopencv_stereo -lopencv_aruco -lopencv_dnn_objdetect -lopencv_phase_unwrapping -lopencv_bgsegm -lopencv_ccalib -lopencv_hfs -lopencv_imgproc -lopencv_shape -lopencv_xphoto -lopencv_structured_light -lopencv_fuzzy -------------------------------------------------------------------------------- /docker/run-shell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$(dirname $0)" 4 | 5 | docker-compose run -p 8888:8888 --rm -u root sd-infinity bash 6 | -------------------------------------------------------------------------------- /docs/run_with_docker.md: -------------------------------------------------------------------------------- 1 | 2 | # Running with Docker on Windows or Linux with NVIDIA GPU 3 | On Windows 10 or 11 you can follow this guide to setting up Docker with WSL2 https://www.youtube.com/watch?v=PB7zM3JrgkI 4 | 5 | Native Linux 6 | 7 | ``` 8 | cd stablediffusion-infinity/docker 9 | ./docker-run.sh 10 | ``` 11 | 12 | Windows 10,11 with WSL2 shell: 13 | - open windows Command Prompt, type "bash" 14 | - once in bash, type: 15 | ``` 16 | cd /mnt/c/PATH-TO-YOUR/stablediffusion-infinity/docker 17 | ./docker-run.sh 18 | ``` 19 | 20 | Open "http://localhost:8888" in your browser ( even though the log says http://0.0.0.0:8888 ) -------------------------------------------------------------------------------- /docs/setup_guide.md: -------------------------------------------------------------------------------- 1 | # Setup Guide 2 | 3 | Please install conda at first ([miniconda](https://docs.conda.io/en/latest/miniconda.html) or [anaconda](https://docs.anaconda.com/anaconda/install/)). 4 | 5 | - [Setup with Linux/Nvidia GPU](#linux) 6 | - [Setup with Linux/AMD GPU](#linux-amd) 7 | - [Setup with Windows](#windows-nvidia) 8 | - [Setup with MacOS](#macos) 9 | - [Upgrade from previous version](#upgrade) 10 | 11 | ## Setup with Linux/Nvidia GPU 12 | 13 | ### conda env 14 | setup with `environment.yml` 15 | ``` 16 | git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity 17 | cd stablediffusion-infinity 18 | conda env create -f environment.yml 19 | ``` 20 | 21 | if the `environment.yml` doesn't work for you, you may install dependencies manually: 22 | ``` 23 | conda create -n sd-inf python=3.10 24 | conda activate sd-inf 25 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 26 | conda install scipy scikit-image 27 | conda install -c conda-forge diffusers transformers ftfy accelerate 28 | pip install opencv-python 29 | pip install -U gradio 30 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 31 | pip install timm 32 | ``` 33 | 34 | After setup the environment, you can run stablediffusion-infinity with following commands: 35 | ``` 36 | conda activate sd-inf 37 | python app.py 38 | ``` 39 | 40 | ## Setup with Linux/AMD GPU (untested) 41 | 42 | ``` 43 | conda create -n sd-inf python=3.10 44 | conda activate sd-inf 45 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 46 | conda install scipy scikit-image 47 | conda install -c conda-forge diffusers transformers ftfy accelerate 48 | pip install opencv-python 49 | pip install -U gradio 50 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 51 | pip install timm 52 | ``` 53 | 54 | 55 | ### CPP library (optional) 56 | 57 | Note that `opencv` library (e.g. `libopencv-dev`/`opencv-devel`, the package name may differ on different distributions) is required for `PyPatchMatch`. You may need to install `opencv` by yourself. If no `opencv` installed, the `patch_match` option (usually better quality) won't work. 58 | 59 | ## Setup with Windows 60 | 61 | 62 | ``` 63 | conda create -n sd-inf python=3.10 64 | conda activate sd-inf 65 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 66 | conda install scipy scikit-image 67 | conda install -c conda-forge diffusers transformers ftfy accelerate 68 | pip install opencv-python 69 | pip install -U gradio 70 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 71 | pip install timm 72 | ``` 73 | 74 | If you use AMD GPUs, you need to install the ONNX runtime `pip install onnxruntime-directml` (only works with the `stablediffusion-inpainting` model, untested on AMD devices). 75 | 76 | For windows, you may need to replace `pip install opencv-python` with `conda install -c conda-forge opencv` 77 | 78 | After setup the environment, you can run stablediffusion-infinity with following commands: 79 | ``` 80 | conda activate sd-inf 81 | python app.py 82 | ``` 83 | ## Setup with MacOS 84 | 85 | ### conda env 86 | ``` 87 | conda create -n sd-inf python=3.10 88 | conda activate sd-inf 89 | conda install pytorch torchvision torchaudio -c pytorch-nightly 90 | conda install scipy scikit-image 91 | conda install -c conda-forge diffusers transformers ftfy accelerate 92 | pip install opencv-python 93 | pip install -U gradio 94 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3 95 | pip install timm 96 | ``` 97 | 98 | After setup the environment, you can run stablediffusion-infinity with following commands: 99 | ``` 100 | conda activate sd-inf 101 | python app.py 102 | ``` 103 | ### CPP library (optional) 104 | 105 | Note that `opencv` library is required for `PyPatchMatch`. You may need to install `opencv` by yourself (via `homebrew` or compile from source). If no `opencv` installed, the `patch_match` option (usually better quality) won't work. 106 | 107 | ## Upgrade 108 | 109 | ``` 110 | conda install -c conda-forge diffusers transformers ftfy accelerate 111 | conda update -c conda-forge diffusers transformers ftfy accelerate 112 | pip install -U gradio 113 | ``` -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Models 4 | 5 | - stablediffusion-inpainting: `runwayml/stable-diffusion-inpainting`, does not support img2img mode 6 | - stablediffusion-inpainting+img2img-v1.5: `runwayml/stable-diffusion-inpainting` + `runwayml/stable-diffusion-v1-5`, supports img2img mode, requires larger vRAM 7 | - stablediffusion-v1.5: `runwayml/stable-diffusion-v1-5`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode 8 | - stablediffusion-v1.4: `CompVis/stable-diffusion-v1-4`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode 9 | 10 | ## Loading local model 11 | 12 | Note that when loading a local checkpoint, you have to specify the correct model choice before setup. 13 | ```shell 14 | python app.py --local_model path_to_local_model 15 | # e.g. 16 | # diffusers model weights 17 | python app.py --local_model ./models/runwayml/stable-diffusion-inpainting 18 | python app.py --local_model models/CompVis/stable-diffusion-v1-4/model_index.json 19 | # original model checkpoint 20 | python app.py --local_model /home/user/checkpoint/model.ckpt 21 | ``` 22 | 23 | ## Loading remote model 24 | 25 | Note that when loading a remote model, you have to specify the correct model choice before setup. 26 | ```shell 27 | python app.py --remote_model model_name 28 | # e.g. 29 | python app.py --remote_model hakurei/waifu-diffusion-v1-3 30 | ``` 31 | 32 | ## Using textual inversion embeddings 33 | 34 | Put `*.bin` inside `embeddings` directory. 35 | 36 | ## Using a dreambooth finetuned model 37 | 38 | ``` 39 | python app.py --remote_model model_name 40 | # e.g. 41 | python app.py --remote_model sd-dreambooth-library/pikachu 42 | # or download the weight/checkpoint and load with 43 | python app.py --local_model path_to_model 44 | ``` 45 | 46 | ## Model Path for Docker users 47 | 48 | Docker users can specify a local model path or remote mode name within the web app. 49 | 50 | ## Using fp32 mode or low vRAM mode (some GPUs might not work well fp16) 51 | 52 | ```shell 53 | python app.py --fp32 --lowvram 54 | ``` 55 | 56 | ## HTTPS 57 | 58 | ```shell 59 | python app.py --encrypt --ssl_keyfile path_to_ssl_keyfile --ssl_certfile path_to_ssl_certfile 60 | ``` 61 | 62 | ## Keyboard shortcut 63 | 64 | The shortcut can be configured via `config.yaml`. Currently only support `[key]` or `[Ctrl]` + `[key]` 65 | 66 | Default shortcuts are: 67 | 68 | ```yaml 69 | shortcut: 70 | clear: Escape 71 | load: Ctrl+o 72 | save: Ctrl+s 73 | export: Ctrl+e 74 | upload: Ctrl+u 75 | selection: 1 76 | canvas: 2 77 | eraser: 3 78 | outpaint: d 79 | accept: a 80 | cancel: c 81 | retry: r 82 | prev: q 83 | next: e 84 | zoom_in: z 85 | zoom_out: x 86 | random_seed: s 87 | ``` 88 | 89 | ## Glossary 90 | 91 | (From diffusers' document https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) 92 | - prompt: The prompt to guide the image generation. 93 | - step: The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. 94 | - guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,usually at the expense of lower image quality. 95 | - negative_prompt: The prompt or prompts not to guide the image generation. 96 | - Sample number: The number of images to generate per prompt 97 | - scheduler: A scheduler is used in combination with `unet` to denoise the encoded image latens. 98 | - eta: Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to DDIMScheduler, will be ignored for others. 99 | - strength: for img2img only, Conceptually, indicates how much to transform the reference image. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sd-inf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - abseil-cpp=20211102.0=h27087fc_1 11 | - accelerate=0.14.0=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h5764c6d_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - arrow-cpp=8.0.0=py310h3098874_0 15 | - async-timeout=4.0.2=pyhd8ed1ab_0 16 | - attrs=22.1.0=pyh71513ae_1 17 | - aws-c-common=0.4.57=he1b5a44_1 18 | - aws-c-event-stream=0.1.6=h72b8ae1_3 19 | - aws-checksums=0.1.9=h346380f_0 20 | - aws-sdk-cpp=1.8.185=hce553d0_0 21 | - backports=1.0=py_2 22 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 23 | - blas=1.0=mkl 24 | - blosc=1.21.0=h4ff587b_1 25 | - boost-cpp=1.78.0=he72f1d9_0 26 | - brotli=1.0.9=h5eee18b_7 27 | - brotli-bin=1.0.9=h5eee18b_7 28 | - brotlipy=0.7.0=py310h7f8727e_1002 29 | - brunsli=0.1=h2531618_0 30 | - bzip2=1.0.8=h7b6447c_0 31 | - c-ares=1.18.1=h7f8727e_0 32 | - ca-certificates=2022.10.11=h06a4308_0 33 | - certifi=2022.9.24=py310h06a4308_0 34 | - cffi=1.15.1=py310h74dc2b5_0 35 | - cfitsio=3.470=h5893167_7 36 | - charls=2.2.0=h2531618_0 37 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 38 | - click=8.1.3=unix_pyhd8ed1ab_2 39 | - cloudpickle=2.0.0=pyhd3eb1b0_0 40 | - colorama=0.4.6=pyhd8ed1ab_0 41 | - cryptography=38.0.1=py310h9ce1e76_0 42 | - cuda=11.6.2=0 43 | - cuda-cccl=11.6.55=hf6102b2_0 44 | - cuda-command-line-tools=11.6.2=0 45 | - cuda-compiler=11.6.2=0 46 | - cuda-cudart=11.6.55=he381448_0 47 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 48 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 49 | - cuda-cupti=11.6.124=h86345e5_0 50 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 51 | - cuda-driver-dev=11.6.55=0 52 | - cuda-gdb=11.8.86=0 53 | - cuda-libraries=11.6.2=0 54 | - cuda-libraries-dev=11.6.2=0 55 | - cuda-memcheck=11.8.86=0 56 | - cuda-nsight=11.8.86=0 57 | - cuda-nsight-compute=11.8.0=0 58 | - cuda-nvcc=11.6.124=hbba6d2d_0 59 | - cuda-nvdisasm=11.8.86=0 60 | - cuda-nvml-dev=11.6.55=haa9ef22_0 61 | - cuda-nvprof=11.8.87=0 62 | - cuda-nvprune=11.6.124=he22ec0a_0 63 | - cuda-nvrtc=11.6.124=h020bade_0 64 | - cuda-nvrtc-dev=11.6.124=h249d397_0 65 | - cuda-nvtx=11.6.124=h0630a44_0 66 | - cuda-nvvp=11.8.87=0 67 | - cuda-runtime=11.6.2=0 68 | - cuda-samples=11.6.101=h8efea70_0 69 | - cuda-sanitizer-api=11.8.86=0 70 | - cuda-toolkit=11.6.2=0 71 | - cuda-tools=11.6.2=0 72 | - cuda-visual-tools=11.6.2=0 73 | - cytoolz=0.12.0=py310h5eee18b_0 74 | - dask-core=2022.7.0=py310h06a4308_0 75 | - dataclasses=0.8=pyhc8e2a94_3 76 | - datasets=2.7.0=pyhd8ed1ab_0 77 | - diffusers=0.11.1=pyhd8ed1ab_0 78 | - dill=0.3.6=pyhd8ed1ab_1 79 | - ffmpeg=4.3=hf484d3e_0 80 | - fftw=3.3.9=h27cfd23_1 81 | - filelock=3.8.0=pyhd8ed1ab_0 82 | - freetype=2.12.1=h4a9f257_0 83 | - frozenlist=1.3.0=py310h5764c6d_1 84 | - fsspec=2022.10.0=py310h06a4308_0 85 | - ftfy=6.1.1=pyhd8ed1ab_0 86 | - gds-tools=1.4.0.31=0 87 | - gflags=2.2.2=he1b5a44_1004 88 | - giflib=5.2.1=h7b6447c_0 89 | - glog=0.6.0=h6f12383_0 90 | - gmp=6.2.1=h295c915_3 91 | - gnutls=3.6.15=he1e5248_0 92 | - grpc-cpp=1.46.1=h33aed49_0 93 | - huggingface_hub=0.11.0=pyhd8ed1ab_0 94 | - icu=70.1=h27087fc_0 95 | - idna=3.4=py310h06a4308_0 96 | - imagecodecs=2021.8.26=py310hecf7e94_1 97 | - imageio=2.19.3=py310h06a4308_0 98 | - importlib-metadata=5.0.0=pyha770c72_1 99 | - importlib_metadata=5.0.0=hd8ed1ab_1 100 | - intel-openmp=2021.4.0=h06a4308_3561 101 | - joblib=1.2.0=pyhd8ed1ab_0 102 | - jpeg=9e=h7f8727e_0 103 | - jxrlib=1.1=h7b6447c_2 104 | - krb5=1.19.2=hac12032_0 105 | - lame=3.100=h7b6447c_0 106 | - lcms2=2.12=h3be6417_0 107 | - ld_impl_linux-64=2.38=h1181459_1 108 | - lerc=3.0=h295c915_0 109 | - libaec=1.0.4=he6710b0_1 110 | - libbrotlicommon=1.0.9=h5eee18b_7 111 | - libbrotlidec=1.0.9=h5eee18b_7 112 | - libbrotlienc=1.0.9=h5eee18b_7 113 | - libcublas=11.11.3.6=0 114 | - libcublas-dev=11.11.3.6=0 115 | - libcufft=10.9.0.58=0 116 | - libcufft-dev=10.9.0.58=0 117 | - libcufile=1.4.0.31=0 118 | - libcufile-dev=1.4.0.31=0 119 | - libcurand=10.3.0.86=0 120 | - libcurand-dev=10.3.0.86=0 121 | - libcurl=7.85.0=h91b91d3_0 122 | - libcusolver=11.4.1.48=0 123 | - libcusolver-dev=11.4.1.48=0 124 | - libcusparse=11.7.5.86=0 125 | - libcusparse-dev=11.7.5.86=0 126 | - libdeflate=1.8=h7f8727e_5 127 | - libedit=3.1.20210910=h7f8727e_0 128 | - libev=4.33=h7f8727e_1 129 | - libevent=2.1.10=h9b69904_4 130 | - libffi=3.3=he6710b0_2 131 | - libgcc-ng=11.2.0=h1234567_1 132 | - libgfortran-ng=11.2.0=h00389a5_1 133 | - libgfortran5=11.2.0=h1234567_1 134 | - libgomp=11.2.0=h1234567_1 135 | - libiconv=1.16=h7f8727e_2 136 | - libidn2=2.3.2=h7f8727e_0 137 | - libnghttp2=1.46.0=hce63b2e_0 138 | - libnpp=11.8.0.86=0 139 | - libnpp-dev=11.8.0.86=0 140 | - libnvjpeg=11.9.0.86=0 141 | - libnvjpeg-dev=11.9.0.86=0 142 | - libpng=1.6.37=hbc83047_0 143 | - libprotobuf=3.20.1=h4ff587b_0 144 | - libssh2=1.10.0=h8f2d780_0 145 | - libstdcxx-ng=11.2.0=h1234567_1 146 | - libtasn1=4.16.0=h27cfd23_0 147 | - libthrift=0.15.0=he6d91bd_0 148 | - libtiff=4.4.0=hecacb30_2 149 | - libunistring=0.9.10=h27cfd23_0 150 | - libuuid=1.41.5=h5eee18b_0 151 | - libwebp=1.2.4=h11a3e52_0 152 | - libwebp-base=1.2.4=h5eee18b_0 153 | - libzopfli=1.0.3=he6710b0_0 154 | - locket=1.0.0=py310h06a4308_0 155 | - lz4-c=1.9.3=h295c915_1 156 | - mkl=2021.4.0=h06a4308_640 157 | - mkl-service=2.4.0=py310h7f8727e_0 158 | - mkl_fft=1.3.1=py310hd6ae3a3_0 159 | - mkl_random=1.2.2=py310h00e6091_0 160 | - multidict=6.0.2=py310h5764c6d_1 161 | - multiprocess=0.70.12.2=py310h5764c6d_2 162 | - ncurses=6.3=h5eee18b_3 163 | - nettle=3.7.3=hbbd107a_1 164 | - networkx=2.8.4=py310h06a4308_0 165 | - nsight-compute=2022.3.0.22=0 166 | - numpy=1.23.4=py310hd5efca6_0 167 | - numpy-base=1.23.4=py310h8e6c178_0 168 | - openh264=2.1.1=h4ff587b_0 169 | - openjpeg=2.4.0=h3ad879b_0 170 | - openssl=1.1.1s=h7f8727e_0 171 | - orc=1.7.4=h07ed6aa_0 172 | - packaging=21.3=pyhd3eb1b0_0 173 | - pandas=1.4.2=py310h769672d_1 174 | - partd=1.2.0=pyhd3eb1b0_1 175 | - pillow=9.2.0=py310hace64e9_1 176 | - pip=22.2.2=py310h06a4308_0 177 | - psutil=5.9.1=py310h5764c6d_0 178 | - pyarrow=8.0.0=py310h468efa6_0 179 | - pycparser=2.21=pyhd3eb1b0_0 180 | - pyopenssl=22.0.0=pyhd3eb1b0_0 181 | - pyparsing=3.0.9=py310h06a4308_0 182 | - pysocks=1.7.1=py310h06a4308_0 183 | - python=3.10.8=haa1d7c7_0 184 | - python-dateutil=2.8.2=pyhd8ed1ab_0 185 | - python-xxhash=3.0.0=py310h5764c6d_1 186 | - python_abi=3.10=2_cp310 187 | - pytorch=1.13.0=py3.10_cuda11.6_cudnn8.3.2_0 188 | - pytorch-cuda=11.6=h867d48c_0 189 | - pytorch-mutex=1.0=cuda 190 | - pytz=2022.6=pyhd8ed1ab_0 191 | - pywavelets=1.3.0=py310h7f8727e_0 192 | - re2=2022.04.01=h27087fc_0 193 | - readline=8.2=h5eee18b_0 194 | - regex=2022.4.24=py310h5764c6d_0 195 | - requests=2.28.1=py310h06a4308_0 196 | - responses=0.18.0=pyhd8ed1ab_0 197 | - sacremoses=0.0.53=pyhd8ed1ab_0 198 | - scikit-image=0.19.2=py310h00e6091_0 199 | - scipy=1.9.3=py310hd5efca6_0 200 | - setuptools=65.5.0=py310h06a4308_0 201 | - six=1.16.0=pyhd3eb1b0_1 202 | - snappy=1.1.9=h295c915_0 203 | - sqlite=3.39.3=h5082296_0 204 | - tifffile=2021.7.2=pyhd3eb1b0_2 205 | - tk=8.6.12=h1ccaba5_0 206 | - tokenizers=0.11.4=py310h3dcd8bd_1 207 | - toolz=0.12.0=py310h06a4308_0 208 | - torchaudio=0.13.0=py310_cu116 209 | - torchvision=0.14.0=py310_cu116 210 | - tqdm=4.64.1=pyhd8ed1ab_0 211 | - transformers=4.24.0=pyhd8ed1ab_0 212 | - typing-extensions=4.3.0=py310h06a4308_0 213 | - typing_extensions=4.3.0=py310h06a4308_0 214 | - tzdata=2022f=h04d1e81_0 215 | - urllib3=1.26.12=py310h06a4308_0 216 | - utf8proc=2.6.1=h27cfd23_0 217 | - wcwidth=0.2.5=pyh9f0ad1d_2 218 | - wheel=0.37.1=pyhd3eb1b0_0 219 | - xxhash=0.8.0=h7f98852_3 220 | - xz=5.2.6=h5eee18b_0 221 | - yaml=0.2.5=h7b6447c_0 222 | - yarl=1.7.2=py310h5764c6d_2 223 | - zfp=0.5.5=h295c915_6 224 | - zipp=3.10.0=pyhd8ed1ab_0 225 | - zlib=1.2.13=h5eee18b_0 226 | - zstd=1.5.2=ha4553b6_0 227 | - pip: 228 | - absl-py==1.3.0 229 | - antlr4-python3-runtime==4.9.3 230 | - anyio==3.6.2 231 | - bcrypt==4.0.1 232 | - cachetools==5.2.0 233 | - cmake==3.25.0 234 | - commonmark==0.9.1 235 | - contourpy==1.0.6 236 | - cycler==0.11.0 237 | - einops==0.4.1 238 | - fastapi==0.87.0 239 | - ffmpy==0.3.0 240 | - fonttools==4.38.0 241 | - fpie==0.2.4 242 | - google-auth==2.14.1 243 | - google-auth-oauthlib==0.4.6 244 | - gradio==3.10.1 245 | - grpcio==1.51.0 246 | - h11==0.12.0 247 | - httpcore==0.15.0 248 | - httpx==0.23.1 249 | - jinja2==3.1.2 250 | - kiwisolver==1.4.4 251 | - linkify-it-py==1.0.3 252 | - llvmlite==0.39.1 253 | - markdown==3.4.1 254 | - markdown-it-py==2.1.0 255 | - markupsafe==2.1.1 256 | - matplotlib==3.6.2 257 | - mdit-py-plugins==0.3.1 258 | - mdurl==0.1.2 259 | - numba==0.56.4 260 | - oauthlib==3.2.2 261 | - omegaconf==2.2.3 262 | - opencv-python==4.6.0.66 263 | - opencv-python-headless==4.6.0.66 264 | - orjson==3.8.2 265 | - paramiko==2.12.0 266 | - protobuf==3.20.3 267 | - pyasn1==0.4.8 268 | - pyasn1-modules==0.2.8 269 | - pycryptodome==3.15.0 270 | - pydantic==1.10.2 271 | - pydeprecate==0.3.2 272 | - pydub==0.25.1 273 | - pygments==2.13.0 274 | - pynacl==1.5.0 275 | - python-multipart==0.0.5 276 | - pytorch-lightning==1.7.7 277 | - pyyaml==6.0 278 | - requests-oauthlib==1.3.1 279 | - rfc3986==1.5.0 280 | - rich==12.6.0 281 | - rsa==4.9 282 | - sniffio==1.3.0 283 | - sourceinspect==0.0.4 284 | - starlette==0.21.0 285 | - taichi==1.2.2 286 | - tensorboard==2.11.0 287 | - tensorboard-data-server==0.6.1 288 | - tensorboard-plugin-wit==1.8.1 289 | - timm==0.6.11 290 | - torchmetrics==0.10.3 291 | - uc-micro-py==1.0.1 292 | - uvicorn==0.20.0 293 | - websockets==10.4 294 | - werkzeug==2.2.2 -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Stablediffusion Infinity 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 31 | 32 | 33 | 34 |
35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 |
47 |
48 | 49 | 50 |
51 |
52 | 53 | 54 |
55 |
56 | 57 | 58 |
59 |
60 |
61 |
62 |
63 | 64 | 65 | 66 | 67 | 68 |
69 | 70 |
71 |
72 | 73 | 74 |
75 |
76 |
77 |
78 | 79 | - numpy 80 | - Pillow 81 | - paths: 82 | - ./canvas.py 83 | 84 | 85 | 86 | from pyodide import to_js, create_proxy 87 | from PIL import Image 88 | import io 89 | import time 90 | import base64 91 | from collections import deque 92 | import numpy as np 93 | from js import ( 94 | console, 95 | document, 96 | parent, 97 | devicePixelRatio, 98 | ImageData, 99 | Uint8ClampedArray, 100 | CanvasRenderingContext2D as Context2d, 101 | requestAnimationFrame, 102 | window, 103 | encodeURIComponent, 104 | w2ui, 105 | update_eraser, 106 | update_scale, 107 | adjust_selection, 108 | update_count, 109 | enable_result_lst, 110 | setup_shortcut, 111 | update_undo_redo, 112 | ) 113 | 114 | 115 | from canvas import InfCanvas 116 | 117 | 118 | class History: 119 | def __init__(self,maxlen=10): 120 | self.idx=-1 121 | self.undo_lst=deque([],maxlen=maxlen) 122 | self.redo_lst=deque([],maxlen=maxlen) 123 | self.state=None 124 | 125 | def undo(self): 126 | cur=None 127 | if len(self.undo_lst): 128 | cur=self.undo_lst.pop() 129 | self.redo_lst.appendleft(cur) 130 | return cur 131 | def redo(self): 132 | cur=None 133 | if len(self.redo_lst): 134 | cur=self.redo_lst.popleft() 135 | self.undo_lst.append(cur) 136 | return cur 137 | 138 | def check(self): 139 | return len(self.undo_lst)>0,len(self.redo_lst)>0 140 | 141 | def append(self,state,update=True): 142 | self.redo_lst.clear() 143 | self.undo_lst.append(state) 144 | if update: 145 | update_undo_redo(*self.check()) 146 | 147 | history = History() 148 | 149 | base_lst = [None] 150 | async def draw_canvas() -> None: 151 | width=1024 152 | height=600 153 | canvas=InfCanvas(1024,600) 154 | update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w)) 155 | document.querySelector("#container").style.height= f"{height}px" 156 | document.querySelector("#container").style.width = f"{width}px" 157 | canvas.setup_mouse() 158 | canvas.clear_background() 159 | canvas.draw_buffer() 160 | canvas.draw_selection_box() 161 | base_lst[0]=canvas 162 | 163 | async def draw_canvas_func(event): 164 | try: 165 | app=parent.document.querySelector("gradio-app") 166 | if app.shadowRoot: 167 | app=app.shadowRoot 168 | width=app.querySelector("#canvas_width input").value 169 | height=app.querySelector("#canvas_height input").value 170 | selection_size=app.querySelector("#selection_size input").value 171 | except: 172 | width=1024 173 | height=768 174 | selection_size=384 175 | document.querySelector("#container").style.width = f"{width}px" 176 | document.querySelector("#container").style.height= f"{height}px" 177 | canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size)) 178 | canvas.setup_mouse() 179 | canvas.clear_background() 180 | canvas.draw_buffer() 181 | canvas.draw_selection_box() 182 | base_lst[0]=canvas 183 | 184 | async def export_func(event): 185 | base=base_lst[0] 186 | arr=base.export() 187 | base.draw_buffer() 188 | base.canvas[2].clear() 189 | base64_str = base.numpy_to_base64(arr) 190 | time_str = time.strftime("%Y%m%d_%H%M%S") 191 | link = document.createElement("a") 192 | if len(event.data)>2 and event.data[2]: 193 | filename = event.data[2] 194 | else: 195 | filename = f"outpaint_{time_str}" 196 | # link.download = f"sdinf_state_{time_str}.json" 197 | link.download = f"{filename}.png" 198 | # link.download = f"outpaint_{time_str}.png" 199 | link.href = "data:image/png;base64,"+base64_str 200 | link.click() 201 | console.log(f"Canvas saved to {filename}.png") 202 | 203 | img_candidate_lst=[None,0] 204 | 205 | async def outpaint_func(event): 206 | base=base_lst[0] 207 | if len(event.data)==2: 208 | app=parent.document.querySelector("gradio-app") 209 | if app.shadowRoot: 210 | app=app.shadowRoot 211 | base64_str_raw=app.querySelector("#output textarea").value 212 | base64_str_lst=base64_str_raw.split(",") 213 | img_candidate_lst[0]=base64_str_lst 214 | img_candidate_lst[1]=0 215 | elif event.data[2]=="next": 216 | img_candidate_lst[1]+=1 217 | elif event.data[2]=="prev": 218 | img_candidate_lst[1]-=1 219 | enable_result_lst() 220 | if img_candidate_lst[0] is None: 221 | return 222 | lst=img_candidate_lst[0] 223 | idx=img_candidate_lst[1] 224 | update_count(idx%len(lst)+1,len(lst)) 225 | arr=base.base64_to_numpy(lst[idx%len(lst)]) 226 | base.fill_selection(arr) 227 | base.draw_selection_box() 228 | 229 | async def undo_func(event): 230 | base=base_lst[0] 231 | img_candidate_lst[0]=None 232 | if base.sel_dirty: 233 | base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8) 234 | base.sel_dirty = False 235 | base.canvas[2].clear() 236 | 237 | async def commit_func(event): 238 | base=base_lst[0] 239 | img_candidate_lst[0]=None 240 | if base.sel_dirty: 241 | base.write_selection_to_buffer() 242 | base.draw_buffer() 243 | base.canvas[2].clear() 244 | if len(event.data)>2: 245 | history.append(base.save()) 246 | 247 | async def history_undo_func(event): 248 | base=base_lst[0] 249 | if base.buffer_dirty or len(history.redo_lst)>0: 250 | state=history.undo() 251 | else: 252 | history.undo() 253 | state=history.undo() 254 | if state is not None: 255 | base.load(state) 256 | update_undo_redo(*history.check()) 257 | 258 | async def history_setup_func(event): 259 | base=base_lst[0] 260 | history.undo_lst.clear() 261 | history.redo_lst.clear() 262 | history.append(base.save(),update=False) 263 | 264 | async def history_redo_func(event): 265 | base=base_lst[0] 266 | if len(history.undo_lst)>0: 267 | state=history.redo() 268 | else: 269 | history.redo() 270 | state=history.redo() 271 | if state is not None: 272 | base.load(state) 273 | update_undo_redo(*history.check()) 274 | 275 | 276 | async def transfer_func(event): 277 | base=base_lst[0] 278 | base.read_selection_from_buffer() 279 | sel_buffer=base.sel_buffer 280 | sel_buffer_str=base.numpy_to_base64(sel_buffer) 281 | app=parent.document.querySelector("gradio-app") 282 | if app.shadowRoot: 283 | app=app.shadowRoot 284 | app.querySelector("#input textarea").value=sel_buffer_str 285 | app.querySelector("#proceed").click() 286 | 287 | async def upload_func(event): 288 | base=base_lst[0] 289 | # base64_str=event.data[1] 290 | base64_str=document.querySelector("#upload_content").value 291 | base64_str=base64_str.split(",")[-1] 292 | # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value 293 | arr=base.base64_to_numpy(base64_str) 294 | h,w,c=base.buffer.shape 295 | base.sync_to_buffer() 296 | base.buffer_dirty=True 297 | mask=arr[:,:,3:4].repeat(4,axis=2) 298 | base.buffer[mask>0]=0 299 | # in case mismatch 300 | base.buffer[0:h,0:w,:]+=arr 301 | #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3] 302 | #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1] 303 | base.draw_buffer() 304 | if len(event.data)>2: 305 | history.append(base.save()) 306 | 307 | async def setup_shortcut_func(event): 308 | setup_shortcut(event.data[1]) 309 | 310 | 311 | document.querySelector("#export").addEventListener("click",create_proxy(export_func)) 312 | document.querySelector("#undo").addEventListener("click",create_proxy(undo_func)) 313 | document.querySelector("#commit").addEventListener("click",create_proxy(commit_func)) 314 | document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func)) 315 | document.querySelector("#upload").addEventListener("click",create_proxy(upload_func)) 316 | 317 | document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func)) 318 | document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func)) 319 | 320 | async def setup_func(): 321 | document.querySelector("#setup").value="1" 322 | 323 | async def reset_func(event): 324 | base=base_lst[0] 325 | base.reset() 326 | 327 | async def load_func(event): 328 | base=base_lst[0] 329 | base.load(event.data[1]) 330 | 331 | async def save_func(event): 332 | base=base_lst[0] 333 | json_str=base.save() 334 | time_str = time.strftime("%Y%m%d_%H%M%S") 335 | link = document.createElement("a") 336 | if len(event.data)>2 and event.data[2]: 337 | filename = str(event.data[2]).strip() 338 | else: 339 | filename = f"outpaint_{time_str}" 340 | # link.download = f"sdinf_state_{time_str}.json" 341 | link.download = f"{filename}.sdinf" 342 | link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str) 343 | link.click() 344 | 345 | async def prev_result_func(event): 346 | base=base_lst[0] 347 | base.reset() 348 | 349 | async def next_result_func(event): 350 | base=base_lst[0] 351 | base.reset() 352 | 353 | async def zoom_in_func(event): 354 | base=base_lst[0] 355 | scale=base.scale 356 | if scale>=0.2: 357 | scale-=0.1 358 | if len(event.data)>2: 359 | base.update_scale(scale,int(event.data[2]),int(event.data[3])) 360 | else: 361 | base.update_scale(scale) 362 | scale=base.scale 363 | update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)") 364 | 365 | async def zoom_out_func(event): 366 | base=base_lst[0] 367 | scale=base.scale 368 | if scale<10: 369 | scale+=0.1 370 | console.log(len(event.data)) 371 | if len(event.data)>2: 372 | base.update_scale(scale,int(event.data[2]),int(event.data[3])) 373 | else: 374 | base.update_scale(scale) 375 | scale=base.scale 376 | update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)") 377 | 378 | async def sync_func(event): 379 | base=base_lst[0] 380 | base.sync_to_buffer() 381 | base.canvas[2].clear() 382 | 383 | async def eraser_size_func(event): 384 | base=base_lst[0] 385 | eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w)) 386 | eraser_size=max(8,eraser_size) 387 | base.eraser_size=eraser_size 388 | 389 | async def resize_selection_func(event): 390 | base=base_lst[0] 391 | cursor=base.cursor 392 | if len(event.data)>3: 393 | console.log(event.data) 394 | base.cursor[0]=int(event.data[1]) 395 | base.cursor[1]=int(event.data[2]) 396 | base.selection_size_w=int(event.data[3])//8*8 397 | base.selection_size_h=int(event.data[4])//8*8 398 | base.refine_selection() 399 | base.draw_selection_box() 400 | elif len(event.data)>2: 401 | base.draw_selection_box() 402 | else: 403 | base.canvas[-1].clear() 404 | adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h) 405 | 406 | async def eraser_func(event): 407 | base=base_lst[0] 408 | if event.data[1]!="eraser": 409 | base.canvas[-2].clear() 410 | else: 411 | x,y=base.mouse_pos 412 | base.draw_eraser(x,y) 413 | 414 | async def resize_func(event): 415 | base=base_lst[0] 416 | width=int(event.data[1]) 417 | height=int(event.data[2]) 418 | if width>=256 and height>=256: 419 | if max(base.selection_size_h,base.selection_size_w)>min(width,height): 420 | base.selection_size_h=256 421 | base.selection_size_w=256 422 | base.resize(width,height) 423 | 424 | async def message_func(event): 425 | if event.data[0]=="click": 426 | if event.data[1]=="clear": 427 | await reset_func(event) 428 | elif event.data[1]=="save": 429 | await save_func(event) 430 | elif event.data[1]=="export": 431 | await export_func(event) 432 | elif event.data[1]=="accept": 433 | await commit_func(event) 434 | elif event.data[1]=="cancel": 435 | await undo_func(event) 436 | elif event.data[1]=="zoom_in": 437 | await zoom_in_func(event) 438 | elif event.data[1]=="zoom_out": 439 | await zoom_out_func(event) 440 | elif event.data[1]=="redo": 441 | await history_redo_func(event) 442 | elif event.data[1]=="undo": 443 | await history_undo_func(event) 444 | elif event.data[1]=="history": 445 | await history_setup_func(event) 446 | elif event.data[0]=="sync": 447 | await sync_func(event) 448 | elif event.data[0]=="load": 449 | await load_func(event) 450 | elif event.data[0]=="upload": 451 | await upload_func(event) 452 | elif event.data[0]=="outpaint": 453 | await outpaint_func(event) 454 | elif event.data[0]=="mode": 455 | if event.data[1]!="selection": 456 | await sync_func(event) 457 | await eraser_func(event) 458 | document.querySelector("#mode").value=event.data[1] 459 | elif event.data[0]=="transfer": 460 | await transfer_func(event) 461 | elif event.data[0]=="setup": 462 | await draw_canvas_func(event) 463 | elif event.data[0]=="eraser_size": 464 | await eraser_size_func(event) 465 | elif event.data[0]=="resize_selection": 466 | await resize_selection_func(event) 467 | elif event.data[0]=="shortcut": 468 | await setup_shortcut_func(event) 469 | elif event.data[0]=="resize": 470 | await resize_func(event) 471 | 472 | window.addEventListener("message",create_proxy(message_func)) 473 | 474 | import asyncio 475 | 476 | _ = await asyncio.gather( 477 | setup_func() 478 | ) 479 | 480 | 481 | 482 | 483 | -------------------------------------------------------------------------------- /interrogate.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2022 pharmapsychotic 5 | https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb 6 | """ 7 | 8 | import numpy as np 9 | import os 10 | import torch 11 | import torchvision.transforms as T 12 | import torchvision.transforms.functional as TF 13 | 14 | from torch import nn 15 | from torch.nn import functional as F 16 | from torchvision import transforms 17 | from torchvision.transforms.functional import InterpolationMode 18 | from transformers import CLIPTokenizer, CLIPModel 19 | from transformers import CLIPProcessor, CLIPModel 20 | 21 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "blip_model", "data") 22 | def load_list(filename): 23 | with open(filename, 'r', encoding='utf-8', errors='replace') as f: 24 | items = [line.strip() for line in f.readlines()] 25 | return items 26 | 27 | artists = load_list(os.path.join(data_path, 'artists.txt')) 28 | flavors = load_list(os.path.join(data_path, 'flavors.txt')) 29 | mediums = load_list(os.path.join(data_path, 'mediums.txt')) 30 | movements = load_list(os.path.join(data_path, 'movements.txt')) 31 | 32 | sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central'] 33 | trending_list = [site for site in sites] 34 | trending_list.extend(["trending on "+site for site in sites]) 35 | trending_list.extend(["featured on "+site for site in sites]) 36 | trending_list.extend([site+" contest winner" for site in sites]) 37 | 38 | device="cpu" 39 | blip_image_eval_size = 384 40 | clip_name="openai/clip-vit-large-patch14" 41 | 42 | blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' 43 | 44 | def generate_caption(blip_model, pil_image, device="cpu"): 45 | gpu_image = transforms.Compose([ 46 | transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 49 | ])(pil_image).unsqueeze(0).to(device) 50 | 51 | with torch.no_grad(): 52 | caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) 53 | return caption[0] 54 | 55 | def rank(text_features, image_features, text_array, top_count=1): 56 | top_count = min(top_count, len(text_array)) 57 | similarity = torch.zeros((1, len(text_array))) 58 | for i in range(image_features.shape[0]): 59 | similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) 60 | similarity /= image_features.shape[0] 61 | 62 | top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) 63 | return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] 64 | 65 | class Interrogator: 66 | def __init__(self) -> None: 67 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_name) 68 | try: 69 | self.get_blip() 70 | except: 71 | self.blip_model = None 72 | self.model = CLIPModel.from_pretrained(clip_name) 73 | self.processor = CLIPProcessor.from_pretrained(clip_name) 74 | self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)] 75 | 76 | def get_blip(self): 77 | from blip_model.blip import blip_decoder 78 | blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base') 79 | blip_model.eval() 80 | self.blip_model = blip_model 81 | 82 | 83 | def interrogate(self,image,use_caption=False): 84 | if self.blip_model: 85 | caption = generate_caption(self.blip_model, image) 86 | else: 87 | caption = "" 88 | model,processor=self.model,self.processor 89 | bests = [[('',0)]]*5 90 | if True: 91 | print(f"Interrogating with {clip_name}...") 92 | 93 | inputs = processor(images=image, return_tensors="pt") 94 | with torch.no_grad(): 95 | image_features = model.get_image_features(**inputs) 96 | image_features /= image_features.norm(dim=-1, keepdim=True) 97 | ranks = [ 98 | rank(self.text_feature_lst[0], image_features, mediums), 99 | rank(self.text_feature_lst[1], image_features, ["by "+artist for artist in artists]), 100 | rank(self.text_feature_lst[2], image_features, trending_list), 101 | rank(self.text_feature_lst[3], image_features, movements), 102 | rank(self.text_feature_lst[4], image_features, flavors, top_count=3) 103 | ] 104 | 105 | for i in range(len(ranks)): 106 | confidence_sum = 0 107 | for ci in range(len(ranks[i])): 108 | confidence_sum += ranks[i][ci][1] 109 | if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): 110 | bests[i] = ranks[i] 111 | 112 | flaves = ', '.join([f"{x[0]}" for x in bests[4]]) 113 | medium = bests[0][0][0] 114 | print(ranks) 115 | if caption.startswith(medium): 116 | return f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" 117 | else: 118 | return f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}" 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /js/keyboard.js: -------------------------------------------------------------------------------- 1 | 2 | window.my_setup_keyboard=setInterval(function(){ 3 | let app=document.querySelector("gradio-app"); 4 | app=app.shadowRoot??app; 5 | let frame=app.querySelector("#sdinfframe").contentWindow; 6 | console.log("Check iframe..."); 7 | if(frame.setup_shortcut) 8 | { 9 | frame.setup_shortcut(json); 10 | clearInterval(window.my_setup_keyboard); 11 | } 12 | }, 1000); 13 | var config=JSON.parse(json); 14 | var key_map={}; 15 | Object.keys(config.shortcut).forEach(k=>{ 16 | key_map[config.shortcut[k]]=k; 17 | }); 18 | document.addEventListener("keydown", e => { 19 | if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA") 20 | { 21 | let key=e.key; 22 | if(e.ctrlKey) 23 | { 24 | key="Ctrl+"+e.key; 25 | if(key in key_map) 26 | { 27 | e.preventDefault(); 28 | } 29 | } 30 | let app=document.querySelector("gradio-app"); 31 | app=app.shadowRoot??app; 32 | let frame=app.querySelector("#sdinfframe").contentDocument; 33 | frame.dispatchEvent( 34 | new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey}) 35 | ); 36 | } 37 | }) -------------------------------------------------------------------------------- /js/mode.js: -------------------------------------------------------------------------------- 1 | function(mode){ 2 | let app=document.querySelector("gradio-app").shadowRoot; 3 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 4 | frame.querySelector("#mode").value=mode; 5 | return mode; 6 | } -------------------------------------------------------------------------------- /js/outpaint.js: -------------------------------------------------------------------------------- 1 | function(a){ 2 | if(!window.my_observe_outpaint) 3 | { 4 | console.log("setup outpaint here"); 5 | window.my_observe_outpaint = new MutationObserver(function (event) { 6 | console.log(event); 7 | let app=document.querySelector("gradio-app"); 8 | app=app.shadowRoot??app; 9 | let frame=app.querySelector("#sdinfframe").contentWindow; 10 | frame.postMessage(["outpaint", ""], "*"); 11 | }); 12 | var app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | window.my_observe_outpaint_target=app.querySelector("#output span"); 15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, { 16 | attributes: false, 17 | subtree: true, 18 | childList: true, 19 | characterData: true 20 | }); 21 | } 22 | return a; 23 | } -------------------------------------------------------------------------------- /js/proceed.js: -------------------------------------------------------------------------------- 1 | function(sel_buffer_str, 2 | prompt_text, 3 | negative_prompt_text, 4 | strength, 5 | guidance, 6 | step, 7 | resize_check, 8 | fill_mode, 9 | enable_safety, 10 | use_correction, 11 | enable_img2img, 12 | use_seed, 13 | seed_val, 14 | generate_num, 15 | scheduler, 16 | scheduler_eta, 17 | interrogate_mode, 18 | state){ 19 | let app=document.querySelector("gradio-app"); 20 | app=app.shadowRoot??app; 21 | sel_buffer=app.querySelector("#input textarea").value; 22 | let use_correction_bak=false; 23 | ({resize_check,enable_safety,enable_img2img,use_seed,seed_val,interrogate_mode}=window.config_obj); 24 | seed_val=Number(seed_val); 25 | return [ 26 | sel_buffer, 27 | prompt_text, 28 | negative_prompt_text, 29 | strength, 30 | guidance, 31 | step, 32 | resize_check, 33 | fill_mode, 34 | enable_safety, 35 | use_correction, 36 | enable_img2img, 37 | use_seed, 38 | seed_val, 39 | generate_num, 40 | scheduler, 41 | scheduler_eta, 42 | interrogate_mode, 43 | state, 44 | ] 45 | } -------------------------------------------------------------------------------- /js/setup.js: -------------------------------------------------------------------------------- 1 | function(token_val, width, height, size, model_choice, model_path){ 2 | let app=document.querySelector("gradio-app"); 3 | app=app.shadowRoot??app; 4 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px"; 5 | // app.querySelector("#setup_row").style.display="none"; 6 | app.querySelector("#model_path_input").style.display="none"; 7 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 8 | 9 | if(frame.querySelector("#setup").value=="0") 10 | { 11 | window.my_setup=setInterval(function(){ 12 | let app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | let frame=app.querySelector("#sdinfframe").contentWindow.document; 15 | console.log("Check PyScript...") 16 | if(frame.querySelector("#setup").value=="1") 17 | { 18 | frame.querySelector("#draw").click(); 19 | clearInterval(window.my_setup); 20 | } 21 | }, 100) 22 | } 23 | else 24 | { 25 | frame.querySelector("#draw").click(); 26 | } 27 | return [token_val, width, height, size, model_choice, model_path]; 28 | } -------------------------------------------------------------------------------- /js/toolbar.js: -------------------------------------------------------------------------------- 1 | // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js" 2 | // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js" 3 | 4 | // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript 5 | function getBase64(file) { 6 | var reader = new FileReader(); 7 | reader.readAsDataURL(file); 8 | reader.onload = function () { 9 | add_image(reader.result); 10 | // console.log(reader.result); 11 | }; 12 | reader.onerror = function (error) { 13 | console.log("Error: ", error); 14 | }; 15 | } 16 | 17 | function getText(file) { 18 | var reader = new FileReader(); 19 | reader.readAsText(file); 20 | reader.onload = function () { 21 | window.postMessage(["load",reader.result],"*") 22 | // console.log(reader.result); 23 | }; 24 | reader.onerror = function (error) { 25 | console.log("Error: ", error); 26 | }; 27 | } 28 | 29 | document.querySelector("#upload_file").addEventListener("change", (event)=>{ 30 | console.log(event); 31 | let file = document.querySelector("#upload_file").files[0]; 32 | getBase64(file); 33 | }) 34 | 35 | document.querySelector("#upload_state").addEventListener("change", (event)=>{ 36 | console.log(event); 37 | let file = document.querySelector("#upload_state").files[0]; 38 | getText(file); 39 | }) 40 | 41 | open_setting = function() { 42 | if (!w2ui.foo) { 43 | new w2form({ 44 | name: "foo", 45 | style: "border: 0px; background-color: transparent;", 46 | fields: [{ 47 | field: "canvas_width", 48 | type: "int", 49 | required: true, 50 | html: { 51 | label: "Canvas Width" 52 | } 53 | }, 54 | { 55 | field: "canvas_height", 56 | type: "int", 57 | required: true, 58 | html: { 59 | label: "Canvas Height" 60 | } 61 | }, 62 | ], 63 | record: { 64 | canvas_width: 1200, 65 | canvas_height: 600, 66 | }, 67 | actions: { 68 | Save() { 69 | this.validate(); 70 | let record = this.getCleanRecord(); 71 | window.postMessage(["resize",record.canvas_width,record.canvas_height],"*"); 72 | w2popup.close(); 73 | }, 74 | custom: { 75 | text: "Cancel", 76 | style: "text-transform: uppercase", 77 | onClick(event) { 78 | w2popup.close(); 79 | } 80 | } 81 | } 82 | }); 83 | } 84 | w2popup.open({ 85 | title: "Form in a Popup", 86 | body: "
", 87 | style: "padding: 15px 0px 0px 0px", 88 | width: 500, 89 | height: 280, 90 | showMax: true, 91 | async onToggle(event) { 92 | await event.complete 93 | w2ui.foo.resize(); 94 | } 95 | }) 96 | .then((event) => { 97 | w2ui.foo.render("#form") 98 | }); 99 | } 100 | 101 | var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"]; 102 | var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting", "interrogate"]; 103 | var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting", "interrogate"]; 104 | var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting", "interrogate", "undo", "redo"]; 105 | var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"]; 106 | var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"]; 107 | 108 | function check_button(id,text="",checked=true,tooltip="") 109 | { 110 | return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip }; 111 | } 112 | 113 | var toolbar=new w2toolbar({ 114 | box: "#toolbar", 115 | name: "toolbar", 116 | tooltip: "top", 117 | items: [ 118 | { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" }, 119 | { type: "break" }, 120 | { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" }, 121 | { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" }, 122 | { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" }, 123 | { type: "break" }, 124 | { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" }, 125 | { type: "break" }, 126 | { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true }, 127 | { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" }, 128 | { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" }, 129 | { type: "break" }, 130 | { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" }, 131 | { type: "button", id: "interrogate", text: "Interrogate", tooltip: "Get a prompt with Clip Interrogator ", icon: "fa-solid fa-magnifying-glass" }, 132 | { type: "break" }, 133 | { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disabled:true,}, 134 | { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true}, 135 | { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disabled:true,}, 136 | { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disabled:true,}, 137 | { type: "html", id: "current", hidden: true, disabled:true, 138 | async onRefresh(event) { 139 | await event.complete 140 | let fragment = query.html(` 141 |
142 |
143 | ${this.sel_value ?? "1/1"} 144 |
`) 145 | query(this.box).find("#tb_toolbar_item_current").append(fragment) 146 | } 147 | }, 148 | { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disabled:true,}, 149 | { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disabled:true,}, 150 | { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disabled:true,}, 151 | { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disabled:true,}, 152 | { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disabled:true,}, 153 | { type: "break" }, 154 | { type: "spacer" }, 155 | { type: "break" }, 156 | { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32}, 157 | { type: "html", id: "eraser_size", hidden: true, 158 | async onRefresh(event) { 159 | await event.complete 160 | // let fragment = query.html(` 161 | // 162 | // `) 163 | let fragment = query.html(` 164 | 165 | `) 166 | fragment.filter("input").on("change", event => { 167 | this.eraser_size = event.target.value; 168 | window.overlay.freeDrawingBrush.width=this.eraser_size; 169 | this.setCount("eraser_size_btn", event.target.value); 170 | window.postMessage(["eraser_size", event.target.value],"*") 171 | this.refresh(); 172 | }) 173 | query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment) 174 | } 175 | }, 176 | // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" }, 177 | { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" }, 178 | { type: "break" }, 179 | { type: "html", id: "scale", 180 | async onRefresh(event) { 181 | await event.complete 182 | let fragment = query.html(` 183 |
184 |
185 | ${this.scale_value ?? "100%"} 186 |
`) 187 | query(this.box).find("#tb_toolbar_item_scale").append(fragment) 188 | } 189 | }, 190 | { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" }, 191 | { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" }, 192 | { type: "break" }, 193 | { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" }, 194 | { type: "new-line"}, 195 | { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" }, 196 | { type: "break" }, 197 | check_button("enable_history","Enable History:",false, "Enable Canvas History"), 198 | { type: "button", id: "undo", tooltip: "Undo last erasing/last outpainting", icon: "fa-solid fa-rotate-left", disabled: true }, 199 | { type: "button", id: "redo", tooltip: "Redo", icon: "fa-solid fa-rotate-right", disabled: true }, 200 | { type: "break" }, 201 | check_button("enable_img2img","Enable Img2Img",false), 202 | // check_button("use_correction","Photometric Correction",false), 203 | check_button("resize_check","Resize Small Input",true), 204 | check_button("enable_safety","Enable Safety Checker",true), 205 | check_button("square_selection","Square Selection Only",false), 206 | {type: "break"}, 207 | check_button("use_seed","Use Seed:",false), 208 | { type: "html", id: "seed_val", 209 | async onRefresh(event) { 210 | await event.complete 211 | let fragment = query.html(` 212 | `) 213 | fragment.filter("input").on("change", event => { 214 | this.config_obj.seed_val = event.target.value; 215 | parent.config_obj=this.config_obj; 216 | this.refresh(); 217 | }) 218 | query(this.box).find("#tb_toolbar_item_seed_val").append(fragment) 219 | } 220 | }, 221 | { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" }, 222 | ], 223 | onClick(event) { 224 | switch(event.target){ 225 | case "setting": 226 | open_setting(); 227 | break; 228 | case "upload": 229 | this.upload_mode=true 230 | document.querySelector("#overlay_container").style.pointerEvents="auto"; 231 | this.click("canvas"); 232 | this.click("selection"); 233 | this.show("confirm","cancel_overlay","add_image","delete_image"); 234 | this.enable("confirm","cancel_overlay","add_image","delete_image"); 235 | this.disable(...upload_button_lst); 236 | this.disable("undo","redo") 237 | query("#upload_file").click(); 238 | if(this.upload_tip) 239 | { 240 | this.upload_tip=false; 241 | w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")}) 242 | } 243 | break; 244 | case "resize_selection": 245 | this.resize_mode=true; 246 | this.disable(...resize_button_lst); 247 | this.enable("confirm","cancel_overlay"); 248 | this.show("confirm","cancel_overlay"); 249 | window.postMessage(["resize_selection",""],"*"); 250 | document.querySelector("#overlay_container").style.pointerEvents="auto"; 251 | break; 252 | case "confirm": 253 | if(this.upload_mode) 254 | { 255 | export_image(); 256 | } 257 | else 258 | { 259 | let sel_box=this.selection_box; 260 | if(sel_box.width*sel_box.height>512*512) 261 | { 262 | w2utils.notify("Note that the outpainting will be much slower when the area of selection is larger than 512x512",{timeout:2000,where:query("#container")}) 263 | } 264 | window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*"); 265 | } 266 | case "cancel_overlay": 267 | end_overlay(); 268 | this.hide("confirm","cancel_overlay","add_image","delete_image"); 269 | if(this.upload_mode){ 270 | this.enable(...upload_button_lst); 271 | } 272 | else 273 | { 274 | this.enable(...resize_button_lst); 275 | window.postMessage(["resize_selection","",""],"*"); 276 | if(event.target=="cancel_overlay") 277 | { 278 | this.selection_box=this.selection_box_bak; 279 | } 280 | } 281 | if(this.selection_box) 282 | { 283 | this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`); 284 | } 285 | this.disable("confirm","cancel_overlay","add_image","delete_image"); 286 | this.upload_mode=false; 287 | this.resize_mode=false; 288 | this.click("selection"); 289 | window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo); 290 | break; 291 | case "add_image": 292 | query("#upload_file").click(); 293 | break; 294 | case "delete_image": 295 | let active_obj = window.overlay.getActiveObject(); 296 | if(active_obj) 297 | { 298 | window.overlay.remove(active_obj); 299 | window.overlay.renderAll(); 300 | } 301 | else 302 | { 303 | w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")}) 304 | } 305 | break; 306 | case "load": 307 | query("#upload_state").click(); 308 | this.selection_box=null; 309 | this.setCount("resize_selection",""); 310 | break; 311 | case "next": 312 | case "prev": 313 | window.postMessage(["outpaint", "", event.target], "*"); 314 | break; 315 | case "outpaint": 316 | this.click("selection"); 317 | this.disable(...outpaint_button_lst); 318 | this.show(...outpaint_result_lst); 319 | this.disable("undo","redo"); 320 | if(this.outpaint_tip) 321 | { 322 | this.outpaint_tip=false; 323 | w2utils.notify("The canvas stays locked until you accept/cancel current outpainting. You can modify the 'sample number' to get multiple results; you can resize the canvas/selection with 'canvas setting'/'resize selection'; you can use 'photometric correction' to help preserve existing contents",{timeout:15000,where:query("#container")}) 324 | } 325 | document.querySelector("#container").style.pointerEvents="none"; 326 | case "retry": 327 | this.disable(...outpaint_result_func_lst); 328 | parent.config_obj["interrogate_mode"]=false; 329 | window.postMessage(["transfer",""],"*") 330 | break; 331 | case "interrogate": 332 | if(this.interrogate_tip) 333 | { 334 | this.interrogate_tip=false; 335 | w2utils.notify("ClipInterrogator v1 will be dynamically loaded when run at the first time, which may take a while",{timeout:10000,where:query("#container")}) 336 | } 337 | parent.config_obj["interrogate_mode"]=true; 338 | window.postMessage(["transfer",""],"*") 339 | break 340 | case "accept": 341 | case "cancel": 342 | this.hide(...outpaint_result_lst); 343 | this.disable(...outpaint_result_func_lst); 344 | this.enable(...outpaint_button_lst); 345 | document.querySelector("#container").style.pointerEvents="auto"; 346 | if(this.config_obj.enable_history) 347 | { 348 | window.postMessage(["click", event.target, ""],"*"); 349 | } 350 | else 351 | { 352 | window.postMessage(["click", event.target],"*"); 353 | } 354 | let app=parent.document.querySelector("gradio-app"); 355 | app=app.shadowRoot??app; 356 | app.querySelector("#cancel").click(); 357 | window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo); 358 | break; 359 | case "eraser": 360 | case "selection": 361 | case "canvas": 362 | if(event.target=="eraser") 363 | { 364 | this.show("eraser_size","eraser_size_btn"); 365 | window.overlay.freeDrawingBrush.width=this.eraser_size; 366 | window.overlay.isDrawingMode = true; 367 | } 368 | else 369 | { 370 | this.hide("eraser_size","eraser_size_btn"); 371 | window.overlay.isDrawingMode = false; 372 | } 373 | if(this.upload_mode) 374 | { 375 | if(event.target=="canvas") 376 | { 377 | window.postMessage(["mode", event.target],"*") 378 | document.querySelector("#overlay_container").style.pointerEvents="none"; 379 | document.querySelector("#overlay_container").style.opacity = 0.5; 380 | } 381 | else 382 | { 383 | document.querySelector("#overlay_container").style.pointerEvents="auto"; 384 | document.querySelector("#overlay_container").style.opacity = 1.0; 385 | } 386 | } 387 | else 388 | { 389 | window.postMessage(["mode", event.target],"*") 390 | } 391 | break; 392 | case "help": 393 | w2popup.open({ 394 | title: "Document", 395 | body: "Usage: https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md" 396 | }) 397 | break; 398 | case "clear": 399 | w2confirm("Reset canvas?").yes(() => { 400 | window.postMessage(["click", event.target],"*"); 401 | }).no(() => {}) 402 | break; 403 | case "random_seed": 404 | this.config_obj.seed_val=Math.floor(Math.random() * 3000000000); 405 | parent.config_obj=this.config_obj; 406 | this.refresh(); 407 | break; 408 | case "enable_history": 409 | case "enable_img2img": 410 | case "use_correction": 411 | case "resize_check": 412 | case "enable_safety": 413 | case "use_seed": 414 | case "square_selection": 415 | let target=this.get(event.target); 416 | if(event.target=="enable_history") 417 | { 418 | if(!target.checked) 419 | { 420 | w2utils.notify("Enable canvas history might increase resource usage / slow down the canvas ", {error:true,timeout:3000,where:query("#container")}) 421 | window.postMessage(["click","history"],"*"); 422 | } 423 | else 424 | { 425 | window.undo_redo_state.undo=false; 426 | window.undo_redo_state.redo=false; 427 | this.disable("undo","redo"); 428 | } 429 | } 430 | target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check"; 431 | this.config_obj[event.target]=!target.checked; 432 | parent.config_obj=this.config_obj; 433 | this.refresh(); 434 | break; 435 | case "save": 436 | case "export": 437 | ask_filename(event.target); 438 | break; 439 | default: 440 | // clear, save, export, outpaint, retry 441 | // break, save, export, accept, retry, outpaint 442 | window.postMessage(["click", event.target],"*") 443 | } 444 | console.log("Target: "+ event.target, event) 445 | } 446 | }) 447 | window.w2ui=w2ui; 448 | w2ui.toolbar.config_obj={ 449 | resize_check: true, 450 | enable_safety: true, 451 | use_correction: false, 452 | enable_img2img: false, 453 | use_seed: false, 454 | seed_val: 0, 455 | square_selection: false, 456 | enable_history: false, 457 | }; 458 | w2ui.toolbar.outpaint_tip=true; 459 | w2ui.toolbar.upload_tip=true; 460 | w2ui.toolbar.interrogate_tip=true; 461 | window.update_count=function(cur,total){ 462 | w2ui.toolbar.sel_value=`${cur}/${total}`; 463 | w2ui.toolbar.refresh(); 464 | } 465 | window.update_eraser=function(val,max_val){ 466 | w2ui.toolbar.eraser_size=`${val}`; 467 | w2ui.toolbar.eraser_max=`${max_val}`; 468 | w2ui.toolbar.setCount("eraser_size_btn", `${val}`); 469 | w2ui.toolbar.refresh(); 470 | } 471 | window.update_scale=function(val){ 472 | w2ui.toolbar.scale_value=`${val}`; 473 | w2ui.toolbar.refresh(); 474 | } 475 | window.enable_result_lst=function(){ 476 | w2ui.toolbar.enable(...outpaint_result_lst); 477 | } 478 | function onObjectScaled(e) 479 | { 480 | let object = e.target; 481 | if(object.isType("rect")) 482 | { 483 | let width=object.getScaledWidth(); 484 | let height=object.getScaledHeight(); 485 | object.scale(1); 486 | width=Math.max(Math.min(width,window.overlay.width-object.left),256); 487 | height=Math.max(Math.min(height,window.overlay.height-object.top),256); 488 | let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0); 489 | let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0); 490 | if(window.w2ui.toolbar.config_obj.square_selection) 491 | { 492 | let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height); 493 | width=max_val; 494 | height=max_val; 495 | } 496 | object.set({ width: width, height: height, left:l,top:t}) 497 | window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top}; 498 | window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`); 499 | window.w2ui.toolbar.refresh(); 500 | } 501 | } 502 | function onObjectMoved(e) 503 | { 504 | let object = e.target; 505 | if(object.isType("rect")) 506 | { 507 | let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0); 508 | let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0); 509 | object.set({left:l,top:t}); 510 | window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top}; 511 | } 512 | } 513 | window.setup_overlay=function(width,height) 514 | { 515 | if(window.overlay) 516 | { 517 | window.overlay.setDimensions({width:width,height:height}); 518 | let app=parent.document.querySelector("gradio-app"); 519 | app=app.shadowRoot??app; 520 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px"; 521 | document.querySelector("#container").style.height= height+"px"; 522 | document.querySelector("#container").style.width = width+"px"; 523 | } 524 | else 525 | { 526 | canvas=new fabric.Canvas("overlay_canvas"); 527 | canvas.setDimensions({width:width,height:height}); 528 | let app=parent.document.querySelector("gradio-app"); 529 | app=app.shadowRoot??app; 530 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px"; 531 | canvas.freeDrawingBrush = new fabric.EraserBrush(canvas); 532 | canvas.on("object:scaling", onObjectScaled); 533 | canvas.on("object:moving", onObjectMoved); 534 | window.overlay=canvas; 535 | } 536 | document.querySelector("#overlay_container").style.pointerEvents="none"; 537 | } 538 | window.update_overlay=function(width,height) 539 | { 540 | window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true}); 541 | // document.querySelector("#overlay_container").style.pointerEvents="none"; 542 | } 543 | window.adjust_selection=function(x,y,width,height) 544 | { 545 | var rect = new fabric.Rect({ 546 | left: x, 547 | top: y, 548 | fill: "rgba(0,0,0,0)", 549 | strokeWidth: 3, 550 | stroke: "rgba(0,0,0,0.7)", 551 | cornerColor: "red", 552 | cornerStrokeColor: "red", 553 | borderColor: "rgba(255, 0, 0, 1.0)", 554 | width: width, 555 | height: height, 556 | lockRotation: true, 557 | }); 558 | rect.setControlsVisibility({ mtr: false }); 559 | window.overlay.add(rect); 560 | window.overlay.setActiveObject(window.overlay.item(0)); 561 | window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y}; 562 | window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y}; 563 | } 564 | function add_image(url) 565 | { 566 | fabric.Image.fromURL(url,function(img){ 567 | window.overlay.add(img); 568 | window.overlay.setActiveObject(img); 569 | },{left:100,top:100}); 570 | } 571 | function export_image() 572 | { 573 | data=window.overlay.toDataURL(); 574 | document.querySelector("#upload_content").value=data; 575 | if(window.w2ui.toolbar.config_obj.enable_history) 576 | { 577 | window.postMessage(["upload","",""],"*"); 578 | window.w2ui.toolbar.enable("undo"); 579 | window.w2ui.toolbar.disable("redo"); 580 | } 581 | else 582 | { 583 | window.postMessage(["upload",""],"*"); 584 | } 585 | end_overlay(); 586 | } 587 | function end_overlay() 588 | { 589 | window.overlay.clear(); 590 | document.querySelector("#overlay_container").style.opacity = 1.0; 591 | document.querySelector("#overlay_container").style.pointerEvents="none"; 592 | } 593 | function ask_filename(target) 594 | { 595 | w2prompt({ 596 | label: "Enter filename", 597 | value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`, 598 | }) 599 | .change((event) => { 600 | console.log("change", event.detail.originalEvent.target.value); 601 | }) 602 | .ok((event) => { 603 | console.log("value=", event.detail.value); 604 | window.postMessage(["click",target,event.detail.value],"*"); 605 | }) 606 | .cancel((event) => { 607 | console.log("cancel"); 608 | }); 609 | } 610 | 611 | document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()}) 612 | window.setup_shortcut=function(json) 613 | { 614 | var config=JSON.parse(json); 615 | var key_map={}; 616 | Object.keys(config.shortcut).forEach(k=>{ 617 | key_map[config.shortcut[k]]=k; 618 | }) 619 | document.addEventListener("keydown",(e)=>{ 620 | if(e.target.tagName!="INPUT") 621 | { 622 | let key=e.key; 623 | if(e.ctrlKey) 624 | { 625 | key="Ctrl+"+e.key; 626 | if(key in key_map) 627 | { 628 | e.preventDefault(); 629 | } 630 | } 631 | if(key in key_map) 632 | { 633 | w2ui.toolbar.click(key_map[key]); 634 | } 635 | } 636 | }) 637 | } 638 | window.undo_redo_state={undo:false,redo:false}; 639 | window.update_undo_redo=function(s0,s1) 640 | { 641 | if(s0) 642 | { 643 | w2ui.toolbar.enable("undo"); 644 | } 645 | else 646 | { 647 | w2ui.toolbar.disable("undo"); 648 | } 649 | if(s1) 650 | { 651 | w2ui.toolbar.enable("redo"); 652 | } 653 | else 654 | { 655 | w2ui.toolbar.disable("redo"); 656 | } 657 | window.undo_redo_state.undo=s0; 658 | window.undo_redo_state.redo=s1; 659 | } -------------------------------------------------------------------------------- /js/upload.js: -------------------------------------------------------------------------------- 1 | function(a,b){ 2 | if(!window.my_observe_upload) 3 | { 4 | console.log("setup upload here"); 5 | window.my_observe_upload = new MutationObserver(function (event) { 6 | console.log(event); 7 | var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document; 8 | frame.querySelector("#upload").click(); 9 | }); 10 | window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span"); 11 | window.my_observe_upload.observe(window.my_observe_upload_target, { 12 | attributes: false, 13 | subtree: true, 14 | childList: true, 15 | characterData: true 16 | }); 17 | } 18 | return [a,b]; 19 | } -------------------------------------------------------------------------------- /js/xss.js: -------------------------------------------------------------------------------- 1 | var setup_outpaint=function(){ 2 | if(!window.my_observe_outpaint) 3 | { 4 | console.log("setup outpaint here"); 5 | window.my_observe_outpaint = new MutationObserver(function (event) { 6 | console.log(event); 7 | let app=document.querySelector("gradio-app"); 8 | app=app.shadowRoot??app; 9 | let frame=app.querySelector("#sdinfframe").contentWindow; 10 | frame.postMessage(["outpaint", ""], "*"); 11 | }); 12 | var app=document.querySelector("gradio-app"); 13 | app=app.shadowRoot??app; 14 | window.my_observe_outpaint_target=app.querySelector("#output span"); 15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, { 16 | attributes: false, 17 | subtree: true, 18 | childList: true, 19 | characterData: true 20 | }); 21 | } 22 | }; 23 | window.config_obj={ 24 | resize_check: true, 25 | enable_safety: true, 26 | use_correction: false, 27 | enable_img2img: false, 28 | use_seed: false, 29 | seed_val: 0, 30 | interrogate_mode: false, 31 | }; 32 | setup_outpaint(); -------------------------------------------------------------------------------- /models/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /models/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /perlin2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ########## 4 | # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921 5 | def perlin(x, y, seed=0): 6 | # permutation table 7 | np.random.seed(seed) 8 | p = np.arange(256, dtype=int) 9 | np.random.shuffle(p) 10 | p = np.stack([p, p]).flatten() 11 | # coordinates of the top-left 12 | xi, yi = x.astype(int), y.astype(int) 13 | # internal coordinates 14 | xf, yf = x - xi, y - yi 15 | # fade factors 16 | u, v = fade(xf), fade(yf) 17 | # noise components 18 | n00 = gradient(p[p[xi] + yi], xf, yf) 19 | n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1) 20 | n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1) 21 | n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf) 22 | # combine noises 23 | x1 = lerp(n00, n10, u) 24 | x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01 25 | return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here 26 | 27 | 28 | def lerp(a, b, x): 29 | "linear interpolation" 30 | return a + x * (b - a) 31 | 32 | 33 | def fade(t): 34 | "6t^5 - 15t^4 + 10t^3" 35 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3 36 | 37 | 38 | def gradient(h, x, y): 39 | "grad converts h to the right gradient vector and return the dot product with (x,y)" 40 | vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]]) 41 | g = vectors[h % 4] 42 | return g[:, :, 0] * x + g[:, :, 1] * y 43 | 44 | 45 | ########## -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 3 | MIT License 4 | 5 | Copyright (c) 2022 Jiayi Weng 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | 26 | import time 27 | import argparse 28 | import os 29 | import fpie 30 | from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND 31 | from fpie.io import read_images, write_image 32 | from process import BaseProcessor, EquProcessor, GridProcessor 33 | 34 | from PIL import Image 35 | import numpy as np 36 | import skimage 37 | import skimage.measure 38 | import scipy 39 | import scipy.signal 40 | 41 | 42 | class PhotometricCorrection: 43 | def __init__(self,quite=False): 44 | self.get_parser("cli") 45 | args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"]) 46 | args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0) 47 | self.backend=args.backend 48 | self.args=args 49 | self.quite=quite 50 | proc: BaseProcessor 51 | proc = GridProcessor( 52 | args.gradient, 53 | args.backend, 54 | args.cpu, 55 | args.mpi_sync_interval, 56 | args.block_size, 57 | args.grid_x, 58 | args.grid_y, 59 | ) 60 | print( 61 | f"[PIE]Successfully initialize PIE {args.method} solver " 62 | f"with {args.backend} backend" 63 | ) 64 | self.proc=proc 65 | 66 | def run(self, original_image, inpainted_image, mode="mask_mode"): 67 | print(f"[PIE] start") 68 | if mode=="disabled": 69 | return inpainted_image 70 | input_arr=np.array(original_image) 71 | if input_arr[:,:,-1].sum()<1: 72 | return inpainted_image 73 | output_arr=np.array(inpainted_image) 74 | mask=input_arr[:,:,-1] 75 | mask=255-mask 76 | if mask.sum()<1 and mode=="mask_mode": 77 | mode="" 78 | if mode=="mask_mode": 79 | mask = skimage.measure.block_reduce(mask, (8, 8), np.max) 80 | mask = mask.repeat(8, axis=0).repeat(8, axis=1) 81 | else: 82 | mask[8:-9,8:-9]=255 83 | mask = mask[:,:,np.newaxis].repeat(3,axis=2) 84 | nmask=mask.copy() 85 | output_arr2=output_arr[:,:,0:3].copy() 86 | input_arr2=input_arr[:,:,0:3].copy() 87 | output_arr2[nmask<128]=0 88 | input_arr2[nmask>=128]=0 89 | output_arr2+=input_arr2 90 | src = output_arr2[:,:,0:3] 91 | tgt = src.copy() 92 | proc=self.proc 93 | args=self.args 94 | if proc.root: 95 | n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1)) 96 | proc.sync() 97 | if proc.root: 98 | result = tgt 99 | t = time.time() 100 | if args.p == 0: 101 | args.p = args.n 102 | 103 | for i in range(0, args.n, args.p): 104 | if proc.root: 105 | result, err = proc.step(args.p) # type: ignore 106 | print(f"[PIE] Iter {i + args.p}, abs_err {err}") 107 | else: 108 | proc.step(args.p) 109 | 110 | if proc.root: 111 | dt = time.time() - t 112 | print(f"[PIE] Time elapsed: {dt:.4f}s") 113 | # make sure consistent with dummy process 114 | return Image.fromarray(result) 115 | 116 | 117 | def get_parser(self,gen_type: str) -> argparse.Namespace: 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument( 120 | "-v", "--version", action="store_true", help="show the version and exit" 121 | ) 122 | parser.add_argument( 123 | "--check-backend", action="store_true", help="print all available backends" 124 | ) 125 | if gen_type == "gui" and "mpi" in ALL_BACKEND: 126 | # gui doesn't support MPI backend 127 | ALL_BACKEND.remove("mpi") 128 | parser.add_argument( 129 | "-b", 130 | "--backend", 131 | type=str, 132 | choices=ALL_BACKEND, 133 | default=DEFAULT_BACKEND, 134 | help="backend choice", 135 | ) 136 | parser.add_argument( 137 | "-c", 138 | "--cpu", 139 | type=int, 140 | default=CPU_COUNT, 141 | help="number of CPU used", 142 | ) 143 | parser.add_argument( 144 | "-z", 145 | "--block-size", 146 | type=int, 147 | default=1024, 148 | help="cuda block size (only for equ solver)", 149 | ) 150 | parser.add_argument( 151 | "--method", 152 | type=str, 153 | choices=["equ", "grid"], 154 | default="equ", 155 | help="how to parallelize computation", 156 | ) 157 | parser.add_argument("-s", "--source", type=str, help="source image filename") 158 | if gen_type == "cli": 159 | parser.add_argument( 160 | "-m", 161 | "--mask", 162 | type=str, 163 | help="mask image filename (default is to use the whole source image)", 164 | default="", 165 | ) 166 | parser.add_argument("-t", "--target", type=str, help="target image filename") 167 | parser.add_argument("-o", "--output", type=str, help="output image filename") 168 | if gen_type == "cli": 169 | parser.add_argument( 170 | "-h0", type=int, help="mask position (height) on source image", default=0 171 | ) 172 | parser.add_argument( 173 | "-w0", type=int, help="mask position (width) on source image", default=0 174 | ) 175 | parser.add_argument( 176 | "-h1", type=int, help="mask position (height) on target image", default=0 177 | ) 178 | parser.add_argument( 179 | "-w1", type=int, help="mask position (width) on target image", default=0 180 | ) 181 | parser.add_argument( 182 | "-g", 183 | "--gradient", 184 | type=str, 185 | choices=["max", "src", "avg"], 186 | default="max", 187 | help="how to calculate gradient for PIE", 188 | ) 189 | parser.add_argument( 190 | "-n", 191 | type=int, 192 | help="how many iteration would you perfer, the more the better", 193 | default=5000, 194 | ) 195 | if gen_type == "cli": 196 | parser.add_argument( 197 | "-p", type=int, help="output result every P iteration", default=0 198 | ) 199 | if "mpi" in ALL_BACKEND: 200 | parser.add_argument( 201 | "--mpi-sync-interval", 202 | type=int, 203 | help="MPI sync iteration interval", 204 | default=100, 205 | ) 206 | parser.add_argument( 207 | "--grid-x", type=int, help="x axis stride for grid solver", default=8 208 | ) 209 | parser.add_argument( 210 | "--grid-y", type=int, help="y axis stride for grid solver", default=8 211 | ) 212 | self.parser=parser 213 | 214 | if __name__ =="__main__": 215 | import sys 216 | import io 217 | import base64 218 | from PIL import Image 219 | def base64_to_pil(base64_str): 220 | data = base64.b64decode(str(base64_str)) 221 | pil = Image.open(io.BytesIO(data)) 222 | return pil 223 | 224 | def pil_to_base64(out_pil): 225 | out_buffer = io.BytesIO() 226 | out_pil.save(out_buffer, format="PNG") 227 | out_buffer.seek(0) 228 | base64_bytes = base64.b64encode(out_buffer.read()) 229 | base64_str = base64_bytes.decode("ascii") 230 | return base64_str 231 | correction_func=PhotometricCorrection(quite=True) 232 | while True: 233 | buffer = sys.stdin.readline() 234 | print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ") 235 | if len(buffer)==0: 236 | break 237 | if isinstance(buffer,str): 238 | lst=buffer.strip().split(",") 239 | else: 240 | lst=buffer.decode("ascii").strip().split(",") 241 | img0=base64_to_pil(lst[0]) 242 | img1=base64_to_pil(lst[1]) 243 | ret=correction_func.run(img0,img1,mode=lst[2]) 244 | ret_base64=pil_to_base64(ret) 245 | if isinstance(buffer,str): 246 | sys.stdout.write(f"{ret_base64}\n") 247 | else: 248 | sys.stdout.write(f"{ret_base64}\n".encode()) 249 | sys.stdout.flush() -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 3 | MIT License 4 | 5 | Copyright (c) 2022 Jiayi Weng 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | """ 25 | import os 26 | from abc import ABC, abstractmethod 27 | from typing import Any, Optional, Tuple 28 | 29 | import numpy as np 30 | 31 | from fpie import np_solver 32 | 33 | import scipy 34 | import scipy.signal 35 | 36 | CPU_COUNT = os.cpu_count() or 1 37 | DEFAULT_BACKEND = "numpy" 38 | ALL_BACKEND = ["numpy"] 39 | 40 | try: 41 | from fpie import numba_solver 42 | ALL_BACKEND += ["numba"] 43 | DEFAULT_BACKEND = "numba" 44 | except ImportError: 45 | numba_solver = None # type: ignore 46 | 47 | try: 48 | from fpie import taichi_solver 49 | ALL_BACKEND += ["taichi-cpu", "taichi-gpu"] 50 | DEFAULT_BACKEND = "taichi-cpu" 51 | except ImportError: 52 | taichi_solver = None # type: ignore 53 | 54 | # try: 55 | # from fpie import core_gcc # type: ignore 56 | # DEFAULT_BACKEND = "gcc" 57 | # ALL_BACKEND.append("gcc") 58 | # except ImportError: 59 | # core_gcc = None 60 | 61 | # try: 62 | # from fpie import core_openmp # type: ignore 63 | # DEFAULT_BACKEND = "openmp" 64 | # ALL_BACKEND.append("openmp") 65 | # except ImportError: 66 | # core_openmp = None 67 | 68 | # try: 69 | # from mpi4py import MPI 70 | 71 | # from fpie import core_mpi # type: ignore 72 | # ALL_BACKEND.append("mpi") 73 | # except ImportError: 74 | # MPI = None # type: ignore 75 | # core_mpi = None 76 | 77 | try: 78 | from fpie import core_cuda # type: ignore 79 | DEFAULT_BACKEND = "cuda" 80 | ALL_BACKEND.append("cuda") 81 | except ImportError: 82 | core_cuda = None 83 | 84 | 85 | class BaseProcessor(ABC): 86 | """API definition for processor class.""" 87 | 88 | def __init__( 89 | self, gradient: str, rank: int, backend: str, core: Optional[Any] 90 | ): 91 | if core is None: 92 | error_msg = { 93 | "numpy": 94 | "Please run `pip install numpy`.", 95 | "numba": 96 | "Please run `pip install numba`.", 97 | "gcc": 98 | "Please install cmake and gcc in your operating system.", 99 | "openmp": 100 | "Please make sure your gcc is compatible with `-fopenmp` option.", 101 | "mpi": 102 | "Please install MPI and run `pip install mpi4py`.", 103 | "cuda": 104 | "Please make sure nvcc and cuda-related libraries are available.", 105 | "taichi": 106 | "Please run `pip install taichi`.", 107 | } 108 | print(error_msg[backend.split("-")[0]]) 109 | 110 | raise AssertionError(f"Invalid backend {backend}.") 111 | 112 | self.gradient = gradient 113 | self.rank = rank 114 | self.backend = backend 115 | self.core = core 116 | self.root = rank == 0 117 | 118 | def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: 119 | if self.gradient == "src": 120 | return a 121 | if self.gradient == "avg": 122 | return (a + b) / 2 123 | # mix gradient, see Equ. 12 in PIE paper 124 | mask = np.abs(a) < np.abs(b) 125 | a[mask] = b[mask] 126 | return a 127 | 128 | @abstractmethod 129 | def reset( 130 | self, 131 | src: np.ndarray, 132 | mask: np.ndarray, 133 | tgt: np.ndarray, 134 | mask_on_src: Tuple[int, int], 135 | mask_on_tgt: Tuple[int, int], 136 | ) -> int: 137 | pass 138 | 139 | def sync(self) -> None: 140 | self.core.sync() 141 | 142 | @abstractmethod 143 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 144 | pass 145 | 146 | 147 | class EquProcessor(BaseProcessor): 148 | """PIE Jacobi equation processor.""" 149 | 150 | def __init__( 151 | self, 152 | gradient: str = "max", 153 | backend: str = DEFAULT_BACKEND, 154 | n_cpu: int = CPU_COUNT, 155 | min_interval: int = 100, 156 | block_size: int = 1024, 157 | ): 158 | core: Optional[Any] = None 159 | rank = 0 160 | 161 | if backend == "numpy": 162 | core = np_solver.EquSolver() 163 | elif backend == "numba" and numba_solver is not None: 164 | core = numba_solver.EquSolver() 165 | elif backend == "gcc": 166 | core = core_gcc.EquSolver() 167 | elif backend == "openmp" and core_openmp is not None: 168 | core = core_openmp.EquSolver(n_cpu) 169 | elif backend == "mpi" and core_mpi is not None: 170 | core = core_mpi.EquSolver(min_interval) 171 | rank = MPI.COMM_WORLD.Get_rank() 172 | elif backend == "cuda" and core_cuda is not None: 173 | core = core_cuda.EquSolver(block_size) 174 | elif backend.startswith("taichi") and taichi_solver is not None: 175 | core = taichi_solver.EquSolver(backend, n_cpu, block_size) 176 | 177 | super().__init__(gradient, rank, backend, core) 178 | 179 | def mask2index( 180 | self, mask: np.ndarray 181 | ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]: 182 | x, y = np.nonzero(mask) 183 | max_id = x.shape[0] + 1 184 | index = np.zeros((max_id, 3)) 185 | ids = self.core.partition(mask) 186 | ids[mask == 0] = 0 # reserve id=0 for constant 187 | index = ids[x, y].argsort() 188 | return ids, max_id, x[index], y[index] 189 | 190 | def reset( 191 | self, 192 | src: np.ndarray, 193 | mask: np.ndarray, 194 | tgt: np.ndarray, 195 | mask_on_src: Tuple[int, int], 196 | mask_on_tgt: Tuple[int, int], 197 | ) -> int: 198 | assert self.root 199 | # check validity 200 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1] 201 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0] 202 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1] 203 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0] 204 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1] 205 | 206 | if len(mask.shape) == 3: 207 | mask = mask.mean(-1) 208 | mask = (mask >= 128).astype(np.int32) 209 | 210 | # zero-out edge 211 | mask[0] = 0 212 | mask[-1] = 0 213 | mask[:, 0] = 0 214 | mask[:, -1] = 0 215 | 216 | x, y = np.nonzero(mask) 217 | x0, x1 = x.min() - 1, x.max() + 2 218 | y0, y1 = y.min() - 1, y.max() + 2 219 | mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1]) 220 | mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1]) 221 | mask = mask[x0:x1, y0:y1] 222 | ids, max_id, index_x, index_y = self.mask2index(mask) 223 | 224 | src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1] 225 | tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] 226 | 227 | src_C = src[src_x, src_y].astype(np.float32) 228 | src_U = src[src_x - 1, src_y].astype(np.float32) 229 | src_D = src[src_x + 1, src_y].astype(np.float32) 230 | src_L = src[src_x, src_y - 1].astype(np.float32) 231 | src_R = src[src_x, src_y + 1].astype(np.float32) 232 | tgt_C = tgt[tgt_x, tgt_y].astype(np.float32) 233 | tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32) 234 | tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32) 235 | tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32) 236 | tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32) 237 | 238 | grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \ 239 | + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \ 240 | + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \ 241 | + self.mixgrad(src_C - src_D, tgt_C - tgt_D) 242 | 243 | A = np.zeros((max_id, 4), np.int32) 244 | X = np.zeros((max_id, 3), np.float32) 245 | B = np.zeros((max_id, 3), np.float32) 246 | 247 | X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]] 248 | # four-way 249 | A[1:, 0] = ids[index_x - 1, index_y] 250 | A[1:, 1] = ids[index_x + 1, index_y] 251 | A[1:, 2] = ids[index_x, index_y - 1] 252 | A[1:, 3] = ids[index_x, index_y + 1] 253 | B[1:] = grad 254 | m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1) 255 | B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]] 256 | m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1) 257 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1] 258 | m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1) 259 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1] 260 | m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1) 261 | B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]] 262 | 263 | self.tgt = tgt.copy() 264 | self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]) 265 | self.core.reset(max_id, A, X, B) 266 | return max_id 267 | 268 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 269 | result = self.core.step(iteration) 270 | if self.root: 271 | x, err = result 272 | self.tgt[self.tgt_index] = x[1:] 273 | return self.tgt, err 274 | return None 275 | 276 | 277 | class GridProcessor(BaseProcessor): 278 | """PIE grid processor.""" 279 | 280 | def __init__( 281 | self, 282 | gradient: str = "max", 283 | backend: str = DEFAULT_BACKEND, 284 | n_cpu: int = CPU_COUNT, 285 | min_interval: int = 100, 286 | block_size: int = 1024, 287 | grid_x: int = 8, 288 | grid_y: int = 8, 289 | ): 290 | core: Optional[Any] = None 291 | rank = 0 292 | 293 | if backend == "numpy": 294 | core = np_solver.GridSolver() 295 | elif backend == "numba" and numba_solver is not None: 296 | core = numba_solver.GridSolver() 297 | elif backend == "gcc": 298 | core = core_gcc.GridSolver(grid_x, grid_y) 299 | elif backend == "openmp" and core_openmp is not None: 300 | core = core_openmp.GridSolver(grid_x, grid_y, n_cpu) 301 | elif backend == "mpi" and core_mpi is not None: 302 | core = core_mpi.GridSolver(min_interval) 303 | rank = MPI.COMM_WORLD.Get_rank() 304 | elif backend == "cuda" and core_cuda is not None: 305 | core = core_cuda.GridSolver(grid_x, grid_y) 306 | elif backend.startswith("taichi") and taichi_solver is not None: 307 | core = taichi_solver.GridSolver( 308 | grid_x, grid_y, backend, n_cpu, block_size 309 | ) 310 | 311 | super().__init__(gradient, rank, backend, core) 312 | 313 | def reset( 314 | self, 315 | src: np.ndarray, 316 | mask: np.ndarray, 317 | tgt: np.ndarray, 318 | mask_on_src: Tuple[int, int], 319 | mask_on_tgt: Tuple[int, int], 320 | ) -> int: 321 | assert self.root 322 | # check validity 323 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1] 324 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0] 325 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1] 326 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0] 327 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1] 328 | 329 | if len(mask.shape) == 3: 330 | mask = mask.mean(-1) 331 | mask = (mask >= 128).astype(np.int32) 332 | 333 | # zero-out edge 334 | mask[0] = 0 335 | mask[-1] = 0 336 | mask[:, 0] = 0 337 | mask[:, -1] = 0 338 | 339 | x, y = np.nonzero(mask) 340 | x0, x1 = x.min() - 1, x.max() + 2 341 | y0, y1 = y.min() - 1, y.max() + 2 342 | mask = mask[x0:x1, y0:y1] 343 | max_id = np.prod(mask.shape) 344 | 345 | src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1, 346 | mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32) 347 | tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1, 348 | mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32) 349 | grad = np.zeros([*mask.shape, 3], np.float32) 350 | grad[1:] += self.mixgrad( 351 | src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1] 352 | ) 353 | grad[:-1] += self.mixgrad( 354 | src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:] 355 | ) 356 | grad[:, 1:] += self.mixgrad( 357 | src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1] 358 | ) 359 | grad[:, :-1] += self.mixgrad( 360 | src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:] 361 | ) 362 | 363 | grad[mask == 0] = 0 364 | if True: 365 | kernel = [[1] * 3 for _ in range(3)] 366 | nmask = mask.copy() 367 | nmask[nmask > 0] = 1 368 | res = scipy.signal.convolve2d( 369 | nmask, kernel, mode="same", boundary="fill", fillvalue=1 370 | ) 371 | res[nmask < 1] = 0 372 | res[res == 9] = 0 373 | res[res > 0] = 1 374 | grad[res>0]=0 375 | # ylst, xlst = res.nonzero() 376 | # for y, x in zip(ylst, xlst): 377 | # grad[y,x]=0 378 | # for yi in range(-1,2): 379 | # for xi in range(-1,2): 380 | # grad[y+yi,x+xi]=0 381 | self.x0 = mask_on_tgt[0] + x0 382 | self.x1 = mask_on_tgt[0] + x1 383 | self.y0 = mask_on_tgt[1] + y0 384 | self.y1 = mask_on_tgt[1] + y1 385 | self.tgt = tgt.copy() 386 | self.core.reset(max_id, mask, tgt_crop, grad) 387 | return max_id 388 | 389 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]: 390 | result = self.core.step(iteration) 391 | if self.root: 392 | tgt, err = result 393 | self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt 394 | return self.tgt, err 395 | return None 396 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # stablediffusion-infinity 2 | 3 | Outpainting with Stable Diffusion on an infinite canvas. 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb) 6 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) 7 | [![Setup Locally](https://img.shields.io/badge/%F0%9F%96%A5%EF%B8%8F%20Setup-Locally-blue)](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md) 8 | 9 | ![outpaint](https://user-images.githubusercontent.com/1665437/197257616-82c1e58f-7463-4896-8345-6750a828c844.png) 10 | 11 | https://user-images.githubusercontent.com/1665437/197244111-51884b3b-dffe-4dcf-a82a-fa5117c79934.mp4 12 | 13 | ## Status 14 | 15 | Powered by Stable Diffusion inpainting model, this project now works well. However, the quality of results is still not guaranteed. 16 | You may need to do prompt engineering, change the size of the selection, reduce the size of the outpainting region to get better outpainting results. 17 | 18 | The project now becomes a web app based on PyScript and Gradio. For Jupyter Notebook version, please check out the [ipycanvas](https://github.com/lkwq007/stablediffusion-infinity/tree/ipycanvas) branch. 19 | 20 | Pull requests are welcome for better UI control, ideas to achieve better results, or any other improvements. 21 | 22 | Update: the project add photometric correction to suppress seams, to use this feature, you need to install [fpie](https://github.com/Trinkle23897/Fast-Poisson-Image-Editing): `pip install fpie` (Linux/MacOS only) 23 | 24 | ## Docs 25 | 26 | ### Get Started 27 | 28 | - Setup for Windows: [setup_guide](./docs/setup_guide.md#windows) 29 | - Setup for Linux: [setup_guide](./docs/setup_guide.md#linux) 30 | - Setup for MacOS: [setup_guide](./docs/setup_guide.md#macos) 31 | - Running with Docker on Windows or Linux with NVIDIA GPU: [run_with_docker](./docs/run_with_docker.md) 32 | - Usages: [usage](./docs/usage.md) 33 | 34 | ### FAQs 35 | 36 | - The result is a black square: 37 | - False positive rate of safety checker is relatively high, you may disable the safety_checker 38 | - Some GPUs might not work with `fp16`: `python app.py --fp32 --lowvram` 39 | - What is the init_mode 40 | - init_mode indicates how to fill the empty/masked region, usually `patch_match` is better than others 41 | - Why not use `postMessage` for iframe interaction 42 | - The iframe and the gradio are in the same origin. For `postMessage` version, check out [gradio-space](https://github.com/lkwq007/stablediffusion-infinity/tree/gradio-space) version 43 | 44 | ### Known issues 45 | 46 | - The canvas is implemented with `NumPy` + `PyScript` (the project was originally implemented with `ipycanvas` inside a jupyter notebook), which is relatively inefficient compared with pure frontend solutions. 47 | - By design, the canvas is infinite. However, the canvas size is **finite** in practice. Your RAM and browser limit the canvas size. The canvas might crash or behave strangely when zoomed out by a certain scale. 48 | - The canvas requires internet: You can deploy and serve PyScript, Pyodide, and other JS/CSS assets with a local HTTP server and modify `index.html` accordingly. 49 | - Photometric correction might not work (`taichi` does not support the multithreading environment). A dirty hack (quite unreliable) is implemented to move related computation inside a subprocess. 50 | - Stable Diffusion inpainting model is much slower when selection size is larger than 512x512 51 | 52 | ## Credit 53 | 54 | The code of `perlin2d.py` is from https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921 and is **not** included in the scope of LICENSE used in this repo. 55 | 56 | The submodule `glid_3_xl_stable` is based on https://github.com/Jack000/glid-3-xl-stable 57 | 58 | The submodule `PyPatchMatch` is based on https://github.com/vacancy/PyPatchMatch 59 | 60 | The code of `postprocess.py` and `process.py` is modified based on https://github.com/Trinkle23897/Fast-Poisson-Image-Editing 61 | 62 | The code of `convert_checkpoint.py` is modified based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py 63 | 64 | The submodule `sd_grpcserver` and `handleImageAdjustment()` in `utils.py` are based on https://github.com/hafriedlander/stable-diffusion-grpcserver and https://github.com/parlance-zz/g-diffuser-bot 65 | 66 | `w2ui.min.js` and `w2ui.min.css` is from https://github.com/vitmalina/w2ui. `fabric.min.js` is a custom build of https://github.com/fabricjs/fabric.js 67 | 68 | `interrogate.py` is based on https://github.com/pharmapsychotic/clip-interrogator v1, the submodule `blip_model` is based on https://github.com/salesforce/BLIP 69 | -------------------------------------------------------------------------------- /stablediffusion_infinity_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "source": [ 22 | "# stablediffusion-infinity\n", 23 | "\n", 24 | "https://github.com/lkwq007/stablediffusion-infinity\n", 25 | "\n", 26 | "Outpainting with Stable Diffusion on an infinite canvas" 27 | ], 28 | "metadata": { 29 | "id": "IgN1jqV_DemW" 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "JvbfNNSJDTW5" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "#@title setup libs\n", 41 | "!nvidia-smi -L\n", 42 | "!pip install -qq -U diffusers==0.11.1 transformers ftfy accelerate\n", 43 | "!pip install -q gradio==3.11.0\n", 44 | "!pip install -q fpie timm\n", 45 | "!pip uninstall taichi -y" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "source": [ 51 | "#@title setup stablediffusion-infinity\n", 52 | "!git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity\n", 53 | "%cd stablediffusion-infinity\n", 54 | "!cp -r PyPatchMatch/csrc .\n", 55 | "!cp PyPatchMatch/Makefile .\n", 56 | "!cp PyPatchMatch/Makefile_fallback .\n", 57 | "!cp PyPatchMatch/travis.sh .\n", 58 | "!cp PyPatchMatch/patch_match.py . " 59 | ], 60 | "metadata": { 61 | "id": "D1BDhQCJDilE" 62 | }, 63 | "execution_count": null, 64 | "outputs": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "source": [ 69 | "#@title start stablediffusion-infinity (first setup may takes about two minutes for downloading models)\n", 70 | "!python app.py --share" 71 | ], 72 | "metadata": { 73 | "id": "UGotC5ckDlmO" 74 | }, 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "source": [], 81 | "metadata": { 82 | "id": "R1-E07CMFZoj" 83 | }, 84 | "execution_count": null, 85 | "outputs": [] 86 | } 87 | ] 88 | } 89 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from PIL import ImageFilter 3 | import cv2 4 | import numpy as np 5 | import scipy 6 | import scipy.signal 7 | from scipy.spatial import cKDTree 8 | 9 | import os 10 | from perlin2d import * 11 | 12 | patch_match_compiled = True 13 | 14 | try: 15 | from PyPatchMatch import patch_match 16 | except Exception as e: 17 | try: 18 | import patch_match 19 | except Exception as e: 20 | patch_match_compiled = False 21 | 22 | try: 23 | patch_match 24 | except NameError: 25 | print("patch_match compiling failed, will fall back to edge_pad") 26 | patch_match_compiled = False 27 | 28 | 29 | 30 | 31 | def edge_pad(img, mask, mode=1): 32 | if mode == 0: 33 | nmask = mask.copy() 34 | nmask[nmask > 0] = 1 35 | res0 = 1 - nmask 36 | res1 = nmask 37 | p0 = np.stack(res0.nonzero(), axis=0).transpose() 38 | p1 = np.stack(res1.nonzero(), axis=0).transpose() 39 | min_dists, min_dist_idx = cKDTree(p1).query(p0, 1) 40 | loc = p1[min_dist_idx] 41 | for (a, b), (c, d) in zip(p0, loc): 42 | img[a, b] = img[c, d] 43 | elif mode == 1: 44 | record = {} 45 | kernel = [[1] * 3 for _ in range(3)] 46 | nmask = mask.copy() 47 | nmask[nmask > 0] = 1 48 | res = scipy.signal.convolve2d( 49 | nmask, kernel, mode="same", boundary="fill", fillvalue=1 50 | ) 51 | res[nmask < 1] = 0 52 | res[res == 9] = 0 53 | res[res > 0] = 1 54 | ylst, xlst = res.nonzero() 55 | queue = [(y, x) for y, x in zip(ylst, xlst)] 56 | # bfs here 57 | cnt = res.astype(np.float32) 58 | acc = img.astype(np.float32) 59 | step = 1 60 | h = acc.shape[0] 61 | w = acc.shape[1] 62 | offset = [(1, 0), (-1, 0), (0, 1), (0, -1)] 63 | while queue: 64 | target = [] 65 | for y, x in queue: 66 | val = acc[y][x] 67 | for yo, xo in offset: 68 | yn = y + yo 69 | xn = x + xo 70 | if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1: 71 | if record.get((yn, xn), step) == step: 72 | acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val 73 | cnt[yn][xn] += 1 74 | acc[yn][xn] /= cnt[yn][xn] 75 | if (yn, xn) not in record: 76 | record[(yn, xn)] = step 77 | target.append((yn, xn)) 78 | step += 1 79 | queue = target 80 | img = acc.astype(np.uint8) 81 | else: 82 | nmask = mask.copy() 83 | ylst, xlst = nmask.nonzero() 84 | yt, xt = ylst.min(), xlst.min() 85 | yb, xb = ylst.max(), xlst.max() 86 | content = img[yt : yb + 1, xt : xb + 1] 87 | img = np.pad( 88 | content, 89 | ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)), 90 | mode="edge", 91 | ) 92 | return img, mask 93 | 94 | 95 | def perlin_noise(img, mask): 96 | lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False) 97 | lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False) 98 | x, y = np.meshgrid(lin_x, lin_y) 99 | avg = img.mean(axis=0).mean(axis=0) 100 | # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)] 101 | noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)] 102 | noise = np.stack(noise, axis=-1) 103 | # mask=skimage.measure.block_reduce(mask,(8,8),np.min) 104 | # mask=mask.repeat(8, axis=0).repeat(8, axis=1) 105 | # mask_image=Image.fromarray(mask) 106 | # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4)) 107 | # mask=np.array(mask_image) 108 | nmask = mask.copy() 109 | # nmask=nmask/255.0 110 | nmask[mask > 0] = 1 111 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise 112 | # img=img.astype(np.uint8) 113 | return img, mask 114 | 115 | 116 | def gaussian_noise(img, mask): 117 | noise = np.random.randn(mask.shape[0], mask.shape[1], 3) 118 | noise = (noise + 1) / 2 * 255 119 | noise = noise.astype(np.uint8) 120 | nmask = mask.copy() 121 | nmask[mask > 0] = 1 122 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise 123 | return img, mask 124 | 125 | 126 | def cv2_telea(img, mask): 127 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA) 128 | return ret, mask 129 | 130 | 131 | def cv2_ns(img, mask): 132 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS) 133 | return ret, mask 134 | 135 | 136 | def patch_match_func(img, mask): 137 | ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3) 138 | return ret, mask 139 | 140 | 141 | def mean_fill(img, mask): 142 | avg = img.mean(axis=0).mean(axis=0) 143 | img[mask < 1] = avg 144 | return img, mask 145 | 146 | """ 147 | Apache-2.0 license 148 | https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py 149 | https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2 150 | _handleImageAdjustment 151 | """ 152 | try: 153 | from sd_grpcserver.sdgrpcserver import images 154 | import torch 155 | from math import sqrt 156 | def handleImageAdjustment(array, adjustments): 157 | tensor = images.fromPIL(Image.fromarray(array)) 158 | for adjustment in adjustments: 159 | which = adjustment[0] 160 | 161 | if which == "blur": 162 | sigma = adjustment[1] 163 | direction = adjustment[2] 164 | 165 | if direction == "DOWN" or direction == "UP": 166 | orig = tensor 167 | repeatCount=256 168 | sigma /= sqrt(repeatCount) 169 | 170 | for _ in range(repeatCount): 171 | tensor = images.gaussianblur(tensor, sigma) 172 | if direction == "DOWN": 173 | tensor = torch.minimum(tensor, orig) 174 | else: 175 | tensor = torch.maximum(tensor, orig) 176 | else: 177 | tensor = images.gaussianblur(tensor, adjustment.blur.sigma) 178 | elif which == "invert": 179 | tensor = images.invert(tensor) 180 | elif which == "levels": 181 | tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4]) 182 | elif which == "channels": 183 | tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a]) 184 | elif which == "rescale": 185 | self.unimp("Rescale") 186 | elif which == "crop": 187 | tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width) 188 | return np.array(images.toPIL(tensor)[0]) 189 | 190 | def g_diffuser(img,mask): 191 | adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]] 192 | mask=handleImageAdjustment(mask,adjustments) 193 | out_mask=handleImageAdjustment(mask,adjustments) 194 | return img, mask 195 | except: 196 | def g_diffuser(img,mask): 197 | return img,mask 198 | 199 | def dummy_fill(img,mask): 200 | return img,mask 201 | functbl = { 202 | "gaussian": gaussian_noise, 203 | "perlin": perlin_noise, 204 | "edge_pad": edge_pad, 205 | "patchmatch": patch_match_func if patch_match_compiled else edge_pad, 206 | "cv2_ns": cv2_ns, 207 | "cv2_telea": cv2_telea, 208 | "g_diffuser": g_diffuser, 209 | "g_diffuser_lib": dummy_fill, 210 | } 211 | 212 | try: 213 | from postprocess import PhotometricCorrection 214 | correction_func = PhotometricCorrection() 215 | except Exception as e: 216 | print(e, "so PhotometricCorrection is disabled") 217 | class DummyCorrection: 218 | def __init__(self): 219 | self.backend="" 220 | pass 221 | def run(self,a,b,**kwargs): 222 | return b 223 | correction_func=DummyCorrection() 224 | 225 | class DummyInterrogator: 226 | def __init__(self) -> None: 227 | pass 228 | def interrogate(self,pil): 229 | return "Interrogator init failed" 230 | 231 | if "taichi" in correction_func.backend: 232 | import sys 233 | import io 234 | import base64 235 | from PIL import Image 236 | def base64_to_pil(base64_str): 237 | data = base64.b64decode(str(base64_str)) 238 | pil = Image.open(io.BytesIO(data)) 239 | return pil 240 | 241 | def pil_to_base64(out_pil): 242 | out_buffer = io.BytesIO() 243 | out_pil.save(out_buffer, format="PNG") 244 | out_buffer.seek(0) 245 | base64_bytes = base64.b64encode(out_buffer.read()) 246 | base64_str = base64_bytes.decode("ascii") 247 | return base64_str 248 | from subprocess import Popen, PIPE, STDOUT 249 | class SubprocessCorrection: 250 | def __init__(self): 251 | self.backend=correction_func.backend 252 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT) 253 | def run(self,img_input,img_inpainted,mode): 254 | if mode=="disabled": 255 | return img_inpainted 256 | base64_str_input = pil_to_base64(img_input) 257 | base64_str_inpainted = pil_to_base64(img_inpainted) 258 | try: 259 | if self.child.poll(): 260 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT) 261 | self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode()) 262 | self.child.stdin.flush() 263 | out = self.child.stdout.readline() 264 | base64_str=out.decode().strip() 265 | while base64_str and base64_str[0]=="[": 266 | print(base64_str) 267 | out = self.child.stdout.readline() 268 | base64_str=out.decode().strip() 269 | ret=base64_to_pil(base64_str) 270 | except: 271 | print("[PIE] not working, photometric correction is disabled") 272 | ret=img_inpainted 273 | return ret 274 | correction_func = SubprocessCorrection() 275 | --------------------------------------------------------------------------------