├── .gitattributes
├── .gitignore
├── LICENSE.txt
├── README.md
├── app.py
├── doc
└── StableNormal-Teaser.png
├── gradio_cached_examples
└── examples_image
│ └── log.csv
├── hubconf.py
├── requirements.txt
├── requirements_min.txt
├── setup.py
└── stablenormal
├── __init__.py
├── metrics
├── compute_metric.py
└── compute_variance.py
├── pipeline_stablenormal.py
├── pipeline_yoso_normal.py
├── scheduler
├── __init__.py
└── heuristics_ddimsampler.py
└── stablecontrolnet.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 | *.stl filter=lfs diff=lfs merge=lfs -text
37 | *.glb filter=lfs diff=lfs merge=lfs -text
38 | *.jpg filter=lfs diff=lfs merge=lfs -text
39 | *.jpeg filter=lfs diff=lfs merge=lfs -text
40 | *.png filter=lfs diff=lfs merge=lfs -text
41 | *.mp4 filter=lfs diff=lfs merge=lfs -text
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .DS_Store
3 | __pycache__
4 | weights
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal**
2 |
3 | [Chongjie Ye*](https://github.com/hugoycj), [Lingteng Qiu*](https://lingtengqiu.github.io/), [Xiaodong Gu](https://github.com/gxd1994), [Qi Zuo](https://github.com/hitsz-zuoqi), [Yushuang Wu](https://scholar.google.com/citations?hl=zh-TW&user=x5gpN0sAAAAJ), [Zilong Dong](https://scholar.google.com/citations?user=GHOQKCwAAAAJ), [Liefeng Bo](https://research.cs.washington.edu/istc/lfb/), [Yuliang Xiu#](https://xiuyuliang.cn/), [Xiaoguang Han#](https://gaplab.cuhk.edu.cn/)
4 |
5 | \* Equal contribution
6 | \# Corresponding Author
7 |
8 |
9 |
265 | """
266 | )
267 |
268 | with gr.Tabs(elem_classes=["tabs"]):
269 | with gr.Tab("Image"):
270 | with gr.Row():
271 | with gr.Column():
272 | image_input = gr.Image(
273 | label="Input Image",
274 | type="filepath",
275 | )
276 | with gr.Row():
277 | image_submit_btn = gr.Button(
278 | value="Compute Normal", variant="primary"
279 | )
280 | image_reset_btn = gr.Button(value="Reset")
281 | with gr.Column():
282 | image_output_slider = ImageSlider(
283 | label="Normal outputs",
284 | type="filepath",
285 | show_download_button=True,
286 | show_share_button=True,
287 | interactive=False,
288 | elem_classes="slider",
289 | position=0.25,
290 | )
291 |
292 | Examples(
293 | fn=process_pipe_image,
294 | examples=sorted([
295 | os.path.join("files", "image", name)
296 | for name in os.listdir(os.path.join("files", "image"))
297 | ]),
298 | inputs=[image_input],
299 | outputs=[image_output_slider],
300 | cache_examples=True,
301 | directory_name="examples_image",
302 | )
303 |
304 | with gr.Tab("Video"):
305 | with gr.Row():
306 | with gr.Column():
307 | video_input = gr.Video(
308 | label="Input Video",
309 | sources=["upload", "webcam"],
310 | )
311 | with gr.Row():
312 | video_submit_btn = gr.Button(
313 | value="Compute Normal", variant="primary"
314 | )
315 | video_reset_btn = gr.Button(value="Reset")
316 | with gr.Column():
317 | processed_frames = ImageSlider(
318 | label="Realtime Visualization",
319 | type="filepath",
320 | show_download_button=True,
321 | show_share_button=True,
322 | interactive=False,
323 | elem_classes="slider",
324 | position=0.25,
325 | )
326 | video_output_files = gr.Files(
327 | label="Normal outputs",
328 | elem_id="download",
329 | interactive=False,
330 | )
331 | Examples(
332 | fn=process_pipe_video,
333 | examples=sorted([
334 | os.path.join("files", "video", name)
335 | for name in os.listdir(os.path.join("files", "video"))
336 | ]),
337 | inputs=[video_input],
338 | outputs=[processed_frames, video_output_files],
339 | directory_name="examples_video",
340 | cache_examples=False,
341 | )
342 |
343 | with gr.Tab("Panorama"):
344 | with gr.Column():
345 | gr.Markdown("Functionality coming soon on June.10th")
346 |
347 | with gr.Tab("4K Image"):
348 | with gr.Column():
349 | gr.Markdown("Functionality coming soon on June.17th")
350 |
351 | with gr.Tab("Normal Mapping"):
352 | with gr.Column():
353 | gr.Markdown("Functionality coming soon on June.24th")
354 |
355 | with gr.Tab("Normal SuperResolution"):
356 | with gr.Column():
357 | gr.Markdown("Functionality coming soon on June.30th")
358 |
359 | ### Image tab
360 | image_submit_btn.click(
361 | fn=process_image_check,
362 | inputs=image_input,
363 | outputs=None,
364 | preprocess=False,
365 | queue=False,
366 | ).success(
367 | fn=process_pipe_image,
368 | inputs=[
369 | image_input,
370 | ],
371 | outputs=[image_output_slider],
372 | concurrency_limit=1,
373 | )
374 |
375 | image_reset_btn.click(
376 | fn=lambda: (
377 | None,
378 | None,
379 | None,
380 | ),
381 | inputs=[],
382 | outputs=[
383 | image_input,
384 | image_output_slider,
385 | ],
386 | queue=False,
387 | )
388 |
389 | ### Video tab
390 |
391 | video_submit_btn.click(
392 | fn=process_pipe_video,
393 | inputs=[video_input],
394 | outputs=[processed_frames, video_output_files],
395 | concurrency_limit=1,
396 | )
397 |
398 | video_reset_btn.click(
399 | fn=lambda: (None, None, None),
400 | inputs=[],
401 | outputs=[video_input, processed_frames, video_output_files],
402 | concurrency_limit=1,
403 | )
404 |
405 | ### Server launch
406 |
407 | demo.queue(
408 | api_open=False,
409 | ).launch(
410 | server_name="0.0.0.0",
411 | server_port=7860,
412 | )
413 |
414 |
415 | def main():
416 | os.system("pip freeze")
417 |
418 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
419 |
420 | x_start_pipeline = YOSONormalsPipeline.from_pretrained(
421 | 'Stable-X/yoso-normal-v0-2', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16).to(device)
422 | pipe = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True,
423 | variant="fp16", torch_dtype=torch.float16,
424 | scheduler=HEURI_DDIMScheduler(prediction_type='sample',
425 | beta_start=0.00085, beta_end=0.0120,
426 | beta_schedule = "scaled_linear"))
427 | pipe.x_start_pipeline = x_start_pipeline
428 | pipe.to(device)
429 | pipe.prior.to(device, torch.float16)
430 |
431 | try:
432 | import xformers
433 | pipe.enable_xformers_memory_efficient_attention()
434 | except:
435 | pass # run without xformers
436 |
437 | run_demo_server(pipe)
438 |
439 |
440 | if __name__ == "__main__":
441 | main()
442 |
--------------------------------------------------------------------------------
/doc/StableNormal-Teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/doc/StableNormal-Teaser.png
--------------------------------------------------------------------------------
/gradio_cached_examples/examples_image/log.csv:
--------------------------------------------------------------------------------
1 | Normal outputs,flag,username,timestamp
2 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2d4f7f127cd7a9edc084/image.png"", ""url"": ""/file=/tmp/gradio/7be1a00df43e3503a62a56854aa4a6ba77a1ea44/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/472ed8041e180f69c4c5/001-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/103753422ac5dee2bf5d5acb4b6bb61347940a4e/001-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:05.968762
3 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/93cc4728f3ddd2e73674/image.png"", ""url"": ""/file=/tmp/gradio/a29e4d24479969a6716cdcc81399136e1198577f/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/e7852f7fcb3af59944df/002-pokemon_normal_colored.png"", ""url"": ""/file=/tmp/gradio/97bb02ae84152b298d5630074f7ffce5bcb468a8/002-pokemon_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:07.054492
4 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/2b13e352a2f64559fb45/image.png"", ""url"": ""/file=/tmp/gradio/52a84baab6b90942e4e52893b05d46ebb07ab5d6/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/8140f9d60ff69e5c758e/003-i16_normal_colored.png"", ""url"": ""/file=/tmp/gradio/70621cdafd90cf33161abbc4443bcf8848572200/003-i16_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:09.418150
5 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/eee50f57e70688e0b41e/image.png"", ""url"": ""/file=/tmp/gradio/97f7024ce31891524b2df35ae1c264de6f837d03/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/177821eb343c3a4bb763/004-i9_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f4057759958bfabf93ec3af965127679f3e0d9c1/004-i9_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:11.299376
6 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/b1c6af08cda3c3c93fbd/image.png"", ""url"": ""/file=/tmp/gradio/e583eb461f205b2ac80b66ee08bc4840ad768bed/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/946a134e6042d6696d00/005-portrait_2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/024c038ee0ca204fb73d490daa21732b9fa75d0f/005-portrait_2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:12.289625
7 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9c36ae7dba67a7fc17b0/image.png"", ""url"": ""/file=/tmp/gradio/e59529ce13ca4ebd62d403eb4536def99ea3c682/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ec927aca4136969901c3/006-portrait_1_normal_colored.png"", ""url"": ""/file=/tmp/gradio/8241a84d99e983cd31f579594cfacb9b61dabafa/006-portrait_1_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:13.321568
8 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a796f358306e9cc82ab1/image.png"", ""url"": ""/file=/tmp/gradio/74d88aba046b3e3e4794350b99649c6052a765db/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/302ed1aa45194d1cd79b/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""url"": ""/file=/tmp/gradio/1d39e2403f6891500a59cafb013d3fe8423f06f5/007-452de0d99803510d10f7c8a741c5cd35_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:14.417873
9 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/81fa0ed014675728f7dd/image.png"", ""url"": ""/file=/tmp/gradio/5472ceecadedaea6dd9cb860bbcca45c79e51747/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0e7d005fd390fe0e7260/008-basketball_dog_normal_colored.png"", ""url"": ""/file=/tmp/gradio/01f21bf8932b5ffca5dac621ef59badbe373bfd7/008-basketball_dog_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:15.448725
10 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/55a588daa0227a76b6df/image.png"", ""url"": ""/file=/tmp/gradio/9e483db34f266478b2f80b77cae61f2b3c1a2efa/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c897c1bcc85c9741da9/009-GBmIySaXAAA1lkr_normal_colored.png"", ""url"": ""/file=/tmp/gradio/08f3057b4d4da125b3542b32052ac6e683a31575/009-GBmIySaXAAA1lkr_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:16.507773
11 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/35aea69d821208b6a791/image.png"", ""url"": ""/file=/tmp/gradio/229a96f57da136e033d2f1f3392102fa50d69884/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1da94cbc991f7c1e6702/010-i14_normal_colored.png"", ""url"": ""/file=/tmp/gradio/161fa04c208ad1b16d8299647e883f6d1abab34d/010-i14_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:17.787002
12 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/381436a171532188df95/image.png"", ""url"": ""/file=/tmp/gradio/a98b48da707f9f87531487864c87d34c3064ec67/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0b684f4d7f2bc405b6b4/011-book2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/f83718f2f8916494abf4f2b6c626a659f0dd365a/011-book2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:19.545135
13 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/613fbe7b9e8b9f8d1cf5/image.png"", ""url"": ""/file=/tmp/gradio/3f6a24ffb651ffe7f4865895ebb82a29fa42b862/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cc3c976d01907060062d/012-books_normal_colored.png"", ""url"": ""/file=/tmp/gradio/0bb4938dcc4ef68bf1fee02a109a61c5ebea4858/012-books_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:21.329191
14 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/103c3a711b1fabdb66f7/image.png"", ""url"": ""/file=/tmp/gradio/64a0e56adbdce8b1b1733af87b12e158992b8329/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/cd99dc25736099290414/013-indoor_normal_colored.png"", ""url"": ""/file=/tmp/gradio/4162921c1b940758e23af8cfedf73fef29d4b90f/013-indoor_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:22.812387
15 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/f949f78ce9f0c2fce4dd/image.png"", ""url"": ""/file=/tmp/gradio/d4e79fdbac1a90d8cbe96c56adea57e4f1a55a8b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1d9b5fa931deaf603d2/014-libary_normal_colored.png"", ""url"": ""/file=/tmp/gradio/ac82bd67f629978fff24d5ddb9fc5826a488407d/014-libary_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:24.159711
16 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/73c074dd42825f6d687f/image.png"", ""url"": ""/file=/tmp/gradio/a346ded75bb42809e985864aed506e08c2849c40/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/68c460788450d0e31431/015-libary2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/50c9ae87f5ffb240e771314f579d3f4fcbbd3dba/015-libary2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:25.899696
17 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fe32222ddab3097f11ae/image.png"", ""url"": ""/file=/tmp/gradio/7335b00e631ed3d1cb39532607a2d9b24c1ce07c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/47c27db82eb7e0250e83/016-living_room_normal_colored.png"", ""url"": ""/file=/tmp/gradio/214e82dfdb98aac3f634129af54f5ccb26735dbb/016-living_room_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:27.518310
18 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/9d3101f14a5044b29b52/image.png"", ""url"": ""/file=/tmp/gradio/974e4b18a1856ee5eceb8b9545e09df28d38cca5/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2f35aee961855817702d/017-libary3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/afc0f3e03e2786b1cbf075bdf5e83759a3633b97/017-libary3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:28.918460
19 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/a56b6306e78ada5d33ec/image.png"", ""url"": ""/file=/tmp/gradio/72967c4b27304da6714b3984dd129b18718f37a0/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/ae0ded229f94b3f66401/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3bd4bed062c80b4ba5a82e23a004dfa97ee8a723/018-643D6C85FD2F7353A812464DA0197FDEABB7B6E57F2AAA2E8CC2DD680B8E788B_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:30.663220
20 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/3e53f519db0dedd74592/image.png"", ""url"": ""/file=/tmp/gradio/0135dd399ffcfc16c9fa96c2f9f0760ecacb6a85/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/1100f812d5f2f1f8b5d6/019-1183760519562665485_normal_colored.png"", ""url"": ""/file=/tmp/gradio/022a973e1490cc64e02fa9f2efab97724fbc38dc/019-1183760519562665485_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:32.124338
21 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/342aa230760098db6c01/image.png"", ""url"": ""/file=/tmp/gradio/3b3f3fb7c2d6d1c4161eb4a1573fec7c025c6b95/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/dcf9e7244f25007aeb4f/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3d0d183f6f4d437999319bf5ab292da129db2028/020-A393A6C7E43B432C684E7EEA6FFB384BCA0479E19ED148066B5A51ECFB18BA43_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:33.153153
22 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/c1b97b8c0ae18780ee33/image.png"", ""url"": ""/file=/tmp/gradio/a82c447fe6c653bc64d9af720e6c07218bd3000c/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/50b2a45487422c6b3c12/021-engine_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3310e12c797f9cecdd99059cca52e4b76267600d/021-engine_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:35.016476
23 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/da3ecdaeb365cdd41531/image.png"", ""url"": ""/file=/tmp/gradio/7428129f20bee93a4c58e36b1ac8db7d0efac1d8/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/55e82f125ae6ba97d8bb/022-doughnuts_normal_colored.png"", ""url"": ""/file=/tmp/gradio/b8be712b8775af4fa3f6eb076d5c7aa9aba87fef/022-doughnuts_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:36.685128
24 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6a9ecd2bd27a536b66db/image.png"", ""url"": ""/file=/tmp/gradio/c2acd919b2b0ae24e5457a78995aefa69df963b2/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/b2f3576350a86f63cc45/023-pumpkins_normal_colored.png"", ""url"": ""/file=/tmp/gradio/7693590345004ecfd411683814bf2e54bcd01276/023-pumpkins_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:38.482288
25 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/fabdf10a517be751ce40/image.png"", ""url"": ""/file=/tmp/gradio/072bc4024de7d4c01665d8e7c5ac345abd57c0c4/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/f362987b1effa2cacd59/024-try2_normal_colored.png"", ""url"": ""/file=/tmp/gradio/3e358def11f37d7e463ddeb51e9191b582b97393/024-try2_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:39.490890
26 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6d2944994ae0d8f9a6ae/image.png"", ""url"": ""/file=/tmp/gradio/128f06cabe94c4b8043eb68500a04ed3f1804286/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/2190f91490fe86ef9685/025-try3_normal_colored.png"", ""url"": ""/file=/tmp/gradio/2a718542fa1e63857355b7c108a35c7d2bb60274/025-try3_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:40.531666
27 | "[{""path"": ""gradio_cached_examples/examples_image/Normal outputs/6c33edcfef8449c7985d/image.png"", ""url"": ""/file=/tmp/gradio/551e7223b83406494210c874bc52352dcd8f984b/image.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""gradio_cached_examples/examples_image/Normal outputs/0a1402aeb93330ce3021/026-try4_normal_colored.png"", ""url"": ""/file=/tmp/gradio/74d592ee5342dc2f421c572fba60098de93fb4bd/026-try4_normal_colored.png"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,,2024-06-12 18:35:41.528918
28 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Tuple
3 | import torch
4 | import numpy as np
5 | from torchvision import transforms
6 | from PIL import Image, ImageOps
7 | from torch.nn.functional import interpolate
8 |
9 |
10 | dependencies = ["torch", "numpy", "diffusers", "PIL"]
11 |
12 | from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline
13 | from stablenormal.pipeline_stablenormal import StableNormalPipeline
14 | from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler
15 |
16 | def pad_to_square(image: Image.Image) -> Tuple[Image.Image, Tuple[int, int], Tuple[int, int, int, int]]:
17 | """Pad the input image to make it square."""
18 | width, height = image.size
19 | size = max(width, height)
20 |
21 | delta_w = size - width
22 | delta_h = size - height
23 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
24 |
25 | padded_image = ImageOps.expand(image, padding)
26 | return padded_image, image.size, padding
27 |
28 | def resize_image(image: Image.Image, resolution: int) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float]]:
29 | """Resize the image while maintaining aspect ratio and then pad to nearest multiple of 64."""
30 | if not isinstance(image, Image.Image):
31 | raise ValueError("Expected a PIL Image object")
32 |
33 | np_image = np.array(image)
34 | height, width = np_image.shape[:2]
35 |
36 | scale = resolution / min(height, width)
37 | new_height = int(np.round(height * scale / 64.0)) * 64
38 | new_width = int(np.round(width * scale / 64.0)) * 64
39 |
40 | resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
41 | return resized_image, (height, width), (new_height / height, new_width / width)
42 |
43 | def center_crop(image: Image.Image) -> Tuple[Image.Image, Tuple[int, int], Tuple[float, float, float, float]]:
44 | """Crop the center of the image to make it square."""
45 | width, height = image.size
46 | crop_size = min(width, height)
47 |
48 | left = (width - crop_size) / 2
49 | top = (height - crop_size) / 2
50 | right = (width + crop_size) / 2
51 | bottom = (height + crop_size) / 2
52 |
53 | cropped_image = image.crop((left, top, right, bottom))
54 | return cropped_image, image.size, (left, top, right, bottom)
55 |
56 | class Predictor:
57 | """Predictor class for Stable Diffusion models."""
58 |
59 | def __init__(self, model):
60 | self.model = model
61 | try:
62 | import xformers
63 | self.model.enable_xformers_memory_efficient_attention()
64 | except ImportError:
65 | pass
66 |
67 | def to(self, device, dtype=torch.float16):
68 | self.model.to(device, dtype)
69 | return self
70 |
71 | @torch.no_grad()
72 | def __call__(self, img: Image.Image, image_resolution=768, mode='stable', preprocess='pad') -> Image.Image:
73 | if img.mode == 'RGBA':
74 | img = img.convert('RGB')
75 |
76 | if preprocess == 'pad':
77 | img, original_size, padding_info = pad_to_square(img)
78 | elif preprocess == 'crop':
79 | img, original_size, crop_info = center_crop(img)
80 | else:
81 | raise ValueError("Invalid preprocessing mode. Choose 'pad' or 'crop'.")
82 |
83 | img, original_dims, scaling_factors = resize_image(img, image_resolution)
84 |
85 | if mode == 'stable':
86 | init_latents = torch.zeros([1, 4, image_resolution // 8, image_resolution // 8],
87 | device="cuda", dtype=torch.float16)
88 | else:
89 | init_latents = None
90 |
91 | pipe_out = self.model(img, match_input_resolution=True, latents=init_latents)
92 | pred_normal = (pipe_out.prediction.clip(-1, 1) + 1) / 2
93 | pred_normal = (pred_normal[0] * 255).astype(np.uint8)
94 | pred_normal = Image.fromarray(pred_normal)
95 |
96 | new_dims = (int(original_dims[1]), int(original_dims[0])) # reverse the shape (width, height)
97 | pred_normal = pred_normal.resize(new_dims, Image.Resampling.LANCZOS)
98 |
99 | if preprocess == 'pad':
100 | left, top, right, bottom = padding_info[0], padding_info[1], original_dims[0] - padding_info[2], original_dims[1] - padding_info[3]
101 | pred_normal = pred_normal.crop((left, top, right, bottom))
102 | return pred_normal
103 | else:
104 | left, top, right, bottom = crop_info
105 | pred_normal_with_bg = Image.new("RGB", original_size)
106 | pred_normal_with_bg.paste(pred_normal, (int(left), int(top)))
107 | return pred_normal_with_bg
108 |
109 | def __repr__(self):
110 | return f"Predictor(model={self.model})"
111 |
112 | def StableNormal(local_cache_dir: Optional[str] = None, device="cuda:0",
113 | yoso_version='yoso-normal-v0-3', diffusion_version='stable-normal-v0-1') -> Predictor:
114 | """Load the StableNormal pipeline and return a Predictor instance."""
115 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version)
116 | diffusion_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", diffusion_version)
117 |
118 | x_start_pipeline = YOSONormalsPipeline.from_pretrained(
119 | yoso_weight_path, trust_remote_code=True, safety_checker=None,
120 | variant="fp16", torch_dtype=torch.float16).to(device)
121 |
122 | pipe = StableNormalPipeline.from_pretrained(diffusion_weight_path, trust_remote_code=True, safety_checker=None,
123 | variant="fp16", torch_dtype=torch.float16,
124 | scheduler=HEURI_DDIMScheduler(prediction_type='sample',
125 | beta_start=0.00085, beta_end=0.0120,
126 | beta_schedule="scaled_linear"))
127 |
128 | pipe.x_start_pipeline = x_start_pipeline
129 | pipe.to(device)
130 | pipe.prior.to(device, torch.float16)
131 |
132 | return Predictor(pipe)
133 |
134 | def StableNormal_turbo(local_cache_dir: Optional[str] = None, device="cuda:0", yoso_version='yoso-normal-v1-0') -> Predictor:
135 | """Load the StableNormal_turbo pipeline for a faster inference."""
136 |
137 | yoso_weight_path = os.path.join(local_cache_dir if local_cache_dir else "Stable-X", yoso_version)
138 | pipe = YOSONormalsPipeline.from_pretrained(yoso_weight_path,
139 | trust_remote_code=True, safety_checker=None, variant="fp16",
140 | torch_dtype=torch.float16, t_start=0).to(device)
141 |
142 | return Predictor(pipe)
143 |
144 | def _test_run():
145 | import argparse
146 |
147 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
148 | parser.add_argument("--input", "-i", type=str, required=True, help="Input image file")
149 | parser.add_argument("--output", "-o", type=str, required=True, help="Output image file")
150 | parser.add_argument("--mode", type=str, default="StableNormal_turbo", help="Mode of operation")
151 |
152 | args = parser.parse_args()
153 |
154 | predictor_func = StableNormal_turbo if args.mode == "StableNormal_turbo" else StableNormal
155 | predictor = predictor_func(local_cache_dir='./weights', device="cuda:0")
156 |
157 | image = Image.open(args.input)
158 | with torch.inference_mode():
159 | normal_image = predictor(image)
160 | normal_image.save(args.output)
161 |
162 | if __name__ == "__main__":
163 | _test_run()
164 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | aiofiles==23.2.1
3 | aiohttp==3.9.5
4 | aiosignal==1.3.1
5 | altair==5.3.0
6 | annotated-types==0.7.0
7 | anyio==4.4.0
8 | async-timeout==4.0.3
9 | attrs==23.2.0
10 | Authlib==1.3.0
11 | certifi==2024.2.2
12 | cffi==1.16.0
13 | charset-normalizer==3.3.2
14 | click==8.0.4
15 | contourpy==1.2.1
16 | cryptography==42.0.7
17 | cycler==0.12.1
18 | dataclasses-json==0.6.6
19 | datasets==2.19.1
20 | Deprecated==1.2.14
21 | diffusers==0.28.0
22 | dill==0.3.8
23 | dnspython==2.6.1
24 | email_validator==2.1.1
25 | exceptiongroup==1.2.1
26 | fastapi==0.111.0
27 | fastapi-cli==0.0.4
28 | ffmpy==0.3.2
29 | filelock==3.14.0
30 | fonttools==4.53.0
31 | frozenlist==1.4.1
32 | fsspec==2024.3.1
33 | gradio==4.32.2
34 | gradio_client==0.17.0
35 | gradio_imageslider==0.0.20
36 | h11==0.14.0
37 | httpcore==1.0.5
38 | httptools==0.6.1
39 | httpx==0.27.0
40 | huggingface-hub==0.23.0
41 | idna==3.7
42 | imageio==2.34.1
43 | imageio-ffmpeg==0.5.0
44 | importlib_metadata==7.1.0
45 | importlib_resources==6.4.0
46 | itsdangerous==2.2.0
47 | Jinja2==3.1.4
48 | jsonschema==4.22.0
49 | jsonschema-specifications==2023.12.1
50 | kiwisolver==1.4.5
51 | markdown-it-py==3.0.0
52 | MarkupSafe==2.1.5
53 | marshmallow==3.21.2
54 | matplotlib==3.8.2
55 | mdurl==0.1.2
56 | mpmath==1.3.0
57 | multidict==6.0.5
58 | multiprocess==0.70.16
59 | mypy-extensions==1.0.0
60 | networkx==3.3
61 | numpy==1.26.4
62 | nvidia-cublas-cu12==12.1.3.1
63 | nvidia-cuda-cupti-cu12==12.1.105
64 | nvidia-cuda-nvrtc-cu12==12.1.105
65 | nvidia-cuda-runtime-cu12==12.1.105
66 | nvidia-cudnn-cu12==8.9.2.26
67 | nvidia-cufft-cu12==11.0.2.54
68 | nvidia-curand-cu12==10.3.2.106
69 | nvidia-cusolver-cu12==11.4.5.107
70 | nvidia-cusparse-cu12==12.1.0.106
71 | nvidia-nccl-cu12==2.19.3
72 | nvidia-nvjitlink-cu12==12.5.40
73 | nvidia-nvtx-cu12==12.1.105
74 | orjson==3.10.3
75 | packaging==24.0
76 | pandas==2.2.2
77 | pillow==10.3.0
78 | protobuf==3.20.3
79 | psutil==5.9.8
80 | pyarrow==16.0.0
81 | pyarrow-hotfix==0.6
82 | pycparser==2.22
83 | pydantic==2.7.2
84 | pydantic_core==2.18.3
85 | pydub==0.25.1
86 | pygltflib==1.16.1
87 | Pygments==2.18.0
88 | pyparsing==3.1.2
89 | python-dateutil==2.9.0.post0
90 | python-dotenv==1.0.1
91 | python-multipart==0.0.9
92 | pytz==2024.1
93 | PyYAML==6.0.1
94 | referencing==0.35.1
95 | regex==2024.5.15
96 | requests==2.31.0
97 | rich==13.7.1
98 | rpds-py==0.18.1
99 | ruff==0.4.7
100 | safetensors==0.4.3
101 | scipy==1.11.4
102 | semantic-version==2.10.0
103 | shellingham==1.5.4
104 | six==1.16.0
105 | sniffio==1.3.1
106 | spaces==0.28.3
107 | starlette==0.37.2
108 | sympy==1.12.1
109 | tokenizers==0.15.2
110 | tomlkit==0.12.0
111 | toolz==0.12.1
112 | torch==2.2.0
113 | tqdm==4.66.4
114 | transformers==4.36.1
115 | trimesh==4.0.5
116 | triton==2.2.0
117 | typer==0.12.3
118 | typing-inspect==0.9.0
119 | typing_extensions==4.11.0
120 | tzdata==2024.1
121 | ujson==5.10.0
122 | urllib3==2.2.1
123 | uvicorn==0.30.0
124 | uvloop==0.19.0
125 | watchfiles==0.22.0
126 | websockets==11.0.3
127 | wrapt==1.16.0
128 | xformers==0.0.24
129 | xxhash==3.4.1
130 | yarl==1.9.4
131 | zipp==3.19.1
132 | einops==0.7.0
--------------------------------------------------------------------------------
/requirements_min.txt:
--------------------------------------------------------------------------------
1 | gradio>=4.32.1
2 | gradio-imageslider>=0.0.20
3 | pygltflib==1.16.1
4 | trimesh==4.0.5
5 | imageio
6 | imageio-ffmpeg
7 | Pillow
8 | einops==0.7.0
9 |
10 | spaces
11 | accelerate
12 | diffusers>=0.28.0
13 | matplotlib==3.8.2
14 | scipy==1.11.4
15 | torch==2.0.1
16 | transformers==4.36.1
17 | xformers==0.0.21
18 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from setuptools import setup, find_packages
3 |
4 | setup_path = Path(__file__).parent
5 | README = (setup_path / "README.md").read_text(encoding="utf-8")
6 |
7 | with open("README.md", "r") as fh:
8 | long_description = fh.read()
9 |
10 | def split_requirements(requirements):
11 | install_requires = []
12 | dependency_links = []
13 | for requirement in requirements:
14 | if requirement.startswith("git+"):
15 | dependency_links.append(requirement)
16 | else:
17 | install_requires.append(requirement)
18 |
19 | return install_requires, dependency_links
20 |
21 | with open("./requirements.txt", "r") as f:
22 | requirements = f.read().splitlines()
23 |
24 | install_requires, dependency_links = split_requirements(requirements)
25 |
26 | setup(
27 | name = "stablenormal",
28 | packages=find_packages(),
29 | description=long_description,
30 | long_description=README,
31 | install_requires=install_requires
32 | )
33 |
--------------------------------------------------------------------------------
/stablenormal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/stablenormal/__init__.py
--------------------------------------------------------------------------------
/stablenormal/metrics/compute_metric.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : Alibaba XR-Lab, CUHK-SZ
3 | # @Author : Lingteng Qiu
4 | # @Email : 220019047@link.cuhk.edu.cn
5 | # @Time : 2024-01-23 11:21:30
6 | # @Function : An example to compute metrics of normal prediction.
7 |
8 |
9 | import argparse
10 | import csv
11 | import multiprocessing
12 | import os
13 | import time
14 | from collections import defaultdict
15 |
16 | import cv2
17 | import numpy as np
18 | import torch
19 |
20 |
21 | def dot(x, y):
22 | """dot product (along the last dim).
23 |
24 | Args:
25 | x (Union[Tensor, ndarray]): x, [..., C]
26 | y (Union[Tensor, ndarray]): y, [..., C]
27 |
28 | Returns:
29 | Union[Tensor, ndarray]: x dot y, [..., 1]
30 | """
31 | if isinstance(x, np.ndarray):
32 | return np.sum(x * y, -1, keepdims=True)
33 | else:
34 | return torch.sum(x * y, -1, keepdim=True)
35 |
36 |
37 | def is_format(f, format):
38 | """if a file's extension is in a set of format
39 |
40 | Args:
41 | f (str): file name.
42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok).
43 |
44 | Returns:
45 | bool: if the file's extension is in the set.
46 | """
47 | ext = os.path.splitext(f)[1].lower() # include the dot
48 | return ext in format or ext[1:] in format
49 |
50 |
51 | def is_img(input_list):
52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list))
53 |
54 |
55 | def length(x, eps=1e-20):
56 | """length of an array (along the last dim).
57 |
58 | Args:
59 | x (Union[Tensor, ndarray]): x, [..., C]
60 | eps (float, optional): eps. Defaults to 1e-20.
61 |
62 | Returns:
63 | Union[Tensor, ndarray]: length, [..., 1]
64 | """
65 | if isinstance(x, np.ndarray):
66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
67 | else:
68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps))
69 |
70 |
71 | def safe_normalize(x, eps=1e-20):
72 | """normalize an array (along the last dim).
73 |
74 | Args:
75 | x (Union[Tensor, ndarray]): x, [..., C]
76 | eps (float, optional): eps. Defaults to 1e-20.
77 |
78 | Returns:
79 | Union[Tensor, ndarray]: normalized x, [..., C]
80 | """
81 |
82 | return x / length(x, eps)
83 |
84 |
85 | def strip(s):
86 | if s[-1] == "/":
87 | return s[:-1]
88 | else:
89 | return s
90 |
91 |
92 | def obtain_states(img_list):
93 | all_states = defaultdict(list)
94 | for img in img_list:
95 | states = os.path.basename(img)
96 | states = os.path.splitext(states)[0].split("_")[-1]
97 |
98 | all_states[states].append(img)
99 |
100 | for key in all_states.keys():
101 | all_states[key] = sorted(all_states[key])
102 |
103 | return all_states
104 |
105 |
106 | def writer_csv(filename, data):
107 | with open(filename, "w", newline="") as file:
108 | writer = csv.writer(file)
109 | writer.writerows(data)
110 |
111 |
112 | def worker(gt_result, cur_state_list):
113 |
114 | angles = []
115 | rmses = []
116 |
117 | normal_gt = cv2.imread(gt_result)
118 | normal_gt = normal_gt / 255 * 2 - 1
119 |
120 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1)
121 |
122 | for target in cur_state_list:
123 |
124 | normal_pred = cv2.imread(target)
125 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0]))
126 | normal_pred = normal_pred / 255 * 2 - 1
127 |
128 | normal_pred_norm = np.linalg.norm(normal_pred, axis=-1)
129 | normal_pred = safe_normalize(normal_pred)
130 |
131 | fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5)
132 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5)
133 |
134 | # fg_mask = fg_mask_gt & fg_mask_pred
135 | fg_mask = fg_mask_gt
136 |
137 | rmse = np.sqrt(((normal_pred - normal_gt) ** 2)[fg_mask].sum(axis=-1).mean())
138 | dot_product = (normal_pred * normal_gt).sum(axis=-1)
139 |
140 | dot_product = np.clip(dot_product, -1, 1)
141 | dot_product = dot_product[fg_mask]
142 |
143 | angle = np.arccos(dot_product) / np.pi * 180
144 |
145 | # Create an error map visualization
146 | error_map = np.zeros_like(normal_gt[:, :, 0])
147 | error_map[fg_mask] = angle
148 | error_map = np.clip(
149 | error_map, 0, 90
150 | ) # Clipping the values to [0, 90] for better visualization
151 | error_map = cv2.applyColorMap(np.uint8(error_map * 255 / 90), cv2.COLORMAP_JET)
152 |
153 | # Save the error map
154 | # cv2.imwrite(f"{root_dir}/{os.path.basename(source).replace('_gt.png', f'_{method}_error.png')}", error_map)
155 |
156 | angles.append(angle)
157 | rmses.append(rmse.item())
158 |
159 | print(f"processing {gt_result}")
160 |
161 | return gt_result, angles, rmses
162 |
163 |
164 | if __name__ == "__main__":
165 | parser = argparse.ArgumentParser(description="")
166 | parser.add_argument("--dataset_name", default="DIODE", type=str, choices=["DIODE"])
167 | parser.add_argument("--input", "-i", required=True, type=str)
168 |
169 | save_metric_path = "./eval_results/metrics"
170 |
171 | opt = parser.parse_args()
172 |
173 | save_path = strip(opt.input)
174 | model_name = save_path.split("/")[-2]
175 | sampling_name = os.path.basename(save_path)
176 |
177 | root_dir = f"{opt.input}"
178 | save_metric_path = os.path.join(save_metric_path, f"{model_name}_{sampling_name}")
179 |
180 | os.makedirs(save_metric_path, exist_ok=True)
181 |
182 | img_list = [os.path.join(root_dir, x) for x in os.listdir(root_dir)]
183 | img_list = is_img(img_list)
184 |
185 | data_states = obtain_states(img_list)
186 |
187 | gt_results = data_states.pop("gt")
188 | ref_results = data_states.pop("ref")
189 |
190 | num_cpus = multiprocessing.cpu_count()
191 |
192 | states = data_states.keys()
193 | states = sorted(states, key=lambda x: int(x.replace("step", "")))
194 |
195 | start = time.time()
196 |
197 | print(f"using cpu: {num_cpus}")
198 |
199 | pool = multiprocessing.Pool(processes=num_cpus)
200 | metrics_results = []
201 |
202 | for idx, gt_result in enumerate(gt_results):
203 | cur_state_list = [data_states[state][idx] for state in states]
204 | metrics_results.append(pool.apply_async(worker, (gt_result, cur_state_list)))
205 |
206 | pool.close()
207 | pool.join()
208 |
209 | times = time.time() - start
210 | print(f"All processes completed using time {times:.4f} s...")
211 |
212 | metrics_results = [metrics_result.get() for metrics_result in metrics_results]
213 |
214 | angles_csv = [["name", *states]]
215 | rmse_csv = [["name", *states]]
216 |
217 | angle_arr = []
218 | rmse_arr = []
219 |
220 | for metrics in metrics_results:
221 | name, angle, rmse = metrics
222 |
223 | angles_csv.append([name, *angle])
224 |
225 | angle_arr.append(angle)
226 |
227 | print(angles_csv[0])
228 |
229 | tokens = [[] for _ in range(len(angles_csv[0]))]
230 |
231 | for angles in angles_csv[1:]:
232 | for token_idx, angle in enumerate(angles):
233 | tokens[token_idx].append(angle)
234 |
235 | new_tokens = [[] for _ in range(len(angles_csv[0]))]
236 | for token_idx, token in enumerate(tokens):
237 |
238 | if token_idx == 0:
239 | new_tokens[token_idx] = np.asarray(token)
240 | else:
241 | new_tokens[token_idx] = np.concatenate(token)
242 |
243 | for i in range(1, len(new_tokens)):
244 | angle_arr = new_tokens[i]
245 |
246 | pct_gt_5 = 100.0 * np.sum(angle_arr < 11.25, axis=0) / angle_arr.shape[0]
247 | pct_gt_10 = 100.0 * np.sum(angle_arr < 22.5, axis=0) / angle_arr.shape[0]
248 | pct_gt_30 = 100.0 * np.sum(angle_arr < 30, axis=0) / angle_arr.shape[0]
249 | media = np.median(angle_arr)
250 | mean = np.mean(angle_arr)
251 |
252 | print("*" * 10)
253 | print(("{:.3f}\t" * 5).format(mean, media, pct_gt_5, pct_gt_10, pct_gt_30))
254 | print("*" * 10)
255 |
--------------------------------------------------------------------------------
/stablenormal/metrics/compute_variance.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : Alibaba XR-Lab, CUHK-SZ
3 | # @Author : Lingteng Qiu
4 | # @Email : 220019047@link.cuhk.edu.cn
5 | # @Time : 2024-01-23 11:21:30
6 | # @Function : An example to compute variance metrics of normal prediction.
7 |
8 | import argparse
9 | import csv
10 | import glob
11 | import multiprocessing
12 | import os
13 | import time
14 | from collections import defaultdict
15 |
16 | import cv2
17 | import numpy as np
18 | import torch
19 |
20 |
21 | def dot(x, y):
22 | """dot product (along the last dim).
23 |
24 | Args:
25 | x (Union[Tensor, ndarray]): x, [..., C]
26 | y (Union[Tensor, ndarray]): y, [..., C]
27 |
28 | Returns:
29 | Union[Tensor, ndarray]: x dot y, [..., 1]
30 | """
31 | if isinstance(x, np.ndarray):
32 | return np.sum(x * y, -1, keepdims=True)
33 | else:
34 | return torch.sum(x * y, -1, keepdim=True)
35 |
36 |
37 | def is_format(f, format):
38 | """if a file's extension is in a set of format
39 |
40 | Args:
41 | f (str): file name.
42 | format (Sequence[str]): set of extensions (both '.jpg' or 'jpg' is ok).
43 |
44 | Returns:
45 | bool: if the file's extension is in the set.
46 | """
47 | ext = os.path.splitext(f)[1].lower() # include the dot
48 | return ext in format or ext[1:] in format
49 |
50 |
51 | def is_img(input_list):
52 | return list(filter(lambda x: is_format(x, [".jpg", ".jpeg", ".png"]), input_list))
53 |
54 |
55 | def length(x, eps=1e-20):
56 | """length of an array (along the last dim).
57 |
58 | Args:
59 | x (Union[Tensor, ndarray]): x, [..., C]
60 | eps (float, optional): eps. Defaults to 1e-20.
61 |
62 | Returns:
63 | Union[Tensor, ndarray]: length, [..., 1]
64 | """
65 | if isinstance(x, np.ndarray):
66 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
67 | else:
68 | return torch.sqrt(torch.clamp(dot(x, x), min=eps))
69 |
70 |
71 | def safe_normalize(x, eps=1e-20):
72 | """normalize an array (along the last dim).
73 |
74 | Args:
75 | x (Union[Tensor, ndarray]): x, [..., C]
76 | eps (float, optional): eps. Defaults to 1e-20.
77 |
78 | Returns:
79 | Union[Tensor, ndarray]: normalized x, [..., C]
80 | """
81 |
82 | return x / length(x, eps)
83 |
84 |
85 | def strip(s):
86 | if s[-1] == "/":
87 | return s[:-1]
88 | else:
89 | return s
90 |
91 |
92 | def obtain_states(img_list):
93 | all_states = defaultdict(list)
94 | for img in img_list:
95 | states = os.path.basename(img)
96 | states = os.path.splitext(states)[0].split("_")[-1]
97 |
98 | all_states[states].append(img)
99 |
100 | for key in all_states.keys():
101 | all_states[key] = sorted(all_states[key])
102 |
103 | return all_states
104 |
105 |
106 | def writer_csv(filename, data):
107 | with open(filename, "w", newline="") as file:
108 | writer = csv.writer(file)
109 | writer.writerows(data)
110 |
111 |
112 | def worker(gt_result, ref_image, cur_state_list, high_frequency=False):
113 |
114 | angles = []
115 | rmses = []
116 |
117 | normal_gt = cv2.imread(gt_result)
118 | ref_image = cv2.imread(ref_image)
119 | normal_gt = normal_gt / 255 * 2 - 1
120 |
121 | # normal_gt = cv2.resize(normal_gt, (512, 512))
122 |
123 | normal_gt_norm = np.linalg.norm(normal_gt, axis=-1)
124 | fg_mask_gt = (normal_gt_norm > 0.5) & (normal_gt_norm < 1.5)
125 |
126 | if high_frequency:
127 |
128 | edges = cv2.Canny(ref_image, 0, 50)
129 | kernel = np.ones((3, 3), np.uint8)
130 | fg_mask_gt = cv2.dilate(edges, kernel, iterations=1) / 255
131 | fg_mask_gt = edges / 255
132 | fg_mask_gt = fg_mask_gt == 1.0
133 |
134 | angles = []
135 | for target in cur_state_list:
136 |
137 | normal_pred = cv2.imread(target)
138 | normal_pred = cv2.resize(normal_pred, (normal_gt.shape[1], normal_gt.shape[0]))
139 | normal_pred = normal_pred / 255 * 2 - 1
140 |
141 | # normal_pred_norm = np.linalg.norm(normal_pred, axis=-1)
142 | normal_pred = safe_normalize(normal_pred)
143 |
144 | # fg_mask_pred = (normal_pred_norm > 0.5) & (normal_pred_norm < 1.5)
145 | # fg_mask = fg_mask_gt & fg_mask_pred
146 |
147 | fg_mask = fg_mask_gt
148 | dot_product = (normal_pred * normal_gt).sum(axis=-1)
149 | dot_product = np.clip(dot_product, -1, 1)
150 | dot_product = dot_product[fg_mask]
151 |
152 | angle = np.arccos(dot_product) / np.pi * 180
153 |
154 | angle = angle.mean().item()
155 |
156 | angles.append(angle)
157 |
158 | print(f"processing {gt_result}")
159 |
160 | return angles
161 |
162 |
163 | if __name__ == "__main__":
164 | parser = argparse.ArgumentParser(description="")
165 | parser.add_argument("--input", "-i", required=True, type=str)
166 | parser.add_argument("--model_name", "-m", type=str, default="geowizard")
167 | parser.add_argument("--hf", action="store_true", help="high frequency error map")
168 |
169 | opt = parser.parse_args()
170 | save_metric_path = "./eval_results/metrics_variance/{opt.model_name}"
171 |
172 | save_path = strip(opt.input)
173 | model_name = save_path.split("/")[-2]
174 | sampling_name = os.path.basename(save_path)
175 |
176 | root_dir = f"{opt.input}"
177 |
178 | seed_model_list = sorted(
179 | glob.glob(os.path.join(opt.input, f"{opt.model_name}_seed*"))
180 | )
181 | # seed_model_list = sorted(glob.glob(os.path.join(opt.input, f'seed*')))
182 | seed_model_list = [
183 | is_img(sorted(glob.glob(os.path.join(seed_model_path, "*.png"))))
184 | for seed_model_path in seed_model_list
185 | ]
186 |
187 | seed_states_list = []
188 |
189 | length = None
190 | for seed_idx, seed_model in enumerate(seed_model_list):
191 | data_states = obtain_states(seed_model)
192 | gt_results = data_states.pop("gt")
193 | ref_results = data_states.pop("ref")
194 |
195 | keys = data_states.keys()
196 | last_key = sorted(keys, key=lambda x: int(x.replace("step", "")))[-1]
197 |
198 | try:
199 | if length is None:
200 | length = len(data_states[last_key])
201 | else:
202 | assert length == len(data_states[last_key]), print(seed_idx)
203 | except:
204 | continue
205 |
206 | seed_states_list.append(data_states[last_key])
207 |
208 | num_cpus = multiprocessing.cpu_count()
209 |
210 | states = data_states.keys()
211 |
212 | start = time.time()
213 |
214 | print(f"using cpu: {num_cpus}")
215 |
216 | pool = multiprocessing.Pool(processes=num_cpus)
217 | metrics_results = []
218 |
219 | for idx, gt_result in enumerate(gt_results):
220 | ref_result = ref_results[idx]
221 |
222 | cur_seed_states = [
223 | seed_states_list[_][idx] for _ in range(len(seed_states_list))
224 | ]
225 |
226 | metrics_results.append(
227 | pool.apply_async(worker, (gt_result, ref_result, cur_seed_states, opt.hf))
228 | )
229 |
230 | pool.close()
231 | pool.join()
232 |
233 | times = time.time() - start
234 | print(f"All processes completed using time {times:.4f} s...")
235 |
236 | metrics_results = [metrics_result.get() for metrics_result in metrics_results]
237 |
238 | metrics_results = np.asarray(metrics_results)
239 |
240 | print("*" * 10)
241 | print("variance: {}".format(metrics_results.var(axis=-1).mean()))
242 |
243 | print("*" * 10)
244 |
--------------------------------------------------------------------------------
/stablenormal/pipeline_stablenormal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2 | # Copyright 2024 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # --------------------------------------------------------------------------
16 | # More information and citation instructions are available on the
17 | # --------------------------------------------------------------------------
18 | from dataclasses import dataclass
19 | from typing import Any, Dict, List, Optional, Tuple, Union
20 |
21 | import numpy as np
22 | import torch
23 | from PIL import Image
24 | from tqdm.auto import tqdm
25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26 |
27 |
28 | from diffusers.image_processor import PipelineImageInput
29 | from diffusers.models import (
30 | AutoencoderKL,
31 | UNet2DConditionModel,
32 | ControlNetModel,
33 | )
34 | from diffusers.schedulers import (
35 | DDIMScheduler
36 | )
37 |
38 | from diffusers.utils import (
39 | BaseOutput,
40 | logging,
41 | replace_example_docstring,
42 | )
43 |
44 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
45 |
46 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
47 |
48 |
49 |
50 | from diffusers.utils.torch_utils import randn_tensor
51 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
52 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
53 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
54 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
55 | import torch.nn.functional as F
56 |
57 | import pdb
58 |
59 |
60 |
61 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62 |
63 |
64 | EXAMPLE_DOC_STRING = """
65 | Examples:
66 | ```py
67 | >>> import diffusers
68 | >>> import torch
69 |
70 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
71 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
72 | ... ).to("cuda")
73 |
74 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
75 | >>> normals = pipe(image)
76 |
77 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
78 | >>> vis[0].save("einstein_normals.png")
79 | ```
80 | """
81 |
82 |
83 | @dataclass
84 | class StableNormalOutput(BaseOutput):
85 | """
86 | Output class for Marigold monocular normals prediction pipeline.
87 |
88 | Args:
89 | prediction (`np.ndarray`, `torch.Tensor`):
90 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
91 | \times width$, regardless of whether the images were passed as a 4D array or a list.
92 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
93 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
94 | \times 1 \times height \times width$.
95 | latent (`None`, `torch.Tensor`):
96 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
97 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
98 | """
99 |
100 | prediction: Union[np.ndarray, torch.Tensor]
101 | latent: Union[None, torch.Tensor]
102 | gaus_noise: Union[None, torch.Tensor]
103 |
104 | from einops import rearrange
105 | class DINOv2_Encoder(torch.nn.Module):
106 | IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
107 | IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
108 |
109 | def __init__(
110 | self,
111 | model_name = 'dinov2_vitl14',
112 | freeze = True,
113 | antialias=True,
114 | device="cuda",
115 | size = 448,
116 | ):
117 | super(DINOv2_Encoder, self).__init__()
118 |
119 | self.model = torch.hub.load('facebookresearch/dinov2', model_name)
120 | self.model.eval().to(device)
121 | self.device = device
122 | self.antialias = antialias
123 | self.dtype = torch.float32
124 |
125 | self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN)
126 | self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD)
127 | self.size = size
128 | if freeze:
129 | self.freeze()
130 |
131 |
132 | def freeze(self):
133 | for param in self.model.parameters():
134 | param.requires_grad = False
135 |
136 | @torch.no_grad()
137 | def encoder(self, x):
138 | '''
139 | x: [b h w c], range from (-1, 1), rbg
140 | '''
141 |
142 | x = self.preprocess(x).to(self.device, self.dtype)
143 |
144 | b, c, h, w = x.shape
145 | patch_h, patch_w = h // 14, w // 14
146 |
147 | embeddings = self.model.forward_features(x)['x_norm_patchtokens']
148 | embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w)
149 |
150 | return rearrange(embeddings, 'b h w c -> b c h w')
151 |
152 | def preprocess(self, x):
153 | ''' x
154 | '''
155 | # normalize to [0,1],
156 | x = torch.nn.functional.interpolate(
157 | x,
158 | size=(self.size, self.size),
159 | mode='bicubic',
160 | align_corners=True,
161 | antialias=self.antialias,
162 | )
163 |
164 | x = (x + 1.0) / 2.0
165 | # renormalize according to dino
166 | mean = self.mean.view(1, 3, 1, 1).to(x.device)
167 | std = self.std.view(1, 3, 1, 1).to(x.device)
168 | x = (x - mean) / std
169 |
170 | return x
171 |
172 | def to(self, device, dtype=None):
173 | if dtype is not None:
174 | self.dtype = dtype
175 | self.model.to(device, dtype)
176 | self.mean.to(device, dtype)
177 | self.std.to(device, dtype)
178 | else:
179 | self.model.to(device)
180 | self.mean.to(device)
181 | self.std.to(device)
182 | return self
183 |
184 | def __call__(self, x, **kwargs):
185 | return self.encoder(x, **kwargs)
186 |
187 | class StableNormalPipeline(StableDiffusionControlNetPipeline):
188 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
189 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
190 |
191 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
192 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
193 |
194 | The pipeline also inherits the following loading methods:
195 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
196 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
197 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
198 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
199 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
200 |
201 | Args:
202 | vae ([`AutoencoderKL`]):
203 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
204 | text_encoder ([`~transformers.CLIPTextModel`]):
205 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
206 | tokenizer ([`~transformers.CLIPTokenizer`]):
207 | A `CLIPTokenizer` to tokenize text.
208 | unet ([`UNet2DConditionModel`]):
209 | A `UNet2DConditionModel` to denoise the encoded image latents.
210 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
211 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple
212 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined
213 | additional conditioning.
214 | scheduler ([`SchedulerMixin`]):
215 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
216 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
217 | safety_checker ([`StableDiffusionSafetyChecker`]):
218 | Classification module that estimates whether generated images could be considered offensive or harmful.
219 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
220 | about a model's potential harms.
221 | feature_extractor ([`~transformers.CLIPImageProcessor`]):
222 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
223 | """
224 |
225 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
226 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
227 | _exclude_from_cpu_offload = ["safety_checker"]
228 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
229 |
230 |
231 |
232 | def __init__(
233 | self,
234 | vae: AutoencoderKL,
235 | text_encoder: CLIPTextModel,
236 | tokenizer: CLIPTokenizer,
237 | unet: UNet2DConditionModel,
238 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
239 | dino_controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
240 | scheduler: Union[DDIMScheduler],
241 | safety_checker: StableDiffusionSafetyChecker,
242 | feature_extractor: CLIPImageProcessor,
243 | image_encoder: CLIPVisionModelWithProjection = None,
244 | requires_safety_checker: bool = True,
245 | default_denoising_steps: Optional[int] = 10,
246 | default_processing_resolution: Optional[int] = 768,
247 | prompt="The normal map",
248 | empty_text_embedding=None,
249 | ):
250 | super().__init__(
251 | vae,
252 | text_encoder,
253 | tokenizer,
254 | unet,
255 | controlnet,
256 | scheduler,
257 | safety_checker,
258 | feature_extractor,
259 | image_encoder,
260 | requires_safety_checker,
261 | )
262 |
263 | self.register_modules(
264 | dino_controlnet=dino_controlnet,
265 | )
266 |
267 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
268 | self.dino_image_processor = lambda x: x / 127.5 -1.
269 |
270 | self.default_denoising_steps = default_denoising_steps
271 | self.default_processing_resolution = default_processing_resolution
272 | self.prompt = prompt
273 | self.prompt_embeds = None
274 | self.empty_text_embedding = empty_text_embedding
275 | self.prior = DINOv2_Encoder(size=672)
276 |
277 | def check_inputs(
278 | self,
279 | image: PipelineImageInput,
280 | num_inference_steps: int,
281 | ensemble_size: int,
282 | processing_resolution: int,
283 | resample_method_input: str,
284 | resample_method_output: str,
285 | batch_size: int,
286 | ensembling_kwargs: Optional[Dict[str, Any]],
287 | latents: Optional[torch.Tensor],
288 | generator: Optional[Union[torch.Generator, List[torch.Generator]]],
289 | output_type: str,
290 | output_uncertainty: bool,
291 | ) -> int:
292 | if num_inference_steps is None:
293 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
294 | if num_inference_steps < 1:
295 | raise ValueError("`num_inference_steps` must be positive.")
296 | if ensemble_size < 1:
297 | raise ValueError("`ensemble_size` must be positive.")
298 | if ensemble_size == 2:
299 | logger.warning(
300 | "`ensemble_size` == 2 results are similar to no ensembling (1); "
301 | "consider increasing the value to at least 3."
302 | )
303 | if ensemble_size == 1 and output_uncertainty:
304 | raise ValueError(
305 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
306 | "greater than 1."
307 | )
308 | if processing_resolution is None:
309 | raise ValueError(
310 | "`processing_resolution` is not specified and could not be resolved from the model config."
311 | )
312 | if processing_resolution < 0:
313 | raise ValueError(
314 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
315 | "downsampled processing."
316 | )
317 | if processing_resolution % self.vae_scale_factor != 0:
318 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
319 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
320 | raise ValueError(
321 | "`resample_method_input` takes string values compatible with PIL library: "
322 | "nearest, nearest-exact, bilinear, bicubic, area."
323 | )
324 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
325 | raise ValueError(
326 | "`resample_method_output` takes string values compatible with PIL library: "
327 | "nearest, nearest-exact, bilinear, bicubic, area."
328 | )
329 | if batch_size < 1:
330 | raise ValueError("`batch_size` must be positive.")
331 | if output_type not in ["pt", "np"]:
332 | raise ValueError("`output_type` must be one of `pt` or `np`.")
333 | if latents is not None and generator is not None:
334 | raise ValueError("`latents` and `generator` cannot be used together.")
335 | if ensembling_kwargs is not None:
336 | if not isinstance(ensembling_kwargs, dict):
337 | raise ValueError("`ensembling_kwargs` must be a dictionary.")
338 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
339 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
340 |
341 | # image checks
342 | num_images = 0
343 | W, H = None, None
344 | if not isinstance(image, list):
345 | image = [image]
346 | for i, img in enumerate(image):
347 | if isinstance(img, np.ndarray) or torch.is_tensor(img):
348 | if img.ndim not in (2, 3, 4):
349 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
350 | H_i, W_i = img.shape[-2:]
351 | N_i = 1
352 | if img.ndim == 4:
353 | N_i = img.shape[0]
354 | elif isinstance(img, Image.Image):
355 | W_i, H_i = img.size
356 | N_i = 1
357 | else:
358 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
359 | if W is None:
360 | W, H = W_i, H_i
361 | elif (W, H) != (W_i, H_i):
362 | raise ValueError(
363 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
364 | )
365 | num_images += N_i
366 |
367 | # latents checks
368 | if latents is not None:
369 | if not torch.is_tensor(latents):
370 | raise ValueError("`latents` must be a torch.Tensor.")
371 | if latents.dim() != 4:
372 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
373 |
374 | if processing_resolution > 0:
375 | max_orig = max(H, W)
376 | new_H = H * processing_resolution // max_orig
377 | new_W = W * processing_resolution // max_orig
378 | if new_H == 0 or new_W == 0:
379 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
380 | W, H = new_W, new_H
381 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
382 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
383 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
384 |
385 | if latents.shape != shape_expected:
386 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
387 |
388 | # generator checks
389 | if generator is not None:
390 | if isinstance(generator, list):
391 | if len(generator) != num_images * ensemble_size:
392 | raise ValueError(
393 | "The number of generators must match the total number of ensemble members for all input images."
394 | )
395 | if not all(g.device.type == generator[0].device.type for g in generator):
396 | raise ValueError("`generator` device placement is not consistent in the list.")
397 | elif not isinstance(generator, torch.Generator):
398 | raise ValueError(f"Unsupported generator type: {type(generator)}.")
399 |
400 | return num_images
401 |
402 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
403 | if not hasattr(self, "_progress_bar_config"):
404 | self._progress_bar_config = {}
405 | elif not isinstance(self._progress_bar_config, dict):
406 | raise ValueError(
407 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
408 | )
409 |
410 | progress_bar_config = dict(**self._progress_bar_config)
411 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
412 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
413 | if iterable is not None:
414 | return tqdm(iterable, **progress_bar_config)
415 | elif total is not None:
416 | return tqdm(total=total, **progress_bar_config)
417 | else:
418 | raise ValueError("Either `total` or `iterable` has to be defined.")
419 |
420 | @torch.no_grad()
421 | @replace_example_docstring(EXAMPLE_DOC_STRING)
422 | def __call__(
423 | self,
424 | image: PipelineImageInput,
425 | prompt: Union[str, List[str]] = None,
426 | negative_prompt: Optional[Union[str, List[str]]] = None,
427 | num_inference_steps: Optional[int] = None,
428 | ensemble_size: int = 1,
429 | processing_resolution: Optional[int] = None,
430 | match_input_resolution: bool = True,
431 | resample_method_input: str = "bilinear",
432 | resample_method_output: str = "bilinear",
433 | batch_size: int = 1,
434 | ensembling_kwargs: Optional[Dict[str, Any]] = None,
435 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
436 | prompt_embeds: Optional[torch.Tensor] = None,
437 | negative_prompt_embeds: Optional[torch.Tensor] = None,
438 | num_images_per_prompt: Optional[int] = 1,
439 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
440 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
441 | output_type: str = "np",
442 | output_uncertainty: bool = False,
443 | output_latent: bool = False,
444 | return_dict: bool = True,
445 | ):
446 | """
447 | Function invoked when calling the pipeline.
448 |
449 | Args:
450 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
451 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
452 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
453 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
454 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
455 | same width and height.
456 | num_inference_steps (`int`, *optional*, defaults to `None`):
457 | Number of denoising diffusion steps during inference. The default value `None` results in automatic
458 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
459 | for Marigold-LCM models.
460 | ensemble_size (`int`, defaults to `1`):
461 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
462 | faster inference.
463 | processing_resolution (`int`, *optional*, defaults to `None`):
464 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This
465 | produces crisper predictions, but may also lead to the overall loss of global context. The default
466 | value `None` resolves to the optimal value from the model config.
467 | match_input_resolution (`bool`, *optional*, defaults to `True`):
468 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
469 | side of the output will equal to `processing_resolution`.
470 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
471 | Resampling method used to resize input images to `processing_resolution`. The accepted values are:
472 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
473 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
474 | Resampling method used to resize output predictions to match the input resolution. The accepted values
475 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
476 | batch_size (`int`, *optional*, defaults to `1`):
477 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
478 | ensembling_kwargs (`dict`, *optional*, defaults to `None`)
479 | Extra dictionary with arguments for precise ensembling control. The following options are available:
480 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
481 | every pixel location, can be either `"closest"` or `"mean"`.
482 | latents (`torch.Tensor`, *optional*, defaults to `None`):
483 | Latent noise tensors to replace the random initialization. These can be taken from the previous
484 | function call's output.
485 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
486 | Random number generator object to ensure reproducibility.
487 | output_type (`str`, *optional*, defaults to `"np"`):
488 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
489 | values are: `"np"` (numpy array) or `"pt"` (torch tensor).
490 | output_uncertainty (`bool`, *optional*, defaults to `False`):
491 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
492 | the `ensemble_size` argument is set to a value above 2.
493 | output_latent (`bool`, *optional*, defaults to `False`):
494 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
495 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
496 | `latents` argument.
497 | return_dict (`bool`, *optional*, defaults to `True`):
498 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
499 |
500 | Examples:
501 |
502 | Returns:
503 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
504 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
505 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty
506 | (or `None`), and the third is the latent (or `None`).
507 | """
508 |
509 | # 0. Resolving variables.
510 | device = self._execution_device
511 | dtype = self.dtype
512 |
513 | # Model-specific optimal default values leading to fast and reasonable results.
514 | if num_inference_steps is None:
515 | num_inference_steps = self.default_denoising_steps
516 | if processing_resolution is None:
517 | processing_resolution = self.default_processing_resolution
518 |
519 |
520 | image, padding, original_resolution = self.image_processor.preprocess(
521 | image, processing_resolution, resample_method_input, device, dtype
522 | ) # [N,3,PPH,PPW]
523 |
524 | image_latent, gaus_noise = self.prepare_latents(
525 | image, latents, generator, ensemble_size, batch_size
526 | ) # [N,4,h,w], [N,4,h,w]
527 |
528 | # 0. X_start latent obtain
529 | predictor = self.x_start_pipeline(image, latents=gaus_noise,
530 | processing_resolution=processing_resolution, skip_preprocess=True)
531 | x_start_latent = predictor.latent
532 |
533 | # 1. Check inputs.
534 | num_images = self.check_inputs(
535 | image,
536 | num_inference_steps,
537 | ensemble_size,
538 | processing_resolution,
539 | resample_method_input,
540 | resample_method_output,
541 | batch_size,
542 | ensembling_kwargs,
543 | latents,
544 | generator,
545 | output_type,
546 | output_uncertainty,
547 | )
548 |
549 |
550 | # 2. Prepare empty text conditioning.
551 | # Model invocation: self.tokenizer, self.text_encoder.
552 | if self.empty_text_embedding is None:
553 | prompt = ""
554 | text_inputs = self.tokenizer(
555 | prompt,
556 | padding="do_not_pad",
557 | max_length=self.tokenizer.model_max_length,
558 | truncation=True,
559 | return_tensors="pt",
560 | )
561 | text_input_ids = text_inputs.input_ids.to(device)
562 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
563 |
564 |
565 |
566 | # 3. prepare prompt
567 | if self.prompt_embeds is None:
568 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
569 | self.prompt,
570 | device,
571 | num_images_per_prompt,
572 | False,
573 | negative_prompt,
574 | prompt_embeds=prompt_embeds,
575 | negative_prompt_embeds=None,
576 | lora_scale=None,
577 | clip_skip=None,
578 | )
579 | self.prompt_embeds = prompt_embeds
580 | self.negative_prompt_embeds = negative_prompt_embeds
581 |
582 |
583 |
584 | # 5. dino guider features obtaining
585 | ## TODO different case-1
586 | dino_features = self.prior(image)
587 | dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features)
588 | dino_features = self.match_noisy(dino_features, x_start_latent)
589 |
590 | del (
591 | image,
592 | )
593 |
594 | # 7. denoise sampling, using heuritic sampling proposed by Ye.
595 |
596 | t_start = self.x_start_pipeline.t_start
597 | self.scheduler.set_timesteps(num_inference_steps, t_start=t_start,device=device)
598 |
599 | cond_scale =controlnet_conditioning_scale
600 | pred_latent = x_start_latent
601 |
602 | cur_step = 0
603 |
604 | # dino controlnet
605 | dino_down_block_res_samples, dino_mid_block_res_sample = self.dino_controlnet(
606 | dino_features.detach(),
607 | 0, # not depend on time steps
608 | encoder_hidden_states=self.prompt_embeds,
609 | conditioning_scale=cond_scale,
610 | guess_mode=False,
611 | return_dict=False,
612 | )
613 | assert dino_mid_block_res_sample == None
614 |
615 | pred_latents = []
616 |
617 | last_pred_latent = pred_latent
618 | for (t, prev_t) in self.progress_bar(zip(self.scheduler.timesteps,self.scheduler.prev_timesteps), leave=False, desc="Diffusion steps..."):
619 |
620 | _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery
621 |
622 | # controlnet
623 | down_block_res_samples, mid_block_res_sample = self.controlnet(
624 | image_latent.detach(),
625 | t,
626 | encoder_hidden_states=self.prompt_embeds,
627 | conditioning_scale=cond_scale,
628 | guess_mode=False,
629 | return_dict=False,
630 | )
631 |
632 | # SG-DRN
633 | noise = self.dino_unet_forward(
634 | self.unet,
635 | pred_latent,
636 | t,
637 | encoder_hidden_states=self.prompt_embeds,
638 | down_block_additional_residuals=down_block_res_samples,
639 | mid_block_additional_residual=mid_block_res_sample,
640 | dino_down_block_additional_residuals= _dino_down_block_res_samples,
641 | return_dict=False,
642 | )[0] # [B,4,h,w]
643 |
644 | pred_latents.append(noise)
645 | # ddim steps
646 | out = self.scheduler.step(
647 | noise, t, prev_t, pred_latent, gaus_noise = gaus_noise, generator=generator, cur_step=cur_step+1 # NOTE that cur_step dirs to next_step
648 | )# [B,4,h,w]
649 | pred_latent = out.prev_sample
650 |
651 | cur_step += 1
652 |
653 | del (
654 | image_latent,
655 | dino_features,
656 | )
657 | pred_latent = pred_latents[-1] # using x0
658 |
659 | # decoder
660 | prediction = self.decode_prediction(pred_latent)
661 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
662 | prediction = self.image_processor.resize_antialias(prediction, original_resolution, resample_method_output, is_aa=False) # [N,3,H,W]
663 |
664 | if match_input_resolution:
665 | prediction = self.image_processor.resize_antialias(
666 | prediction, original_resolution, resample_method_output, is_aa=False
667 | ) # [N,3,H,W]
668 |
669 | if match_input_resolution:
670 | prediction = self.image_processor.resize_antialias(
671 | prediction, original_resolution, resample_method_output, is_aa=False
672 | ) # [N,3,H,W]
673 | prediction = self.normalize_normals(prediction) # [N,3,H,W]
674 |
675 | if output_type == "np":
676 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
677 | prediction = prediction.clip(min=-1, max=1)
678 |
679 | # 11. Offload all models
680 | self.maybe_free_model_hooks()
681 |
682 | return StableNormalOutput(
683 | prediction=prediction,
684 | latent=pred_latent,
685 | gaus_noise=gaus_noise
686 | )
687 |
688 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
689 | def prepare_latents(
690 | self,
691 | image: torch.Tensor,
692 | latents: Optional[torch.Tensor],
693 | generator: Optional[torch.Generator],
694 | ensemble_size: int,
695 | batch_size: int,
696 | ) -> Tuple[torch.Tensor, torch.Tensor]:
697 | def retrieve_latents(encoder_output):
698 | if hasattr(encoder_output, "latent_dist"):
699 | return encoder_output.latent_dist.mode()
700 | elif hasattr(encoder_output, "latents"):
701 | return encoder_output.latents
702 | else:
703 | raise AttributeError("Could not access latents of provided encoder_output")
704 |
705 |
706 |
707 | image_latent = torch.cat(
708 | [
709 | retrieve_latents(self.vae.encode(image[i : i + batch_size]))
710 | for i in range(0, image.shape[0], batch_size)
711 | ],
712 | dim=0,
713 | ) # [N,4,h,w]
714 | image_latent = image_latent * self.vae.config.scaling_factor
715 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
716 |
717 | pred_latent = latents
718 | if pred_latent is None:
719 |
720 |
721 | pred_latent = randn_tensor(
722 | image_latent.shape,
723 | generator=generator,
724 | device=image_latent.device,
725 | dtype=image_latent.dtype,
726 | ) # [N*E,4,h,w]
727 |
728 | return image_latent, pred_latent
729 |
730 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
731 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
732 | raise ValueError(
733 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
734 | )
735 |
736 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
737 |
738 | return prediction # [B,3,H,W]
739 |
740 | @staticmethod
741 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
742 | if normals.dim() != 4 or normals.shape[1] != 3:
743 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
744 |
745 | norm = torch.norm(normals, dim=1, keepdim=True)
746 | normals /= norm.clamp(min=eps)
747 |
748 | return normals
749 |
750 | @staticmethod
751 | def match_noisy(dino, noisy):
752 | _, __, dino_h, dino_w = dino.shape
753 | _, __, h, w = noisy.shape
754 |
755 | if h == dino_h and w == dino_w:
756 | return dino
757 | else:
758 | return F.interpolate(dino, (h, w), mode='bilinear')
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 | @staticmethod
770 | def dino_unet_forward(
771 | self, # NOTE that repurpose to UNet
772 | sample: torch.Tensor,
773 | timestep: Union[torch.Tensor, float, int],
774 | encoder_hidden_states: torch.Tensor,
775 | class_labels: Optional[torch.Tensor] = None,
776 | timestep_cond: Optional[torch.Tensor] = None,
777 | attention_mask: Optional[torch.Tensor] = None,
778 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
780 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
781 | mid_block_additional_residual: Optional[torch.Tensor] = None,
782 | dino_down_block_additional_residuals: Optional[torch.Tensor] = None,
783 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
784 | encoder_attention_mask: Optional[torch.Tensor] = None,
785 | return_dict: bool = True,
786 | ) -> Union[UNet2DConditionOutput, Tuple]:
787 | r"""
788 | The [`UNet2DConditionModel`] forward method.
789 |
790 | Args:
791 | sample (`torch.Tensor`):
792 | The noisy input tensor with the following shape `(batch, channel, height, width)`.
793 | timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
794 | encoder_hidden_states (`torch.Tensor`):
795 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
796 | class_labels (`torch.Tensor`, *optional*, defaults to `None`):
797 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
798 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
799 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
800 | through the `self.time_embedding` layer to obtain the timestep embeddings.
801 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
802 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
803 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
804 | negative values to the attention scores corresponding to "discard" tokens.
805 | cross_attention_kwargs (`dict`, *optional*):
806 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
807 | `self.processor` in
808 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
809 | added_cond_kwargs: (`dict`, *optional*):
810 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
811 | are passed along to the UNet blocks.
812 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
813 | A tuple of tensors that if specified are added to the residuals of down unet blocks.
814 | mid_block_additional_residual: (`torch.Tensor`, *optional*):
815 | A tensor that if specified is added to the residual of the middle unet block.
816 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
817 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
818 | encoder_attention_mask (`torch.Tensor`):
819 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
820 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
821 | which adds large negative values to the attention scores corresponding to "discard" tokens.
822 | return_dict (`bool`, *optional*, defaults to `True`):
823 | Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
824 | tuple.
825 |
826 | Returns:
827 | [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
828 | If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
829 | otherwise a `tuple` is returned where the first element is the sample tensor.
830 | """
831 | # By default samples have to be AT least a multiple of the overall upsampling factor.
832 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
833 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
834 | # on the fly if necessary.
835 |
836 |
837 | default_overall_up_factor = 2**self.num_upsamplers
838 |
839 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
840 | forward_upsample_size = False
841 | upsample_size = None
842 |
843 | for dim in sample.shape[-2:]:
844 | if dim % default_overall_up_factor != 0:
845 | # Forward upsample size to force interpolation output size.
846 | forward_upsample_size = True
847 | break
848 |
849 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
850 | # expects mask of shape:
851 | # [batch, key_tokens]
852 | # adds singleton query_tokens dimension:
853 | # [batch, 1, key_tokens]
854 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
855 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
856 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
857 | if attention_mask is not None:
858 | # assume that mask is expressed as:
859 | # (1 = keep, 0 = discard)
860 | # convert mask into a bias that can be added to attention scores:
861 | # (keep = +0, discard = -10000.0)
862 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
863 | attention_mask = attention_mask.unsqueeze(1)
864 |
865 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
866 | if encoder_attention_mask is not None:
867 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
868 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
869 |
870 | # 0. center input if necessary
871 | if self.config.center_input_sample:
872 | sample = 2 * sample - 1.0
873 |
874 | # 1. time
875 | t_emb = self.get_time_embed(sample=sample, timestep=timestep)
876 | emb = self.time_embedding(t_emb, timestep_cond)
877 | aug_emb = None
878 |
879 | class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
880 | if class_emb is not None:
881 | if self.config.class_embeddings_concat:
882 | emb = torch.cat([emb, class_emb], dim=-1)
883 | else:
884 | emb = emb + class_emb
885 |
886 | aug_emb = self.get_aug_embed(
887 | emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
888 | )
889 | if self.config.addition_embed_type == "image_hint":
890 | aug_emb, hint = aug_emb
891 | sample = torch.cat([sample, hint], dim=1)
892 |
893 | emb = emb + aug_emb if aug_emb is not None else emb
894 |
895 | if self.time_embed_act is not None:
896 | emb = self.time_embed_act(emb)
897 |
898 | encoder_hidden_states = self.process_encoder_hidden_states(
899 | encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
900 | )
901 |
902 | # 2. pre-process
903 | sample = self.conv_in(sample)
904 |
905 | # 2.5 GLIGEN position net
906 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
907 | cross_attention_kwargs = cross_attention_kwargs.copy()
908 | gligen_args = cross_attention_kwargs.pop("gligen")
909 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
910 |
911 | # 3. down
912 | # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
913 | # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
914 | if cross_attention_kwargs is not None:
915 | cross_attention_kwargs = cross_attention_kwargs.copy()
916 | lora_scale = cross_attention_kwargs.pop("scale", 1.0)
917 | else:
918 | lora_scale = 1.0
919 |
920 | if USE_PEFT_BACKEND:
921 | # weight the lora layers by setting `lora_scale` for each PEFT layer
922 | scale_lora_layers(self, lora_scale)
923 |
924 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
925 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
926 | is_adapter = down_intrablock_additional_residuals is not None
927 | # maintain backward compatibility for legacy usage, where
928 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
929 | # but can only use one or the other
930 | if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
931 | deprecate(
932 | "T2I should not use down_block_additional_residuals",
933 | "1.3.0",
934 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
935 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
936 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
937 | standard_warn=False,
938 | )
939 | down_intrablock_additional_residuals = down_block_additional_residuals
940 | is_adapter = True
941 |
942 |
943 |
944 | def residual_downforward(
945 | self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None,
946 | additional_residuals: Optional[torch.Tensor] = None,
947 | *args, **kwargs,
948 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
949 | if len(args) > 0 or kwargs.get("scale", None) is not None:
950 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
951 | deprecate("scale", "1.0.0", deprecation_message)
952 |
953 | output_states = ()
954 |
955 | for resnet in self.resnets:
956 | if self.training and self.gradient_checkpointing:
957 |
958 | def create_custom_forward(module):
959 | def custom_forward(*inputs):
960 | return module(*inputs)
961 |
962 | return custom_forward
963 |
964 | if is_torch_version(">=", "1.11.0"):
965 | hidden_states = torch.utils.checkpoint.checkpoint(
966 | create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
967 | )
968 | else:
969 | hidden_states = torch.utils.checkpoint.checkpoint(
970 | create_custom_forward(resnet), hidden_states, temb
971 | )
972 | else:
973 | hidden_states = resnet(hidden_states, temb)
974 | hidden_states += additional_residuals.pop(0)
975 |
976 |
977 | output_states = output_states + (hidden_states,)
978 |
979 | if self.downsamplers is not None:
980 | for downsampler in self.downsamplers:
981 | hidden_states = downsampler(hidden_states)
982 | hidden_states += additional_residuals.pop(0)
983 |
984 | output_states = output_states + (hidden_states,)
985 |
986 | return hidden_states, output_states
987 |
988 |
989 | def residual_blockforward(
990 | self, ## NOTE that repurpose to unet_blocks
991 | hidden_states: torch.Tensor,
992 | temb: Optional[torch.Tensor] = None,
993 | encoder_hidden_states: Optional[torch.Tensor] = None,
994 | attention_mask: Optional[torch.Tensor] = None,
995 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
996 | encoder_attention_mask: Optional[torch.Tensor] = None,
997 | additional_residuals: Optional[torch.Tensor] = None,
998 | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
999 | if cross_attention_kwargs is not None:
1000 | if cross_attention_kwargs.get("scale", None) is not None:
1001 | logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1002 |
1003 |
1004 |
1005 | output_states = ()
1006 |
1007 | blocks = list(zip(self.resnets, self.attentions))
1008 |
1009 | for i, (resnet, attn) in enumerate(blocks):
1010 | if self.training and self.gradient_checkpointing:
1011 |
1012 | def create_custom_forward(module, return_dict=None):
1013 | def custom_forward(*inputs):
1014 | if return_dict is not None:
1015 | return module(*inputs, return_dict=return_dict)
1016 | else:
1017 | return module(*inputs)
1018 |
1019 | return custom_forward
1020 |
1021 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1022 | hidden_states = torch.utils.checkpoint.checkpoint(
1023 | create_custom_forward(resnet),
1024 | hidden_states,
1025 | temb,
1026 | **ckpt_kwargs,
1027 | )
1028 | hidden_states = attn(
1029 | hidden_states,
1030 | encoder_hidden_states=encoder_hidden_states,
1031 | cross_attention_kwargs=cross_attention_kwargs,
1032 | attention_mask=attention_mask,
1033 | encoder_attention_mask=encoder_attention_mask,
1034 | return_dict=False,
1035 | )[0]
1036 | else:
1037 | hidden_states = resnet(hidden_states, temb)
1038 | hidden_states = attn(
1039 | hidden_states,
1040 | encoder_hidden_states=encoder_hidden_states,
1041 | cross_attention_kwargs=cross_attention_kwargs,
1042 | attention_mask=attention_mask,
1043 | encoder_attention_mask=encoder_attention_mask,
1044 | return_dict=False,
1045 | )[0]
1046 |
1047 | hidden_states += additional_residuals.pop(0)
1048 |
1049 | output_states = output_states + (hidden_states,)
1050 |
1051 | if self.downsamplers is not None:
1052 | for downsampler in self.downsamplers:
1053 | hidden_states = downsampler(hidden_states)
1054 | hidden_states += additional_residuals.pop(0)
1055 |
1056 | output_states = output_states + (hidden_states,)
1057 |
1058 | return hidden_states, output_states
1059 |
1060 |
1061 | down_intrablock_additional_residuals = dino_down_block_additional_residuals
1062 |
1063 | sample += down_intrablock_additional_residuals.pop(0)
1064 | down_block_res_samples = (sample,)
1065 |
1066 | for downsample_block in self.down_blocks:
1067 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1068 |
1069 | sample, res_samples = residual_blockforward(
1070 | downsample_block,
1071 | hidden_states=sample,
1072 | temb=emb,
1073 | encoder_hidden_states=encoder_hidden_states,
1074 | attention_mask=attention_mask,
1075 | cross_attention_kwargs=cross_attention_kwargs,
1076 | encoder_attention_mask=encoder_attention_mask,
1077 | additional_residuals = down_intrablock_additional_residuals,
1078 | )
1079 |
1080 | else:
1081 | sample, res_samples = residual_downforward(
1082 | downsample_block,
1083 | hidden_states=sample,
1084 | temb=emb,
1085 | additional_residuals = down_intrablock_additional_residuals,
1086 | )
1087 |
1088 |
1089 | down_block_res_samples += res_samples
1090 |
1091 |
1092 | if is_controlnet:
1093 | new_down_block_res_samples = ()
1094 |
1095 | for down_block_res_sample, down_block_additional_residual in zip(
1096 | down_block_res_samples, down_block_additional_residuals
1097 | ):
1098 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
1099 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1100 |
1101 | down_block_res_samples = new_down_block_res_samples
1102 |
1103 | # 4. mid
1104 | if self.mid_block is not None:
1105 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1106 | sample = self.mid_block(
1107 | sample,
1108 | emb,
1109 | encoder_hidden_states=encoder_hidden_states,
1110 | attention_mask=attention_mask,
1111 | cross_attention_kwargs=cross_attention_kwargs,
1112 | encoder_attention_mask=encoder_attention_mask,
1113 | )
1114 | else:
1115 | sample = self.mid_block(sample, emb)
1116 |
1117 | # To support T2I-Adapter-XL
1118 | if (
1119 | is_adapter
1120 | and len(down_intrablock_additional_residuals) > 0
1121 | and sample.shape == down_intrablock_additional_residuals[0].shape
1122 | ):
1123 | sample += down_intrablock_additional_residuals.pop(0)
1124 |
1125 | if is_controlnet:
1126 | sample = sample + mid_block_additional_residual
1127 |
1128 | # 5. up
1129 | for i, upsample_block in enumerate(self.up_blocks):
1130 | is_final_block = i == len(self.up_blocks) - 1
1131 |
1132 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1133 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1134 |
1135 | # if we have not reached the final block and need to forward the
1136 | # upsample size, we do it here
1137 | if not is_final_block and forward_upsample_size:
1138 | upsample_size = down_block_res_samples[-1].shape[2:]
1139 |
1140 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1141 | sample = upsample_block(
1142 | hidden_states=sample,
1143 | temb=emb,
1144 | res_hidden_states_tuple=res_samples,
1145 | encoder_hidden_states=encoder_hidden_states,
1146 | cross_attention_kwargs=cross_attention_kwargs,
1147 | upsample_size=upsample_size,
1148 | attention_mask=attention_mask,
1149 | encoder_attention_mask=encoder_attention_mask,
1150 | )
1151 | else:
1152 | sample = upsample_block(
1153 | hidden_states=sample,
1154 | temb=emb,
1155 | res_hidden_states_tuple=res_samples,
1156 | upsample_size=upsample_size,
1157 | )
1158 |
1159 | # 6. post-process
1160 | if self.conv_norm_out:
1161 | sample = self.conv_norm_out(sample)
1162 | sample = self.conv_act(sample)
1163 | sample = self.conv_out(sample)
1164 |
1165 | if USE_PEFT_BACKEND:
1166 | # remove `lora_scale` from each PEFT layer
1167 | unscale_lora_layers(self, lora_scale)
1168 |
1169 | if not return_dict:
1170 | return (sample,)
1171 |
1172 | return UNet2DConditionOutput(sample=sample)
1173 |
1174 |
1175 |
1176 | @staticmethod
1177 | def ensemble_normals(
1178 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
1179 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1180 | """
1181 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
1182 | the number of ensemble members for a given prediction of size `(H x W)`.
1183 |
1184 | Args:
1185 | normals (`torch.Tensor`):
1186 | Input ensemble normals maps.
1187 | output_uncertainty (`bool`, *optional*, defaults to `False`):
1188 | Whether to output uncertainty map.
1189 | reduction (`str`, *optional*, defaults to `"closest"`):
1190 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
1191 | `"mean"`.
1192 |
1193 | Returns:
1194 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
1195 | uncertainties of shape `(1, 1, H, W)`.
1196 | """
1197 | if normals.dim() != 4 or normals.shape[1] != 3:
1198 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
1199 | if reduction not in ("closest", "mean"):
1200 | raise ValueError(f"Unrecognized reduction method: {reduction}.")
1201 |
1202 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
1203 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
1204 |
1205 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
1206 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
1207 |
1208 | uncertainty = None
1209 | if output_uncertainty:
1210 | uncertainty = sim_cos.arccos() # [E,1,H,W]
1211 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
1212 |
1213 | if reduction == "mean":
1214 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1215 |
1216 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
1217 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
1218 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
1219 |
1220 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1221 |
1222 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
1223 | def retrieve_timesteps(
1224 | scheduler,
1225 | num_inference_steps: Optional[int] = None,
1226 | device: Optional[Union[str, torch.device]] = None,
1227 | timesteps: Optional[List[int]] = None,
1228 | sigmas: Optional[List[float]] = None,
1229 | **kwargs,
1230 | ):
1231 | """
1232 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
1233 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
1234 |
1235 | Args:
1236 | scheduler (`SchedulerMixin`):
1237 | The scheduler to get timesteps from.
1238 | num_inference_steps (`int`):
1239 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
1240 | must be `None`.
1241 | device (`str` or `torch.device`, *optional*):
1242 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
1243 | timesteps (`List[int]`, *optional*):
1244 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
1245 | `num_inference_steps` and `sigmas` must be `None`.
1246 | sigmas (`List[float]`, *optional*):
1247 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
1248 | `num_inference_steps` and `timesteps` must be `None`.
1249 |
1250 | Returns:
1251 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
1252 | second element is the number of inference steps.
1253 | """
1254 | if timesteps is not None and sigmas is not None:
1255 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
1256 | if timesteps is not None:
1257 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1258 | if not accepts_timesteps:
1259 | raise ValueError(
1260 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1261 | f" timestep schedules. Please check whether you are using the correct scheduler."
1262 | )
1263 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
1264 | timesteps = scheduler.timesteps
1265 | num_inference_steps = len(timesteps)
1266 | elif sigmas is not None:
1267 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1268 | if not accept_sigmas:
1269 | raise ValueError(
1270 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1271 | f" sigmas schedules. Please check whether you are using the correct scheduler."
1272 | )
1273 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
1274 | timesteps = scheduler.timesteps
1275 | num_inference_steps = len(timesteps)
1276 | else:
1277 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
1278 | timesteps = scheduler.timesteps
1279 | return timesteps, num_inference_steps
--------------------------------------------------------------------------------
/stablenormal/pipeline_yoso_normal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2 | # Copyright 2024 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # --------------------------------------------------------------------------
16 | # More information and citation instructions are available on the
17 | # --------------------------------------------------------------------------
18 | from dataclasses import dataclass
19 | from typing import Any, Dict, List, Optional, Tuple, Union
20 |
21 | import numpy as np
22 | import torch
23 | from PIL import Image
24 | from tqdm.auto import tqdm
25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26 |
27 |
28 | from diffusers.image_processor import PipelineImageInput
29 | from diffusers.models import (
30 | AutoencoderKL,
31 | UNet2DConditionModel,
32 | ControlNetModel,
33 | )
34 | from diffusers.schedulers import (
35 | DDIMScheduler
36 | )
37 |
38 | from diffusers.utils import (
39 | BaseOutput,
40 | logging,
41 | replace_example_docstring,
42 | )
43 |
44 |
45 | from diffusers.utils.torch_utils import randn_tensor
46 | from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
49 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
50 |
51 | import pdb
52 |
53 |
54 |
55 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56 |
57 |
58 | EXAMPLE_DOC_STRING = """
59 | Examples:
60 | ```py
61 | >>> import diffusers
62 | >>> import torch
63 |
64 | >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
65 | ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
66 | ... ).to("cuda")
67 |
68 | >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
69 | >>> normals = pipe(image)
70 |
71 | >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
72 | >>> vis[0].save("einstein_normals.png")
73 | ```
74 | """
75 |
76 |
77 | @dataclass
78 | class YosoNormalsOutput(BaseOutput):
79 | """
80 | Output class for Marigold monocular normals prediction pipeline.
81 |
82 | Args:
83 | prediction (`np.ndarray`, `torch.Tensor`):
84 | Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
85 | \times width$, regardless of whether the images were passed as a 4D array or a list.
86 | uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
87 | Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
88 | \times 1 \times height \times width$.
89 | latent (`None`, `torch.Tensor`):
90 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
91 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
92 | """
93 |
94 | prediction: Union[np.ndarray, torch.Tensor]
95 | latent: Union[None, torch.Tensor]
96 | gaus_noise: Union[None, torch.Tensor]
97 |
98 |
99 | class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
100 | """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
101 | Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
102 |
103 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
104 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
105 |
106 | The pipeline also inherits the following loading methods:
107 | - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
108 | - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
109 | - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
110 | - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
111 | - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
112 |
113 | Args:
114 | vae ([`AutoencoderKL`]):
115 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
116 | text_encoder ([`~transformers.CLIPTextModel`]):
117 | Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
118 | tokenizer ([`~transformers.CLIPTokenizer`]):
119 | A `CLIPTokenizer` to tokenize text.
120 | unet ([`UNet2DConditionModel`]):
121 | A `UNet2DConditionModel` to denoise the encoded image latents.
122 | controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
123 | Provides additional conditioning to the `unet` during the denoising process. If you set multiple
124 | ControlNets as a list, the outputs from each ControlNet are added together to create one combined
125 | additional conditioning.
126 | scheduler ([`SchedulerMixin`]):
127 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
128 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
129 | safety_checker ([`StableDiffusionSafetyChecker`]):
130 | Classification module that estimates whether generated images could be considered offensive or harmful.
131 | Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
132 | about a model's potential harms.
133 | feature_extractor ([`~transformers.CLIPImageProcessor`]):
134 | A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
135 | """
136 |
137 | model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
138 | _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
139 | _exclude_from_cpu_offload = ["safety_checker"]
140 | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
141 |
142 |
143 |
144 | def __init__(
145 | self,
146 | vae: AutoencoderKL,
147 | text_encoder: CLIPTextModel,
148 | tokenizer: CLIPTokenizer,
149 | unet: UNet2DConditionModel,
150 | controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
151 | scheduler: Union[DDIMScheduler],
152 | safety_checker: StableDiffusionSafetyChecker,
153 | feature_extractor: CLIPImageProcessor,
154 | image_encoder: CLIPVisionModelWithProjection = None,
155 | requires_safety_checker: bool = True,
156 | default_denoising_steps: Optional[int] = 1,
157 | default_processing_resolution: Optional[int] = 768,
158 | prompt="",
159 | empty_text_embedding=None,
160 | t_start: Optional[int] = 401,
161 | ):
162 | super().__init__(
163 | vae,
164 | text_encoder,
165 | tokenizer,
166 | unet,
167 | controlnet,
168 | scheduler,
169 | safety_checker,
170 | feature_extractor,
171 | image_encoder,
172 | requires_safety_checker,
173 | )
174 |
175 | # TODO yoso ImageProcessor
176 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
177 | self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
178 | self.default_denoising_steps = default_denoising_steps
179 | self.default_processing_resolution = default_processing_resolution
180 | self.prompt = prompt
181 | self.prompt_embeds = None
182 | self.empty_text_embedding = empty_text_embedding
183 | self.t_start= t_start # target_out latents
184 |
185 | def check_inputs(
186 | self,
187 | image: PipelineImageInput,
188 | num_inference_steps: int,
189 | ensemble_size: int,
190 | processing_resolution: int,
191 | resample_method_input: str,
192 | resample_method_output: str,
193 | batch_size: int,
194 | ensembling_kwargs: Optional[Dict[str, Any]],
195 | latents: Optional[torch.Tensor],
196 | generator: Optional[Union[torch.Generator, List[torch.Generator]]],
197 | output_type: str,
198 | output_uncertainty: bool,
199 | ) -> int:
200 | if num_inference_steps is None:
201 | raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
202 | if num_inference_steps < 1:
203 | raise ValueError("`num_inference_steps` must be positive.")
204 | if ensemble_size < 1:
205 | raise ValueError("`ensemble_size` must be positive.")
206 | if ensemble_size == 2:
207 | logger.warning(
208 | "`ensemble_size` == 2 results are similar to no ensembling (1); "
209 | "consider increasing the value to at least 3."
210 | )
211 | if ensemble_size == 1 and output_uncertainty:
212 | raise ValueError(
213 | "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
214 | "greater than 1."
215 | )
216 | if processing_resolution is None:
217 | raise ValueError(
218 | "`processing_resolution` is not specified and could not be resolved from the model config."
219 | )
220 | if processing_resolution < 0:
221 | raise ValueError(
222 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
223 | "downsampled processing."
224 | )
225 | if processing_resolution % self.vae_scale_factor != 0:
226 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
227 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
228 | raise ValueError(
229 | "`resample_method_input` takes string values compatible with PIL library: "
230 | "nearest, nearest-exact, bilinear, bicubic, area."
231 | )
232 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
233 | raise ValueError(
234 | "`resample_method_output` takes string values compatible with PIL library: "
235 | "nearest, nearest-exact, bilinear, bicubic, area."
236 | )
237 | if batch_size < 1:
238 | raise ValueError("`batch_size` must be positive.")
239 | if output_type not in ["pt", "np"]:
240 | raise ValueError("`output_type` must be one of `pt` or `np`.")
241 | if latents is not None and generator is not None:
242 | raise ValueError("`latents` and `generator` cannot be used together.")
243 | if ensembling_kwargs is not None:
244 | if not isinstance(ensembling_kwargs, dict):
245 | raise ValueError("`ensembling_kwargs` must be a dictionary.")
246 | if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
247 | raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
248 |
249 | # image checks
250 | num_images = 0
251 | W, H = None, None
252 | if not isinstance(image, list):
253 | image = [image]
254 | for i, img in enumerate(image):
255 | if isinstance(img, np.ndarray) or torch.is_tensor(img):
256 | if img.ndim not in (2, 3, 4):
257 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
258 | H_i, W_i = img.shape[-2:]
259 | N_i = 1
260 | if img.ndim == 4:
261 | N_i = img.shape[0]
262 | elif isinstance(img, Image.Image):
263 | W_i, H_i = img.size
264 | N_i = 1
265 | else:
266 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
267 | if W is None:
268 | W, H = W_i, H_i
269 | elif (W, H) != (W_i, H_i):
270 | raise ValueError(
271 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
272 | )
273 | num_images += N_i
274 |
275 | # latents checks
276 | if latents is not None:
277 | if not torch.is_tensor(latents):
278 | raise ValueError("`latents` must be a torch.Tensor.")
279 | if latents.dim() != 4:
280 | raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
281 |
282 | if processing_resolution > 0:
283 | max_orig = max(H, W)
284 | new_H = H * processing_resolution // max_orig
285 | new_W = W * processing_resolution // max_orig
286 | if new_H == 0 or new_W == 0:
287 | raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
288 | W, H = new_W, new_H
289 | w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
290 | h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
291 | shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
292 |
293 | if latents.shape != shape_expected:
294 | raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
295 |
296 | # generator checks
297 | if generator is not None:
298 | if isinstance(generator, list):
299 | if len(generator) != num_images * ensemble_size:
300 | raise ValueError(
301 | "The number of generators must match the total number of ensemble members for all input images."
302 | )
303 | if not all(g.device.type == generator[0].device.type for g in generator):
304 | raise ValueError("`generator` device placement is not consistent in the list.")
305 | elif not isinstance(generator, torch.Generator):
306 | raise ValueError(f"Unsupported generator type: {type(generator)}.")
307 |
308 | return num_images
309 |
310 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
311 | if not hasattr(self, "_progress_bar_config"):
312 | self._progress_bar_config = {}
313 | elif not isinstance(self._progress_bar_config, dict):
314 | raise ValueError(
315 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
316 | )
317 |
318 | progress_bar_config = dict(**self._progress_bar_config)
319 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
320 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
321 | if iterable is not None:
322 | return tqdm(iterable, **progress_bar_config)
323 | elif total is not None:
324 | return tqdm(total=total, **progress_bar_config)
325 | else:
326 | raise ValueError("Either `total` or `iterable` has to be defined.")
327 |
328 | @torch.no_grad()
329 | @replace_example_docstring(EXAMPLE_DOC_STRING)
330 | def __call__(
331 | self,
332 | image: PipelineImageInput,
333 | prompt: Union[str, List[str]] = None,
334 | negative_prompt: Optional[Union[str, List[str]]] = None,
335 | num_inference_steps: Optional[int] = None,
336 | ensemble_size: int = 1,
337 | processing_resolution: Optional[int] = None,
338 | match_input_resolution: bool = True,
339 | resample_method_input: str = "bilinear",
340 | resample_method_output: str = "bilinear",
341 | batch_size: int = 1,
342 | ensembling_kwargs: Optional[Dict[str, Any]] = None,
343 | latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
344 | prompt_embeds: Optional[torch.Tensor] = None,
345 | negative_prompt_embeds: Optional[torch.Tensor] = None,
346 | num_images_per_prompt: Optional[int] = 1,
347 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
348 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
349 | output_type: str = "np",
350 | output_uncertainty: bool = False,
351 | output_latent: bool = False,
352 | skip_preprocess: bool = False,
353 | return_dict: bool = True,
354 | **kwargs,
355 | ):
356 | """
357 | Function invoked when calling the pipeline.
358 |
359 | Args:
360 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
361 | `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
362 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
363 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
364 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
365 | same width and height.
366 | num_inference_steps (`int`, *optional*, defaults to `None`):
367 | Number of denoising diffusion steps during inference. The default value `None` results in automatic
368 | selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
369 | for Marigold-LCM models.
370 | ensemble_size (`int`, defaults to `1`):
371 | Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
372 | faster inference.
373 | processing_resolution (`int`, *optional*, defaults to `None`):
374 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This
375 | produces crisper predictions, but may also lead to the overall loss of global context. The default
376 | value `None` resolves to the optimal value from the model config.
377 | match_input_resolution (`bool`, *optional*, defaults to `True`):
378 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
379 | side of the output will equal to `processing_resolution`.
380 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
381 | Resampling method used to resize input images to `processing_resolution`. The accepted values are:
382 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
383 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
384 | Resampling method used to resize output predictions to match the input resolution. The accepted values
385 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
386 | batch_size (`int`, *optional*, defaults to `1`):
387 | Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
388 | ensembling_kwargs (`dict`, *optional*, defaults to `None`)
389 | Extra dictionary with arguments for precise ensembling control. The following options are available:
390 | - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
391 | every pixel location, can be either `"closest"` or `"mean"`.
392 | latents (`torch.Tensor`, *optional*, defaults to `None`):
393 | Latent noise tensors to replace the random initialization. These can be taken from the previous
394 | function call's output.
395 | generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
396 | Random number generator object to ensure reproducibility.
397 | output_type (`str`, *optional*, defaults to `"np"`):
398 | Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
399 | values are: `"np"` (numpy array) or `"pt"` (torch tensor).
400 | output_uncertainty (`bool`, *optional*, defaults to `False`):
401 | When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
402 | the `ensemble_size` argument is set to a value above 2.
403 | output_latent (`bool`, *optional*, defaults to `False`):
404 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
405 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
406 | `latents` argument.
407 | return_dict (`bool`, *optional*, defaults to `True`):
408 | Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
409 |
410 | Examples:
411 |
412 | Returns:
413 | [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
414 | If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
415 | `tuple` is returned where the first element is the prediction, the second element is the uncertainty
416 | (or `None`), and the third is the latent (or `None`).
417 | """
418 |
419 | # 0. Resolving variables.
420 | device = self._execution_device
421 | dtype = self.dtype
422 |
423 | # Model-specific optimal default values leading to fast and reasonable results.
424 | if num_inference_steps is None:
425 | num_inference_steps = self.default_denoising_steps
426 | if processing_resolution is None:
427 | processing_resolution = self.default_processing_resolution
428 |
429 | # 1. Check inputs.
430 | num_images = self.check_inputs(
431 | image,
432 | num_inference_steps,
433 | ensemble_size,
434 | processing_resolution,
435 | resample_method_input,
436 | resample_method_output,
437 | batch_size,
438 | ensembling_kwargs,
439 | latents,
440 | generator,
441 | output_type,
442 | output_uncertainty,
443 | )
444 |
445 |
446 | # 2. Prepare empty text conditioning.
447 | # Model invocation: self.tokenizer, self.text_encoder.
448 | if self.empty_text_embedding is None:
449 | prompt = ""
450 | text_inputs = self.tokenizer(
451 | prompt,
452 | padding="do_not_pad",
453 | max_length=self.tokenizer.model_max_length,
454 | truncation=True,
455 | return_tensors="pt",
456 | )
457 | text_input_ids = text_inputs.input_ids.to(device)
458 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
459 |
460 |
461 |
462 | # 3. prepare prompt
463 | if self.prompt_embeds is None:
464 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
465 | self.prompt,
466 | device,
467 | num_images_per_prompt,
468 | False,
469 | negative_prompt,
470 | prompt_embeds=prompt_embeds,
471 | negative_prompt_embeds=None,
472 | lora_scale=None,
473 | clip_skip=None,
474 | )
475 | self.prompt_embeds = prompt_embeds
476 | self.negative_prompt_embeds = negative_prompt_embeds
477 |
478 |
479 |
480 | # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
481 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
482 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
483 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
484 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
485 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing
486 | # resolution can lead to loss of either fine details or global context in the output predictions.
487 | if not skip_preprocess:
488 | image, padding, original_resolution = self.image_processor.preprocess(
489 | image, processing_resolution, resample_method_input, device, dtype
490 | ) # [N,3,PPH,PPW]
491 | else:
492 | padding = (0, 0)
493 | original_resolution = image.shape[2:]
494 | # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
495 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
496 | # Latents of each such predictions across all input images and all ensemble members are represented in the
497 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
498 | # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
499 | # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
500 | # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
501 | # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
502 | # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
503 | # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
504 | # Model invocation: self.vae.encoder.
505 | image_latent, pred_latent = self.prepare_latents(
506 | image, latents, generator, ensemble_size, batch_size
507 | ) # [N*E,4,h,w], [N*E,4,h,w]
508 |
509 | gaus_noise = pred_latent.detach().clone()
510 | del image
511 |
512 |
513 | # 6. obtain control_output
514 |
515 | cond_scale =controlnet_conditioning_scale
516 | down_block_res_samples, mid_block_res_sample = self.controlnet(
517 | image_latent.detach(),
518 | self.t_start,
519 | encoder_hidden_states=self.prompt_embeds,
520 | conditioning_scale=cond_scale,
521 | guess_mode=False,
522 | return_dict=False,
523 | )
524 |
525 | # 7. YOSO sampling
526 | latent_x_t = self.unet(
527 | pred_latent,
528 | self.t_start,
529 | encoder_hidden_states=self.prompt_embeds,
530 | down_block_additional_residuals=down_block_res_samples,
531 | mid_block_additional_residual=mid_block_res_sample,
532 | return_dict=False,
533 | )[0]
534 |
535 |
536 | del (
537 | pred_latent,
538 | image_latent,
539 | )
540 |
541 | # decoder
542 | prediction = self.decode_prediction(latent_x_t)
543 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
544 |
545 | prediction = self.image_processor.resize_antialias(
546 | prediction, original_resolution, resample_method_output, is_aa=False
547 | ) # [N,3,H,W]
548 | prediction = self.normalize_normals(prediction) # [N,3,H,W]
549 |
550 | if output_type == "np":
551 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
552 |
553 | # 11. Offload all models
554 | self.maybe_free_model_hooks()
555 |
556 | return YosoNormalsOutput(
557 | prediction=prediction,
558 | latent=latent_x_t,
559 | gaus_noise=gaus_noise,
560 | )
561 |
562 | # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
563 | def prepare_latents(
564 | self,
565 | image: torch.Tensor,
566 | latents: Optional[torch.Tensor],
567 | generator: Optional[torch.Generator],
568 | ensemble_size: int,
569 | batch_size: int,
570 | ) -> Tuple[torch.Tensor, torch.Tensor]:
571 | def retrieve_latents(encoder_output):
572 | if hasattr(encoder_output, "latent_dist"):
573 | return encoder_output.latent_dist.mode()
574 | elif hasattr(encoder_output, "latents"):
575 | return encoder_output.latents
576 | else:
577 | raise AttributeError("Could not access latents of provided encoder_output")
578 |
579 |
580 |
581 | image_latent = torch.cat(
582 | [
583 | retrieve_latents(self.vae.encode(image[i : i + batch_size]))
584 | for i in range(0, image.shape[0], batch_size)
585 | ],
586 | dim=0,
587 | ) # [N,4,h,w]
588 | image_latent = image_latent * self.vae.config.scaling_factor
589 | image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590 |
591 | pred_latent = latents
592 | if pred_latent is None:
593 | pred_latent = randn_tensor(
594 | image_latent.shape,
595 | generator=generator,
596 | device=image_latent.device,
597 | dtype=image_latent.dtype,
598 | ) # [N*E,4,h,w]
599 |
600 | return image_latent, pred_latent
601 |
602 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
603 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
604 | raise ValueError(
605 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
606 | )
607 |
608 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
609 |
610 | prediction = self.normalize_normals(prediction) # [B,3,H,W]
611 |
612 | return prediction # [B,3,H,W]
613 |
614 | @staticmethod
615 | def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
616 | if normals.dim() != 4 or normals.shape[1] != 3:
617 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
618 |
619 | norm = torch.norm(normals, dim=1, keepdim=True)
620 | normals /= norm.clamp(min=eps)
621 |
622 | return normals
623 |
624 | @staticmethod
625 | def ensemble_normals(
626 | normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
627 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
628 | """
629 | Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
630 | the number of ensemble members for a given prediction of size `(H x W)`.
631 |
632 | Args:
633 | normals (`torch.Tensor`):
634 | Input ensemble normals maps.
635 | output_uncertainty (`bool`, *optional*, defaults to `False`):
636 | Whether to output uncertainty map.
637 | reduction (`str`, *optional*, defaults to `"closest"`):
638 | Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
639 | `"mean"`.
640 |
641 | Returns:
642 | A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
643 | uncertainties of shape `(1, 1, H, W)`.
644 | """
645 | if normals.dim() != 4 or normals.shape[1] != 3:
646 | raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
647 | if reduction not in ("closest", "mean"):
648 | raise ValueError(f"Unrecognized reduction method: {reduction}.")
649 |
650 | mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
651 | mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
652 |
653 | sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
654 | sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
655 |
656 | uncertainty = None
657 | if output_uncertainty:
658 | uncertainty = sim_cos.arccos() # [E,1,H,W]
659 | uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
660 |
661 | if reduction == "mean":
662 | return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
663 |
664 | closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
665 | closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
666 | closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
667 |
668 | return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
669 |
670 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
671 | def retrieve_timesteps(
672 | scheduler,
673 | num_inference_steps: Optional[int] = None,
674 | device: Optional[Union[str, torch.device]] = None,
675 | timesteps: Optional[List[int]] = None,
676 | sigmas: Optional[List[float]] = None,
677 | **kwargs,
678 | ):
679 | """
680 | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
681 | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
682 |
683 | Args:
684 | scheduler (`SchedulerMixin`):
685 | The scheduler to get timesteps from.
686 | num_inference_steps (`int`):
687 | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
688 | must be `None`.
689 | device (`str` or `torch.device`, *optional*):
690 | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
691 | timesteps (`List[int]`, *optional*):
692 | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
693 | `num_inference_steps` and `sigmas` must be `None`.
694 | sigmas (`List[float]`, *optional*):
695 | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
696 | `num_inference_steps` and `timesteps` must be `None`.
697 |
698 | Returns:
699 | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
700 | second element is the number of inference steps.
701 | """
702 | if timesteps is not None and sigmas is not None:
703 | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
704 | if timesteps is not None:
705 | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
706 | if not accepts_timesteps:
707 | raise ValueError(
708 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
709 | f" timestep schedules. Please check whether you are using the correct scheduler."
710 | )
711 | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
712 | timesteps = scheduler.timesteps
713 | num_inference_steps = len(timesteps)
714 | elif sigmas is not None:
715 | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
716 | if not accept_sigmas:
717 | raise ValueError(
718 | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
719 | f" sigmas schedules. Please check whether you are using the correct scheduler."
720 | )
721 | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
722 | timesteps = scheduler.timesteps
723 | num_inference_steps = len(timesteps)
724 | else:
725 | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
726 | timesteps = scheduler.timesteps
727 | return timesteps, num_inference_steps
--------------------------------------------------------------------------------
/stablenormal/scheduler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Stable-X/StableNormal/417b8f569ac58be2c71a8cce6fe549cc989a3e94/stablenormal/scheduler/__init__.py
--------------------------------------------------------------------------------
/stablenormal/scheduler/heuristics_ddimsampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from typing import List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
8 | from diffusers.schedulers.scheduling_utils import SchedulerMixin
9 | from diffusers.configuration_utils import register_to_config, ConfigMixin
10 | import pdb
11 |
12 |
13 | class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
14 |
15 | def set_timesteps(self, num_inference_steps: int, t_start: int, device: Union[str, torch.device] = None):
16 | """
17 | Sets the discrete timesteps used for the diffusion chain (to be run before inference).
18 |
19 | Args:
20 | num_inference_steps (`int`):
21 | The number of diffusion steps used when generating samples with a pre-trained model.
22 | """
23 |
24 | if num_inference_steps > self.config.num_train_timesteps:
25 | raise ValueError(
26 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
27 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
28 | f" maximal {self.config.num_train_timesteps} timesteps."
29 | )
30 |
31 | self.num_inference_steps = num_inference_steps
32 |
33 | # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
34 | if self.config.timestep_spacing == "linspace":
35 | timesteps = (
36 | np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
37 | .round()[::-1]
38 | .copy()
39 | .astype(np.int64)
40 | )
41 | elif self.config.timestep_spacing == "leading":
42 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps
43 | # creates integer timesteps by multiplying by ratio
44 | # casting to int to avoid issues when num_inference_step is power of 3
45 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
46 | timesteps += self.config.steps_offset
47 | elif self.config.timestep_spacing == "trailing":
48 | step_ratio = self.config.num_train_timesteps / self.num_inference_steps
49 | # creates integer timesteps by multiplying by ratio
50 | # casting to int to avoid issues when num_inference_step is power of 3
51 | timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
52 | timesteps -= 1
53 | else:
54 | raise ValueError(
55 | f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
56 | )
57 |
58 | timesteps = torch.from_numpy(timesteps).to(device)
59 |
60 |
61 | naive_sampling_step = num_inference_steps //2
62 |
63 | # TODO for debug
64 | # naive_sampling_step = 0
65 |
66 | self.naive_sampling_step = naive_sampling_step
67 |
68 | timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6
69 |
70 | timesteps = [timestep + 1 for timestep in timesteps]
71 |
72 | self.timesteps = timesteps
73 | self.gap = self.config.num_train_timesteps // self.num_inference_steps
74 | self.prev_timesteps = [timestep for timestep in self.timesteps[1:]]
75 | self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1]))
76 |
77 | def step(
78 | self,
79 | model_output: torch.Tensor,
80 | timestep: int,
81 | prev_timestep: int,
82 | sample: torch.Tensor,
83 | eta: float = 0.0,
84 | use_clipped_model_output: bool = False,
85 | generator=None,
86 | cur_step=None,
87 | variance_noise: Optional[torch.Tensor] = None,
88 | gaus_noise: Optional[torch.Tensor] = None,
89 | return_dict: bool = True,
90 | ) -> Union[DDIMSchedulerOutput, Tuple]:
91 | """
92 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
93 | process from the learned model outputs (most often the predicted noise).
94 |
95 | Args:
96 | model_output (`torch.Tensor`):
97 | The direct output from learned diffusion model.
98 | timestep (`float`):
99 | The current discrete timestep in the diffusion chain.
100 | pre_timestep (`float`):
101 | next_timestep
102 | sample (`torch.Tensor`):
103 | A current instance of a sample created by the diffusion process.
104 | eta (`float`):
105 | The weight of noise for added noise in diffusion step.
106 | use_clipped_model_output (`bool`, defaults to `False`):
107 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
108 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
109 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
110 | `use_clipped_model_output` has no effect.
111 | generator (`torch.Generator`, *optional*):
112 | A random number generator.
113 | variance_noise (`torch.Tensor`):
114 | Alternative to generating noise with `generator` by directly providing the noise for the variance
115 | itself. Useful for methods such as [`CycleDiffusion`].
116 | return_dict (`bool`, *optional*, defaults to `True`):
117 | Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
118 |
119 | Returns:
120 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
121 | If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
122 | tuple is returned where the first element is the sample tensor.
123 |
124 | """
125 | if self.num_inference_steps is None:
126 | raise ValueError(
127 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
128 | )
129 |
130 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
131 | # Ideally, read DDIM paper in-detail understanding
132 |
133 | # Notation (