├── .dockerignore
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── .gitmodules
├── LICENSE
├── app.py
├── canvas.py
├── config.yaml
├── convert_checkpoint.py
├── css
└── w2ui.min.css
├── docker-compose.yml
├── docker
├── Dockerfile
├── docker-run.sh
├── entrypoint.sh
├── opencv.pc
└── run-shell.sh
├── docs
├── run_with_docker.md
├── setup_guide.md
└── usage.md
├── environment.yml
├── index.html
├── interrogate.py
├── js
├── fabric.min.js
├── keyboard.js
├── mode.js
├── outpaint.js
├── proceed.js
├── setup.js
├── toolbar.js
├── upload.js
├── w2ui.min.js
└── xss.js
├── models
├── v1-inference.yaml
└── v1-inpainting-inference.yaml
├── perlin2d.py
├── postprocess.py
├── process.py
├── readme.md
├── stablediffusion_infinity_colab.ipynb
└── utils.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | .github/
3 | .git/
4 | docs/
5 | .dockerignore
6 | readme.md
7 | LICENSE
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[Bug]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 | For setup problems or dependencies problems, please post in Q&A in Discussions
13 |
14 | **To Reproduce**
15 | Steps to reproduce the behavior:
16 | 1. Go to '...'
17 | 2. Click on '....'
18 | 3. Scroll down to '....'
19 | 4. See error
20 |
21 | **Expected behavior**
22 | A clear and concise description of what you expected to happen.
23 |
24 | **Screenshots**
25 | If applicable, add screenshots to help explain your problem.
26 |
27 | **Desktop (please complete the following information):**
28 | - OS: [e.g. windows]
29 | - Browser [e.g. chrome]
30 |
31 | **Additional context**
32 | Add any other context about the problem here.
33 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[Feature Request]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is.
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | Makefile
3 | .ipynb_checkpoints/
4 | build/
5 | csrc/
6 | .idea/
7 | travis.sh
8 | *.iml
9 | .token
10 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "glid_3_xl_stable"]
2 | path = glid_3_xl_stable
3 | url = https://github.com/lkwq007/glid_3_xl_stable.git
4 | [submodule "PyPatchMatch"]
5 | path = PyPatchMatch
6 | url = https://github.com/lkwq007/PyPatchMatch.git
7 | [submodule "sd_grpcserver"]
8 | path = sd_grpcserver
9 | url = https://github.com/lkwq007/sd_grpcserver.git
10 | [submodule "blip_model"]
11 | path = blip_model
12 | url = https://github.com/lkwq007/blip_model
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/canvas.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import io
4 | import numpy as np
5 | from PIL import Image
6 | from pyodide import to_js, create_proxy
7 | import gc
8 | from js import (
9 | console,
10 | document,
11 | devicePixelRatio,
12 | ImageData,
13 | Uint8ClampedArray,
14 | CanvasRenderingContext2D as Context2d,
15 | requestAnimationFrame,
16 | update_overlay,
17 | setup_overlay,
18 | window
19 | )
20 |
21 | PAINT_SELECTION = "selection"
22 | IMAGE_SELECTION = "canvas"
23 | BRUSH_SELECTION = "eraser"
24 | NOP_MODE = 0
25 | PAINT_MODE = 1
26 | IMAGE_MODE = 2
27 | BRUSH_MODE = 3
28 |
29 |
30 | def hold_canvas():
31 | pass
32 |
33 |
34 | def prepare_canvas(width, height, canvas) -> Context2d:
35 | ctx = canvas.getContext("2d")
36 |
37 | canvas.style.width = f"{width}px"
38 | canvas.style.height = f"{height}px"
39 |
40 | canvas.width = width
41 | canvas.height = height
42 |
43 | ctx.clearRect(0, 0, width, height)
44 |
45 | return ctx
46 |
47 |
48 | # class MultiCanvas:
49 | # def __init__(self,layer,width=800, height=600) -> None:
50 | # pass
51 | def multi_canvas(layer, width=800, height=600):
52 | lst = [
53 | CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
54 | for i in range(layer)
55 | ]
56 | return lst
57 |
58 |
59 | class CanvasProxy:
60 | def __init__(self, canvas, width=800, height=600) -> None:
61 | self.canvas = canvas
62 | self.ctx = prepare_canvas(width, height, canvas)
63 | self.width = width
64 | self.height = height
65 |
66 | def clear_rect(self, x, y, w, h):
67 | self.ctx.clearRect(x, y, w, h)
68 |
69 | def clear(self,):
70 | self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
71 |
72 | def stroke_rect(self, x, y, w, h):
73 | self.ctx.strokeRect(x, y, w, h)
74 |
75 | def fill_rect(self, x, y, w, h):
76 | self.ctx.fillRect(x, y, w, h)
77 |
78 | def put_image_data(self, image, x, y):
79 | data = Uint8ClampedArray.new(to_js(image.tobytes()))
80 | height, width, _ = image.shape
81 | image_data = ImageData.new(data, width, height)
82 | self.ctx.putImageData(image_data, x, y)
83 | del image_data
84 |
85 | # def draw_image(self,canvas, x, y, w, h):
86 | # self.ctx.drawImage(canvas,x,y,w,h)
87 | def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
88 | self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
89 |
90 | @property
91 | def stroke_style(self):
92 | return self.ctx.strokeStyle
93 |
94 | @stroke_style.setter
95 | def stroke_style(self, value):
96 | self.ctx.strokeStyle = value
97 |
98 | @property
99 | def fill_style(self):
100 | return self.ctx.strokeStyle
101 |
102 | @fill_style.setter
103 | def fill_style(self, value):
104 | self.ctx.fillStyle = value
105 |
106 |
107 | # RGBA for masking
108 | class InfCanvas:
109 | def __init__(
110 | self,
111 | width,
112 | height,
113 | selection_size=256,
114 | grid_size=64,
115 | patch_size=4096,
116 | test_mode=False,
117 | ) -> None:
118 | assert selection_size < min(height, width)
119 | self.width = width
120 | self.height = height
121 | self.display_width = width
122 | self.display_height = height
123 | self.canvas = multi_canvas(5, width=width, height=height)
124 | setup_overlay(width,height)
125 | # place at center
126 | self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
127 | self.cursor = [
128 | width // 2 - selection_size // 2,
129 | height // 2 - selection_size // 2,
130 | ]
131 | self.data = {}
132 | self.grid_size = grid_size
133 | self.selection_size_w = selection_size
134 | self.selection_size_h = selection_size
135 | self.patch_size = patch_size
136 | # note that for image data, the height comes before width
137 | self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
138 | self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
139 | self.sel_buffer_bak = np.zeros(
140 | (selection_size, selection_size, 4), dtype=np.uint8
141 | )
142 | self.sel_dirty = False
143 | self.buffer_dirty = False
144 | self.mouse_pos = [-1, -1]
145 | self.mouse_state = 0
146 | # self.output = widgets.Output()
147 | self.test_mode = test_mode
148 | self.buffer_updated = False
149 | self.image_move_freq = 1
150 | self.show_brush = False
151 | self.scale=1.0
152 | self.eraser_size=32
153 |
154 | def reset_large_buffer(self):
155 | self.canvas[2].canvas.width=self.width
156 | self.canvas[2].canvas.height=self.height
157 | # self.canvas[2].canvas.style.width=f"{self.display_width}px"
158 | # self.canvas[2].canvas.style.height=f"{self.display_height}px"
159 | self.canvas[2].canvas.style.display="block"
160 | self.canvas[2].clear()
161 |
162 | def draw_eraser(self, x, y):
163 | self.canvas[-2].clear()
164 | self.canvas[-2].fill_style = "#ffffff"
165 | self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
166 | self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
167 |
168 | def use_eraser(self,x,y):
169 | if self.sel_dirty:
170 | self.write_selection_to_buffer()
171 | self.draw_buffer()
172 | self.canvas[2].clear()
173 | self.buffer_dirty=True
174 | bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
175 | bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
176 | bx0,by0=max(0,bx0),max(0,by0)
177 | bx1,by1=min(self.width,bx1),min(self.height,by1)
178 | self.buffer[by0:by1,bx0:bx1,:]*=0
179 | self.draw_buffer()
180 | self.draw_selection_box()
181 |
182 | def setup_mouse(self):
183 | self.image_move_cnt = 0
184 |
185 | def get_mouse_mode():
186 | mode = document.querySelector("#mode").value
187 | if mode == PAINT_SELECTION:
188 | return PAINT_MODE
189 | elif mode == IMAGE_SELECTION:
190 | return IMAGE_MODE
191 | return BRUSH_MODE
192 |
193 | def get_event_pos(event):
194 | canvas = self.canvas[-1].canvas
195 | rect = canvas.getBoundingClientRect()
196 | x = (canvas.width * (event.clientX - rect.left)) / rect.width
197 | y = (canvas.height * (event.clientY - rect.top)) / rect.height
198 | return x, y
199 |
200 | def handle_mouse_down(event):
201 | self.mouse_state = get_mouse_mode()
202 | if self.mouse_state==BRUSH_MODE:
203 | x,y=get_event_pos(event)
204 | self.use_eraser(x,y)
205 |
206 | def handle_mouse_out(event):
207 | last_state = self.mouse_state
208 | self.mouse_state = NOP_MODE
209 | self.image_move_cnt = 0
210 | if last_state == IMAGE_MODE:
211 | self.update_view_pos(0, 0)
212 | if True:
213 | self.clear_background()
214 | self.draw_buffer()
215 | self.reset_large_buffer()
216 | self.draw_selection_box()
217 | gc.collect()
218 | if self.show_brush:
219 | self.canvas[-2].clear()
220 | self.show_brush = False
221 |
222 | def handle_mouse_up(event):
223 | last_state = self.mouse_state
224 | self.mouse_state = NOP_MODE
225 | self.image_move_cnt = 0
226 | if last_state == IMAGE_MODE:
227 | self.update_view_pos(0, 0)
228 | if True:
229 | self.clear_background()
230 | self.draw_buffer()
231 | self.reset_large_buffer()
232 | self.draw_selection_box()
233 | gc.collect()
234 |
235 | async def handle_mouse_move(event):
236 | x, y = get_event_pos(event)
237 | x0, y0 = self.mouse_pos
238 | xo = x - x0
239 | yo = y - y0
240 | if self.mouse_state == PAINT_MODE:
241 | self.update_cursor(int(xo), int(yo))
242 | if True:
243 | # self.clear_background()
244 | # console.log(self.buffer_updated)
245 | if self.buffer_updated:
246 | self.draw_buffer()
247 | self.buffer_updated = False
248 | self.draw_selection_box()
249 | elif self.mouse_state == IMAGE_MODE:
250 | self.image_move_cnt += 1
251 | if self.image_move_cnt == self.image_move_freq:
252 | self.draw_buffer()
253 | self.canvas[2].clear()
254 | self.draw_selection_box()
255 | self.update_view_pos(int(xo), int(yo))
256 | self.cached_view_pos=tuple(self.view_pos)
257 | self.canvas[2].canvas.style.display="none"
258 | large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size))
259 | self.canvas[2].canvas.width=large_buffer.shape[1]
260 | self.canvas[2].canvas.height=large_buffer.shape[0]
261 | # self.canvas[2].canvas.style.width=""
262 | # self.canvas[2].canvas.style.height=""
263 | self.canvas[2].put_image_data(large_buffer,0,0)
264 | else:
265 | self.update_view_pos(int(xo), int(yo), False)
266 | self.canvas[1].clear()
267 | self.canvas[1].draw_image(self.canvas[2].canvas,
268 | self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
269 | self.width,self.height,
270 | 0,0,self.width,self.height
271 | )
272 | self.clear_background()
273 | # self.image_move_cnt = 0
274 | elif self.mouse_state == BRUSH_MODE:
275 | self.use_eraser(x,y)
276 |
277 | mode = document.querySelector("#mode").value
278 | if mode == BRUSH_SELECTION:
279 | self.draw_eraser(x,y)
280 | self.show_brush = True
281 | elif self.show_brush:
282 | self.canvas[-2].clear()
283 | self.show_brush = False
284 | self.mouse_pos[0] = x
285 | self.mouse_pos[1] = y
286 |
287 | self.canvas[-1].canvas.addEventListener(
288 | "mousedown", create_proxy(handle_mouse_down)
289 | )
290 | self.canvas[-1].canvas.addEventListener(
291 | "mousemove", create_proxy(handle_mouse_move)
292 | )
293 | self.canvas[-1].canvas.addEventListener(
294 | "mouseup", create_proxy(handle_mouse_up)
295 | )
296 | self.canvas[-1].canvas.addEventListener(
297 | "mouseout", create_proxy(handle_mouse_out)
298 | )
299 | async def handle_mouse_wheel(event):
300 | x, y = get_event_pos(event)
301 | self.mouse_pos[0] = x
302 | self.mouse_pos[1] = y
303 | console.log(to_js(self.mouse_pos))
304 | if event.deltaY>10:
305 | window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
306 | elif event.deltaY<-10:
307 | window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
308 | return False
309 | self.canvas[-1].canvas.addEventListener(
310 | "wheel", create_proxy(handle_mouse_wheel), False
311 | )
312 | def clear_background(self):
313 | # fake transparent background
314 | h, w, step = self.height, self.width, self.grid_size
315 | stride = step * 2
316 | x0, y0 = self.view_pos
317 | x0 = (-x0) % stride
318 | y0 = (-y0) % stride
319 | if y0>=step:
320 | val0,val1=stride,step
321 | else:
322 | val0,val1=step,stride
323 | # self.canvas.clear()
324 | self.canvas[0].fill_style = "#ffffff"
325 | self.canvas[0].fill_rect(0, 0, w, h)
326 | self.canvas[0].fill_style = "#aaaaaa"
327 | for y in range(y0-stride, h + step, step):
328 | start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
329 | for x in range(start, w + step, stride):
330 | self.canvas[0].fill_rect(x, y, step, step)
331 | self.canvas[0].stroke_rect(0, 0, w, h)
332 |
333 | def refine_selection(self):
334 | h,w=self.selection_size_h,self.selection_size_w
335 | h=min(h,self.height)
336 | w=min(w,self.width)
337 | self.selection_size_h=h*8//8
338 | self.selection_size_w=w*8//8
339 | self.update_cursor(1,0)
340 |
341 |
342 | def update_scale(self, scale, mx=-1, my=-1):
343 | self.sync_to_data()
344 | scaled_width=int(self.display_width*scale)
345 | scaled_height=int(self.display_height*scale)
346 | if max(scaled_height,scaled_width)>=self.patch_size*2-128:
347 | return
348 | if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
349 | return
350 | if mx>=0 and my>=0:
351 | scaled_mx=mx/self.scale*scale
352 | scaled_my=my/self.scale*scale
353 | self.view_pos[0]+=int(mx-scaled_mx)
354 | self.view_pos[1]+=int(my-scaled_my)
355 | self.scale=scale
356 | for item in self.canvas:
357 | item.canvas.width=scaled_width
358 | item.canvas.height=scaled_height
359 | item.clear()
360 | update_overlay(scaled_width,scaled_height)
361 | self.width=scaled_width
362 | self.height=scaled_height
363 | self.data2buffer()
364 | self.clear_background()
365 | self.draw_buffer()
366 | self.update_cursor(1,0)
367 | self.draw_selection_box()
368 |
369 | def update_view_pos(self, xo, yo, update=True):
370 | # if abs(xo) + abs(yo) == 0:
371 | # return
372 | if self.sel_dirty:
373 | self.write_selection_to_buffer()
374 | if self.buffer_dirty:
375 | self.buffer2data()
376 | self.view_pos[0] -= xo
377 | self.view_pos[1] -= yo
378 | if update:
379 | self.data2buffer()
380 | # self.read_selection_from_buffer()
381 |
382 | def update_cursor(self, xo, yo):
383 | if abs(xo) + abs(yo) == 0:
384 | return
385 | if self.sel_dirty:
386 | self.write_selection_to_buffer()
387 | self.cursor[0] += xo
388 | self.cursor[1] += yo
389 | self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
390 | self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
391 | # self.read_selection_from_buffer()
392 |
393 | def data2buffer(self):
394 | x, y = self.view_pos
395 | h, w = self.height, self.width
396 | if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
397 | self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
398 | # fill four parts
399 | for i in range(4):
400 | pos_src, pos_dst, data = self.select(x, y, i)
401 | xs0, xs1 = pos_src[0]
402 | ys0, ys1 = pos_src[1]
403 | xd0, xd1 = pos_dst[0]
404 | yd0, yd1 = pos_dst[1]
405 | self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
406 |
407 | def data2array(self, x, y, w, h):
408 | # x, y = self.view_pos
409 | # h, w = self.height, self.width
410 | ret=np.zeros((h, w, 4), dtype=np.uint8)
411 | # fill four parts
412 | for i in range(4):
413 | pos_src, pos_dst, data = self.select(x, y, i, w, h)
414 | xs0, xs1 = pos_src[0]
415 | ys0, ys1 = pos_src[1]
416 | xd0, xd1 = pos_dst[0]
417 | yd0, yd1 = pos_dst[1]
418 | ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
419 | return ret
420 |
421 | def buffer2data(self):
422 | x, y = self.view_pos
423 | h, w = self.height, self.width
424 | # fill four parts
425 | for i in range(4):
426 | pos_src, pos_dst, data = self.select(x, y, i)
427 | xs0, xs1 = pos_src[0]
428 | ys0, ys1 = pos_src[1]
429 | xd0, xd1 = pos_dst[0]
430 | yd0, yd1 = pos_dst[1]
431 | data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
432 | self.buffer_dirty = False
433 |
434 | def select(self, x, y, idx, width=0, height=0):
435 | if width==0:
436 | w, h = self.width, self.height
437 | else:
438 | w, h = width, height
439 | lst = [(0, 0), (0, h), (w, 0), (w, h)]
440 | if idx == 0:
441 | x0, y0 = x % self.patch_size, y % self.patch_size
442 | x1 = min(x0 + w, self.patch_size)
443 | y1 = min(y0 + h, self.patch_size)
444 | elif idx == 1:
445 | y += h
446 | x0, y0 = x % self.patch_size, y % self.patch_size
447 | x1 = min(x0 + w, self.patch_size)
448 | y1 = max(y0 - h, 0)
449 | elif idx == 2:
450 | x += w
451 | x0, y0 = x % self.patch_size, y % self.patch_size
452 | x1 = max(x0 - w, 0)
453 | y1 = min(y0 + h, self.patch_size)
454 | else:
455 | x += w
456 | y += h
457 | x0, y0 = x % self.patch_size, y % self.patch_size
458 | x1 = max(x0 - w, 0)
459 | y1 = max(y0 - h, 0)
460 | xi, yi = x // self.patch_size, y // self.patch_size
461 | cur = self.data.setdefault(
462 | (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
463 | )
464 | x0_img, y0_img = lst[idx]
465 | x1_img = x0_img + x1 - x0
466 | y1_img = y0_img + y1 - y0
467 | sort = lambda a, b: ((a, b) if a < b else (b, a))
468 | return (
469 | (sort(x0, x1), sort(y0, y1)),
470 | (sort(x0_img, x1_img), sort(y0_img, y1_img)),
471 | cur,
472 | )
473 |
474 | def draw_buffer(self):
475 | self.canvas[1].clear()
476 | self.canvas[1].put_image_data(self.buffer, 0, 0)
477 |
478 | def fill_selection(self, img):
479 | self.sel_buffer = img
480 | self.sel_dirty = True
481 |
482 | def draw_selection_box(self):
483 | x0, y0 = self.cursor
484 | w, h = self.selection_size_w, self.selection_size_h
485 | if self.sel_dirty:
486 | self.canvas[2].clear()
487 | self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
488 | self.canvas[-1].clear()
489 | self.canvas[-1].stroke_style = "#0a0a0a"
490 | self.canvas[-1].stroke_rect(x0, y0, w, h)
491 | self.canvas[-1].stroke_style = "#ffffff"
492 | offset=round(self.scale) if self.scale>1.0 else 1
493 | self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
494 | self.canvas[-1].stroke_style = "#000000"
495 | self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
496 |
497 | def write_selection_to_buffer(self):
498 | x0, y0 = self.cursor
499 | x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
500 | self.buffer[y0:y1, x0:x1] = self.sel_buffer
501 | self.sel_dirty = False
502 | self.sel_buffer = np.zeros(
503 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
504 | )
505 | self.buffer_dirty = True
506 | self.buffer_updated = True
507 | # self.canvas[2].clear()
508 |
509 | def read_selection_from_buffer(self):
510 | x0, y0 = self.cursor
511 | x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
512 | self.sel_buffer = self.buffer[y0:y1, x0:x1]
513 | self.sel_dirty = False
514 |
515 | def base64_to_numpy(self, base64_str):
516 | try:
517 | data = base64.b64decode(str(base64_str))
518 | pil = Image.open(io.BytesIO(data))
519 | arr = np.array(pil)
520 | ret = arr
521 | except:
522 | ret = np.tile(
523 | np.array([255, 0, 0, 255], dtype=np.uint8),
524 | (self.selection_size_h, self.selection_size_w, 1),
525 | )
526 | return ret
527 |
528 | def numpy_to_base64(self, arr):
529 | out_pil = Image.fromarray(arr)
530 | out_buffer = io.BytesIO()
531 | out_pil.save(out_buffer, format="PNG")
532 | out_buffer.seek(0)
533 | base64_bytes = base64.b64encode(out_buffer.read())
534 | base64_str = base64_bytes.decode("ascii")
535 | return base64_str
536 |
537 | def sync_to_data(self):
538 | if self.sel_dirty:
539 | self.write_selection_to_buffer()
540 | self.canvas[2].clear()
541 | self.draw_buffer()
542 | if self.buffer_dirty:
543 | self.buffer2data()
544 |
545 | def sync_to_buffer(self):
546 | if self.sel_dirty:
547 | self.canvas[2].clear()
548 | self.write_selection_to_buffer()
549 | self.draw_buffer()
550 |
551 | def resize(self,width,height,scale=None,**kwargs):
552 | self.display_width=width
553 | self.display_height=height
554 | for canvas in self.canvas:
555 | prepare_canvas(width=width,height=height,canvas=canvas.canvas)
556 | setup_overlay(width,height)
557 | if scale is None:
558 | scale=1
559 | self.update_scale(scale)
560 |
561 |
562 | def save(self):
563 | self.sync_to_data()
564 | state={}
565 | state["width"]=self.display_width
566 | state["height"]=self.display_height
567 | state["selection_width"]=self.selection_size_w
568 | state["selection_height"]=self.selection_size_h
569 | state["view_pos"]=self.view_pos[:]
570 | state["cursor"]=self.cursor[:]
571 | state["scale"]=self.scale
572 | keys=list(self.data.keys())
573 | data={}
574 | for key in keys:
575 | if self.data[key].sum()>0:
576 | data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
577 | state["data"]=data
578 | return json.dumps(state)
579 |
580 | def load(self, state_json):
581 | self.reset()
582 | state=json.loads(state_json)
583 | self.display_width=state["width"]
584 | self.display_height=state["height"]
585 | self.selection_size_w=state["selection_width"]
586 | self.selection_size_h=state["selection_height"]
587 | self.view_pos=state["view_pos"][:]
588 | self.cursor=state["cursor"][:]
589 | self.scale=state["scale"]
590 | self.resize(state["width"],state["height"],scale=state["scale"])
591 | for k,v in state["data"].items():
592 | key=tuple(map(int,k.split(",")))
593 | self.data[key]=self.base64_to_numpy(v)
594 | self.data2buffer()
595 | self.display()
596 |
597 | def display(self):
598 | self.clear_background()
599 | self.draw_buffer()
600 | self.draw_selection_box()
601 |
602 | def reset(self):
603 | self.data.clear()
604 | self.buffer*=0
605 | self.buffer_dirty=False
606 | self.buffer_updated=False
607 | self.sel_buffer*=0
608 | self.sel_dirty=False
609 | self.view_pos = [0, 0]
610 | self.clear_background()
611 | for i in range(1,len(self.canvas)-1):
612 | self.canvas[i].clear()
613 |
614 | def export(self):
615 | self.sync_to_data()
616 | xmin, xmax, ymin, ymax = 0, 0, 0, 0
617 | if len(self.data.keys()) == 0:
618 | return np.zeros(
619 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
620 | )
621 | for xi, yi in self.data.keys():
622 | buf = self.data[(xi, yi)]
623 | if buf.sum() > 0:
624 | xmin = min(xi, xmin)
625 | xmax = max(xi, xmax)
626 | ymin = min(yi, ymin)
627 | ymax = max(yi, ymax)
628 | yn = ymax - ymin + 1
629 | xn = xmax - xmin + 1
630 | image = np.zeros(
631 | (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
632 | )
633 | for xi, yi in self.data.keys():
634 | buf = self.data[(xi, yi)]
635 | if buf.sum() > 0:
636 | y0 = (yi - ymin) * self.patch_size
637 | x0 = (xi - xmin) * self.patch_size
638 | image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
639 | ylst, xlst = image[:, :, -1].nonzero()
640 | if len(ylst) > 0:
641 | yt, xt = ylst.min(), xlst.min()
642 | yb, xb = ylst.max(), xlst.max()
643 | image = image[yt : yb + 1, xt : xb + 1]
644 | return image
645 | else:
646 | return np.zeros(
647 | (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
648 | )
649 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | shortcut:
2 | clear: Escape
3 | load: Ctrl+o
4 | save: Ctrl+s
5 | export: Ctrl+e
6 | upload: Ctrl+u
7 | selection: 1
8 | canvas: 2
9 | eraser: 3
10 | outpaint: d
11 | accept: a
12 | cancel: c
13 | retry: r
14 | prev: q
15 | next: e
16 | zoom_in: z
17 | zoom_out: x
18 | random_seed: s
--------------------------------------------------------------------------------
/convert_checkpoint.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
16 | """ Conversion script for the LDM checkpoints. """
17 |
18 | import argparse
19 | import os
20 |
21 | import torch
22 |
23 |
24 | try:
25 | from omegaconf import OmegaConf
26 | except ImportError:
27 | raise ImportError(
28 | "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29 | )
30 |
31 | from diffusers import (
32 | AutoencoderKL,
33 | DDIMScheduler,
34 | LDMTextToImagePipeline,
35 | LMSDiscreteScheduler,
36 | PNDMScheduler,
37 | StableDiffusionPipeline,
38 | UNet2DConditionModel,
39 | )
40 | from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
42 | from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
43 |
44 |
45 | def shave_segments(path, n_shave_prefix_segments=1):
46 | """
47 | Removes segments. Positive values shave the first segments, negative shave the last segments.
48 | """
49 | if n_shave_prefix_segments >= 0:
50 | return ".".join(path.split(".")[n_shave_prefix_segments:])
51 | else:
52 | return ".".join(path.split(".")[:n_shave_prefix_segments])
53 |
54 |
55 | def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
56 | """
57 | Updates paths inside resnets to the new naming scheme (local renaming)
58 | """
59 | mapping = []
60 | for old_item in old_list:
61 | new_item = old_item.replace("in_layers.0", "norm1")
62 | new_item = new_item.replace("in_layers.2", "conv1")
63 |
64 | new_item = new_item.replace("out_layers.0", "norm2")
65 | new_item = new_item.replace("out_layers.3", "conv2")
66 |
67 | new_item = new_item.replace("emb_layers.1", "time_emb_proj")
68 | new_item = new_item.replace("skip_connection", "conv_shortcut")
69 |
70 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
71 |
72 | mapping.append({"old": old_item, "new": new_item})
73 |
74 | return mapping
75 |
76 |
77 | def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
78 | """
79 | Updates paths inside resnets to the new naming scheme (local renaming)
80 | """
81 | mapping = []
82 | for old_item in old_list:
83 | new_item = old_item
84 |
85 | new_item = new_item.replace("nin_shortcut", "conv_shortcut")
86 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
87 |
88 | mapping.append({"old": old_item, "new": new_item})
89 |
90 | return mapping
91 |
92 |
93 | def renew_attention_paths(old_list, n_shave_prefix_segments=0):
94 | """
95 | Updates paths inside attentions to the new naming scheme (local renaming)
96 | """
97 | mapping = []
98 | for old_item in old_list:
99 | new_item = old_item
100 |
101 | # new_item = new_item.replace('norm.weight', 'group_norm.weight')
102 | # new_item = new_item.replace('norm.bias', 'group_norm.bias')
103 |
104 | # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
105 | # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
106 |
107 | # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
108 |
109 | mapping.append({"old": old_item, "new": new_item})
110 |
111 | return mapping
112 |
113 |
114 | def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115 | """
116 | Updates paths inside attentions to the new naming scheme (local renaming)
117 | """
118 | mapping = []
119 | for old_item in old_list:
120 | new_item = old_item
121 |
122 | new_item = new_item.replace("norm.weight", "group_norm.weight")
123 | new_item = new_item.replace("norm.bias", "group_norm.bias")
124 |
125 | new_item = new_item.replace("q.weight", "query.weight")
126 | new_item = new_item.replace("q.bias", "query.bias")
127 |
128 | new_item = new_item.replace("k.weight", "key.weight")
129 | new_item = new_item.replace("k.bias", "key.bias")
130 |
131 | new_item = new_item.replace("v.weight", "value.weight")
132 | new_item = new_item.replace("v.bias", "value.bias")
133 |
134 | new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135 | new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136 |
137 | new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138 |
139 | mapping.append({"old": old_item, "new": new_item})
140 |
141 | return mapping
142 |
143 |
144 | def assign_to_checkpoint(
145 | paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146 | ):
147 | """
148 | This does the final conversion step: take locally converted weights and apply a global renaming
149 | to them. It splits attention layers, and takes into account additional replacements
150 | that may arise.
151 |
152 | Assigns the weights to the new checkpoint.
153 | """
154 | assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
155 |
156 | # Splits the attention layers into three variables.
157 | if attention_paths_to_split is not None:
158 | for path, path_map in attention_paths_to_split.items():
159 | old_tensor = old_checkpoint[path]
160 | channels = old_tensor.shape[0] // 3
161 |
162 | target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
163 |
164 | num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
165 |
166 | old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
167 | query, key, value = old_tensor.split(channels // num_heads, dim=1)
168 |
169 | checkpoint[path_map["query"]] = query.reshape(target_shape)
170 | checkpoint[path_map["key"]] = key.reshape(target_shape)
171 | checkpoint[path_map["value"]] = value.reshape(target_shape)
172 |
173 | for path in paths:
174 | new_path = path["new"]
175 |
176 | # These have already been assigned
177 | if attention_paths_to_split is not None and new_path in attention_paths_to_split:
178 | continue
179 |
180 | # Global renaming happens here
181 | new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
182 | new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
183 | new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
184 |
185 | if additional_replacements is not None:
186 | for replacement in additional_replacements:
187 | new_path = new_path.replace(replacement["old"], replacement["new"])
188 |
189 | # proj_attn.weight has to be converted from conv 1D to linear
190 | if "proj_attn.weight" in new_path:
191 | checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
192 | else:
193 | checkpoint[new_path] = old_checkpoint[path["old"]]
194 |
195 |
196 | def conv_attn_to_linear(checkpoint):
197 | keys = list(checkpoint.keys())
198 | attn_keys = ["query.weight", "key.weight", "value.weight"]
199 | for key in keys:
200 | if ".".join(key.split(".")[-2:]) in attn_keys:
201 | if checkpoint[key].ndim > 2:
202 | checkpoint[key] = checkpoint[key][:, :, 0, 0]
203 | elif "proj_attn.weight" in key:
204 | if checkpoint[key].ndim > 2:
205 | checkpoint[key] = checkpoint[key][:, :, 0]
206 |
207 |
208 | def create_unet_diffusers_config(original_config):
209 | """
210 | Creates a config for the diffusers based on the config of the LDM model.
211 | """
212 | unet_params = original_config.model.params.unet_config.params
213 |
214 | block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
215 |
216 | down_block_types = []
217 | resolution = 1
218 | for i in range(len(block_out_channels)):
219 | block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
220 | down_block_types.append(block_type)
221 | if i != len(block_out_channels) - 1:
222 | resolution *= 2
223 |
224 | up_block_types = []
225 | for i in range(len(block_out_channels)):
226 | block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
227 | up_block_types.append(block_type)
228 | resolution //= 2
229 |
230 | config = dict(
231 | sample_size=unet_params.image_size,
232 | in_channels=unet_params.in_channels,
233 | out_channels=unet_params.out_channels,
234 | down_block_types=tuple(down_block_types),
235 | up_block_types=tuple(up_block_types),
236 | block_out_channels=tuple(block_out_channels),
237 | layers_per_block=unet_params.num_res_blocks,
238 | cross_attention_dim=unet_params.context_dim,
239 | attention_head_dim=unet_params.num_heads,
240 | )
241 |
242 | return config
243 |
244 |
245 | def create_vae_diffusers_config(original_config):
246 | """
247 | Creates a config for the diffusers based on the config of the LDM model.
248 | """
249 | vae_params = original_config.model.params.first_stage_config.params.ddconfig
250 | _ = original_config.model.params.first_stage_config.params.embed_dim
251 |
252 | block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
253 | down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
254 | up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
255 |
256 | config = dict(
257 | sample_size=vae_params.resolution,
258 | in_channels=vae_params.in_channels,
259 | out_channels=vae_params.out_ch,
260 | down_block_types=tuple(down_block_types),
261 | up_block_types=tuple(up_block_types),
262 | block_out_channels=tuple(block_out_channels),
263 | latent_channels=vae_params.z_channels,
264 | layers_per_block=vae_params.num_res_blocks,
265 | )
266 | return config
267 |
268 |
269 | def create_diffusers_schedular(original_config):
270 | schedular = DDIMScheduler(
271 | num_train_timesteps=original_config.model.params.timesteps,
272 | beta_start=original_config.model.params.linear_start,
273 | beta_end=original_config.model.params.linear_end,
274 | beta_schedule="scaled_linear",
275 | )
276 | return schedular
277 |
278 |
279 | def create_ldm_bert_config(original_config):
280 | bert_params = original_config.model.parms.cond_stage_config.params
281 | config = LDMBertConfig(
282 | d_model=bert_params.n_embed,
283 | encoder_layers=bert_params.n_layer,
284 | encoder_ffn_dim=bert_params.n_embed * 4,
285 | )
286 | return config
287 |
288 |
289 | def convert_ldm_unet_checkpoint(checkpoint, config):
290 | """
291 | Takes a state dict and a config, and returns a converted checkpoint.
292 | """
293 |
294 | # extract state_dict for UNet
295 | unet_state_dict = {}
296 | unet_key = "model.diffusion_model."
297 | keys = list(checkpoint.keys())
298 | for key in keys:
299 | if key.startswith(unet_key):
300 | unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
301 |
302 | new_checkpoint = {}
303 |
304 | new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305 | new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306 | new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307 | new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308 |
309 | new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
310 | new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
311 |
312 | new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
313 | new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
314 | new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
315 | new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
316 |
317 | # Retrieves the keys for the input blocks only
318 | num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
319 | input_blocks = {
320 | layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
321 | for layer_id in range(num_input_blocks)
322 | }
323 |
324 | # Retrieves the keys for the middle blocks only
325 | num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
326 | middle_blocks = {
327 | layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
328 | for layer_id in range(num_middle_blocks)
329 | }
330 |
331 | # Retrieves the keys for the output blocks only
332 | num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
333 | output_blocks = {
334 | layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
335 | for layer_id in range(num_output_blocks)
336 | }
337 |
338 | for i in range(1, num_input_blocks):
339 | block_id = (i - 1) // (config["layers_per_block"] + 1)
340 | layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
341 |
342 | resnets = [
343 | key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
344 | ]
345 | attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
346 |
347 | if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
348 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
349 | f"input_blocks.{i}.0.op.weight"
350 | )
351 | new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
352 | f"input_blocks.{i}.0.op.bias"
353 | )
354 |
355 | paths = renew_resnet_paths(resnets)
356 | meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
357 | assign_to_checkpoint(
358 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359 | )
360 |
361 | if len(attentions):
362 | paths = renew_attention_paths(attentions)
363 | meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
364 | assign_to_checkpoint(
365 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
366 | )
367 |
368 | resnet_0 = middle_blocks[0]
369 | attentions = middle_blocks[1]
370 | resnet_1 = middle_blocks[2]
371 |
372 | resnet_0_paths = renew_resnet_paths(resnet_0)
373 | assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
374 |
375 | resnet_1_paths = renew_resnet_paths(resnet_1)
376 | assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
377 |
378 | attentions_paths = renew_attention_paths(attentions)
379 | meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
380 | assign_to_checkpoint(
381 | attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382 | )
383 |
384 | for i in range(num_output_blocks):
385 | block_id = i // (config["layers_per_block"] + 1)
386 | layer_in_block_id = i % (config["layers_per_block"] + 1)
387 | output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
388 | output_block_list = {}
389 |
390 | for layer in output_block_layers:
391 | layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
392 | if layer_id in output_block_list:
393 | output_block_list[layer_id].append(layer_name)
394 | else:
395 | output_block_list[layer_id] = [layer_name]
396 |
397 | if len(output_block_list) > 1:
398 | resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
399 | attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
400 |
401 | resnet_0_paths = renew_resnet_paths(resnets)
402 | paths = renew_resnet_paths(resnets)
403 |
404 | meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
405 | assign_to_checkpoint(
406 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407 | )
408 |
409 | if ["conv.weight", "conv.bias"] in output_block_list.values():
410 | index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
411 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
412 | f"output_blocks.{i}.{index}.conv.weight"
413 | ]
414 | new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
415 | f"output_blocks.{i}.{index}.conv.bias"
416 | ]
417 |
418 | # Clear attentions as they have been attributed above.
419 | if len(attentions) == 2:
420 | attentions = []
421 |
422 | if len(attentions):
423 | paths = renew_attention_paths(attentions)
424 | meta_path = {
425 | "old": f"output_blocks.{i}.1",
426 | "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
427 | }
428 | assign_to_checkpoint(
429 | paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430 | )
431 | else:
432 | resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433 | for path in resnet_0_paths:
434 | old_path = ".".join(["output_blocks", str(i), path["old"]])
435 | new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436 |
437 | new_checkpoint[new_path] = unet_state_dict[old_path]
438 |
439 | return new_checkpoint
440 |
441 |
442 | def convert_ldm_vae_checkpoint(checkpoint, config):
443 | # extract state dict for VAE
444 | vae_state_dict = {}
445 | vae_key = "first_stage_model."
446 | keys = list(checkpoint.keys())
447 | for key in keys:
448 | if key.startswith(vae_key):
449 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
450 |
451 | new_checkpoint = {}
452 |
453 | new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
454 | new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
455 | new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
456 | new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
457 | new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
458 | new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
459 |
460 | new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
461 | new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
462 | new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
463 | new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
464 | new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
465 | new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
466 |
467 | new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
468 | new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
469 | new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
470 | new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
471 |
472 | # Retrieves the keys for the encoder down blocks only
473 | num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
474 | down_blocks = {
475 | layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
476 | }
477 |
478 | # Retrieves the keys for the decoder up blocks only
479 | num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
480 | up_blocks = {
481 | layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
482 | }
483 |
484 | for i in range(num_down_blocks):
485 | resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
486 |
487 | if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
488 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
489 | f"encoder.down.{i}.downsample.conv.weight"
490 | )
491 | new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
492 | f"encoder.down.{i}.downsample.conv.bias"
493 | )
494 |
495 | paths = renew_vae_resnet_paths(resnets)
496 | meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
497 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
498 |
499 | mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
500 | num_mid_res_blocks = 2
501 | for i in range(1, num_mid_res_blocks + 1):
502 | resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
503 |
504 | paths = renew_vae_resnet_paths(resnets)
505 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
506 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
507 |
508 | mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
509 | paths = renew_vae_attention_paths(mid_attentions)
510 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
511 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
512 | conv_attn_to_linear(new_checkpoint)
513 |
514 | for i in range(num_up_blocks):
515 | block_id = num_up_blocks - 1 - i
516 | resnets = [
517 | key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
518 | ]
519 |
520 | if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
521 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
522 | f"decoder.up.{block_id}.upsample.conv.weight"
523 | ]
524 | new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
525 | f"decoder.up.{block_id}.upsample.conv.bias"
526 | ]
527 |
528 | paths = renew_vae_resnet_paths(resnets)
529 | meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
530 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
531 |
532 | mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
533 | num_mid_res_blocks = 2
534 | for i in range(1, num_mid_res_blocks + 1):
535 | resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
536 |
537 | paths = renew_vae_resnet_paths(resnets)
538 | meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
539 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540 |
541 | mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
542 | paths = renew_vae_attention_paths(mid_attentions)
543 | meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
544 | assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545 | conv_attn_to_linear(new_checkpoint)
546 | return new_checkpoint
547 |
548 |
549 | def convert_ldm_bert_checkpoint(checkpoint, config):
550 | def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
551 | hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
552 | hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
553 | hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
554 |
555 | hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
556 | hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
557 |
558 | def _copy_linear(hf_linear, pt_linear):
559 | hf_linear.weight = pt_linear.weight
560 | hf_linear.bias = pt_linear.bias
561 |
562 | def _copy_layer(hf_layer, pt_layer):
563 | # copy layer norms
564 | _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
565 | _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
566 |
567 | # copy attn
568 | _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
569 |
570 | # copy MLP
571 | pt_mlp = pt_layer[1][1]
572 | _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
573 | _copy_linear(hf_layer.fc2, pt_mlp.net[2])
574 |
575 | def _copy_layers(hf_layers, pt_layers):
576 | for i, hf_layer in enumerate(hf_layers):
577 | if i != 0:
578 | i += i
579 | pt_layer = pt_layers[i : i + 2]
580 | _copy_layer(hf_layer, pt_layer)
581 |
582 | hf_model = LDMBertModel(config).eval()
583 |
584 | # copy embeds
585 | hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
586 | hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
587 |
588 | # copy layer norm
589 | _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
590 |
591 | # copy hidden layers
592 | _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
593 |
594 | _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
595 |
596 | return hf_model
597 |
598 |
599 | def convert_ldm_clip_checkpoint(checkpoint):
600 | text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
601 |
602 | keys = list(checkpoint.keys())
603 |
604 | text_model_dict = {}
605 |
606 | for key in keys:
607 | if key.startswith("cond_stage_model.transformer"):
608 | text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
609 |
610 | text_model.load_state_dict(text_model_dict)
611 |
612 | return text_model
613 |
614 | import os
615 | def convert_checkpoint(checkpoint_path, inpainting=False):
616 | parser = argparse.ArgumentParser()
617 |
618 | parser.add_argument(
619 | "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
620 | )
621 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
622 | parser.add_argument(
623 | "--original_config_file",
624 | default=None,
625 | type=str,
626 | help="The YAML config file corresponding to the original architecture.",
627 | )
628 | parser.add_argument(
629 | "--scheduler_type",
630 | default="pndm",
631 | type=str,
632 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
633 | )
634 | parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
635 |
636 | args = parser.parse_args([])
637 | if args.original_config_file is None:
638 | if inpainting:
639 | args.original_config_file = "./models/v1-inpainting-inference.yaml"
640 | else:
641 | args.original_config_file = "./models/v1-inference.yaml"
642 |
643 | original_config = OmegaConf.load(args.original_config_file)
644 | checkpoint = torch.load(args.checkpoint_path)["state_dict"]
645 |
646 | num_train_timesteps = original_config.model.params.timesteps
647 | beta_start = original_config.model.params.linear_start
648 | beta_end = original_config.model.params.linear_end
649 | if args.scheduler_type == "pndm":
650 | scheduler = PNDMScheduler(
651 | beta_end=beta_end,
652 | beta_schedule="scaled_linear",
653 | beta_start=beta_start,
654 | num_train_timesteps=num_train_timesteps,
655 | skip_prk_steps=True,
656 | )
657 | elif args.scheduler_type == "lms":
658 | scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
659 | elif args.scheduler_type == "ddim":
660 | scheduler = DDIMScheduler(
661 | beta_start=beta_start,
662 | beta_end=beta_end,
663 | beta_schedule="scaled_linear",
664 | clip_sample=False,
665 | set_alpha_to_one=False,
666 | )
667 | else:
668 | raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
669 |
670 | # Convert the UNet2DConditionModel model.
671 | unet_config = create_unet_diffusers_config(original_config)
672 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
673 |
674 | unet = UNet2DConditionModel(**unet_config)
675 | unet.load_state_dict(converted_unet_checkpoint)
676 |
677 | # Convert the VAE model.
678 | vae_config = create_vae_diffusers_config(original_config)
679 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
680 |
681 | vae = AutoencoderKL(**vae_config)
682 | vae.load_state_dict(converted_vae_checkpoint)
683 |
684 | # Convert the text model.
685 | text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
686 | if text_model_type == "FrozenCLIPEmbedder":
687 | text_model = convert_ldm_clip_checkpoint(checkpoint)
688 | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
689 | safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
690 | feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
691 | pipe = StableDiffusionPipeline(
692 | vae=vae,
693 | text_encoder=text_model,
694 | tokenizer=tokenizer,
695 | unet=unet,
696 | scheduler=scheduler,
697 | safety_checker=safety_checker,
698 | feature_extractor=feature_extractor,
699 | )
700 | else:
701 | text_config = create_ldm_bert_config(original_config)
702 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
703 | tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
704 | pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
705 |
706 | return pipe
707 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | sd-infinity:
3 | build:
4 | context: .
5 | dockerfile: ./docker/Dockerfile
6 | #shm_size: '2gb' # Enable if more shared memory is needed
7 | ports:
8 | - "8888:8888"
9 | volumes:
10 | - user_home:/home/user
11 | - cond_env:/opt/conda/envs
12 | deploy:
13 | resources:
14 | reservations:
15 | devices:
16 | - driver: nvidia
17 | device_ids: ['0']
18 | capabilities: [gpu]
19 |
20 | volumes:
21 | user_home: {}
22 | cond_env: {}
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM continuumio/miniconda3:4.12.0
2 |
3 | RUN apt-get update && \
4 | apt install -y \
5 | fonts-dejavu-core \
6 | build-essential \
7 | libopencv-dev \
8 | cmake \
9 | vim \
10 | && apt-get clean
11 |
12 | COPY docker/opencv.pc /usr/lib/pkgconfig/opencv.pc
13 |
14 | RUN useradd -ms /bin/bash user
15 | USER user
16 |
17 | RUN mkdir ~/.huggingface && conda init bash
18 |
19 | COPY --chown=user:user . /app
20 | WORKDIR /app
21 |
22 | EXPOSE 8888
23 | CMD ["/app/docker/entrypoint.sh"]
--------------------------------------------------------------------------------
/docker/docker-run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ..
4 | echo Current dir: "$(pwd)"
5 |
6 | if ! docker version | grep 'linux/amd64' ; then
7 | echo "Could not find docker."
8 | exit 1
9 | fi
10 |
11 | if ! docker-compose version | grep v2 ; then
12 | echo "docker-compose v2.x is not installed"
13 | exit 1
14 | fi
15 |
16 |
17 | if ! docker run -it --gpus=all --rm nvidia/cuda:11.4.2-base-ubuntu20.04 nvidia-smi | grep -e 'NVIDIA.*On' ; then
18 | echo "Docker could not find your NVIDIA gpu"
19 | exit 1
20 | fi
21 |
22 | if ! docker compose build ; then
23 | echo "Error while building"
24 | exit 1
25 | fi
26 | docker compose up
--------------------------------------------------------------------------------
/docker/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /app
4 |
5 | set -euxo pipefail
6 |
7 | set -x
8 |
9 | if ! conda env list | grep sd-inf ; then
10 | echo "Creating environment, it may appear to freeze for a few minutes..."
11 | conda env create -f environment.yml
12 | echo "Finished installing."
13 | echo "conda activate sd-inf" >> ~/.bashrc
14 | shasum environment.yml > ~/.environment.sha
15 | fi
16 |
17 | . "/opt/conda/etc/profile.d/conda.sh"
18 | conda activate sd-inf
19 |
20 | if shasum -c ~/.environment.sha > /dev/null 2>&1 ; then
21 | echo "environment.yml is unchanged."
22 | else
23 | echo "environment.yml was changed, please wait a minute until it says 'Done updating'..."
24 | conda env update --file environment.yml
25 | shasum environment.yml > ~/.environment.sha
26 | echo "Done updating."
27 | fi
28 |
29 | python app.py --port=8888 --host=0.0.0.0
30 |
--------------------------------------------------------------------------------
/docker/opencv.pc:
--------------------------------------------------------------------------------
1 | prefix=/usr
2 | exec_prefix=${prefix}
3 | includedir=${prefix}/include
4 | libdir=${exec_prefix}/lib
5 |
6 | Name: opencv
7 | Description: The opencv library
8 | Version: 2.x.x
9 | Cflags: -I${includedir}/opencv4
10 | #Cflags: -I${includedir}/opencv -I${includedir}/opencv2
11 | Libs: -L${libdir} -lopencv_calib3d -lopencv_imgproc -lopencv_xobjdetect -lopencv_hdf -lopencv_flann -lopencv_core -lopencv_dpm -lopencv_videoio -lopencv_reg -lopencv_quality -lopencv_tracking -lopencv_dnn_superres -lopencv_objdetect -lopencv_stitching -lopencv_saliency -lopencv_intensity_transform -lopencv_rapid -lopencv_dnn -lopencv_features2d -lopencv_text -lopencv_calib3d -lopencv_line_descriptor -lopencv_superres -lopencv_ml -lopencv_alphamat -lopencv_viz -lopencv_optflow -lopencv_videostab -lopencv_bioinspired -lopencv_highgui -lopencv_img_hash -lopencv_freetype -lopencv_imgcodecs -lopencv_mcc -lopencv_video -lopencv_photo -lopencv_surface_matching -lopencv_rgbd -lopencv_datasets -lopencv_ximgproc -lopencv_plot -lopencv_face -lopencv_stereo -lopencv_aruco -lopencv_dnn_objdetect -lopencv_phase_unwrapping -lopencv_bgsegm -lopencv_ccalib -lopencv_hfs -lopencv_imgproc -lopencv_shape -lopencv_xphoto -lopencv_structured_light -lopencv_fuzzy
--------------------------------------------------------------------------------
/docker/run-shell.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd "$(dirname $0)"
4 |
5 | docker-compose run -p 8888:8888 --rm -u root sd-infinity bash
6 |
--------------------------------------------------------------------------------
/docs/run_with_docker.md:
--------------------------------------------------------------------------------
1 |
2 | # Running with Docker on Windows or Linux with NVIDIA GPU
3 | On Windows 10 or 11 you can follow this guide to setting up Docker with WSL2 https://www.youtube.com/watch?v=PB7zM3JrgkI
4 |
5 | Native Linux
6 |
7 | ```
8 | cd stablediffusion-infinity/docker
9 | ./docker-run.sh
10 | ```
11 |
12 | Windows 10,11 with WSL2 shell:
13 | - open windows Command Prompt, type "bash"
14 | - once in bash, type:
15 | ```
16 | cd /mnt/c/PATH-TO-YOUR/stablediffusion-infinity/docker
17 | ./docker-run.sh
18 | ```
19 |
20 | Open "http://localhost:8888" in your browser ( even though the log says http://0.0.0.0:8888 )
--------------------------------------------------------------------------------
/docs/setup_guide.md:
--------------------------------------------------------------------------------
1 | # Setup Guide
2 |
3 | Please install conda at first ([miniconda](https://docs.conda.io/en/latest/miniconda.html) or [anaconda](https://docs.anaconda.com/anaconda/install/)).
4 |
5 | - [Setup with Linux/Nvidia GPU](#linux)
6 | - [Setup with Linux/AMD GPU](#linux-amd)
7 | - [Setup with Windows](#windows-nvidia)
8 | - [Setup with MacOS](#macos)
9 | - [Upgrade from previous version](#upgrade)
10 |
11 | ## Setup with Linux/Nvidia GPU
12 |
13 | ### conda env
14 | setup with `environment.yml`
15 | ```
16 | git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity
17 | cd stablediffusion-infinity
18 | conda env create -f environment.yml
19 | ```
20 |
21 | if the `environment.yml` doesn't work for you, you may install dependencies manually:
22 | ```
23 | conda create -n sd-inf python=3.10
24 | conda activate sd-inf
25 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
26 | conda install scipy scikit-image
27 | conda install -c conda-forge diffusers transformers ftfy accelerate
28 | pip install opencv-python
29 | pip install -U gradio
30 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3
31 | pip install timm
32 | ```
33 |
34 | After setup the environment, you can run stablediffusion-infinity with following commands:
35 | ```
36 | conda activate sd-inf
37 | python app.py
38 | ```
39 |
40 | ## Setup with Linux/AMD GPU (untested)
41 |
42 | ```
43 | conda create -n sd-inf python=3.10
44 | conda activate sd-inf
45 | pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
46 | conda install scipy scikit-image
47 | conda install -c conda-forge diffusers transformers ftfy accelerate
48 | pip install opencv-python
49 | pip install -U gradio
50 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3
51 | pip install timm
52 | ```
53 |
54 |
55 | ### CPP library (optional)
56 |
57 | Note that `opencv` library (e.g. `libopencv-dev`/`opencv-devel`, the package name may differ on different distributions) is required for `PyPatchMatch`. You may need to install `opencv` by yourself. If no `opencv` installed, the `patch_match` option (usually better quality) won't work.
58 |
59 | ## Setup with Windows
60 |
61 |
62 | ```
63 | conda create -n sd-inf python=3.10
64 | conda activate sd-inf
65 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
66 | conda install scipy scikit-image
67 | conda install -c conda-forge diffusers transformers ftfy accelerate
68 | pip install opencv-python
69 | pip install -U gradio
70 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3
71 | pip install timm
72 | ```
73 |
74 | If you use AMD GPUs, you need to install the ONNX runtime `pip install onnxruntime-directml` (only works with the `stablediffusion-inpainting` model, untested on AMD devices).
75 |
76 | For windows, you may need to replace `pip install opencv-python` with `conda install -c conda-forge opencv`
77 |
78 | After setup the environment, you can run stablediffusion-infinity with following commands:
79 | ```
80 | conda activate sd-inf
81 | python app.py
82 | ```
83 | ## Setup with MacOS
84 |
85 | ### conda env
86 | ```
87 | conda create -n sd-inf python=3.10
88 | conda activate sd-inf
89 | conda install pytorch torchvision torchaudio -c pytorch-nightly
90 | conda install scipy scikit-image
91 | conda install -c conda-forge diffusers transformers ftfy accelerate
92 | pip install opencv-python
93 | pip install -U gradio
94 | pip install pytorch-lightning==1.7.7 einops==0.4.1 omegaconf==2.2.3
95 | pip install timm
96 | ```
97 |
98 | After setup the environment, you can run stablediffusion-infinity with following commands:
99 | ```
100 | conda activate sd-inf
101 | python app.py
102 | ```
103 | ### CPP library (optional)
104 |
105 | Note that `opencv` library is required for `PyPatchMatch`. You may need to install `opencv` by yourself (via `homebrew` or compile from source). If no `opencv` installed, the `patch_match` option (usually better quality) won't work.
106 |
107 | ## Upgrade
108 |
109 | ```
110 | conda install -c conda-forge diffusers transformers ftfy accelerate
111 | conda update -c conda-forge diffusers transformers ftfy accelerate
112 | pip install -U gradio
113 | ```
--------------------------------------------------------------------------------
/docs/usage.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | ## Models
4 |
5 | - stablediffusion-inpainting: `runwayml/stable-diffusion-inpainting`, does not support img2img mode
6 | - stablediffusion-inpainting+img2img-v1.5: `runwayml/stable-diffusion-inpainting` + `runwayml/stable-diffusion-v1-5`, supports img2img mode, requires larger vRAM
7 | - stablediffusion-v1.5: `runwayml/stable-diffusion-v1-5`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode
8 | - stablediffusion-v1.4: `CompVis/stable-diffusion-v1-4`, inpainting with `diffusers`'s legacy pipeline, low quality for outpainting, supports img2img mode
9 |
10 | ## Loading local model
11 |
12 | Note that when loading a local checkpoint, you have to specify the correct model choice before setup.
13 | ```shell
14 | python app.py --local_model path_to_local_model
15 | # e.g.
16 | # diffusers model weights
17 | python app.py --local_model ./models/runwayml/stable-diffusion-inpainting
18 | python app.py --local_model models/CompVis/stable-diffusion-v1-4/model_index.json
19 | # original model checkpoint
20 | python app.py --local_model /home/user/checkpoint/model.ckpt
21 | ```
22 |
23 | ## Loading remote model
24 |
25 | Note that when loading a remote model, you have to specify the correct model choice before setup.
26 | ```shell
27 | python app.py --remote_model model_name
28 | # e.g.
29 | python app.py --remote_model hakurei/waifu-diffusion-v1-3
30 | ```
31 |
32 | ## Using textual inversion embeddings
33 |
34 | Put `*.bin` inside `embeddings` directory.
35 |
36 | ## Using a dreambooth finetuned model
37 |
38 | ```
39 | python app.py --remote_model model_name
40 | # e.g.
41 | python app.py --remote_model sd-dreambooth-library/pikachu
42 | # or download the weight/checkpoint and load with
43 | python app.py --local_model path_to_model
44 | ```
45 |
46 | ## Model Path for Docker users
47 |
48 | Docker users can specify a local model path or remote mode name within the web app.
49 |
50 | ## Using fp32 mode or low vRAM mode (some GPUs might not work well fp16)
51 |
52 | ```shell
53 | python app.py --fp32 --lowvram
54 | ```
55 |
56 | ## HTTPS
57 |
58 | ```shell
59 | python app.py --encrypt --ssl_keyfile path_to_ssl_keyfile --ssl_certfile path_to_ssl_certfile
60 | ```
61 |
62 | ## Keyboard shortcut
63 |
64 | The shortcut can be configured via `config.yaml`. Currently only support `[key]` or `[Ctrl]` + `[key]`
65 |
66 | Default shortcuts are:
67 |
68 | ```yaml
69 | shortcut:
70 | clear: Escape
71 | load: Ctrl+o
72 | save: Ctrl+s
73 | export: Ctrl+e
74 | upload: Ctrl+u
75 | selection: 1
76 | canvas: 2
77 | eraser: 3
78 | outpaint: d
79 | accept: a
80 | cancel: c
81 | retry: r
82 | prev: q
83 | next: e
84 | zoom_in: z
85 | zoom_out: x
86 | random_seed: s
87 | ```
88 |
89 | ## Glossary
90 |
91 | (From diffusers' document https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py)
92 | - prompt: The prompt to guide the image generation.
93 | - step: The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
94 | - guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,usually at the expense of lower image quality.
95 | - negative_prompt: The prompt or prompts not to guide the image generation.
96 | - Sample number: The number of images to generate per prompt
97 | - scheduler: A scheduler is used in combination with `unet` to denoise the encoded image latens.
98 | - eta: Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to DDIMScheduler, will be ignored for others.
99 | - strength: for img2img only, Conceptually, indicates how much to transform the reference image.
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: sd-inf
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - abseil-cpp=20211102.0=h27087fc_1
11 | - accelerate=0.14.0=pyhd8ed1ab_0
12 | - aiohttp=3.8.1=py310h5764c6d_1
13 | - aiosignal=1.3.1=pyhd8ed1ab_0
14 | - arrow-cpp=8.0.0=py310h3098874_0
15 | - async-timeout=4.0.2=pyhd8ed1ab_0
16 | - attrs=22.1.0=pyh71513ae_1
17 | - aws-c-common=0.4.57=he1b5a44_1
18 | - aws-c-event-stream=0.1.6=h72b8ae1_3
19 | - aws-checksums=0.1.9=h346380f_0
20 | - aws-sdk-cpp=1.8.185=hce553d0_0
21 | - backports=1.0=py_2
22 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
23 | - blas=1.0=mkl
24 | - blosc=1.21.0=h4ff587b_1
25 | - boost-cpp=1.78.0=he72f1d9_0
26 | - brotli=1.0.9=h5eee18b_7
27 | - brotli-bin=1.0.9=h5eee18b_7
28 | - brotlipy=0.7.0=py310h7f8727e_1002
29 | - brunsli=0.1=h2531618_0
30 | - bzip2=1.0.8=h7b6447c_0
31 | - c-ares=1.18.1=h7f8727e_0
32 | - ca-certificates=2022.10.11=h06a4308_0
33 | - certifi=2022.9.24=py310h06a4308_0
34 | - cffi=1.15.1=py310h74dc2b5_0
35 | - cfitsio=3.470=h5893167_7
36 | - charls=2.2.0=h2531618_0
37 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
38 | - click=8.1.3=unix_pyhd8ed1ab_2
39 | - cloudpickle=2.0.0=pyhd3eb1b0_0
40 | - colorama=0.4.6=pyhd8ed1ab_0
41 | - cryptography=38.0.1=py310h9ce1e76_0
42 | - cuda=11.6.2=0
43 | - cuda-cccl=11.6.55=hf6102b2_0
44 | - cuda-command-line-tools=11.6.2=0
45 | - cuda-compiler=11.6.2=0
46 | - cuda-cudart=11.6.55=he381448_0
47 | - cuda-cudart-dev=11.6.55=h42ad0f4_0
48 | - cuda-cuobjdump=11.6.124=h2eeebcb_0
49 | - cuda-cupti=11.6.124=h86345e5_0
50 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0
51 | - cuda-driver-dev=11.6.55=0
52 | - cuda-gdb=11.8.86=0
53 | - cuda-libraries=11.6.2=0
54 | - cuda-libraries-dev=11.6.2=0
55 | - cuda-memcheck=11.8.86=0
56 | - cuda-nsight=11.8.86=0
57 | - cuda-nsight-compute=11.8.0=0
58 | - cuda-nvcc=11.6.124=hbba6d2d_0
59 | - cuda-nvdisasm=11.8.86=0
60 | - cuda-nvml-dev=11.6.55=haa9ef22_0
61 | - cuda-nvprof=11.8.87=0
62 | - cuda-nvprune=11.6.124=he22ec0a_0
63 | - cuda-nvrtc=11.6.124=h020bade_0
64 | - cuda-nvrtc-dev=11.6.124=h249d397_0
65 | - cuda-nvtx=11.6.124=h0630a44_0
66 | - cuda-nvvp=11.8.87=0
67 | - cuda-runtime=11.6.2=0
68 | - cuda-samples=11.6.101=h8efea70_0
69 | - cuda-sanitizer-api=11.8.86=0
70 | - cuda-toolkit=11.6.2=0
71 | - cuda-tools=11.6.2=0
72 | - cuda-visual-tools=11.6.2=0
73 | - cytoolz=0.12.0=py310h5eee18b_0
74 | - dask-core=2022.7.0=py310h06a4308_0
75 | - dataclasses=0.8=pyhc8e2a94_3
76 | - datasets=2.7.0=pyhd8ed1ab_0
77 | - diffusers=0.11.1=pyhd8ed1ab_0
78 | - dill=0.3.6=pyhd8ed1ab_1
79 | - ffmpeg=4.3=hf484d3e_0
80 | - fftw=3.3.9=h27cfd23_1
81 | - filelock=3.8.0=pyhd8ed1ab_0
82 | - freetype=2.12.1=h4a9f257_0
83 | - frozenlist=1.3.0=py310h5764c6d_1
84 | - fsspec=2022.10.0=py310h06a4308_0
85 | - ftfy=6.1.1=pyhd8ed1ab_0
86 | - gds-tools=1.4.0.31=0
87 | - gflags=2.2.2=he1b5a44_1004
88 | - giflib=5.2.1=h7b6447c_0
89 | - glog=0.6.0=h6f12383_0
90 | - gmp=6.2.1=h295c915_3
91 | - gnutls=3.6.15=he1e5248_0
92 | - grpc-cpp=1.46.1=h33aed49_0
93 | - huggingface_hub=0.11.0=pyhd8ed1ab_0
94 | - icu=70.1=h27087fc_0
95 | - idna=3.4=py310h06a4308_0
96 | - imagecodecs=2021.8.26=py310hecf7e94_1
97 | - imageio=2.19.3=py310h06a4308_0
98 | - importlib-metadata=5.0.0=pyha770c72_1
99 | - importlib_metadata=5.0.0=hd8ed1ab_1
100 | - intel-openmp=2021.4.0=h06a4308_3561
101 | - joblib=1.2.0=pyhd8ed1ab_0
102 | - jpeg=9e=h7f8727e_0
103 | - jxrlib=1.1=h7b6447c_2
104 | - krb5=1.19.2=hac12032_0
105 | - lame=3.100=h7b6447c_0
106 | - lcms2=2.12=h3be6417_0
107 | - ld_impl_linux-64=2.38=h1181459_1
108 | - lerc=3.0=h295c915_0
109 | - libaec=1.0.4=he6710b0_1
110 | - libbrotlicommon=1.0.9=h5eee18b_7
111 | - libbrotlidec=1.0.9=h5eee18b_7
112 | - libbrotlienc=1.0.9=h5eee18b_7
113 | - libcublas=11.11.3.6=0
114 | - libcublas-dev=11.11.3.6=0
115 | - libcufft=10.9.0.58=0
116 | - libcufft-dev=10.9.0.58=0
117 | - libcufile=1.4.0.31=0
118 | - libcufile-dev=1.4.0.31=0
119 | - libcurand=10.3.0.86=0
120 | - libcurand-dev=10.3.0.86=0
121 | - libcurl=7.85.0=h91b91d3_0
122 | - libcusolver=11.4.1.48=0
123 | - libcusolver-dev=11.4.1.48=0
124 | - libcusparse=11.7.5.86=0
125 | - libcusparse-dev=11.7.5.86=0
126 | - libdeflate=1.8=h7f8727e_5
127 | - libedit=3.1.20210910=h7f8727e_0
128 | - libev=4.33=h7f8727e_1
129 | - libevent=2.1.10=h9b69904_4
130 | - libffi=3.3=he6710b0_2
131 | - libgcc-ng=11.2.0=h1234567_1
132 | - libgfortran-ng=11.2.0=h00389a5_1
133 | - libgfortran5=11.2.0=h1234567_1
134 | - libgomp=11.2.0=h1234567_1
135 | - libiconv=1.16=h7f8727e_2
136 | - libidn2=2.3.2=h7f8727e_0
137 | - libnghttp2=1.46.0=hce63b2e_0
138 | - libnpp=11.8.0.86=0
139 | - libnpp-dev=11.8.0.86=0
140 | - libnvjpeg=11.9.0.86=0
141 | - libnvjpeg-dev=11.9.0.86=0
142 | - libpng=1.6.37=hbc83047_0
143 | - libprotobuf=3.20.1=h4ff587b_0
144 | - libssh2=1.10.0=h8f2d780_0
145 | - libstdcxx-ng=11.2.0=h1234567_1
146 | - libtasn1=4.16.0=h27cfd23_0
147 | - libthrift=0.15.0=he6d91bd_0
148 | - libtiff=4.4.0=hecacb30_2
149 | - libunistring=0.9.10=h27cfd23_0
150 | - libuuid=1.41.5=h5eee18b_0
151 | - libwebp=1.2.4=h11a3e52_0
152 | - libwebp-base=1.2.4=h5eee18b_0
153 | - libzopfli=1.0.3=he6710b0_0
154 | - locket=1.0.0=py310h06a4308_0
155 | - lz4-c=1.9.3=h295c915_1
156 | - mkl=2021.4.0=h06a4308_640
157 | - mkl-service=2.4.0=py310h7f8727e_0
158 | - mkl_fft=1.3.1=py310hd6ae3a3_0
159 | - mkl_random=1.2.2=py310h00e6091_0
160 | - multidict=6.0.2=py310h5764c6d_1
161 | - multiprocess=0.70.12.2=py310h5764c6d_2
162 | - ncurses=6.3=h5eee18b_3
163 | - nettle=3.7.3=hbbd107a_1
164 | - networkx=2.8.4=py310h06a4308_0
165 | - nsight-compute=2022.3.0.22=0
166 | - numpy=1.23.4=py310hd5efca6_0
167 | - numpy-base=1.23.4=py310h8e6c178_0
168 | - openh264=2.1.1=h4ff587b_0
169 | - openjpeg=2.4.0=h3ad879b_0
170 | - openssl=1.1.1s=h7f8727e_0
171 | - orc=1.7.4=h07ed6aa_0
172 | - packaging=21.3=pyhd3eb1b0_0
173 | - pandas=1.4.2=py310h769672d_1
174 | - partd=1.2.0=pyhd3eb1b0_1
175 | - pillow=9.2.0=py310hace64e9_1
176 | - pip=22.2.2=py310h06a4308_0
177 | - psutil=5.9.1=py310h5764c6d_0
178 | - pyarrow=8.0.0=py310h468efa6_0
179 | - pycparser=2.21=pyhd3eb1b0_0
180 | - pyopenssl=22.0.0=pyhd3eb1b0_0
181 | - pyparsing=3.0.9=py310h06a4308_0
182 | - pysocks=1.7.1=py310h06a4308_0
183 | - python=3.10.8=haa1d7c7_0
184 | - python-dateutil=2.8.2=pyhd8ed1ab_0
185 | - python-xxhash=3.0.0=py310h5764c6d_1
186 | - python_abi=3.10=2_cp310
187 | - pytorch=1.13.0=py3.10_cuda11.6_cudnn8.3.2_0
188 | - pytorch-cuda=11.6=h867d48c_0
189 | - pytorch-mutex=1.0=cuda
190 | - pytz=2022.6=pyhd8ed1ab_0
191 | - pywavelets=1.3.0=py310h7f8727e_0
192 | - re2=2022.04.01=h27087fc_0
193 | - readline=8.2=h5eee18b_0
194 | - regex=2022.4.24=py310h5764c6d_0
195 | - requests=2.28.1=py310h06a4308_0
196 | - responses=0.18.0=pyhd8ed1ab_0
197 | - sacremoses=0.0.53=pyhd8ed1ab_0
198 | - scikit-image=0.19.2=py310h00e6091_0
199 | - scipy=1.9.3=py310hd5efca6_0
200 | - setuptools=65.5.0=py310h06a4308_0
201 | - six=1.16.0=pyhd3eb1b0_1
202 | - snappy=1.1.9=h295c915_0
203 | - sqlite=3.39.3=h5082296_0
204 | - tifffile=2021.7.2=pyhd3eb1b0_2
205 | - tk=8.6.12=h1ccaba5_0
206 | - tokenizers=0.11.4=py310h3dcd8bd_1
207 | - toolz=0.12.0=py310h06a4308_0
208 | - torchaudio=0.13.0=py310_cu116
209 | - torchvision=0.14.0=py310_cu116
210 | - tqdm=4.64.1=pyhd8ed1ab_0
211 | - transformers=4.24.0=pyhd8ed1ab_0
212 | - typing-extensions=4.3.0=py310h06a4308_0
213 | - typing_extensions=4.3.0=py310h06a4308_0
214 | - tzdata=2022f=h04d1e81_0
215 | - urllib3=1.26.12=py310h06a4308_0
216 | - utf8proc=2.6.1=h27cfd23_0
217 | - wcwidth=0.2.5=pyh9f0ad1d_2
218 | - wheel=0.37.1=pyhd3eb1b0_0
219 | - xxhash=0.8.0=h7f98852_3
220 | - xz=5.2.6=h5eee18b_0
221 | - yaml=0.2.5=h7b6447c_0
222 | - yarl=1.7.2=py310h5764c6d_2
223 | - zfp=0.5.5=h295c915_6
224 | - zipp=3.10.0=pyhd8ed1ab_0
225 | - zlib=1.2.13=h5eee18b_0
226 | - zstd=1.5.2=ha4553b6_0
227 | - pip:
228 | - absl-py==1.3.0
229 | - antlr4-python3-runtime==4.9.3
230 | - anyio==3.6.2
231 | - bcrypt==4.0.1
232 | - cachetools==5.2.0
233 | - cmake==3.25.0
234 | - commonmark==0.9.1
235 | - contourpy==1.0.6
236 | - cycler==0.11.0
237 | - einops==0.4.1
238 | - fastapi==0.87.0
239 | - ffmpy==0.3.0
240 | - fonttools==4.38.0
241 | - fpie==0.2.4
242 | - google-auth==2.14.1
243 | - google-auth-oauthlib==0.4.6
244 | - gradio==3.10.1
245 | - grpcio==1.51.0
246 | - h11==0.12.0
247 | - httpcore==0.15.0
248 | - httpx==0.23.1
249 | - jinja2==3.1.2
250 | - kiwisolver==1.4.4
251 | - linkify-it-py==1.0.3
252 | - llvmlite==0.39.1
253 | - markdown==3.4.1
254 | - markdown-it-py==2.1.0
255 | - markupsafe==2.1.1
256 | - matplotlib==3.6.2
257 | - mdit-py-plugins==0.3.1
258 | - mdurl==0.1.2
259 | - numba==0.56.4
260 | - oauthlib==3.2.2
261 | - omegaconf==2.2.3
262 | - opencv-python==4.6.0.66
263 | - opencv-python-headless==4.6.0.66
264 | - orjson==3.8.2
265 | - paramiko==2.12.0
266 | - protobuf==3.20.3
267 | - pyasn1==0.4.8
268 | - pyasn1-modules==0.2.8
269 | - pycryptodome==3.15.0
270 | - pydantic==1.10.2
271 | - pydeprecate==0.3.2
272 | - pydub==0.25.1
273 | - pygments==2.13.0
274 | - pynacl==1.5.0
275 | - python-multipart==0.0.5
276 | - pytorch-lightning==1.7.7
277 | - pyyaml==6.0
278 | - requests-oauthlib==1.3.1
279 | - rfc3986==1.5.0
280 | - rich==12.6.0
281 | - rsa==4.9
282 | - sniffio==1.3.0
283 | - sourceinspect==0.0.4
284 | - starlette==0.21.0
285 | - taichi==1.2.2
286 | - tensorboard==2.11.0
287 | - tensorboard-data-server==0.6.1
288 | - tensorboard-plugin-wit==1.8.1
289 | - timm==0.6.11
290 | - torchmetrics==0.10.3
291 | - uc-micro-py==1.0.1
292 | - uvicorn==0.20.0
293 | - websockets==10.4
294 | - werkzeug==2.2.2
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 | Stablediffusion Infinity
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
31 |
32 |
33 |
34 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
77 |
78 |
79 | - numpy
80 | - Pillow
81 | - paths:
82 | - ./canvas.py
83 |
84 |
85 |
86 | from pyodide import to_js, create_proxy
87 | from PIL import Image
88 | import io
89 | import time
90 | import base64
91 | from collections import deque
92 | import numpy as np
93 | from js import (
94 | console,
95 | document,
96 | parent,
97 | devicePixelRatio,
98 | ImageData,
99 | Uint8ClampedArray,
100 | CanvasRenderingContext2D as Context2d,
101 | requestAnimationFrame,
102 | window,
103 | encodeURIComponent,
104 | w2ui,
105 | update_eraser,
106 | update_scale,
107 | adjust_selection,
108 | update_count,
109 | enable_result_lst,
110 | setup_shortcut,
111 | update_undo_redo,
112 | )
113 |
114 |
115 | from canvas import InfCanvas
116 |
117 |
118 | class History:
119 | def __init__(self,maxlen=10):
120 | self.idx=-1
121 | self.undo_lst=deque([],maxlen=maxlen)
122 | self.redo_lst=deque([],maxlen=maxlen)
123 | self.state=None
124 |
125 | def undo(self):
126 | cur=None
127 | if len(self.undo_lst):
128 | cur=self.undo_lst.pop()
129 | self.redo_lst.appendleft(cur)
130 | return cur
131 | def redo(self):
132 | cur=None
133 | if len(self.redo_lst):
134 | cur=self.redo_lst.popleft()
135 | self.undo_lst.append(cur)
136 | return cur
137 |
138 | def check(self):
139 | return len(self.undo_lst)>0,len(self.redo_lst)>0
140 |
141 | def append(self,state,update=True):
142 | self.redo_lst.clear()
143 | self.undo_lst.append(state)
144 | if update:
145 | update_undo_redo(*self.check())
146 |
147 | history = History()
148 |
149 | base_lst = [None]
150 | async def draw_canvas() -> None:
151 | width=1024
152 | height=600
153 | canvas=InfCanvas(1024,600)
154 | update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
155 | document.querySelector("#container").style.height= f"{height}px"
156 | document.querySelector("#container").style.width = f"{width}px"
157 | canvas.setup_mouse()
158 | canvas.clear_background()
159 | canvas.draw_buffer()
160 | canvas.draw_selection_box()
161 | base_lst[0]=canvas
162 |
163 | async def draw_canvas_func(event):
164 | try:
165 | app=parent.document.querySelector("gradio-app")
166 | if app.shadowRoot:
167 | app=app.shadowRoot
168 | width=app.querySelector("#canvas_width input").value
169 | height=app.querySelector("#canvas_height input").value
170 | selection_size=app.querySelector("#selection_size input").value
171 | except:
172 | width=1024
173 | height=768
174 | selection_size=384
175 | document.querySelector("#container").style.width = f"{width}px"
176 | document.querySelector("#container").style.height= f"{height}px"
177 | canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
178 | canvas.setup_mouse()
179 | canvas.clear_background()
180 | canvas.draw_buffer()
181 | canvas.draw_selection_box()
182 | base_lst[0]=canvas
183 |
184 | async def export_func(event):
185 | base=base_lst[0]
186 | arr=base.export()
187 | base.draw_buffer()
188 | base.canvas[2].clear()
189 | base64_str = base.numpy_to_base64(arr)
190 | time_str = time.strftime("%Y%m%d_%H%M%S")
191 | link = document.createElement("a")
192 | if len(event.data)>2 and event.data[2]:
193 | filename = event.data[2]
194 | else:
195 | filename = f"outpaint_{time_str}"
196 | # link.download = f"sdinf_state_{time_str}.json"
197 | link.download = f"{filename}.png"
198 | # link.download = f"outpaint_{time_str}.png"
199 | link.href = "data:image/png;base64,"+base64_str
200 | link.click()
201 | console.log(f"Canvas saved to {filename}.png")
202 |
203 | img_candidate_lst=[None,0]
204 |
205 | async def outpaint_func(event):
206 | base=base_lst[0]
207 | if len(event.data)==2:
208 | app=parent.document.querySelector("gradio-app")
209 | if app.shadowRoot:
210 | app=app.shadowRoot
211 | base64_str_raw=app.querySelector("#output textarea").value
212 | base64_str_lst=base64_str_raw.split(",")
213 | img_candidate_lst[0]=base64_str_lst
214 | img_candidate_lst[1]=0
215 | elif event.data[2]=="next":
216 | img_candidate_lst[1]+=1
217 | elif event.data[2]=="prev":
218 | img_candidate_lst[1]-=1
219 | enable_result_lst()
220 | if img_candidate_lst[0] is None:
221 | return
222 | lst=img_candidate_lst[0]
223 | idx=img_candidate_lst[1]
224 | update_count(idx%len(lst)+1,len(lst))
225 | arr=base.base64_to_numpy(lst[idx%len(lst)])
226 | base.fill_selection(arr)
227 | base.draw_selection_box()
228 |
229 | async def undo_func(event):
230 | base=base_lst[0]
231 | img_candidate_lst[0]=None
232 | if base.sel_dirty:
233 | base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
234 | base.sel_dirty = False
235 | base.canvas[2].clear()
236 |
237 | async def commit_func(event):
238 | base=base_lst[0]
239 | img_candidate_lst[0]=None
240 | if base.sel_dirty:
241 | base.write_selection_to_buffer()
242 | base.draw_buffer()
243 | base.canvas[2].clear()
244 | if len(event.data)>2:
245 | history.append(base.save())
246 |
247 | async def history_undo_func(event):
248 | base=base_lst[0]
249 | if base.buffer_dirty or len(history.redo_lst)>0:
250 | state=history.undo()
251 | else:
252 | history.undo()
253 | state=history.undo()
254 | if state is not None:
255 | base.load(state)
256 | update_undo_redo(*history.check())
257 |
258 | async def history_setup_func(event):
259 | base=base_lst[0]
260 | history.undo_lst.clear()
261 | history.redo_lst.clear()
262 | history.append(base.save(),update=False)
263 |
264 | async def history_redo_func(event):
265 | base=base_lst[0]
266 | if len(history.undo_lst)>0:
267 | state=history.redo()
268 | else:
269 | history.redo()
270 | state=history.redo()
271 | if state is not None:
272 | base.load(state)
273 | update_undo_redo(*history.check())
274 |
275 |
276 | async def transfer_func(event):
277 | base=base_lst[0]
278 | base.read_selection_from_buffer()
279 | sel_buffer=base.sel_buffer
280 | sel_buffer_str=base.numpy_to_base64(sel_buffer)
281 | app=parent.document.querySelector("gradio-app")
282 | if app.shadowRoot:
283 | app=app.shadowRoot
284 | app.querySelector("#input textarea").value=sel_buffer_str
285 | app.querySelector("#proceed").click()
286 |
287 | async def upload_func(event):
288 | base=base_lst[0]
289 | # base64_str=event.data[1]
290 | base64_str=document.querySelector("#upload_content").value
291 | base64_str=base64_str.split(",")[-1]
292 | # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
293 | arr=base.base64_to_numpy(base64_str)
294 | h,w,c=base.buffer.shape
295 | base.sync_to_buffer()
296 | base.buffer_dirty=True
297 | mask=arr[:,:,3:4].repeat(4,axis=2)
298 | base.buffer[mask>0]=0
299 | # in case mismatch
300 | base.buffer[0:h,0:w,:]+=arr
301 | #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
302 | #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
303 | base.draw_buffer()
304 | if len(event.data)>2:
305 | history.append(base.save())
306 |
307 | async def setup_shortcut_func(event):
308 | setup_shortcut(event.data[1])
309 |
310 |
311 | document.querySelector("#export").addEventListener("click",create_proxy(export_func))
312 | document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
313 | document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
314 | document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
315 | document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
316 |
317 | document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
318 | document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
319 |
320 | async def setup_func():
321 | document.querySelector("#setup").value="1"
322 |
323 | async def reset_func(event):
324 | base=base_lst[0]
325 | base.reset()
326 |
327 | async def load_func(event):
328 | base=base_lst[0]
329 | base.load(event.data[1])
330 |
331 | async def save_func(event):
332 | base=base_lst[0]
333 | json_str=base.save()
334 | time_str = time.strftime("%Y%m%d_%H%M%S")
335 | link = document.createElement("a")
336 | if len(event.data)>2 and event.data[2]:
337 | filename = str(event.data[2]).strip()
338 | else:
339 | filename = f"outpaint_{time_str}"
340 | # link.download = f"sdinf_state_{time_str}.json"
341 | link.download = f"{filename}.sdinf"
342 | link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
343 | link.click()
344 |
345 | async def prev_result_func(event):
346 | base=base_lst[0]
347 | base.reset()
348 |
349 | async def next_result_func(event):
350 | base=base_lst[0]
351 | base.reset()
352 |
353 | async def zoom_in_func(event):
354 | base=base_lst[0]
355 | scale=base.scale
356 | if scale>=0.2:
357 | scale-=0.1
358 | if len(event.data)>2:
359 | base.update_scale(scale,int(event.data[2]),int(event.data[3]))
360 | else:
361 | base.update_scale(scale)
362 | scale=base.scale
363 | update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
364 |
365 | async def zoom_out_func(event):
366 | base=base_lst[0]
367 | scale=base.scale
368 | if scale<10:
369 | scale+=0.1
370 | console.log(len(event.data))
371 | if len(event.data)>2:
372 | base.update_scale(scale,int(event.data[2]),int(event.data[3]))
373 | else:
374 | base.update_scale(scale)
375 | scale=base.scale
376 | update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
377 |
378 | async def sync_func(event):
379 | base=base_lst[0]
380 | base.sync_to_buffer()
381 | base.canvas[2].clear()
382 |
383 | async def eraser_size_func(event):
384 | base=base_lst[0]
385 | eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
386 | eraser_size=max(8,eraser_size)
387 | base.eraser_size=eraser_size
388 |
389 | async def resize_selection_func(event):
390 | base=base_lst[0]
391 | cursor=base.cursor
392 | if len(event.data)>3:
393 | console.log(event.data)
394 | base.cursor[0]=int(event.data[1])
395 | base.cursor[1]=int(event.data[2])
396 | base.selection_size_w=int(event.data[3])//8*8
397 | base.selection_size_h=int(event.data[4])//8*8
398 | base.refine_selection()
399 | base.draw_selection_box()
400 | elif len(event.data)>2:
401 | base.draw_selection_box()
402 | else:
403 | base.canvas[-1].clear()
404 | adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
405 |
406 | async def eraser_func(event):
407 | base=base_lst[0]
408 | if event.data[1]!="eraser":
409 | base.canvas[-2].clear()
410 | else:
411 | x,y=base.mouse_pos
412 | base.draw_eraser(x,y)
413 |
414 | async def resize_func(event):
415 | base=base_lst[0]
416 | width=int(event.data[1])
417 | height=int(event.data[2])
418 | if width>=256 and height>=256:
419 | if max(base.selection_size_h,base.selection_size_w)>min(width,height):
420 | base.selection_size_h=256
421 | base.selection_size_w=256
422 | base.resize(width,height)
423 |
424 | async def message_func(event):
425 | if event.data[0]=="click":
426 | if event.data[1]=="clear":
427 | await reset_func(event)
428 | elif event.data[1]=="save":
429 | await save_func(event)
430 | elif event.data[1]=="export":
431 | await export_func(event)
432 | elif event.data[1]=="accept":
433 | await commit_func(event)
434 | elif event.data[1]=="cancel":
435 | await undo_func(event)
436 | elif event.data[1]=="zoom_in":
437 | await zoom_in_func(event)
438 | elif event.data[1]=="zoom_out":
439 | await zoom_out_func(event)
440 | elif event.data[1]=="redo":
441 | await history_redo_func(event)
442 | elif event.data[1]=="undo":
443 | await history_undo_func(event)
444 | elif event.data[1]=="history":
445 | await history_setup_func(event)
446 | elif event.data[0]=="sync":
447 | await sync_func(event)
448 | elif event.data[0]=="load":
449 | await load_func(event)
450 | elif event.data[0]=="upload":
451 | await upload_func(event)
452 | elif event.data[0]=="outpaint":
453 | await outpaint_func(event)
454 | elif event.data[0]=="mode":
455 | if event.data[1]!="selection":
456 | await sync_func(event)
457 | await eraser_func(event)
458 | document.querySelector("#mode").value=event.data[1]
459 | elif event.data[0]=="transfer":
460 | await transfer_func(event)
461 | elif event.data[0]=="setup":
462 | await draw_canvas_func(event)
463 | elif event.data[0]=="eraser_size":
464 | await eraser_size_func(event)
465 | elif event.data[0]=="resize_selection":
466 | await resize_selection_func(event)
467 | elif event.data[0]=="shortcut":
468 | await setup_shortcut_func(event)
469 | elif event.data[0]=="resize":
470 | await resize_func(event)
471 |
472 | window.addEventListener("message",create_proxy(message_func))
473 |
474 | import asyncio
475 |
476 | _ = await asyncio.gather(
477 | setup_func()
478 | )
479 |
480 |
481 |
482 |
483 |
--------------------------------------------------------------------------------
/interrogate.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2022 pharmapsychotic
5 | https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb
6 | """
7 |
8 | import numpy as np
9 | import os
10 | import torch
11 | import torchvision.transforms as T
12 | import torchvision.transforms.functional as TF
13 |
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torchvision import transforms
17 | from torchvision.transforms.functional import InterpolationMode
18 | from transformers import CLIPTokenizer, CLIPModel
19 | from transformers import CLIPProcessor, CLIPModel
20 |
21 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "blip_model", "data")
22 | def load_list(filename):
23 | with open(filename, 'r', encoding='utf-8', errors='replace') as f:
24 | items = [line.strip() for line in f.readlines()]
25 | return items
26 |
27 | artists = load_list(os.path.join(data_path, 'artists.txt'))
28 | flavors = load_list(os.path.join(data_path, 'flavors.txt'))
29 | mediums = load_list(os.path.join(data_path, 'mediums.txt'))
30 | movements = load_list(os.path.join(data_path, 'movements.txt'))
31 |
32 | sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
33 | trending_list = [site for site in sites]
34 | trending_list.extend(["trending on "+site for site in sites])
35 | trending_list.extend(["featured on "+site for site in sites])
36 | trending_list.extend([site+" contest winner" for site in sites])
37 |
38 | device="cpu"
39 | blip_image_eval_size = 384
40 | clip_name="openai/clip-vit-large-patch14"
41 |
42 | blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
43 |
44 | def generate_caption(blip_model, pil_image, device="cpu"):
45 | gpu_image = transforms.Compose([
46 | transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
47 | transforms.ToTensor(),
48 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
49 | ])(pil_image).unsqueeze(0).to(device)
50 |
51 | with torch.no_grad():
52 | caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
53 | return caption[0]
54 |
55 | def rank(text_features, image_features, text_array, top_count=1):
56 | top_count = min(top_count, len(text_array))
57 | similarity = torch.zeros((1, len(text_array)))
58 | for i in range(image_features.shape[0]):
59 | similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
60 | similarity /= image_features.shape[0]
61 |
62 | top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
63 | return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
64 |
65 | class Interrogator:
66 | def __init__(self) -> None:
67 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
68 | try:
69 | self.get_blip()
70 | except:
71 | self.blip_model = None
72 | self.model = CLIPModel.from_pretrained(clip_name)
73 | self.processor = CLIPProcessor.from_pretrained(clip_name)
74 | self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)]
75 |
76 | def get_blip(self):
77 | from blip_model.blip import blip_decoder
78 | blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
79 | blip_model.eval()
80 | self.blip_model = blip_model
81 |
82 |
83 | def interrogate(self,image,use_caption=False):
84 | if self.blip_model:
85 | caption = generate_caption(self.blip_model, image)
86 | else:
87 | caption = ""
88 | model,processor=self.model,self.processor
89 | bests = [[('',0)]]*5
90 | if True:
91 | print(f"Interrogating with {clip_name}...")
92 |
93 | inputs = processor(images=image, return_tensors="pt")
94 | with torch.no_grad():
95 | image_features = model.get_image_features(**inputs)
96 | image_features /= image_features.norm(dim=-1, keepdim=True)
97 | ranks = [
98 | rank(self.text_feature_lst[0], image_features, mediums),
99 | rank(self.text_feature_lst[1], image_features, ["by "+artist for artist in artists]),
100 | rank(self.text_feature_lst[2], image_features, trending_list),
101 | rank(self.text_feature_lst[3], image_features, movements),
102 | rank(self.text_feature_lst[4], image_features, flavors, top_count=3)
103 | ]
104 |
105 | for i in range(len(ranks)):
106 | confidence_sum = 0
107 | for ci in range(len(ranks[i])):
108 | confidence_sum += ranks[i][ci][1]
109 | if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
110 | bests[i] = ranks[i]
111 |
112 | flaves = ', '.join([f"{x[0]}" for x in bests[4]])
113 | medium = bests[0][0][0]
114 | print(ranks)
115 | if caption.startswith(medium):
116 | return f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
117 | else:
118 | return f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/js/keyboard.js:
--------------------------------------------------------------------------------
1 |
2 | window.my_setup_keyboard=setInterval(function(){
3 | let app=document.querySelector("gradio-app");
4 | app=app.shadowRoot??app;
5 | let frame=app.querySelector("#sdinfframe").contentWindow;
6 | console.log("Check iframe...");
7 | if(frame.setup_shortcut)
8 | {
9 | frame.setup_shortcut(json);
10 | clearInterval(window.my_setup_keyboard);
11 | }
12 | }, 1000);
13 | var config=JSON.parse(json);
14 | var key_map={};
15 | Object.keys(config.shortcut).forEach(k=>{
16 | key_map[config.shortcut[k]]=k;
17 | });
18 | document.addEventListener("keydown", e => {
19 | if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
20 | {
21 | let key=e.key;
22 | if(e.ctrlKey)
23 | {
24 | key="Ctrl+"+e.key;
25 | if(key in key_map)
26 | {
27 | e.preventDefault();
28 | }
29 | }
30 | let app=document.querySelector("gradio-app");
31 | app=app.shadowRoot??app;
32 | let frame=app.querySelector("#sdinfframe").contentDocument;
33 | frame.dispatchEvent(
34 | new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
35 | );
36 | }
37 | })
--------------------------------------------------------------------------------
/js/mode.js:
--------------------------------------------------------------------------------
1 | function(mode){
2 | let app=document.querySelector("gradio-app").shadowRoot;
3 | let frame=app.querySelector("#sdinfframe").contentWindow.document;
4 | frame.querySelector("#mode").value=mode;
5 | return mode;
6 | }
--------------------------------------------------------------------------------
/js/outpaint.js:
--------------------------------------------------------------------------------
1 | function(a){
2 | if(!window.my_observe_outpaint)
3 | {
4 | console.log("setup outpaint here");
5 | window.my_observe_outpaint = new MutationObserver(function (event) {
6 | console.log(event);
7 | let app=document.querySelector("gradio-app");
8 | app=app.shadowRoot??app;
9 | let frame=app.querySelector("#sdinfframe").contentWindow;
10 | frame.postMessage(["outpaint", ""], "*");
11 | });
12 | var app=document.querySelector("gradio-app");
13 | app=app.shadowRoot??app;
14 | window.my_observe_outpaint_target=app.querySelector("#output span");
15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16 | attributes: false,
17 | subtree: true,
18 | childList: true,
19 | characterData: true
20 | });
21 | }
22 | return a;
23 | }
--------------------------------------------------------------------------------
/js/proceed.js:
--------------------------------------------------------------------------------
1 | function(sel_buffer_str,
2 | prompt_text,
3 | negative_prompt_text,
4 | strength,
5 | guidance,
6 | step,
7 | resize_check,
8 | fill_mode,
9 | enable_safety,
10 | use_correction,
11 | enable_img2img,
12 | use_seed,
13 | seed_val,
14 | generate_num,
15 | scheduler,
16 | scheduler_eta,
17 | interrogate_mode,
18 | state){
19 | let app=document.querySelector("gradio-app");
20 | app=app.shadowRoot??app;
21 | sel_buffer=app.querySelector("#input textarea").value;
22 | let use_correction_bak=false;
23 | ({resize_check,enable_safety,enable_img2img,use_seed,seed_val,interrogate_mode}=window.config_obj);
24 | seed_val=Number(seed_val);
25 | return [
26 | sel_buffer,
27 | prompt_text,
28 | negative_prompt_text,
29 | strength,
30 | guidance,
31 | step,
32 | resize_check,
33 | fill_mode,
34 | enable_safety,
35 | use_correction,
36 | enable_img2img,
37 | use_seed,
38 | seed_val,
39 | generate_num,
40 | scheduler,
41 | scheduler_eta,
42 | interrogate_mode,
43 | state,
44 | ]
45 | }
--------------------------------------------------------------------------------
/js/setup.js:
--------------------------------------------------------------------------------
1 | function(token_val, width, height, size, model_choice, model_path){
2 | let app=document.querySelector("gradio-app");
3 | app=app.shadowRoot??app;
4 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
5 | // app.querySelector("#setup_row").style.display="none";
6 | app.querySelector("#model_path_input").style.display="none";
7 | let frame=app.querySelector("#sdinfframe").contentWindow.document;
8 |
9 | if(frame.querySelector("#setup").value=="0")
10 | {
11 | window.my_setup=setInterval(function(){
12 | let app=document.querySelector("gradio-app");
13 | app=app.shadowRoot??app;
14 | let frame=app.querySelector("#sdinfframe").contentWindow.document;
15 | console.log("Check PyScript...")
16 | if(frame.querySelector("#setup").value=="1")
17 | {
18 | frame.querySelector("#draw").click();
19 | clearInterval(window.my_setup);
20 | }
21 | }, 100)
22 | }
23 | else
24 | {
25 | frame.querySelector("#draw").click();
26 | }
27 | return [token_val, width, height, size, model_choice, model_path];
28 | }
--------------------------------------------------------------------------------
/js/toolbar.js:
--------------------------------------------------------------------------------
1 | // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
2 | // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
3 |
4 | // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
5 | function getBase64(file) {
6 | var reader = new FileReader();
7 | reader.readAsDataURL(file);
8 | reader.onload = function () {
9 | add_image(reader.result);
10 | // console.log(reader.result);
11 | };
12 | reader.onerror = function (error) {
13 | console.log("Error: ", error);
14 | };
15 | }
16 |
17 | function getText(file) {
18 | var reader = new FileReader();
19 | reader.readAsText(file);
20 | reader.onload = function () {
21 | window.postMessage(["load",reader.result],"*")
22 | // console.log(reader.result);
23 | };
24 | reader.onerror = function (error) {
25 | console.log("Error: ", error);
26 | };
27 | }
28 |
29 | document.querySelector("#upload_file").addEventListener("change", (event)=>{
30 | console.log(event);
31 | let file = document.querySelector("#upload_file").files[0];
32 | getBase64(file);
33 | })
34 |
35 | document.querySelector("#upload_state").addEventListener("change", (event)=>{
36 | console.log(event);
37 | let file = document.querySelector("#upload_state").files[0];
38 | getText(file);
39 | })
40 |
41 | open_setting = function() {
42 | if (!w2ui.foo) {
43 | new w2form({
44 | name: "foo",
45 | style: "border: 0px; background-color: transparent;",
46 | fields: [{
47 | field: "canvas_width",
48 | type: "int",
49 | required: true,
50 | html: {
51 | label: "Canvas Width"
52 | }
53 | },
54 | {
55 | field: "canvas_height",
56 | type: "int",
57 | required: true,
58 | html: {
59 | label: "Canvas Height"
60 | }
61 | },
62 | ],
63 | record: {
64 | canvas_width: 1200,
65 | canvas_height: 600,
66 | },
67 | actions: {
68 | Save() {
69 | this.validate();
70 | let record = this.getCleanRecord();
71 | window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
72 | w2popup.close();
73 | },
74 | custom: {
75 | text: "Cancel",
76 | style: "text-transform: uppercase",
77 | onClick(event) {
78 | w2popup.close();
79 | }
80 | }
81 | }
82 | });
83 | }
84 | w2popup.open({
85 | title: "Form in a Popup",
86 | body: "",
87 | style: "padding: 15px 0px 0px 0px",
88 | width: 500,
89 | height: 280,
90 | showMax: true,
91 | async onToggle(event) {
92 | await event.complete
93 | w2ui.foo.resize();
94 | }
95 | })
96 | .then((event) => {
97 | w2ui.foo.render("#form")
98 | });
99 | }
100 |
101 | var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
102 | var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting", "interrogate"];
103 | var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting", "interrogate"];
104 | var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting", "interrogate", "undo", "redo"];
105 | var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
106 | var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
107 |
108 | function check_button(id,text="",checked=true,tooltip="")
109 | {
110 | return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
111 | }
112 |
113 | var toolbar=new w2toolbar({
114 | box: "#toolbar",
115 | name: "toolbar",
116 | tooltip: "top",
117 | items: [
118 | { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
119 | { type: "break" },
120 | { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
121 | { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
122 | { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
123 | { type: "break" },
124 | { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
125 | { type: "break" },
126 | { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
127 | { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
128 | { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
129 | { type: "break" },
130 | { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
131 | { type: "button", id: "interrogate", text: "Interrogate", tooltip: "Get a prompt with Clip Interrogator ", icon: "fa-solid fa-magnifying-glass" },
132 | { type: "break" },
133 | { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disabled:true,},
134 | { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
135 | { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disabled:true,},
136 | { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disabled:true,},
137 | { type: "html", id: "current", hidden: true, disabled:true,
138 | async onRefresh(event) {
139 | await event.complete
140 | let fragment = query.html(`
141 |
142 |
143 | ${this.sel_value ?? "1/1"}
144 |
`)
145 | query(this.box).find("#tb_toolbar_item_current").append(fragment)
146 | }
147 | },
148 | { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disabled:true,},
149 | { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disabled:true,},
150 | { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disabled:true,},
151 | { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disabled:true,},
152 | { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disabled:true,},
153 | { type: "break" },
154 | { type: "spacer" },
155 | { type: "break" },
156 | { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
157 | { type: "html", id: "eraser_size", hidden: true,
158 | async onRefresh(event) {
159 | await event.complete
160 | // let fragment = query.html(`
161 | //
162 | // `)
163 | let fragment = query.html(`
164 |
165 | `)
166 | fragment.filter("input").on("change", event => {
167 | this.eraser_size = event.target.value;
168 | window.overlay.freeDrawingBrush.width=this.eraser_size;
169 | this.setCount("eraser_size_btn", event.target.value);
170 | window.postMessage(["eraser_size", event.target.value],"*")
171 | this.refresh();
172 | })
173 | query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
174 | }
175 | },
176 | // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
177 | { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
178 | { type: "break" },
179 | { type: "html", id: "scale",
180 | async onRefresh(event) {
181 | await event.complete
182 | let fragment = query.html(`
183 |
184 |
185 | ${this.scale_value ?? "100%"}
186 |
`)
187 | query(this.box).find("#tb_toolbar_item_scale").append(fragment)
188 | }
189 | },
190 | { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
191 | { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
192 | { type: "break" },
193 | { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
194 | { type: "new-line"},
195 | { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
196 | { type: "break" },
197 | check_button("enable_history","Enable History:",false, "Enable Canvas History"),
198 | { type: "button", id: "undo", tooltip: "Undo last erasing/last outpainting", icon: "fa-solid fa-rotate-left", disabled: true },
199 | { type: "button", id: "redo", tooltip: "Redo", icon: "fa-solid fa-rotate-right", disabled: true },
200 | { type: "break" },
201 | check_button("enable_img2img","Enable Img2Img",false),
202 | // check_button("use_correction","Photometric Correction",false),
203 | check_button("resize_check","Resize Small Input",true),
204 | check_button("enable_safety","Enable Safety Checker",true),
205 | check_button("square_selection","Square Selection Only",false),
206 | {type: "break"},
207 | check_button("use_seed","Use Seed:",false),
208 | { type: "html", id: "seed_val",
209 | async onRefresh(event) {
210 | await event.complete
211 | let fragment = query.html(`
212 | `)
213 | fragment.filter("input").on("change", event => {
214 | this.config_obj.seed_val = event.target.value;
215 | parent.config_obj=this.config_obj;
216 | this.refresh();
217 | })
218 | query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
219 | }
220 | },
221 | { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
222 | ],
223 | onClick(event) {
224 | switch(event.target){
225 | case "setting":
226 | open_setting();
227 | break;
228 | case "upload":
229 | this.upload_mode=true
230 | document.querySelector("#overlay_container").style.pointerEvents="auto";
231 | this.click("canvas");
232 | this.click("selection");
233 | this.show("confirm","cancel_overlay","add_image","delete_image");
234 | this.enable("confirm","cancel_overlay","add_image","delete_image");
235 | this.disable(...upload_button_lst);
236 | this.disable("undo","redo")
237 | query("#upload_file").click();
238 | if(this.upload_tip)
239 | {
240 | this.upload_tip=false;
241 | w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
242 | }
243 | break;
244 | case "resize_selection":
245 | this.resize_mode=true;
246 | this.disable(...resize_button_lst);
247 | this.enable("confirm","cancel_overlay");
248 | this.show("confirm","cancel_overlay");
249 | window.postMessage(["resize_selection",""],"*");
250 | document.querySelector("#overlay_container").style.pointerEvents="auto";
251 | break;
252 | case "confirm":
253 | if(this.upload_mode)
254 | {
255 | export_image();
256 | }
257 | else
258 | {
259 | let sel_box=this.selection_box;
260 | if(sel_box.width*sel_box.height>512*512)
261 | {
262 | w2utils.notify("Note that the outpainting will be much slower when the area of selection is larger than 512x512",{timeout:2000,where:query("#container")})
263 | }
264 | window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
265 | }
266 | case "cancel_overlay":
267 | end_overlay();
268 | this.hide("confirm","cancel_overlay","add_image","delete_image");
269 | if(this.upload_mode){
270 | this.enable(...upload_button_lst);
271 | }
272 | else
273 | {
274 | this.enable(...resize_button_lst);
275 | window.postMessage(["resize_selection","",""],"*");
276 | if(event.target=="cancel_overlay")
277 | {
278 | this.selection_box=this.selection_box_bak;
279 | }
280 | }
281 | if(this.selection_box)
282 | {
283 | this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
284 | }
285 | this.disable("confirm","cancel_overlay","add_image","delete_image");
286 | this.upload_mode=false;
287 | this.resize_mode=false;
288 | this.click("selection");
289 | window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo);
290 | break;
291 | case "add_image":
292 | query("#upload_file").click();
293 | break;
294 | case "delete_image":
295 | let active_obj = window.overlay.getActiveObject();
296 | if(active_obj)
297 | {
298 | window.overlay.remove(active_obj);
299 | window.overlay.renderAll();
300 | }
301 | else
302 | {
303 | w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
304 | }
305 | break;
306 | case "load":
307 | query("#upload_state").click();
308 | this.selection_box=null;
309 | this.setCount("resize_selection","");
310 | break;
311 | case "next":
312 | case "prev":
313 | window.postMessage(["outpaint", "", event.target], "*");
314 | break;
315 | case "outpaint":
316 | this.click("selection");
317 | this.disable(...outpaint_button_lst);
318 | this.show(...outpaint_result_lst);
319 | this.disable("undo","redo");
320 | if(this.outpaint_tip)
321 | {
322 | this.outpaint_tip=false;
323 | w2utils.notify("The canvas stays locked until you accept/cancel current outpainting. You can modify the 'sample number' to get multiple results; you can resize the canvas/selection with 'canvas setting'/'resize selection'; you can use 'photometric correction' to help preserve existing contents",{timeout:15000,where:query("#container")})
324 | }
325 | document.querySelector("#container").style.pointerEvents="none";
326 | case "retry":
327 | this.disable(...outpaint_result_func_lst);
328 | parent.config_obj["interrogate_mode"]=false;
329 | window.postMessage(["transfer",""],"*")
330 | break;
331 | case "interrogate":
332 | if(this.interrogate_tip)
333 | {
334 | this.interrogate_tip=false;
335 | w2utils.notify("ClipInterrogator v1 will be dynamically loaded when run at the first time, which may take a while",{timeout:10000,where:query("#container")})
336 | }
337 | parent.config_obj["interrogate_mode"]=true;
338 | window.postMessage(["transfer",""],"*")
339 | break
340 | case "accept":
341 | case "cancel":
342 | this.hide(...outpaint_result_lst);
343 | this.disable(...outpaint_result_func_lst);
344 | this.enable(...outpaint_button_lst);
345 | document.querySelector("#container").style.pointerEvents="auto";
346 | if(this.config_obj.enable_history)
347 | {
348 | window.postMessage(["click", event.target, ""],"*");
349 | }
350 | else
351 | {
352 | window.postMessage(["click", event.target],"*");
353 | }
354 | let app=parent.document.querySelector("gradio-app");
355 | app=app.shadowRoot??app;
356 | app.querySelector("#cancel").click();
357 | window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo);
358 | break;
359 | case "eraser":
360 | case "selection":
361 | case "canvas":
362 | if(event.target=="eraser")
363 | {
364 | this.show("eraser_size","eraser_size_btn");
365 | window.overlay.freeDrawingBrush.width=this.eraser_size;
366 | window.overlay.isDrawingMode = true;
367 | }
368 | else
369 | {
370 | this.hide("eraser_size","eraser_size_btn");
371 | window.overlay.isDrawingMode = false;
372 | }
373 | if(this.upload_mode)
374 | {
375 | if(event.target=="canvas")
376 | {
377 | window.postMessage(["mode", event.target],"*")
378 | document.querySelector("#overlay_container").style.pointerEvents="none";
379 | document.querySelector("#overlay_container").style.opacity = 0.5;
380 | }
381 | else
382 | {
383 | document.querySelector("#overlay_container").style.pointerEvents="auto";
384 | document.querySelector("#overlay_container").style.opacity = 1.0;
385 | }
386 | }
387 | else
388 | {
389 | window.postMessage(["mode", event.target],"*")
390 | }
391 | break;
392 | case "help":
393 | w2popup.open({
394 | title: "Document",
395 | body: "Usage: https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md"
396 | })
397 | break;
398 | case "clear":
399 | w2confirm("Reset canvas?").yes(() => {
400 | window.postMessage(["click", event.target],"*");
401 | }).no(() => {})
402 | break;
403 | case "random_seed":
404 | this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
405 | parent.config_obj=this.config_obj;
406 | this.refresh();
407 | break;
408 | case "enable_history":
409 | case "enable_img2img":
410 | case "use_correction":
411 | case "resize_check":
412 | case "enable_safety":
413 | case "use_seed":
414 | case "square_selection":
415 | let target=this.get(event.target);
416 | if(event.target=="enable_history")
417 | {
418 | if(!target.checked)
419 | {
420 | w2utils.notify("Enable canvas history might increase resource usage / slow down the canvas ", {error:true,timeout:3000,where:query("#container")})
421 | window.postMessage(["click","history"],"*");
422 | }
423 | else
424 | {
425 | window.undo_redo_state.undo=false;
426 | window.undo_redo_state.redo=false;
427 | this.disable("undo","redo");
428 | }
429 | }
430 | target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
431 | this.config_obj[event.target]=!target.checked;
432 | parent.config_obj=this.config_obj;
433 | this.refresh();
434 | break;
435 | case "save":
436 | case "export":
437 | ask_filename(event.target);
438 | break;
439 | default:
440 | // clear, save, export, outpaint, retry
441 | // break, save, export, accept, retry, outpaint
442 | window.postMessage(["click", event.target],"*")
443 | }
444 | console.log("Target: "+ event.target, event)
445 | }
446 | })
447 | window.w2ui=w2ui;
448 | w2ui.toolbar.config_obj={
449 | resize_check: true,
450 | enable_safety: true,
451 | use_correction: false,
452 | enable_img2img: false,
453 | use_seed: false,
454 | seed_val: 0,
455 | square_selection: false,
456 | enable_history: false,
457 | };
458 | w2ui.toolbar.outpaint_tip=true;
459 | w2ui.toolbar.upload_tip=true;
460 | w2ui.toolbar.interrogate_tip=true;
461 | window.update_count=function(cur,total){
462 | w2ui.toolbar.sel_value=`${cur}/${total}`;
463 | w2ui.toolbar.refresh();
464 | }
465 | window.update_eraser=function(val,max_val){
466 | w2ui.toolbar.eraser_size=`${val}`;
467 | w2ui.toolbar.eraser_max=`${max_val}`;
468 | w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
469 | w2ui.toolbar.refresh();
470 | }
471 | window.update_scale=function(val){
472 | w2ui.toolbar.scale_value=`${val}`;
473 | w2ui.toolbar.refresh();
474 | }
475 | window.enable_result_lst=function(){
476 | w2ui.toolbar.enable(...outpaint_result_lst);
477 | }
478 | function onObjectScaled(e)
479 | {
480 | let object = e.target;
481 | if(object.isType("rect"))
482 | {
483 | let width=object.getScaledWidth();
484 | let height=object.getScaledHeight();
485 | object.scale(1);
486 | width=Math.max(Math.min(width,window.overlay.width-object.left),256);
487 | height=Math.max(Math.min(height,window.overlay.height-object.top),256);
488 | let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
489 | let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
490 | if(window.w2ui.toolbar.config_obj.square_selection)
491 | {
492 | let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
493 | width=max_val;
494 | height=max_val;
495 | }
496 | object.set({ width: width, height: height, left:l,top:t})
497 | window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
498 | window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
499 | window.w2ui.toolbar.refresh();
500 | }
501 | }
502 | function onObjectMoved(e)
503 | {
504 | let object = e.target;
505 | if(object.isType("rect"))
506 | {
507 | let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
508 | let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
509 | object.set({left:l,top:t});
510 | window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
511 | }
512 | }
513 | window.setup_overlay=function(width,height)
514 | {
515 | if(window.overlay)
516 | {
517 | window.overlay.setDimensions({width:width,height:height});
518 | let app=parent.document.querySelector("gradio-app");
519 | app=app.shadowRoot??app;
520 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
521 | document.querySelector("#container").style.height= height+"px";
522 | document.querySelector("#container").style.width = width+"px";
523 | }
524 | else
525 | {
526 | canvas=new fabric.Canvas("overlay_canvas");
527 | canvas.setDimensions({width:width,height:height});
528 | let app=parent.document.querySelector("gradio-app");
529 | app=app.shadowRoot??app;
530 | app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
531 | canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
532 | canvas.on("object:scaling", onObjectScaled);
533 | canvas.on("object:moving", onObjectMoved);
534 | window.overlay=canvas;
535 | }
536 | document.querySelector("#overlay_container").style.pointerEvents="none";
537 | }
538 | window.update_overlay=function(width,height)
539 | {
540 | window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
541 | // document.querySelector("#overlay_container").style.pointerEvents="none";
542 | }
543 | window.adjust_selection=function(x,y,width,height)
544 | {
545 | var rect = new fabric.Rect({
546 | left: x,
547 | top: y,
548 | fill: "rgba(0,0,0,0)",
549 | strokeWidth: 3,
550 | stroke: "rgba(0,0,0,0.7)",
551 | cornerColor: "red",
552 | cornerStrokeColor: "red",
553 | borderColor: "rgba(255, 0, 0, 1.0)",
554 | width: width,
555 | height: height,
556 | lockRotation: true,
557 | });
558 | rect.setControlsVisibility({ mtr: false });
559 | window.overlay.add(rect);
560 | window.overlay.setActiveObject(window.overlay.item(0));
561 | window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
562 | window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
563 | }
564 | function add_image(url)
565 | {
566 | fabric.Image.fromURL(url,function(img){
567 | window.overlay.add(img);
568 | window.overlay.setActiveObject(img);
569 | },{left:100,top:100});
570 | }
571 | function export_image()
572 | {
573 | data=window.overlay.toDataURL();
574 | document.querySelector("#upload_content").value=data;
575 | if(window.w2ui.toolbar.config_obj.enable_history)
576 | {
577 | window.postMessage(["upload","",""],"*");
578 | window.w2ui.toolbar.enable("undo");
579 | window.w2ui.toolbar.disable("redo");
580 | }
581 | else
582 | {
583 | window.postMessage(["upload",""],"*");
584 | }
585 | end_overlay();
586 | }
587 | function end_overlay()
588 | {
589 | window.overlay.clear();
590 | document.querySelector("#overlay_container").style.opacity = 1.0;
591 | document.querySelector("#overlay_container").style.pointerEvents="none";
592 | }
593 | function ask_filename(target)
594 | {
595 | w2prompt({
596 | label: "Enter filename",
597 | value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
598 | })
599 | .change((event) => {
600 | console.log("change", event.detail.originalEvent.target.value);
601 | })
602 | .ok((event) => {
603 | console.log("value=", event.detail.value);
604 | window.postMessage(["click",target,event.detail.value],"*");
605 | })
606 | .cancel((event) => {
607 | console.log("cancel");
608 | });
609 | }
610 |
611 | document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
612 | window.setup_shortcut=function(json)
613 | {
614 | var config=JSON.parse(json);
615 | var key_map={};
616 | Object.keys(config.shortcut).forEach(k=>{
617 | key_map[config.shortcut[k]]=k;
618 | })
619 | document.addEventListener("keydown",(e)=>{
620 | if(e.target.tagName!="INPUT")
621 | {
622 | let key=e.key;
623 | if(e.ctrlKey)
624 | {
625 | key="Ctrl+"+e.key;
626 | if(key in key_map)
627 | {
628 | e.preventDefault();
629 | }
630 | }
631 | if(key in key_map)
632 | {
633 | w2ui.toolbar.click(key_map[key]);
634 | }
635 | }
636 | })
637 | }
638 | window.undo_redo_state={undo:false,redo:false};
639 | window.update_undo_redo=function(s0,s1)
640 | {
641 | if(s0)
642 | {
643 | w2ui.toolbar.enable("undo");
644 | }
645 | else
646 | {
647 | w2ui.toolbar.disable("undo");
648 | }
649 | if(s1)
650 | {
651 | w2ui.toolbar.enable("redo");
652 | }
653 | else
654 | {
655 | w2ui.toolbar.disable("redo");
656 | }
657 | window.undo_redo_state.undo=s0;
658 | window.undo_redo_state.redo=s1;
659 | }
--------------------------------------------------------------------------------
/js/upload.js:
--------------------------------------------------------------------------------
1 | function(a,b){
2 | if(!window.my_observe_upload)
3 | {
4 | console.log("setup upload here");
5 | window.my_observe_upload = new MutationObserver(function (event) {
6 | console.log(event);
7 | var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
8 | frame.querySelector("#upload").click();
9 | });
10 | window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
11 | window.my_observe_upload.observe(window.my_observe_upload_target, {
12 | attributes: false,
13 | subtree: true,
14 | childList: true,
15 | characterData: true
16 | });
17 | }
18 | return [a,b];
19 | }
--------------------------------------------------------------------------------
/js/xss.js:
--------------------------------------------------------------------------------
1 | var setup_outpaint=function(){
2 | if(!window.my_observe_outpaint)
3 | {
4 | console.log("setup outpaint here");
5 | window.my_observe_outpaint = new MutationObserver(function (event) {
6 | console.log(event);
7 | let app=document.querySelector("gradio-app");
8 | app=app.shadowRoot??app;
9 | let frame=app.querySelector("#sdinfframe").contentWindow;
10 | frame.postMessage(["outpaint", ""], "*");
11 | });
12 | var app=document.querySelector("gradio-app");
13 | app=app.shadowRoot??app;
14 | window.my_observe_outpaint_target=app.querySelector("#output span");
15 | window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16 | attributes: false,
17 | subtree: true,
18 | childList: true,
19 | characterData: true
20 | });
21 | }
22 | };
23 | window.config_obj={
24 | resize_check: true,
25 | enable_safety: true,
26 | use_correction: false,
27 | enable_img2img: false,
28 | use_seed: false,
29 | seed_val: 0,
30 | interrogate_mode: false,
31 | };
32 | setup_outpaint();
--------------------------------------------------------------------------------
/models/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/models/v1-inpainting-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 7.5e-05
3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: hybrid # important
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | finetune_keys: null
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/perlin2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | ##########
4 | # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5 | def perlin(x, y, seed=0):
6 | # permutation table
7 | np.random.seed(seed)
8 | p = np.arange(256, dtype=int)
9 | np.random.shuffle(p)
10 | p = np.stack([p, p]).flatten()
11 | # coordinates of the top-left
12 | xi, yi = x.astype(int), y.astype(int)
13 | # internal coordinates
14 | xf, yf = x - xi, y - yi
15 | # fade factors
16 | u, v = fade(xf), fade(yf)
17 | # noise components
18 | n00 = gradient(p[p[xi] + yi], xf, yf)
19 | n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20 | n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21 | n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22 | # combine noises
23 | x1 = lerp(n00, n10, u)
24 | x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25 | return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26 |
27 |
28 | def lerp(a, b, x):
29 | "linear interpolation"
30 | return a + x * (b - a)
31 |
32 |
33 | def fade(t):
34 | "6t^5 - 15t^4 + 10t^3"
35 | return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36 |
37 |
38 | def gradient(h, x, y):
39 | "grad converts h to the right gradient vector and return the dot product with (x,y)"
40 | vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41 | g = vectors[h % 4]
42 | return g[:, :, 0] * x + g[:, :, 1] * y
43 |
44 |
45 | ##########
--------------------------------------------------------------------------------
/postprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3 | MIT License
4 |
5 | Copyright (c) 2022 Jiayi Weng
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 |
26 | import time
27 | import argparse
28 | import os
29 | import fpie
30 | from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
31 | from fpie.io import read_images, write_image
32 | from process import BaseProcessor, EquProcessor, GridProcessor
33 |
34 | from PIL import Image
35 | import numpy as np
36 | import skimage
37 | import skimage.measure
38 | import scipy
39 | import scipy.signal
40 |
41 |
42 | class PhotometricCorrection:
43 | def __init__(self,quite=False):
44 | self.get_parser("cli")
45 | args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
46 | args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
47 | self.backend=args.backend
48 | self.args=args
49 | self.quite=quite
50 | proc: BaseProcessor
51 | proc = GridProcessor(
52 | args.gradient,
53 | args.backend,
54 | args.cpu,
55 | args.mpi_sync_interval,
56 | args.block_size,
57 | args.grid_x,
58 | args.grid_y,
59 | )
60 | print(
61 | f"[PIE]Successfully initialize PIE {args.method} solver "
62 | f"with {args.backend} backend"
63 | )
64 | self.proc=proc
65 |
66 | def run(self, original_image, inpainted_image, mode="mask_mode"):
67 | print(f"[PIE] start")
68 | if mode=="disabled":
69 | return inpainted_image
70 | input_arr=np.array(original_image)
71 | if input_arr[:,:,-1].sum()<1:
72 | return inpainted_image
73 | output_arr=np.array(inpainted_image)
74 | mask=input_arr[:,:,-1]
75 | mask=255-mask
76 | if mask.sum()<1 and mode=="mask_mode":
77 | mode=""
78 | if mode=="mask_mode":
79 | mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
80 | mask = mask.repeat(8, axis=0).repeat(8, axis=1)
81 | else:
82 | mask[8:-9,8:-9]=255
83 | mask = mask[:,:,np.newaxis].repeat(3,axis=2)
84 | nmask=mask.copy()
85 | output_arr2=output_arr[:,:,0:3].copy()
86 | input_arr2=input_arr[:,:,0:3].copy()
87 | output_arr2[nmask<128]=0
88 | input_arr2[nmask>=128]=0
89 | output_arr2+=input_arr2
90 | src = output_arr2[:,:,0:3]
91 | tgt = src.copy()
92 | proc=self.proc
93 | args=self.args
94 | if proc.root:
95 | n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
96 | proc.sync()
97 | if proc.root:
98 | result = tgt
99 | t = time.time()
100 | if args.p == 0:
101 | args.p = args.n
102 |
103 | for i in range(0, args.n, args.p):
104 | if proc.root:
105 | result, err = proc.step(args.p) # type: ignore
106 | print(f"[PIE] Iter {i + args.p}, abs_err {err}")
107 | else:
108 | proc.step(args.p)
109 |
110 | if proc.root:
111 | dt = time.time() - t
112 | print(f"[PIE] Time elapsed: {dt:.4f}s")
113 | # make sure consistent with dummy process
114 | return Image.fromarray(result)
115 |
116 |
117 | def get_parser(self,gen_type: str) -> argparse.Namespace:
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument(
120 | "-v", "--version", action="store_true", help="show the version and exit"
121 | )
122 | parser.add_argument(
123 | "--check-backend", action="store_true", help="print all available backends"
124 | )
125 | if gen_type == "gui" and "mpi" in ALL_BACKEND:
126 | # gui doesn't support MPI backend
127 | ALL_BACKEND.remove("mpi")
128 | parser.add_argument(
129 | "-b",
130 | "--backend",
131 | type=str,
132 | choices=ALL_BACKEND,
133 | default=DEFAULT_BACKEND,
134 | help="backend choice",
135 | )
136 | parser.add_argument(
137 | "-c",
138 | "--cpu",
139 | type=int,
140 | default=CPU_COUNT,
141 | help="number of CPU used",
142 | )
143 | parser.add_argument(
144 | "-z",
145 | "--block-size",
146 | type=int,
147 | default=1024,
148 | help="cuda block size (only for equ solver)",
149 | )
150 | parser.add_argument(
151 | "--method",
152 | type=str,
153 | choices=["equ", "grid"],
154 | default="equ",
155 | help="how to parallelize computation",
156 | )
157 | parser.add_argument("-s", "--source", type=str, help="source image filename")
158 | if gen_type == "cli":
159 | parser.add_argument(
160 | "-m",
161 | "--mask",
162 | type=str,
163 | help="mask image filename (default is to use the whole source image)",
164 | default="",
165 | )
166 | parser.add_argument("-t", "--target", type=str, help="target image filename")
167 | parser.add_argument("-o", "--output", type=str, help="output image filename")
168 | if gen_type == "cli":
169 | parser.add_argument(
170 | "-h0", type=int, help="mask position (height) on source image", default=0
171 | )
172 | parser.add_argument(
173 | "-w0", type=int, help="mask position (width) on source image", default=0
174 | )
175 | parser.add_argument(
176 | "-h1", type=int, help="mask position (height) on target image", default=0
177 | )
178 | parser.add_argument(
179 | "-w1", type=int, help="mask position (width) on target image", default=0
180 | )
181 | parser.add_argument(
182 | "-g",
183 | "--gradient",
184 | type=str,
185 | choices=["max", "src", "avg"],
186 | default="max",
187 | help="how to calculate gradient for PIE",
188 | )
189 | parser.add_argument(
190 | "-n",
191 | type=int,
192 | help="how many iteration would you perfer, the more the better",
193 | default=5000,
194 | )
195 | if gen_type == "cli":
196 | parser.add_argument(
197 | "-p", type=int, help="output result every P iteration", default=0
198 | )
199 | if "mpi" in ALL_BACKEND:
200 | parser.add_argument(
201 | "--mpi-sync-interval",
202 | type=int,
203 | help="MPI sync iteration interval",
204 | default=100,
205 | )
206 | parser.add_argument(
207 | "--grid-x", type=int, help="x axis stride for grid solver", default=8
208 | )
209 | parser.add_argument(
210 | "--grid-y", type=int, help="y axis stride for grid solver", default=8
211 | )
212 | self.parser=parser
213 |
214 | if __name__ =="__main__":
215 | import sys
216 | import io
217 | import base64
218 | from PIL import Image
219 | def base64_to_pil(base64_str):
220 | data = base64.b64decode(str(base64_str))
221 | pil = Image.open(io.BytesIO(data))
222 | return pil
223 |
224 | def pil_to_base64(out_pil):
225 | out_buffer = io.BytesIO()
226 | out_pil.save(out_buffer, format="PNG")
227 | out_buffer.seek(0)
228 | base64_bytes = base64.b64encode(out_buffer.read())
229 | base64_str = base64_bytes.decode("ascii")
230 | return base64_str
231 | correction_func=PhotometricCorrection(quite=True)
232 | while True:
233 | buffer = sys.stdin.readline()
234 | print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
235 | if len(buffer)==0:
236 | break
237 | if isinstance(buffer,str):
238 | lst=buffer.strip().split(",")
239 | else:
240 | lst=buffer.decode("ascii").strip().split(",")
241 | img0=base64_to_pil(lst[0])
242 | img1=base64_to_pil(lst[1])
243 | ret=correction_func.run(img0,img1,mode=lst[2])
244 | ret_base64=pil_to_base64(ret)
245 | if isinstance(buffer,str):
246 | sys.stdout.write(f"{ret_base64}\n")
247 | else:
248 | sys.stdout.write(f"{ret_base64}\n".encode())
249 | sys.stdout.flush()
--------------------------------------------------------------------------------
/process.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3 | MIT License
4 |
5 | Copyright (c) 2022 Jiayi Weng
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 | """
25 | import os
26 | from abc import ABC, abstractmethod
27 | from typing import Any, Optional, Tuple
28 |
29 | import numpy as np
30 |
31 | from fpie import np_solver
32 |
33 | import scipy
34 | import scipy.signal
35 |
36 | CPU_COUNT = os.cpu_count() or 1
37 | DEFAULT_BACKEND = "numpy"
38 | ALL_BACKEND = ["numpy"]
39 |
40 | try:
41 | from fpie import numba_solver
42 | ALL_BACKEND += ["numba"]
43 | DEFAULT_BACKEND = "numba"
44 | except ImportError:
45 | numba_solver = None # type: ignore
46 |
47 | try:
48 | from fpie import taichi_solver
49 | ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
50 | DEFAULT_BACKEND = "taichi-cpu"
51 | except ImportError:
52 | taichi_solver = None # type: ignore
53 |
54 | # try:
55 | # from fpie import core_gcc # type: ignore
56 | # DEFAULT_BACKEND = "gcc"
57 | # ALL_BACKEND.append("gcc")
58 | # except ImportError:
59 | # core_gcc = None
60 |
61 | # try:
62 | # from fpie import core_openmp # type: ignore
63 | # DEFAULT_BACKEND = "openmp"
64 | # ALL_BACKEND.append("openmp")
65 | # except ImportError:
66 | # core_openmp = None
67 |
68 | # try:
69 | # from mpi4py import MPI
70 |
71 | # from fpie import core_mpi # type: ignore
72 | # ALL_BACKEND.append("mpi")
73 | # except ImportError:
74 | # MPI = None # type: ignore
75 | # core_mpi = None
76 |
77 | try:
78 | from fpie import core_cuda # type: ignore
79 | DEFAULT_BACKEND = "cuda"
80 | ALL_BACKEND.append("cuda")
81 | except ImportError:
82 | core_cuda = None
83 |
84 |
85 | class BaseProcessor(ABC):
86 | """API definition for processor class."""
87 |
88 | def __init__(
89 | self, gradient: str, rank: int, backend: str, core: Optional[Any]
90 | ):
91 | if core is None:
92 | error_msg = {
93 | "numpy":
94 | "Please run `pip install numpy`.",
95 | "numba":
96 | "Please run `pip install numba`.",
97 | "gcc":
98 | "Please install cmake and gcc in your operating system.",
99 | "openmp":
100 | "Please make sure your gcc is compatible with `-fopenmp` option.",
101 | "mpi":
102 | "Please install MPI and run `pip install mpi4py`.",
103 | "cuda":
104 | "Please make sure nvcc and cuda-related libraries are available.",
105 | "taichi":
106 | "Please run `pip install taichi`.",
107 | }
108 | print(error_msg[backend.split("-")[0]])
109 |
110 | raise AssertionError(f"Invalid backend {backend}.")
111 |
112 | self.gradient = gradient
113 | self.rank = rank
114 | self.backend = backend
115 | self.core = core
116 | self.root = rank == 0
117 |
118 | def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119 | if self.gradient == "src":
120 | return a
121 | if self.gradient == "avg":
122 | return (a + b) / 2
123 | # mix gradient, see Equ. 12 in PIE paper
124 | mask = np.abs(a) < np.abs(b)
125 | a[mask] = b[mask]
126 | return a
127 |
128 | @abstractmethod
129 | def reset(
130 | self,
131 | src: np.ndarray,
132 | mask: np.ndarray,
133 | tgt: np.ndarray,
134 | mask_on_src: Tuple[int, int],
135 | mask_on_tgt: Tuple[int, int],
136 | ) -> int:
137 | pass
138 |
139 | def sync(self) -> None:
140 | self.core.sync()
141 |
142 | @abstractmethod
143 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
144 | pass
145 |
146 |
147 | class EquProcessor(BaseProcessor):
148 | """PIE Jacobi equation processor."""
149 |
150 | def __init__(
151 | self,
152 | gradient: str = "max",
153 | backend: str = DEFAULT_BACKEND,
154 | n_cpu: int = CPU_COUNT,
155 | min_interval: int = 100,
156 | block_size: int = 1024,
157 | ):
158 | core: Optional[Any] = None
159 | rank = 0
160 |
161 | if backend == "numpy":
162 | core = np_solver.EquSolver()
163 | elif backend == "numba" and numba_solver is not None:
164 | core = numba_solver.EquSolver()
165 | elif backend == "gcc":
166 | core = core_gcc.EquSolver()
167 | elif backend == "openmp" and core_openmp is not None:
168 | core = core_openmp.EquSolver(n_cpu)
169 | elif backend == "mpi" and core_mpi is not None:
170 | core = core_mpi.EquSolver(min_interval)
171 | rank = MPI.COMM_WORLD.Get_rank()
172 | elif backend == "cuda" and core_cuda is not None:
173 | core = core_cuda.EquSolver(block_size)
174 | elif backend.startswith("taichi") and taichi_solver is not None:
175 | core = taichi_solver.EquSolver(backend, n_cpu, block_size)
176 |
177 | super().__init__(gradient, rank, backend, core)
178 |
179 | def mask2index(
180 | self, mask: np.ndarray
181 | ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
182 | x, y = np.nonzero(mask)
183 | max_id = x.shape[0] + 1
184 | index = np.zeros((max_id, 3))
185 | ids = self.core.partition(mask)
186 | ids[mask == 0] = 0 # reserve id=0 for constant
187 | index = ids[x, y].argsort()
188 | return ids, max_id, x[index], y[index]
189 |
190 | def reset(
191 | self,
192 | src: np.ndarray,
193 | mask: np.ndarray,
194 | tgt: np.ndarray,
195 | mask_on_src: Tuple[int, int],
196 | mask_on_tgt: Tuple[int, int],
197 | ) -> int:
198 | assert self.root
199 | # check validity
200 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
201 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
202 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
203 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
204 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
205 |
206 | if len(mask.shape) == 3:
207 | mask = mask.mean(-1)
208 | mask = (mask >= 128).astype(np.int32)
209 |
210 | # zero-out edge
211 | mask[0] = 0
212 | mask[-1] = 0
213 | mask[:, 0] = 0
214 | mask[:, -1] = 0
215 |
216 | x, y = np.nonzero(mask)
217 | x0, x1 = x.min() - 1, x.max() + 2
218 | y0, y1 = y.min() - 1, y.max() + 2
219 | mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
220 | mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
221 | mask = mask[x0:x1, y0:y1]
222 | ids, max_id, index_x, index_y = self.mask2index(mask)
223 |
224 | src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
225 | tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
226 |
227 | src_C = src[src_x, src_y].astype(np.float32)
228 | src_U = src[src_x - 1, src_y].astype(np.float32)
229 | src_D = src[src_x + 1, src_y].astype(np.float32)
230 | src_L = src[src_x, src_y - 1].astype(np.float32)
231 | src_R = src[src_x, src_y + 1].astype(np.float32)
232 | tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
233 | tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
234 | tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
235 | tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
236 | tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
237 |
238 | grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
239 | + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
240 | + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
241 | + self.mixgrad(src_C - src_D, tgt_C - tgt_D)
242 |
243 | A = np.zeros((max_id, 4), np.int32)
244 | X = np.zeros((max_id, 3), np.float32)
245 | B = np.zeros((max_id, 3), np.float32)
246 |
247 | X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
248 | # four-way
249 | A[1:, 0] = ids[index_x - 1, index_y]
250 | A[1:, 1] = ids[index_x + 1, index_y]
251 | A[1:, 2] = ids[index_x, index_y - 1]
252 | A[1:, 3] = ids[index_x, index_y + 1]
253 | B[1:] = grad
254 | m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
255 | B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
256 | m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
257 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
258 | m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
259 | B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
260 | m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
261 | B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
262 |
263 | self.tgt = tgt.copy()
264 | self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
265 | self.core.reset(max_id, A, X, B)
266 | return max_id
267 |
268 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
269 | result = self.core.step(iteration)
270 | if self.root:
271 | x, err = result
272 | self.tgt[self.tgt_index] = x[1:]
273 | return self.tgt, err
274 | return None
275 |
276 |
277 | class GridProcessor(BaseProcessor):
278 | """PIE grid processor."""
279 |
280 | def __init__(
281 | self,
282 | gradient: str = "max",
283 | backend: str = DEFAULT_BACKEND,
284 | n_cpu: int = CPU_COUNT,
285 | min_interval: int = 100,
286 | block_size: int = 1024,
287 | grid_x: int = 8,
288 | grid_y: int = 8,
289 | ):
290 | core: Optional[Any] = None
291 | rank = 0
292 |
293 | if backend == "numpy":
294 | core = np_solver.GridSolver()
295 | elif backend == "numba" and numba_solver is not None:
296 | core = numba_solver.GridSolver()
297 | elif backend == "gcc":
298 | core = core_gcc.GridSolver(grid_x, grid_y)
299 | elif backend == "openmp" and core_openmp is not None:
300 | core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
301 | elif backend == "mpi" and core_mpi is not None:
302 | core = core_mpi.GridSolver(min_interval)
303 | rank = MPI.COMM_WORLD.Get_rank()
304 | elif backend == "cuda" and core_cuda is not None:
305 | core = core_cuda.GridSolver(grid_x, grid_y)
306 | elif backend.startswith("taichi") and taichi_solver is not None:
307 | core = taichi_solver.GridSolver(
308 | grid_x, grid_y, backend, n_cpu, block_size
309 | )
310 |
311 | super().__init__(gradient, rank, backend, core)
312 |
313 | def reset(
314 | self,
315 | src: np.ndarray,
316 | mask: np.ndarray,
317 | tgt: np.ndarray,
318 | mask_on_src: Tuple[int, int],
319 | mask_on_tgt: Tuple[int, int],
320 | ) -> int:
321 | assert self.root
322 | # check validity
323 | # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
324 | # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
325 | # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
326 | # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
327 | # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
328 |
329 | if len(mask.shape) == 3:
330 | mask = mask.mean(-1)
331 | mask = (mask >= 128).astype(np.int32)
332 |
333 | # zero-out edge
334 | mask[0] = 0
335 | mask[-1] = 0
336 | mask[:, 0] = 0
337 | mask[:, -1] = 0
338 |
339 | x, y = np.nonzero(mask)
340 | x0, x1 = x.min() - 1, x.max() + 2
341 | y0, y1 = y.min() - 1, y.max() + 2
342 | mask = mask[x0:x1, y0:y1]
343 | max_id = np.prod(mask.shape)
344 |
345 | src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
346 | mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
347 | tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
348 | mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
349 | grad = np.zeros([*mask.shape, 3], np.float32)
350 | grad[1:] += self.mixgrad(
351 | src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
352 | )
353 | grad[:-1] += self.mixgrad(
354 | src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
355 | )
356 | grad[:, 1:] += self.mixgrad(
357 | src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
358 | )
359 | grad[:, :-1] += self.mixgrad(
360 | src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
361 | )
362 |
363 | grad[mask == 0] = 0
364 | if True:
365 | kernel = [[1] * 3 for _ in range(3)]
366 | nmask = mask.copy()
367 | nmask[nmask > 0] = 1
368 | res = scipy.signal.convolve2d(
369 | nmask, kernel, mode="same", boundary="fill", fillvalue=1
370 | )
371 | res[nmask < 1] = 0
372 | res[res == 9] = 0
373 | res[res > 0] = 1
374 | grad[res>0]=0
375 | # ylst, xlst = res.nonzero()
376 | # for y, x in zip(ylst, xlst):
377 | # grad[y,x]=0
378 | # for yi in range(-1,2):
379 | # for xi in range(-1,2):
380 | # grad[y+yi,x+xi]=0
381 | self.x0 = mask_on_tgt[0] + x0
382 | self.x1 = mask_on_tgt[0] + x1
383 | self.y0 = mask_on_tgt[1] + y0
384 | self.y1 = mask_on_tgt[1] + y1
385 | self.tgt = tgt.copy()
386 | self.core.reset(max_id, mask, tgt_crop, grad)
387 | return max_id
388 |
389 | def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
390 | result = self.core.step(iteration)
391 | if self.root:
392 | tgt, err = result
393 | self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
394 | return self.tgt, err
395 | return None
396 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # stablediffusion-infinity
2 |
3 | Outpainting with Stable Diffusion on an infinite canvas.
4 |
5 | [](https://colab.research.google.com/github/lkwq007/stablediffusion-infinity/blob/master/stablediffusion_infinity_colab.ipynb)
6 | [](https://huggingface.co/spaces/lnyan/stablediffusion-infinity)
7 | [](https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/setup_guide.md)
8 |
9 | 
10 |
11 | https://user-images.githubusercontent.com/1665437/197244111-51884b3b-dffe-4dcf-a82a-fa5117c79934.mp4
12 |
13 | ## Status
14 |
15 | Powered by Stable Diffusion inpainting model, this project now works well. However, the quality of results is still not guaranteed.
16 | You may need to do prompt engineering, change the size of the selection, reduce the size of the outpainting region to get better outpainting results.
17 |
18 | The project now becomes a web app based on PyScript and Gradio. For Jupyter Notebook version, please check out the [ipycanvas](https://github.com/lkwq007/stablediffusion-infinity/tree/ipycanvas) branch.
19 |
20 | Pull requests are welcome for better UI control, ideas to achieve better results, or any other improvements.
21 |
22 | Update: the project add photometric correction to suppress seams, to use this feature, you need to install [fpie](https://github.com/Trinkle23897/Fast-Poisson-Image-Editing): `pip install fpie` (Linux/MacOS only)
23 |
24 | ## Docs
25 |
26 | ### Get Started
27 |
28 | - Setup for Windows: [setup_guide](./docs/setup_guide.md#windows)
29 | - Setup for Linux: [setup_guide](./docs/setup_guide.md#linux)
30 | - Setup for MacOS: [setup_guide](./docs/setup_guide.md#macos)
31 | - Running with Docker on Windows or Linux with NVIDIA GPU: [run_with_docker](./docs/run_with_docker.md)
32 | - Usages: [usage](./docs/usage.md)
33 |
34 | ### FAQs
35 |
36 | - The result is a black square:
37 | - False positive rate of safety checker is relatively high, you may disable the safety_checker
38 | - Some GPUs might not work with `fp16`: `python app.py --fp32 --lowvram`
39 | - What is the init_mode
40 | - init_mode indicates how to fill the empty/masked region, usually `patch_match` is better than others
41 | - Why not use `postMessage` for iframe interaction
42 | - The iframe and the gradio are in the same origin. For `postMessage` version, check out [gradio-space](https://github.com/lkwq007/stablediffusion-infinity/tree/gradio-space) version
43 |
44 | ### Known issues
45 |
46 | - The canvas is implemented with `NumPy` + `PyScript` (the project was originally implemented with `ipycanvas` inside a jupyter notebook), which is relatively inefficient compared with pure frontend solutions.
47 | - By design, the canvas is infinite. However, the canvas size is **finite** in practice. Your RAM and browser limit the canvas size. The canvas might crash or behave strangely when zoomed out by a certain scale.
48 | - The canvas requires internet: You can deploy and serve PyScript, Pyodide, and other JS/CSS assets with a local HTTP server and modify `index.html` accordingly.
49 | - Photometric correction might not work (`taichi` does not support the multithreading environment). A dirty hack (quite unreliable) is implemented to move related computation inside a subprocess.
50 | - Stable Diffusion inpainting model is much slower when selection size is larger than 512x512
51 |
52 | ## Credit
53 |
54 | The code of `perlin2d.py` is from https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921 and is **not** included in the scope of LICENSE used in this repo.
55 |
56 | The submodule `glid_3_xl_stable` is based on https://github.com/Jack000/glid-3-xl-stable
57 |
58 | The submodule `PyPatchMatch` is based on https://github.com/vacancy/PyPatchMatch
59 |
60 | The code of `postprocess.py` and `process.py` is modified based on https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
61 |
62 | The code of `convert_checkpoint.py` is modified based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
63 |
64 | The submodule `sd_grpcserver` and `handleImageAdjustment()` in `utils.py` are based on https://github.com/hafriedlander/stable-diffusion-grpcserver and https://github.com/parlance-zz/g-diffuser-bot
65 |
66 | `w2ui.min.js` and `w2ui.min.css` is from https://github.com/vitmalina/w2ui. `fabric.min.js` is a custom build of https://github.com/fabricjs/fabric.js
67 |
68 | `interrogate.py` is based on https://github.com/pharmapsychotic/clip-interrogator v1, the submodule `blip_model` is based on https://github.com/salesforce/BLIP
69 |
--------------------------------------------------------------------------------
/stablediffusion_infinity_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "collapsed_sections": []
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "# stablediffusion-infinity\n",
23 | "\n",
24 | "https://github.com/lkwq007/stablediffusion-infinity\n",
25 | "\n",
26 | "Outpainting with Stable Diffusion on an infinite canvas"
27 | ],
28 | "metadata": {
29 | "id": "IgN1jqV_DemW"
30 | }
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {
36 | "id": "JvbfNNSJDTW5"
37 | },
38 | "outputs": [],
39 | "source": [
40 | "#@title setup libs\n",
41 | "!nvidia-smi -L\n",
42 | "!pip install -qq -U diffusers==0.11.1 transformers ftfy accelerate\n",
43 | "!pip install -q gradio==3.11.0\n",
44 | "!pip install -q fpie timm\n",
45 | "!pip uninstall taichi -y"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "source": [
51 | "#@title setup stablediffusion-infinity\n",
52 | "!git clone --recurse-submodules https://github.com/lkwq007/stablediffusion-infinity\n",
53 | "%cd stablediffusion-infinity\n",
54 | "!cp -r PyPatchMatch/csrc .\n",
55 | "!cp PyPatchMatch/Makefile .\n",
56 | "!cp PyPatchMatch/Makefile_fallback .\n",
57 | "!cp PyPatchMatch/travis.sh .\n",
58 | "!cp PyPatchMatch/patch_match.py . "
59 | ],
60 | "metadata": {
61 | "id": "D1BDhQCJDilE"
62 | },
63 | "execution_count": null,
64 | "outputs": []
65 | },
66 | {
67 | "cell_type": "code",
68 | "source": [
69 | "#@title start stablediffusion-infinity (first setup may takes about two minutes for downloading models)\n",
70 | "!python app.py --share"
71 | ],
72 | "metadata": {
73 | "id": "UGotC5ckDlmO"
74 | },
75 | "execution_count": null,
76 | "outputs": []
77 | },
78 | {
79 | "cell_type": "code",
80 | "source": [],
81 | "metadata": {
82 | "id": "R1-E07CMFZoj"
83 | },
84 | "execution_count": null,
85 | "outputs": []
86 | }
87 | ]
88 | }
89 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from PIL import ImageFilter
3 | import cv2
4 | import numpy as np
5 | import scipy
6 | import scipy.signal
7 | from scipy.spatial import cKDTree
8 |
9 | import os
10 | from perlin2d import *
11 |
12 | patch_match_compiled = True
13 |
14 | try:
15 | from PyPatchMatch import patch_match
16 | except Exception as e:
17 | try:
18 | import patch_match
19 | except Exception as e:
20 | patch_match_compiled = False
21 |
22 | try:
23 | patch_match
24 | except NameError:
25 | print("patch_match compiling failed, will fall back to edge_pad")
26 | patch_match_compiled = False
27 |
28 |
29 |
30 |
31 | def edge_pad(img, mask, mode=1):
32 | if mode == 0:
33 | nmask = mask.copy()
34 | nmask[nmask > 0] = 1
35 | res0 = 1 - nmask
36 | res1 = nmask
37 | p0 = np.stack(res0.nonzero(), axis=0).transpose()
38 | p1 = np.stack(res1.nonzero(), axis=0).transpose()
39 | min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
40 | loc = p1[min_dist_idx]
41 | for (a, b), (c, d) in zip(p0, loc):
42 | img[a, b] = img[c, d]
43 | elif mode == 1:
44 | record = {}
45 | kernel = [[1] * 3 for _ in range(3)]
46 | nmask = mask.copy()
47 | nmask[nmask > 0] = 1
48 | res = scipy.signal.convolve2d(
49 | nmask, kernel, mode="same", boundary="fill", fillvalue=1
50 | )
51 | res[nmask < 1] = 0
52 | res[res == 9] = 0
53 | res[res > 0] = 1
54 | ylst, xlst = res.nonzero()
55 | queue = [(y, x) for y, x in zip(ylst, xlst)]
56 | # bfs here
57 | cnt = res.astype(np.float32)
58 | acc = img.astype(np.float32)
59 | step = 1
60 | h = acc.shape[0]
61 | w = acc.shape[1]
62 | offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
63 | while queue:
64 | target = []
65 | for y, x in queue:
66 | val = acc[y][x]
67 | for yo, xo in offset:
68 | yn = y + yo
69 | xn = x + xo
70 | if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
71 | if record.get((yn, xn), step) == step:
72 | acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
73 | cnt[yn][xn] += 1
74 | acc[yn][xn] /= cnt[yn][xn]
75 | if (yn, xn) not in record:
76 | record[(yn, xn)] = step
77 | target.append((yn, xn))
78 | step += 1
79 | queue = target
80 | img = acc.astype(np.uint8)
81 | else:
82 | nmask = mask.copy()
83 | ylst, xlst = nmask.nonzero()
84 | yt, xt = ylst.min(), xlst.min()
85 | yb, xb = ylst.max(), xlst.max()
86 | content = img[yt : yb + 1, xt : xb + 1]
87 | img = np.pad(
88 | content,
89 | ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
90 | mode="edge",
91 | )
92 | return img, mask
93 |
94 |
95 | def perlin_noise(img, mask):
96 | lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False)
97 | lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False)
98 | x, y = np.meshgrid(lin_x, lin_y)
99 | avg = img.mean(axis=0).mean(axis=0)
100 | # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
101 | noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
102 | noise = np.stack(noise, axis=-1)
103 | # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
104 | # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
105 | # mask_image=Image.fromarray(mask)
106 | # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
107 | # mask=np.array(mask_image)
108 | nmask = mask.copy()
109 | # nmask=nmask/255.0
110 | nmask[mask > 0] = 1
111 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
112 | # img=img.astype(np.uint8)
113 | return img, mask
114 |
115 |
116 | def gaussian_noise(img, mask):
117 | noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
118 | noise = (noise + 1) / 2 * 255
119 | noise = noise.astype(np.uint8)
120 | nmask = mask.copy()
121 | nmask[mask > 0] = 1
122 | img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
123 | return img, mask
124 |
125 |
126 | def cv2_telea(img, mask):
127 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
128 | return ret, mask
129 |
130 |
131 | def cv2_ns(img, mask):
132 | ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
133 | return ret, mask
134 |
135 |
136 | def patch_match_func(img, mask):
137 | ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
138 | return ret, mask
139 |
140 |
141 | def mean_fill(img, mask):
142 | avg = img.mean(axis=0).mean(axis=0)
143 | img[mask < 1] = avg
144 | return img, mask
145 |
146 | """
147 | Apache-2.0 license
148 | https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py
149 | https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2
150 | _handleImageAdjustment
151 | """
152 | try:
153 | from sd_grpcserver.sdgrpcserver import images
154 | import torch
155 | from math import sqrt
156 | def handleImageAdjustment(array, adjustments):
157 | tensor = images.fromPIL(Image.fromarray(array))
158 | for adjustment in adjustments:
159 | which = adjustment[0]
160 |
161 | if which == "blur":
162 | sigma = adjustment[1]
163 | direction = adjustment[2]
164 |
165 | if direction == "DOWN" or direction == "UP":
166 | orig = tensor
167 | repeatCount=256
168 | sigma /= sqrt(repeatCount)
169 |
170 | for _ in range(repeatCount):
171 | tensor = images.gaussianblur(tensor, sigma)
172 | if direction == "DOWN":
173 | tensor = torch.minimum(tensor, orig)
174 | else:
175 | tensor = torch.maximum(tensor, orig)
176 | else:
177 | tensor = images.gaussianblur(tensor, adjustment.blur.sigma)
178 | elif which == "invert":
179 | tensor = images.invert(tensor)
180 | elif which == "levels":
181 | tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4])
182 | elif which == "channels":
183 | tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a])
184 | elif which == "rescale":
185 | self.unimp("Rescale")
186 | elif which == "crop":
187 | tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width)
188 | return np.array(images.toPIL(tensor)[0])
189 |
190 | def g_diffuser(img,mask):
191 | adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]]
192 | mask=handleImageAdjustment(mask,adjustments)
193 | out_mask=handleImageAdjustment(mask,adjustments)
194 | return img, mask
195 | except:
196 | def g_diffuser(img,mask):
197 | return img,mask
198 |
199 | def dummy_fill(img,mask):
200 | return img,mask
201 | functbl = {
202 | "gaussian": gaussian_noise,
203 | "perlin": perlin_noise,
204 | "edge_pad": edge_pad,
205 | "patchmatch": patch_match_func if patch_match_compiled else edge_pad,
206 | "cv2_ns": cv2_ns,
207 | "cv2_telea": cv2_telea,
208 | "g_diffuser": g_diffuser,
209 | "g_diffuser_lib": dummy_fill,
210 | }
211 |
212 | try:
213 | from postprocess import PhotometricCorrection
214 | correction_func = PhotometricCorrection()
215 | except Exception as e:
216 | print(e, "so PhotometricCorrection is disabled")
217 | class DummyCorrection:
218 | def __init__(self):
219 | self.backend=""
220 | pass
221 | def run(self,a,b,**kwargs):
222 | return b
223 | correction_func=DummyCorrection()
224 |
225 | class DummyInterrogator:
226 | def __init__(self) -> None:
227 | pass
228 | def interrogate(self,pil):
229 | return "Interrogator init failed"
230 |
231 | if "taichi" in correction_func.backend:
232 | import sys
233 | import io
234 | import base64
235 | from PIL import Image
236 | def base64_to_pil(base64_str):
237 | data = base64.b64decode(str(base64_str))
238 | pil = Image.open(io.BytesIO(data))
239 | return pil
240 |
241 | def pil_to_base64(out_pil):
242 | out_buffer = io.BytesIO()
243 | out_pil.save(out_buffer, format="PNG")
244 | out_buffer.seek(0)
245 | base64_bytes = base64.b64encode(out_buffer.read())
246 | base64_str = base64_bytes.decode("ascii")
247 | return base64_str
248 | from subprocess import Popen, PIPE, STDOUT
249 | class SubprocessCorrection:
250 | def __init__(self):
251 | self.backend=correction_func.backend
252 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
253 | def run(self,img_input,img_inpainted,mode):
254 | if mode=="disabled":
255 | return img_inpainted
256 | base64_str_input = pil_to_base64(img_input)
257 | base64_str_inpainted = pil_to_base64(img_inpainted)
258 | try:
259 | if self.child.poll():
260 | self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
261 | self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
262 | self.child.stdin.flush()
263 | out = self.child.stdout.readline()
264 | base64_str=out.decode().strip()
265 | while base64_str and base64_str[0]=="[":
266 | print(base64_str)
267 | out = self.child.stdout.readline()
268 | base64_str=out.decode().strip()
269 | ret=base64_to_pil(base64_str)
270 | except:
271 | print("[PIE] not working, photometric correction is disabled")
272 | ret=img_inpainted
273 | return ret
274 | correction_func = SubprocessCorrection()
275 |
--------------------------------------------------------------------------------